import inspect
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.collections import PatchCollection
import pandas as pd
import numpy as np
import seaborn as sns
import seaborn.objects as sno
from sklearn.preprocessing import minmax_scale
from nilearn.plotting import view_surf as view_surf_nilearn
from neuromaps import images
from . import lgr
from .utils.utils import vect_to_vol_arr
from .datasets import fetch_parcellation, fetch_template, parcellation_lib, template_lib
[docs]def nice_stats_labels(string, add_dollars=True):
replace_dict = {
"r2": "R^2",
"R2": "R^2",
"beta": "Beta",
"spearman": "Spearman's Rho",
"pearson": "Pearson's Rho",
"partialpearson": "Partial Pearson's Rho",
"partialspearman": "Partial Spearman's Rho",
"mi": "Mutual Information",
"mlr": "MLR",
"slr": "SLR",
"pls": "PLS",
"pcr": "PCR",
"dominance": "Dominance Analysis",
"individual": "Individual R^2",
"total": "Total R^2",
"ridge": "Ridge",
"lasso": "Lasso",
"elasticnet": "ElasticNet",
"meandiff": "Mean Difference",
"zscore": "Z score",
"pairedcohen": "Paired Cohen's d",
"pairedhedges": "Paired Hedges' g",
"cohen": "Cohen's d",
"hedges": "Hedges' g",
"psc": "PSC",
"md": "MD",
"rho": "Rho",
"ci": "95% CI",
"mean": "Mean",
"median": "Median",
"sd": "SD",
"std": "STD",
"groups": "Groups",
"sets": "Sets",
"xmaps": "X maps",
"ymaps": "X maps",
"xymaps": "X and Y maps"
}
for k in replace_dict:
if add_dollars:
k_replace = "$" + replace_dict[k].replace(' ', '\ ') + "$"
else:
k_replace = replace_dict[k]
string = string.replace(k, k_replace)
return string
[docs]def hide_empty_axes(axes):
[ax.axis("off") for ax in axes.ravel() if ax.axis() == (0.0, 1.0, 0.0, 1.0)]
[docs]def colors_from_values(values, palette_name):
# normalize the values to range [0, 1]
normalized = (values - min(values)) / (max(values) - min(values))
# convert to indices
indices = np.round(normalized * (len(values) - 1)).astype(np.int32)
# use the indices to get the colors
palette = sns.color_palette(palette_name, len(values))
return np.array(palette).take(indices, axis=0)
[docs]def move_legend_fig_to_ax(fig, ax, loc, bbox_to_anchor=None, no_legend_error=False, **kwargs):
# copied from GitHub user thuiop
# https://github.com/mwaskom/seaborn/issues/3247#issuecomment-1420731692
if hasattr(fig, "legends"):
if len(fig.legends) > 0:
old_fig_legend = fig.legends[-1]
old_fig_boxes = old_fig_legend.get_children()[0].get_children()
if ax.legend_:
old_ax_legend = ax.legend_
old_ax_boxes = old_ax_legend.get_children()[0].get_children()
legend_kws = inspect.signature(mpl.legend.Legend).parameters
props = {
k: v for k, v in old_fig_legend.properties().items() if k in legend_kws
}
props.pop("bbox_to_anchor")
title = props.pop("title")
if "title" in kwargs:
title.set_text(kwargs.pop("title"))
title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
for key, val in title_kwargs.items():
title.set(**{key[6:]: val})
kwargs.pop(key)
kwargs.setdefault("frameon", old_fig_legend.legendPatch.get_visible())
# Remove the old legend and create the new one
props.update(kwargs)
fig.legends = []
new_legend = ax.legend(
[], [], loc=loc, bbox_to_anchor=bbox_to_anchor, **props
)
new_legend.get_children()[0].get_children().extend(old_fig_boxes)
if "old_ax_legend" in locals():
new_legend.get_children()[0].get_children().extend(old_ax_boxes)
else:
if no_legend_error:
raise ValueError("Figure has no legend attached.")
else:
pass
[docs]def linewidth_from_data_units(linewidth, axis, reference='x'):
"""
Convert a linewidth in data units to linewidth in points.
Parameters
----------
linewidth: float
Linewidth in data units of the respective reference-axis
axis: matplotlib axis
The axis which is used to extract the relevant transformation
data (data limits and size must not change afterwards)
reference: string
The axis that is taken as a reference for the data width.
Possible values: 'x' and 'y'. Defaults to 'y'.
Returns
-------
linewidth: float
Linewidth in points
"""
fig = axis.get_figure()
if reference == 'x':
length = fig.bbox_inches.width * axis.get_position().width
value_range = np.diff(axis.get_xlim())
elif reference == 'y':
length = fig.bbox_inches.height * axis.get_position().height
value_range = np.diff(axis.get_ylim())
# Convert length to points
length *= 72
# Scale linewidth to value range
return linewidth * (length / value_range)
[docs]def print_significance(ax, p_values, q_values=None, coloc_values=None,
bold_labels=True,
pq_symbols=["☆", "★"],
pq_size=12,
pq_positions=None,
pq_positions_pad=0,
categorical_axis="y"):
p_values = np.array(p_values).flatten()
if q_values is not None:
q_values = np.array(q_values).flatten()
else:
q_values = p_values
if coloc_values is None and pq_positions is None:
pq_symbols = None
if pq_symbols is not None:
if pq_positions is None:
pq_pos_min = np.array(coloc_values).min(axis=0)
pq_pos_max = np.array(coloc_values).max(axis=0)
pq_pos_mean = np.array(coloc_values).mean(axis=0)
pq_positions = [mi - pq_positions_pad if mu < 0 else ma + pq_positions_pad
for mu, mi, ma in zip(pq_pos_mean, pq_pos_min, pq_pos_max)]
else:
pq_positions += pq_positions_pad
else:
pq_positions = [None] * len(p_values)
if categorical_axis=="y":
labs = ax.get_yticklabels()
else:
labs = ax.get_xticklabels()
lab_positions = np.array([lab.get_position() for lab in labs])
for l, l_pos, pq_pos, p, q in zip(labs, lab_positions, pq_positions, p_values, q_values):
if bold_labels:
if q < 0.05:
l.set_weight("bold")
elif p < 0.05:
l.set_weight("semibold")
else:
l.set_weight("normal")
if pq_pos is not None:
kwargs = {"x": l_pos[0], "y": pq_pos} if categorical_axis=="x" else {"x": pq_pos, "y": l_pos[1]}
kwargs |= {"ha": "center", "va": "center", "size": pq_size}
if q < 0.05:
ax.text(s=pq_symbols[1], **kwargs)
elif p < 0.05:
ax.text(s=pq_symbols[0], **kwargs)
[docs]def catplot(fig, ax, data_long, categorical_var="variable", continuous_var="value", group_var=None,
categorical_axis="x", sort_categories=False, category_order=None,
color_how="continuous", color_which="auto", color_center=None,
labels={},
limits={},
bars={},
violins={},
scatters={},
dots={},
errorbars={},
hline={},
vline={},
legend={}
):
# defaults, overwrite with user input
bars = dict(
plot=False, label=True,
width=0.5, agg_method="mean", dodge_width=0.5,
kwargs={"zorder": 10, "ec": "k", "lw": 0.7}
) | bars
violins = dict(
plot=False, label=False,
kwargs={"zorder": 20, "density_norm": "width", "cut": 0, "inner": "quart",
"fill": False, "edgecolor": "k", "linewidth": 0.7}
) | violins
scatters = dict(
plot=True, label=False,
size="auto", jitter_width=0.5, dodge_width=0.5,
kwargs={"zorder": 30, "linewidth": 0.2, "alpha": 0.2}
) | scatters
dots = dict(
plot=True, label=True,
agg_method="mean", size=7, color="k", dodge_width=0.5,
kwargs={"zorder": 90, "facecolor": (1,1,1,0.8), "lw": 1}
) | dots
errorbars = dict(
plot=True, label=True,
agg_method="ci", dodge_width=0.5, color="k",
kwargs={"zorder": 100}
) | errorbars
hline = dict(
plot=False, y=[0], color="k", linewidth=1, linestyle="--", zorder=-100, kwargs={}
) | hline
vline = dict(
plot=False, x=[0], color="k", linewidth=1, linestyle="--", zorder=-100, kwargs={}
) | vline
legend = dict(
plot=True,
loc="center left",
bbox_to_anchor=(1, 0.5),
nice_labels=True,
kwargs={}
) | legend
# orientation
if categorical_axis == "x":
xy = dict(x=categorical_var, y=continuous_var)
elif categorical_axis == "y":
xy = dict(y=categorical_var, x=continuous_var)
else:
raise ValueError("categorical_axis must be 'x' or 'y'!")
# sort
if sort_categories and (len(data_long[categorical_var].unique()) > 1) and not category_order:
if sort_categories not in ["mean", "median"]:
sort_categories = "mean"
category_order = data_long[[categorical_var, continuous_var]] \
.groupby(categorical_var).apply(sort_categories) \
.sort_values(continuous_var, ascending=False) \
.index.to_list()
if category_order:
data_long[categorical_var] = pd.Categorical(data_long[categorical_var], category_order, ordered=True)
# color by continuous variable
if "color" in limits.keys():
data_range_lim = limits.pop("color")
if color_how:
if color_how.lower().startswith("cont") and data_long.shape[0] > 1:
color_var = continuous_var
data_min, data_max = data_long[continuous_var].min(), data_long[continuous_var].max()
if color_which == "auto":
if (data_min > 0 and data_max < 0) or (data_min < 0 and data_max > 0):
color_which = "icefire"
color_center = True
else:
color_which = "inferno"
if color_center:
data_range = np.max(np.abs([data_min, data_max]))
data_range = (-data_range, data_range)
else:
data_range = (data_min, data_max)
if "data_range_lim" in locals():
data_range = list(data_range)
for i, lim in enumerate(data_range_lim):
data_range[i] = lim if lim is not None else data_range[i]
data_range = tuple(data_range)
# color by categorical variable
elif color_how.lower().startswith("cat"):
color_var = categorical_var
color_which = "Spectral"
# else no coloring
else:
color_var = None
else:
color_var = None
# label handler
def handle_label(label, agg_method=None):
if not label:
return None
elif label == True:
return agg_method if not legend["nice_labels"] else nice_stats_labels(agg_method, False)
else:
return label
## PLOT OBJECT
plot = sno.Plot(data=data_long, **xy, fill=group_var)
## SCATTER
if scatters["plot"] and data_long.shape[0] > 1:
if scatters["size"] == "auto":
n_max = data_long.groupby(categorical_var, observed=False).count().max().values[0]
n_cat = len(data_long[categorical_var].unique())
scatters["size"] = 5 / (0.1 * n_max**0.5) / ((0.3 * n_cat**0.5) if n_cat > 1 else 1)
plot = plot.add(
sno.Dots(pointsize=scatters["size"], artist_kws=scatters["kwargs"]),
sno.Jitter(x=scatters["jitter_width"] if categorical_axis == "x" else None,
y=scatters["jitter_width"] if categorical_axis == "y" else None),
*(sno.Dodge(gap=scatters["dodge_width"])) if group_var else (),
color=color_var,
label=handle_label(scatters["label"]),
legend=False if color_var==categorical_var else True
)
## VIOLINS
if violins["plot"] and data_long.shape[0] > 1:
sns.violinplot(
data=data_long, **xy, hue=group_var,
label=handle_label(violins["label"]),
**violins["kwargs"]
)
## BARS
if bars["plot"]:
if not color_var:
plot = plot.add(
sno.Bar(width=bars["width"], artist_kws=bars["kwargs"]),
sno.Agg(func=bars["agg_method"]),
*(sno.Dodge(gap=bars["dodge_width"])) if group_var else (),
label=handle_label(bars["label"], bars["agg_method"])
)
else:
tmp = data_long[[categorical_var, continuous_var]] \
.groupby(categorical_var, observed=False).apply(bars["agg_method"])
plot = plot.add(
sno.Bar(width=bars["width"], artist_kws=bars["kwargs"]),
data=tmp,
**xy,
color=color_var,
legend=False if color_var==categorical_var else True,
label=None
)
## DOTS
if dots["plot"]:
plot = plot.add(
sno.Dots(pointsize=dots["size"], color=dots["color"], artist_kws=dots["kwargs"]),
sno.Agg(func=dots["agg_method"]),
*(sno.Dodge(gap=dots["dodge_width"])) if group_var else (),
label=handle_label(dots["label"], dots["agg_method"])
)
## ERRORBARS
if errorbars["plot"] and data_long.shape[0] > 1:
plot = plot.add(
sno.Range(color=errorbars["color"], artist_kws=errorbars["kwargs"]),
sno.Est(errorbar=errorbars["agg_method"]),
*(sno.Dodge(gap=errorbars["dodge_width"])) if group_var else (),
label=handle_label(errorbars["label"], errorbars["agg_method"])
)
## HORIZONTAL AND VERTICAL LINES
for hvline, xy in zip([hline, vline], ["y", "x"]):
if hvline["plot"]:
if isinstance(hvline[xy], (int, float)):
hvline[xy] = [hvline[xy]]
for hvline_xy in hvline[xy]:
kws = dict(c=hvline["color"], lw=hvline["linewidth"], ls=hvline["linestyle"],
zorder=hvline["zorder"],
**hvline["kwargs"])
if xy == "y":
ax.axhline(hvline_xy, **kws)
else:
ax.axvline(hvline_xy, **kws)
## LIMITS
plot = plot.limit(**limits)
## COLOR
if color_var == continuous_var:
plot = plot.scale(
color=sno.Continuous(color_which, norm=data_range)
)
elif color_var == categorical_var:
plot = plot.scale(
color=sno.Nominal(color_which)
)
## FINALIZE
plot = plot.label(**labels).on(ax).plot()
## LEGEND
if legend["plot"]:
move_legend_fig_to_ax(fig, ax, loc=legend["loc"], bbox_to_anchor=legend["bbox_to_anchor"],
**legend["kwargs"])
else:
fig.legends[-1].set_visible(False)
return plot
[docs]def nullplot(fig, ax, data_long, categorical_var="variable", continuous_var="value",
categorical_axis="x", category_order=None,
color_which="viridis_r",
quantiles_below_median=[0.01, 0.05, 0.25],
bands={},
median_line={},
violins={},
labels={},
limits={},
legend={}
):
bands = dict(
plot=True, label=True, label_prefix="Null percentile ",
alpha=0.1,
edgewidth=1,
edgestyle="-",
edgealpha=0.3,
kwargs={"zorder": -200}
) | bands
median_line = dict(
plot=True, label="Null percentile 50",
alpha=0.6,
kwargs={"zorder": -150}
) | median_line
violins = dict(
plot=False, label="Null distribution",
legend=None,
kwargs={"zorder": 100, "density_norm": "width", "cut": 0, "inner": "quart",
"fill": False, "edgecolor": "k", "linewidth": 1}
) | violins
legend = dict(
plot=True,
loc="center left",
bbox_to_anchor=(1, 0.5),
nice_labels=True,
kwargs={}
) | legend
# orientation
if categorical_axis not in ["x", "y"]:
raise ValueError("categorical_axis must be 'x' or 'y'!")
continuous_axis = "y" if categorical_axis == "x" else "x"
# correct label
labels = {
continuous_axis: continuous_var
} | labels
# data
#data_long[categorical_var] = pd.Categorical([str(c) for c in data_long[categorical_var]])
# percentiles
quantiles_below_median = np.array(quantiles_below_median)
quantiles = np.concatenate([quantiles_below_median, np.array([0.5]), 1 - quantiles_below_median])
quantiles.sort()
data_quant = pd.concat(
[data_long.groupby(categorical_var).quantile(q).rename(columns={continuous_var: q}) \
for q in quantiles],
axis=1
).reset_index()
# color
if color_which:
try:
colors = sns.color_palette(color_which, len(quantiles_below_median) + 1)# [1:-1]
except KeyError:
colors = [color_which] * len(quantiles_below_median)
else:
colors = ["0.3"] * len(quantiles_below_median)
## PLOT OBJECT
plot = sno.Plot() # data=data_long, **xy
## AREAS
for i_q, q in enumerate(quantiles[quantiles < 0.5]):
if bands["plot"]:
xy = {
categorical_axis: categorical_var,
f"{continuous_axis}min": q,
f"{continuous_axis}max": 1 - q
}
plot = plot.add(
sno.Band(color=colors[i_q], alpha=bands["alpha"], edgewidth=bands["edgewidth"],
edgestyle=bands["edgestyle"], edgealpha=bands["edgealpha"],
artist_kws=bands["kwargs"]),
data=data_quant,
**xy,
label=f"{bands['label_prefix']}{q*100:.0f}/{(1-q)*100:.0f}" if bands["label"] else None
)
# MEDIAN
if median_line["plot"]:
xy = {
categorical_axis: categorical_var,
continuous_axis: 0.5,
}
plot = plot.add(
sno.Line(color=colors[-1], alpha=median_line["alpha"], ), #artist_kws=lines["kwargs"]
data=data_quant,
**xy,
label=median_line['label'] if median_line["label"] else None
)
# VIOLINS
if violins["plot"]:
xy = {
categorical_axis: categorical_var,
continuous_axis: continuous_var,
}
sns.violinplot(
data=data_long, **xy,
label=violins["label"],
legend=violins["legend"],
**violins["kwargs"]
)
# FINALIZE
plot = plot.limit(**limits).label(**labels).on(ax).plot()
## LEGEND
if legend["plot"]:
move_legend_fig_to_ax(fig, ax, loc=legend["loc"], bbox_to_anchor=legend["bbox_to_anchor"],
**legend["kwargs"])
else:
fig.legends[-1].set_visible(False)
return plot
[docs]def heatmap(ax,
data_colors=None,
data_sizes=None,
data_shapes=None,
mapping_shapes=None,
annotation=None,
mask=None,
cmap="auto",
symmetric_cmap=True,
size_scale=(0.2, 1),
shape="square",
color="tab:blue",
edgecolor="k",
linewidth=0.075,
square=True,
spines=False,
spinewidth=0.1,
spinecolor="k",
xy_pad=0.1,
xtick_labels=None,
ytick_labels=None,
legend_orientation="vertical",
legend_colors=True,
legend_colors_kwargs={},
legend_sizes=True,
legend_sizes_kwargs={},
legend_shapes=True,
legend_shapes_kwargs={},
):
# input arrays
arrays = [data_colors, data_sizes, data_shapes, annotation, mask]
arrays = [arr for arr in arrays if arr is not None]
if len(arrays) == 0:
raise ValueError("No input arrays provided.")
if not all([isinstance(arr, (np.ndarray, pd.DataFrame)) for arr in arrays]):
raise ValueError("All input arrays must be 2d arrays or dataframes.")
shapes = [arr.shape for arr in arrays]
for i in range(len(shapes)):
for j in range(i + 1, len(shapes)):
if shapes[i] != shapes[j]:
raise ValueError("All input arrays must have the same shape.")
# x labels
autolabels = {"x": False, "y": False}
if xtick_labels is None:
autolabels["x"] = True
for arr in arrays:
if isinstance(arr, pd.DataFrame):
xtick_labels = arr.columns.to_list()
break
if xtick_labels is None:
xtick_labels = list(range(arrays[0].shape[1]))
if len(xtick_labels) != arrays[0].shape[1]:
raise ValueError("xtick_labels must have the same length as the number of columns in the input array(s).")
xtick_labels = [str(x) for x in xtick_labels]
# y labels
if ytick_labels is None:
autolabels["y"] = True
for arr in arrays:
if isinstance(arr, pd.DataFrame):
ytick_labels = arr.index.to_list()
break
if ytick_labels is None:
ytick_labels = list(range(arrays[0].shape[0]))
if len(ytick_labels) != arrays[0].shape[0]:
raise ValueError("ytick_labels must have the same length as the number of rows in the input array(s).")
ytick_labels = [str(y) for y in ytick_labels]
# make indices
x_idc = np.arange(len(xtick_labels))
y_idc = np.arange(len(ytick_labels))
x_idc_2d, y_idc_2d = np.meshgrid(x_idc, y_idc)
data_x = x_idc_2d.flatten()
data_y = y_idc_2d.flatten()
# plotting sizes
if data_sizes is not None:
data_sizes = np.array(data_sizes).flatten("C")
if not np.issubdtype(data_sizes.dtype, np.number):
raise ValueError("data_sizes array must be numeric.")
data_sizes_input = data_sizes.copy()
data_sizes = minmax_scale(data_sizes, size_scale)
else:
data_sizes = np.ones(len(data_x)) * size_scale[1]
# plotting colors
if data_colors is not None:
data_colors = np.array(data_colors).flatten("C")
if not np.issubdtype(data_colors.dtype, np.number):
raise ValueError("data_colors array must be numeric.")
if symmetric_cmap:
vmax = np.nanmax(np.abs(data_colors))
vmin = -vmax
if cmap in ["auto", None, False]:
cmap = "icefire"
else:
vmin, vmax = np.nanmin(data_colors), np.nanmax(data_colors)
if cmap in ["auto", None, False]:
cmap = "magma"
if isinstance(cmap, str):
cmap = mpl.colormaps[cmap]
else:
data_colors = np.ones((len(data_x)))
vmin, vmax = 1, 1
cmap = mpl.colors.ListedColormap([color])
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
# plotting shapes
if data_shapes is not None:
data_shapes = np.array(data_shapes).flatten("C")
data_shapes_unique = np.unique(data_shapes)
if data_shapes_unique[~pd.isnull(data_shapes_unique)].shape[0] > 3:
raise ValueError("For data_shapes, maximally 3 unique values are supported.")
if mapping_shapes is None:
mapping_shapes = {val: shape for val, shape in zip(data_shapes_unique, ["s", "o", "D"])}
if not all([shape in ["circle", "o", "square", "s", "diamond", "D", ""] for shape in mapping_shapes.values()]):
raise ValueError("mapping_shapes values must be 'circle'/'o', 'square'/'s', or 'diamond'/'D'")
data_shapes = [mapping_shapes[val] if ~pd.isnull(val) else "" for val in data_shapes]
else:
data_shapes = [shape] * len(data_x)
# mask
if mask is not None:
mask = np.array(mask).flatten("C")
if not np.issubdtype(mask.dtype, bool):
raise ValueError("mask array must be boolean.")
else:
mask = [True] * len(data_x)
# plot
elements = []
if linewidth in [0, None, False]:
linewidth = 0
if spinewidth in [0, None, False]:
spinewidth = 0
lw = linewidth_from_data_units(linewidth / len(xtick_labels), ax)
sw = linewidth_from_data_units(spinewidth / len(xtick_labels), ax)
kwargs_base = {
"lw": lw,
"ec": edgecolor,
"joinstyle": "bevel"
}
for mask, x, y, color, shape, size in zip(mask, data_x, data_y, data_colors, data_shapes, data_sizes):
if mask and ~np.isnan(color) and ~np.isnan(size) and shape!="":
if shape in ["o", "circle"]:
fun = plt.Circle
kwargs = {
"xy": (x, y),
"radius": size / 2,
**kwargs_base
}
elif shape in ["s", "square"]:
fun = plt.Rectangle
kwargs = {
"xy": (x - size / 2, y - size / 2),
"width": size,
"height": size,
**kwargs_base
}
elif shape in ["diamond", "D"]:
fun = plt.Rectangle
size *= 0.7
kwargs = {
"xy": (x - size / 2, y - size / 2),
"width": size,
"height": size,
"angle": 45,
"rotation_point": "center",
**kwargs_base
}
elements.append(fun(**kwargs))
collection = PatchCollection(elements, cmap=cmap, norm=norm, match_original=True)
if data_colors is not None:
collection.set_array(data_colors)
ax.add_collection(collection)
# padding at the outside of x and y axes
xy_lims = 0.5
ax.set_xlim(-xy_lims - xy_pad, len(xtick_labels) - xy_lims + xy_pad)
ax.set_ylim(len(ytick_labels) - xy_lims + xy_pad, -xy_lims - xy_pad)
# spines
for spine in ax.spines.values():
if spines:
spine.set_linewidth(sw)
spine.set_color(spinecolor)
else:
spine.set_visible(False)
# labels
ax.set_xticks(x_idc, xtick_labels, rotation=45, ha="right", va="center", rotation_mode="anchor")
ax.set_yticks(y_idc, ytick_labels)
if autolabels["x"]:
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
if autolabels["y"] and len(ytick_labels) > 5:
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# layout
if square:
ax.set_aspect("equal")
# legends
# colors
if legend_colors and np.unique(data_colors).shape[0] > 1:
if "cax" in legend_colors_kwargs:
cax = legend_colors_kwargs.pop("cax")
else:
cax = ax.inset_axes((0, 1.1, 1/3, 0.05) if legend_orientation == "horizontal" else (1.05, 2/3, 0.05, 1/3))
legend_colors_kwargs = {
"label": "Colors",
"orientation": legend_orientation,
"cax": cax
} | legend_colors_kwargs
plt.colorbar(collection, **legend_colors_kwargs)
# sizes
if legend_sizes and np.unique(data_sizes).shape[0] > 1:
legend_sizes_kwargs = {
"title": "Sizes",
"labelspacing": 1,
"bbox_to_anchor": (0.5, 1.05) if legend_orientation == "horizontal" else (1.03, 0.45),
"ncol": 2 if legend_orientation == "horizontal" else 1,
"loc": "lower center" if legend_orientation == "horizontal" else "center left",
"fmt": ".2f"
} | legend_sizes_kwargs
fmt = legend_sizes_kwargs.pop('fmt')
handles = []
for size, size_label in zip(np.linspace(data_sizes.min(), data_sizes.max(), 5),
np.linspace(data_sizes_input.min(), data_sizes_input.max(), 5)):
handles.append(
mpl.lines.Line2D(
[0], [0],
color=(0,0,0,0),
marker="s",
markerfacecolor="k",
markeredgewidth=0.5,
markeredgecolor="w",
markersize=linewidth_from_data_units(size, ax),
label=f"{size_label:{fmt}}"
)
)
lax = ax.inset_axes((0,0,1,1))
lax.axis("off")
lax.legend(handles=handles, **legend_sizes_kwargs)
# shapes
if legend_shapes and np.unique(data_shapes).shape[0] > 1:
legend_shapes_kwargs = {
"title": "Shapes",
"labelspacing": 1,
"bbox_to_anchor": (0.5 + 0.33, 1.05) if legend_orientation == "horizontal" else (1.03, 0),
"loc": "lower center" if legend_orientation == "horizontal" else "lower left",
} | legend_shapes_kwargs
handles = []
for val, shape in mapping_shapes.items():
handles.append(
mpl.lines.Line2D(
[0], [0],
color=(0,0,0,0),
marker=shape,
markerfacecolor="k",
markeredgewidth=0.5,
markeredgecolor="w",
markersize=linewidth_from_data_units(0.5 if shape in ["D", "diamond"] else 0.7, ax),
label=val
)
)
lax = ax.inset_axes((0,0,1,1))
lax.axis("off")
lax.legend(handles=handles, **legend_shapes_kwargs)
return ax, collection
[docs]def view_surf(data=None, parcellation=None, hemi="L", template="fsaverage", replace_nan=0,
template_kwargs={}, parcellation_kwargs={},
verbose=False, **kwargs):
lgr.setLevel(verbose)
# parcellation and data
if data is None and parcellation is None:
raise ValueError("Either data or parcellation must be provided")
# template
if not isinstance(template, str):
raise NotImplementedError(f"For now, template must be a string: {list(template_lib.keys())}")
else:
space = template
template = fetch_template(template, hemi=hemi, **template_kwargs, verbose=verbose)
template = images.load_gifti(template)
template_arr = template.agg_data()
# parcellation
if parcellation is not None:
if not isinstance(parcellation, str):
raise NotImplementedError(f"For now, parcellation must be a string: {list(parcellation_lib.keys())}")
else:
parc, labels = fetch_parcellation(parcellation, space=space, return_loaded=True, **parcellation_kwargs)
parc_arr = parc[0 if hemi == "L" else 1].agg_data()
labels = [l for l in labels if f"hemi-{hemi}" in l]
if data is None:
data = np.trim_zeros(np.unique(parc_arr))
# data
if not isinstance(data, (list, pd.Series, pd.DataFrame, np.ndarray)):
raise ValueError(f"Data must be a list, pd.Series, pd.DataFrame or np.array, not {type(data)}")
else:
data = np.squeeze(np.array(data))
if data.ndim > 1:
raise ValueError(f"Data must be a 1D array, not {data.ndim}D")
# check dimensions
if len(data) == len(template_arr):
data_arr = data
elif "parc_arr" in locals():
if len(data) == len(labels):
data_arr = vect_to_vol_arr(data, parc_arr, np.trim_zeros(np.unique(parc_arr)))
# replace nan
if replace_nan is not False:
data_arr = np.nan_to_num(data_arr, replace_nan)
else:
raise ValueError(f"Data length ({len(data)}) must match number of parcels ({len(labels)})")
else:
raise ValueError("Data must match either the shape of the template or the number of parcels"
f"template: {template_arr.shape}, data: {data_arr.shape}")
# plot
return view_surf_nilearn(surf_map=data_arr, surf_mesh=template_arr, **{"cmap": "RdBu_r"} | kwargs)