"""Module providing some standard plots for visualizing data collected during a psychological protocol."""
import re
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
import matplotlib.patches as mpatch
import matplotlib.pyplot as plt
import matplotlib.ticker as mticks
import numpy as np
import pandas as pd
import seaborn as sns
from biopsykit.plotting import feature_boxplot, lineplot, multi_feature_boxplot
from biopsykit.protocols._utils import _get_sample_times
from biopsykit.saliva.utils import _remove_s0
from biopsykit.utils.data_processing import get_subphase_durations
from biopsykit.utils.datatype_helper import (
MeanSeDataFrame,
MergedStudyDataDict,
SalivaFeatureDataFrame,
SalivaMeanSeDataFrame,
SalivaRawDataFrame,
is_mean_se_dataframe,
is_saliva_feature_dataframe,
is_saliva_mean_se_dataframe,
is_saliva_raw_dataframe,
)
from biopsykit.utils.exceptions import ValidationError
from fau_colors import cmaps, colors_all
from matplotlib.legend_handler import HandlerTuple
_hr_ensemble_plot_params = {
"linestyle": ["solid", "dashed", "dotted", "dashdot"],
"ensemble_alpha": 0.4,
"background_base_color": "#e0e0e0",
"background_color": None,
"background_alpha": 0.1,
"xlabel": r"Time [s]",
"xaxis_minor_tick_locator": mticks.MultipleLocator(60),
"ylabel": "Heart Rate [bpm]",
"legend_loc": "lower right",
"legend_bbox_to_anchor": (0.99, 0.01),
"phase_text": "{}",
"end_phase_text": "End {}",
"end_phase_line_color": "#e0e0e0",
"end_phase_line_style": "--",
"end_phase_line_width": 2.0,
}
_hr_mean_plot_params = {
"linestyle": ["solid", "dashed", "dotted", "dashdot"],
"marker": ["o", "P", "*", "X"],
"background_base_color": "#e0e0e0",
"background_color": None,
"background_alpha": 0.1,
"x_offset": 0.1,
"ylabel": r"Heart Rate [bpm]",
"phase_text": "{}",
}
_saliva_feature_params: Dict[str, Dict[str, Any]] = {
"ylabel": {
"cortisol": {
"auc": r"Cortisol AUC $\left[\frac{nmol \cdot min}{l} \right]$",
"auc_g": r"Cortisol AUC $\left[\frac{nmol \cdot min}{l} \right]$",
"auc_i": r"Cortisol AUC $\left[\frac{nmol \cdot min}{l} \right]$",
"auc_i_post": r"Cortisol AUC $\left[\frac{nmol \cdot min}{l} \right]$",
"slope": r"Cortisol Change $\left[\frac{nmol}{l \cdot min} \right]$",
"max": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"argmax": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"max_inc": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"mean": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"std": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"kurt": r"Cortisol $\left[\frac{nmol}{l} \right]$",
"skew": r"Cortisol $\left[\frac{nmol}{l} \right]$",
},
"amylase": {
"auc": r"Amylase AUC $\left[\frac{U \cdot min}{l} \right]$",
"auc_g": r"Amylase AUC $\left[\frac{U \cdot min}{l} \right]$",
"auc_i": r"Amylase AUC $\left[\frac{U \cdot min}{l} \right]$",
"auc_i_post": r"Amylase AUC $\left[\frac{U \cdot min}{l} \right]$",
"slope": r"Amylase Change $\left[\frac{U}{l \cdot min} \right]$",
"max": r"Amylase $\left[\frac{U}{l} \right]$",
"max_inc": r"Amylase $\left[\frac{U}{l} \right]$",
"mean": r"Amylase $\left[\frac{U}{l} \right]$",
"std": r"Amylase $\left[\frac{U}{l} \right]$",
"kurt": r"Amylase $\left[\frac{U}{l} \right]$",
"skew": r"Amylase $\left[\frac{U}{l} \right]$",
},
},
"xticklabels": {
"auc": r"$AUC_{$}$",
"auc_g": r"$AUC_G$",
"auc_i": r"$AUC_I$",
"auc_i_post": r"$AUC_{I}^{Post}$",
"slope": r"$a_{§}$",
"max_inc": r"$\Delta c_{max}$",
"cmax": r"$c_{max}$",
"argmax": r"$argmax(c)$",
"mean": r"$\mu(c)$",
"std": r"$\sigma(c)$",
"skew": r"$skew(c)$",
"kurt": r"$kurt(c)$",
},
}
_saliva_plot_params: Dict = {
"palette": None,
"linestyle": ["-", "--"],
"marker": ["o", "P"],
"test_title": "",
"test_fontsize": "medium",
"test_color": "#9e9e9e",
"test_alpha": 0.2,
"multi_x_offset": 0.01,
"xlabel": "Time [min]",
"ylabel": {
"cortisol": "Cortisol [nmol/l]",
"amylase": "sAA [U/l]",
"il6": "IL-6 [pg/ml]",
},
"legend_title": {"cortisol": "Cortisol", "amylase": "sAA", "il6": "IL-6"},
}
def _get_palette(color: Optional[Union[str, Sequence[str]]] = None, num_colors: Optional[int] = 3):
if isinstance(color, list):
return color
if color is None:
color = "fau"
color_val = getattr(colors_all, color, None)
if color_val is None:
return color
return sns.light_palette(color_val, num_colors + 1, reverse=True)[:-1]
[docs]def hr_ensemble_plot(
data: MergedStudyDataDict,
subphases: Optional[Dict[str, Dict[str, int]]] = None,
**kwargs,
) -> Optional[Tuple[plt.Figure, plt.Axes]]:
r"""Draw a heart rate ensemble plot.
This function plots time-series heart rate continuously as ensemble plot (mean ± standard error).
If the data consist of multiple phases, data from each phase are overlaid in the same plot.
If each phase additionally consists of subphases, the single subphases are highlighted in the plot.
The input data is expected to be a :obj:`~biopsykit.utils.datatype_helper.MergedStudyDataDict`, i.e.,
a dictionary with merged time-series heart rate data, of multiple subjects, split into individual phases.
Per phase, the data of each subjects have same length and are combined into one common dataframe.
Parameters
----------
data : :obj:`~biopsykit.utils.datatype_helper.MergedStudyDataDict`
dict with heart rate data to plot
subphases : dict, optional
dictionary with phases (keys) and subphases (values - dict with subphase names and subphase durations) or
``None`` if no subphases are present. Default: ``None``
**kwargs : dict, optional
optional arguments for plot configuration.
To style general plot appearance:
* ``ax``: pre-existing axes for the plot. Otherwise, a new figure and axes object is created and returned.
* ``palette``: color palette to plot data from different phases
* ``ensemble_alpha``: transparency value for ensemble plot errorband (around mean). Default: 0.3
* ``background_alpha``: transparency value for background spans (if subphases are present). Default: 0.2
* ``linestyle``: list of line styles for ensemble plots. Must match the number of phases to plot
* ``phase_text``: string pattern to customize phase name shown in legend with placeholder for subphase name.
Default: "{}"
To style axes:
* ``is_relative``: boolean indicating whether heart rate data is relative (in % relative to baseline)
or absolute (in bpm). Default: ``True``
* ``xlabel``: label of x axis. Default: ":math:`Time [s]`"
* ``xaxis_minor_tick_locator``: locator object to style x axis minor ticks. Default: 60 sec
* ``ylabel``: label of y axis. Default: ":math:`\Delta HR [\%]`"
* ``ylims``: y axis limits. Default: ``None`` to automatically infer limits
To style the annotations at the end of each phase:
* ``end_phase_text``: string pattern to customize text at the end of phase with placeholder for phase name.
Default: "{}"
* ``end_phase_line_color``: line color of vertical lines used to indicate end of phase. Default: "#e0e0e0"
* ``end_phase_line_width``: line width of vertical lines used to indicate end of phase. Default: 2.0
To style legend:
* ``legend_loc``: location of legend. Default: "lower right"
* ``legend_bbox_to_anchor``: box that is used to position the legend in conjunction with ``legend_loc``
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
figure object
ax : :class:`~matplotlib.axes.Axes`
axes object
See Also
--------
:obj:`~biopsykit.utils.datatype_helper.MergedStudyDataDict`
dictionary format
:func:`~biopsykit.utils.data_processing.merge_study_data_dict`
function to build ``MergedStudyDataDict``
Examples
--------
>>> from biopsykit.protocols.plotting import hr_ensemble_plot
>>> # Example with subphases
>>> subphase_dict = {
>>> "Phase1": {"Baseline": 60, "Stress": 120, "Recovery": 60},
>>> "Phase2": {"Baseline": 60, "Stress": 120, "Recovery": 60},
>>> "Phase3": {"Baseline": 60, "Stress": 120, "Recovery": 60}
>>> }
>>> fig, ax = hr_ensemble_plot(data=data, subphases=subphase_dict)
"""
ax: plt.Axes = kwargs.pop("ax", None)
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.get("figsize"))
else:
fig = ax.get_figure()
palette = kwargs.get("palette")
palette = _get_palette(palette, len(data))
sns.set_palette(palette)
linestyle = kwargs.get("linestyle", _hr_ensemble_plot_params.get("linestyle"))
xlabel = kwargs.get("xlabel", _hr_ensemble_plot_params.get("xlabel"))
ylabel_default = _hr_ensemble_plot_params.get("ylabel")
if kwargs.get("is_relative", True):
ylabel_default = r"$\Delta$ HR [%]"
ylabel = kwargs.get("ylabel", ylabel_default)
ylims = kwargs.get("ylims", _hr_ensemble_plot_params.get("ylims"))
xaxis_minor_tick_locator = kwargs.get(
"xaxis_minor_tick_locator", _hr_ensemble_plot_params.get("xaxis_minor_tick_locator")
)
ensemble_alpha = kwargs.get("ensemble_alpha", _hr_ensemble_plot_params.get("ensemble_alpha"))
phase_text = kwargs.get("phase_text", _hr_ensemble_plot_params.get("phase_text"))
legend_loc = kwargs.get("legend_loc", _hr_ensemble_plot_params.get("legend_loc"))
legend_bbox_to_anchor = kwargs.get("legend_bbox_to_anchor", _hr_ensemble_plot_params.get("legend_bbox_to_anchor"))
for i, phase in enumerate(data):
df_hr_phase = data[phase]
x = df_hr_phase.index
hr_mean = df_hr_phase.mean(axis=1)
hr_stderr = df_hr_phase.std(axis=1) / np.sqrt(df_hr_phase.shape[1])
ax.plot(x, hr_mean, zorder=2, label=phase_text.format(phase), linestyle=linestyle[i])
ax.fill_between(x, hr_mean - hr_stderr, hr_mean + hr_stderr, zorder=1, alpha=ensemble_alpha)
_hr_ensemble_plot_end_phase_annotation(ax, df_hr_phase, phase, i, **kwargs)
if subphases is not None:
_hr_ensemble_plot_subphase_vspans(ax, data, subphases, **kwargs)
ax.set_xlabel(xlabel)
ax.xaxis.set_minor_locator(xaxis_minor_tick_locator)
ax.tick_params(axis="x", which="both", bottom=True)
ax.set_ylabel(ylabel)
ax.tick_params(axis="y", which="major", left=True)
if ylims is not None:
ax.margins(x=0)
ax.set_ylim(ylims)
else:
ax.margins(0, 0.1)
ax.legend(loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor)
fig.tight_layout()
return fig, ax
def _hr_ensemble_plot_end_phase_annotation(ax: plt.Axes, data: pd.DataFrame, phase: str, i: int, **kwargs):
"""Add End Phase annotations to heart rate ensemble plot.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
axes object
data : :class:`~pandas.DataFrame`
data belonging to ``phase``
phase : str
phase to add annotations
i : int
counter of phase
"""
end_phase_text = kwargs.get("end_phase_text", _hr_ensemble_plot_params.get("end_phase_text"))
end_phase_line_color = kwargs.get("end_phase_line_color", _hr_ensemble_plot_params.get("end_phase_line_color"))
end_phase_line_style = kwargs.get("end_phase_line_style", _hr_ensemble_plot_params.get("end_phase_line_style"))
end_phase_line_width = kwargs.get("end_phase_line_width", _hr_ensemble_plot_params.get("end_phase_line_width"))
ax.vlines(
x=len(data),
ymin=0,
ymax=1,
transform=ax.get_xaxis_transform(),
ls=end_phase_line_style,
lw=end_phase_line_width,
colors=end_phase_line_color,
zorder=3,
)
ax.annotate(
text=end_phase_text.format(phase),
xy=(len(data), 0.85 - 0.075 * i),
xytext=(-5, 0),
xycoords=ax.get_xaxis_transform(),
textcoords="offset points",
ha="right",
fontsize="small",
bbox={"facecolor": "#e0e0e0", "alpha": 0.7, "boxstyle": "round"},
zorder=5,
)
def _hr_ensemble_plot_subphase_vspans(
ax: plt.Axes, data: Dict[str, pd.DataFrame], subphases: Dict[str, Dict[str, int]], **kwargs
):
"""Add subphase vertical spans (vspans) to heart rate ensemble plot.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
axes object
data : :class:`~pandas.DataFrame`
data belonging to ``phase``
subphases : dict
dictionary with phases (keys) and subphases (values - dict with subphase names and subphase durations)
"""
subphase_times = [get_subphase_durations(df, subphases[phase]) for phase, df in data.items()]
subphase_times = np.array(subphase_times)
subphase_times = np.max(subphase_times, axis=0)
subphase_names = np.array([list(subphase_dict.keys()) for phase, subphase_dict in subphases.items()])
if not (subphase_names[0] == subphase_names).all():
raise ValueError("Subphases must be the same for all phases!")
subphase_names = subphase_names[0]
bg_colors = kwargs.get("background_color", _hr_ensemble_plot_params.get("background_color"))
if bg_colors is None:
bg_color_base = kwargs.get("background_base_color", _hr_ensemble_plot_params.get("background_base_color"))
bg_colors = list(sns.dark_palette(bg_color_base, n_colors=len(subphase_names), reverse=True))
bg_alphas = kwargs.get("background_alpha", _hr_ensemble_plot_params.get("background_alpha"))
bg_alphas = [bg_alphas] * len(subphase_names)
for i, subphase in enumerate(subphase_names):
start, end = subphase_times[i]
color = bg_colors[i]
alpha = bg_alphas[i]
ax.axvspan(start, end, color=color, alpha=alpha, zorder=0, lw=0)
ax.text(
x=start + 0.5 * (end - start),
y=0.95,
transform=ax.get_xaxis_transform(),
zorder=3,
s=subphase,
ha="center",
va="center",
)
p = mpatch.Rectangle(
xy=(0, 0.9),
width=1,
height=0.1,
transform=ax.transAxes,
color="white",
alpha=0.4,
zorder=1,
lw=0,
)
ax.add_patch(p)
ax.set_xticks([start for (start, end) in subphase_times])
[docs]def hr_mean_plot( # pylint:disable=too-many-branches
data: MeanSeDataFrame,
**kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
r"""Plot course of heart rate as mean ± standard error over phases (and subphases) of a psychological protocol.
The correct plot is automatically inferred from the provided data:
* only ``phase`` index level: plot phases over x axis
* ``phase`` and ``subphase`` index levels: plot subphases over x axis, highlight phases as vertical spans
* additionally: ``condition`` level: plot data of different conditions individually
(corresponds to ``hue`` parameter in :func:`~biopsykit.plotting.lineplot`)
Parameters
----------
data : :class:`~biopsykit.utils.datatype_helper.MeanSeDataFrame`
Heart rate data to plot. Must be provided as ``MeanSeDataFrame`` with columns ``mean`` and ``se``
computed over phases (and, if available, subphases)
**kwargs
additional parameters to be passed to the plot, such as:
* ``ax``: pre-existing axes for the plot. Otherwise, a new figure and axes object is created and returned.
* ``figsize``: tuple specifying figure dimensions
* ``palette``: color palette to plot data from different conditions. If ``palette`` is a str then it is
assumed to be the name of a ``fau_colors`` palette (``fau_colors.cmaps._fields``).
* ``is_relative``: boolean indicating whether heart rate data is relative (in % relative to baseline)
or absolute (in bpm). Default: ``False``
* ``order``: list specifying the order of categorical values (i.e., conditions) along the x axis.
* ``x_offset``: offset value to move different groups along the x axis for better visualization.
Default: 0.05
* ``xlabel``: label of x axis. Default: "Subphases" (if subphases are present)
or "Phases" (if only phases are present).
* ``ylabel``: label of y axis. Default: ":math:`\Delta HR [%]`"
* ``ylims``: list to manually specify y axis limits, float to specify y axis margin
(see :meth:`~matplotlib.axes.Axes.margins()` for further information), or ``None`` to automatically infer
y axis limits.
* ``marker``: string or list of strings to specify marker style.
If ``marker`` is a string, then marker of each line will have the same style.
If ``marker`` is a list, then marker of each line will have a different style.
* ``linestyle``: string or list of strings to specify line style.
If ``linestyle`` is a string, then each line will have the same style.
If ``linestyle`` is a list, then each line will have a different style.
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
figure object
ax : :class:`~matplotlib.axes.Axes`
axes object
See Also
--------
:func:`~biopsykit.plotting.lineplot`
Plot data as lineplot with mean and standard error
"""
fig, ax = _plot_get_fig_ax(**kwargs)
kwargs.update({"ax": ax})
num_conditions = 1
if "condition" in data.index.names:
num_conditions = len(data.index.names)
# get all plot parameter
palette = kwargs.get("palette", cmaps.faculties)
palette = _get_palette(palette, num_conditions)
sns.set_palette(palette)
ylabel_default = _hr_mean_plot_params.get("ylabel")
if kwargs.get("is_relative", False):
ylabel_default = r"$\Delta$ HR [%]"
ylabel = kwargs.get("ylabel", ylabel_default)
ylims = kwargs.get("ylims", None)
phase_dict = _hr_mean_get_phases_subphases(data)
num_phases = len(phase_dict)
num_subphases = [len(arr) for arr in phase_dict.values()]
x_vals = _hr_mean_get_x_vals(num_phases, num_subphases)
# build x axis, axis limits and limits for phase spans
dist = np.mean(np.ediff1d(x_vals))
x_lims = np.append(x_vals, x_vals[-1] + dist)
x_lims = x_lims - 0.5 * np.ediff1d(x_lims, to_end=dist)
if "condition" in data.index.names:
data_grp = dict(tuple(data.groupby("condition")))
order = kwargs.get("order", list(data_grp.keys()))
data_grp = {key: data_grp[key] for key in order}
for i, (key, df) in enumerate(data_grp.items()):
_hr_mean_plot(df, x_vals, key, index=i, **kwargs)
else:
_hr_mean_plot(data, x_vals, "Data", index=0, **kwargs)
# add decorators to phases if subphases are present
if sum(num_subphases) > 0:
_hr_mean_plot_subphase_annotations(phase_dict, x_lims, **kwargs)
# customize x axis
ax.tick_params(axis="x", bottom=True)
ax.set_xticks(x_vals)
ax.set_xlim(np.min(x_lims), np.max(x_lims))
_hr_mean_style_x_axis(ax, phase_dict, num_subphases)
# customize y axis
ax.tick_params(axis="y", which="major", left=True)
ax.set_ylabel(ylabel)
_hr_mean_plot_set_axis_lims(ylims, ax)
# customize legend
if "condition" in data.index.names:
_hr_mean_add_legend(**kwargs)
fig.tight_layout()
return fig, ax
def _hr_mean_plot_set_axis_lims(ylims: Union[Sequence[float], float], ax: plt.Axes):
if isinstance(ylims, (tuple, list)):
ax.set_ylim(ylims)
else:
ymargin = 0.15
if isinstance(ylims, float):
ymargin = ylims
ax.margins(y=ymargin)
ax.margins(x=0)
ax.relim()
def _hr_mean_plot(data: MeanSeDataFrame, x_vals: np.array, key: str, index: int, **kwargs):
ax: plt.Axes = kwargs.get("ax")
x_offset = kwargs.get("x_offset", _hr_mean_plot_params.get("x_offset"))
marker = kwargs.get("marker", _hr_mean_plot_params.get("marker"))
linestyle = kwargs.get("linestyle", _hr_mean_plot_params.get("linestyle"))
if isinstance(marker, list):
marker = marker[index]
if isinstance(linestyle, list):
linestyle = linestyle[index]
is_mean_se_dataframe(data)
if isinstance(data.columns, pd.MultiIndex):
# if data has multiindex columns: drop all levels except the last one
# (which is expected to contain ["mean", "se"])
data.columns = data.columns.droplevel(list(range(0, data.columns.nlevels - 1)))
ax.errorbar(
x=x_vals + index * x_offset,
y=data["mean"],
label=key,
yerr=data["se"],
capsize=3,
marker=marker,
linestyle=linestyle,
)
def _hr_mean_add_legend(**kwargs):
"""Add legend to mean HR plot."""
ax: plt.Axes = kwargs.get("ax")
legend_loc = kwargs.get("legend_loc", "upper left")
# get handles
handles, labels = ax.get_legend_handles_labels()
# remove the errorbars
handles = [h[0] for h in handles]
# use them in the legend
if legend_loc == "upper left":
bbox_to_anchor = (0.01, 0.90)
elif legend_loc == "upper right":
bbox_to_anchor = (0.99, 0.90)
else:
bbox_to_anchor = None
ax.legend(
handles,
labels,
loc=legend_loc,
bbox_to_anchor=bbox_to_anchor,
numpoints=1,
)
def _hr_mean_style_x_axis(ax: plt.Axes, phase_dict: Dict[str, Sequence[str]], num_subphases: Sequence[int], **kwargs):
"""Style x axis of mean HR plot.
Parameters
----------
ax : :class:`~matplotlib.axes.Axes`
axes object
phase_dict : dict
dictionary with phase names (keys) and dict of subphases (values)
num_subphases : list
list with number of subphases for each phase
"""
if sum(num_subphases) == 0:
# no subphases
ax.set_xticklabels(phase_dict.keys())
ax.set_xlabel(kwargs.get("xlabel", "Phases"))
else:
ax.set_xticklabels([s for subph in phase_dict.values() for s in subph])
ax.set_xlabel(kwargs.get("xlabel", "Subphases"))
def _hr_mean_plot_subphase_annotations(phase_dict: Dict[str, Sequence[str]], xlims: Sequence[float], **kwargs):
"""Add subphase annotations to mean HR plot.
Parameters
----------
phase_dict : dict
dictionary with phase names (keys) and dict of subphases (values)
xlims : list
x axis limits
"""
ax: plt.Axes = kwargs.get("ax")
num_phases = len(phase_dict)
num_subphases = [len(arr) for arr in phase_dict.values()]
bg_colors = kwargs.get("background_color", _hr_ensemble_plot_params.get("background_color"))
if bg_colors is None:
bg_color_base = kwargs.get("background_base_color", _hr_ensemble_plot_params.get("background_base_color"))
bg_colors = list(sns.dark_palette(bg_color_base, n_colors=num_phases, reverse=True))
bg_alphas = kwargs.get("background_alpha", _hr_ensemble_plot_params.get("background_alpha"))
bg_alphas = [bg_alphas] * num_phases
phase_text = kwargs.get("phase_text", _hr_mean_plot_params.get("phase_text"))
x_spans = _hr_mean_get_x_spans(num_phases, num_subphases)
for (i, phase) in enumerate(phase_dict):
left, right = x_spans[i]
bg_color = bg_colors[i]
bg_alpha = bg_alphas[i]
ax.axvspan(xlims[left], xlims[right], color=bg_color, alpha=bg_alpha, zorder=0, lw=0)
name = phase_text.format(phase)
ax.text(
x=xlims[left] + 0.5 * (xlims[right] - xlims[left]),
y=0.95,
s=name,
transform=ax.get_xaxis_transform(),
horizontalalignment="center",
verticalalignment="center",
zorder=3,
)
p = mpatch.Rectangle(
xy=(0, 0.9),
width=1,
height=0.1,
transform=ax.transAxes,
color="white",
alpha=0.4,
zorder=1,
lw=0,
)
ax.add_patch(p)
def _hr_mean_get_x_spans(num_phases: int, num_subphases: Sequence[int]):
if sum(num_subphases) == 0:
x_spans = list(zip([0, *list(range(0, num_phases))], list(range(0, num_phases))))
else:
x_spans = list(zip([0, *list(np.cumsum(num_subphases))], list(np.cumsum(num_subphases))))
return x_spans
def _hr_mean_get_x_vals(num_phases: int, num_subphases: Sequence[int]):
x_vals = np.linspace(0, 10, num_phases) if sum(num_subphases) == 0 else np.linspace(0, 10, sum(num_subphases))
return x_vals
def _hr_mean_get_phases_subphases(data: pd.DataFrame) -> Dict[str, Sequence[str]]:
if "condition" in data.index.names:
data = [value for key, value in data.groupby("condition")][0]
phases = data.index.get_level_values("phase").unique()
if "subphase" in data.index.names:
phase_dict = {phase: list(df.index.get_level_values("subphase")) for phase, df in data.groupby("phase")}
else:
phase_dict = {phase: [] for phase in phases}
return phase_dict
[docs]def saliva_plot( # pylint:disable=too-many-branches
data: Union[
SalivaRawDataFrame, SalivaMeanSeDataFrame, Dict[str, SalivaRawDataFrame], Dict[str, SalivaMeanSeDataFrame]
],
saliva_type: Optional[str] = None,
sample_times: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None,
test_times: Optional[Sequence[int]] = None,
sample_times_absolute: Optional[bool] = False,
remove_s0: Optional[bool] = False,
**kwargs,
) -> Optional[Tuple[plt.Figure, plt.Axes]]:
r"""Plot saliva data during psychological protocol as mean ± standard error.
The function accepts raw saliva data per subject (:obj:`~biopsykit.utils.datatype_helper.SalivaRawDataFrame`)
as well as pre-computed mean and standard error values of saliva samples
( :obj:`~biopsykit.utils.datatype_helper.SalivaMeanSeDataFrame`). To combine data from multiple saliva
types (maximum: 2) into one plot a dict can be passed to ``data``.
If a psychological test (e.g., TSST, MIST, or Stroop) was performed, the test time is highlighted as vertical span
within the plot.
.. note::
If no sample times are provided (neither via ``time`` column in ``data`` nor via ``sample_times`` parameter)
then ``samples`` will be used as x axis
Parameters
----------
data : :obj:`~biopsykit.utils.datatype_helper.SalivaRawDataFrame`, \
:obj:`~biopsykit.utils.datatype_helper.SalivaMeanSeDataFrame`, or dict of such
Saliva data to plot. Must either be provided as ``SalivaRawDataFrame`` with raw saliva data per subject or
as ``SalivaMeanSeDataFrame`` with columns ``mean`` and ``se`` computed per saliva sample. To plot data from
multiple saliva types (maximum: 2) a dict can be passed (keys: saliva types, values: saliva data).
saliva_type : {"cortisol", "amylase", "il6"}, optional
saliva type to be plotted. If a dict is passed and ``saliva_type`` is ``None``
the saliva types are inferred from dict keys.
sample_times : list or dict of lists
sample times in minutes relative to psychological test or a dict of such if sample times are different for
the individual saliva types.
test_times : list of int, optional
start and end times of psychological test (in minutes) or ``None`` if no test was performed
sample_times_absolute : bool, optional
``True`` if absolute sample times were provided (i.e., the duration of the psychological test was already
added to the sample times), ``False`` if relative sample times were provided and absolute times should be
computed based on test times specified by ``test_times``. Default: ``False``
remove_s0 : bool, optional
whether to remove the first saliva sample for plotting or not. Default: ``False``
**kwargs
additional parameters to be passed to the plot.
To style general plot appearance:
* ``ax``: pre-existing axes for the plot. Otherwise, a new figure and axes object is created and returned.
* ``palette``: color palette to plot data from different phases
* ``figsize``: tuple specifying figure dimensions
* ``marker``: string or list of strings to specify marker style.
If ``marker`` is a string, then the markers of each line will have the same style.
If ``marker`` is a list, then the markers of each line will have a different style.
* ``linestyle``: string or list of strings to specify line style.
If ``linestyle`` is a string, then each line will have the same style.
If ``linestyle`` is a list, then each line will have a different style.
To style axes:
* ``x_offset``: offset value to move different groups along the x axis for better visualization.
Default: 0.05
* ``xlabel``: label of x axis. Default: "Subphases" (if subphases are present).
or "Phases" (if only phases are present)
* ``ylabel``: label of y axis. Default: ":math:`\Delta HR [%]`"
* ``ylims``: list to manually specify y axis limits, float to specify y axis margin
(see :meth:`~matplotlib.axes.Axes.margins()` for further information), or ``None`` to automatically infer
y axis limits.
To style the vertical span highlighting the psychological test in the plot:
* ``test_title``: title of test
* ``test_fontsize``: fontsize of the test title. Default: "medium"
* ``test_color``: color of vspan. Default: #9e9e9e
* ``test_alpha``: transparency value of vspan: Default: 0.5
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
figure object
ax : :class:`~matplotlib.axes.Axes`
axes object
See Also
--------
:func:`~biopsykit.plotting.lineplot`
Plot data as lineplot with mean and standard error
"""
fig, ax = _plot_get_fig_ax(**kwargs)
kwargs.update({"ax": ax})
if saliva_type is None and not isinstance(data, dict):
raise ValueError("If 'saliva_type' is None, you must pass a dict!")
if isinstance(data, pd.DataFrame):
# multiple saliva data were passed in a dict => get the selected saliva type
data = {saliva_type: data}
sample_times = {saliva_type: sample_times}
linestyle = kwargs.pop("linestyle", None)
marker = kwargs.pop("marker", "o")
palette = kwargs.pop("palette", None)
if isinstance(palette, str) and getattr(colors_all, palette, None):
palette = _get_palette(palette, len(data))
for i, key in enumerate(data):
df = data[key]
if remove_s0:
df = _remove_s0(df)
if sample_times is None:
st = None
else:
st = sample_times[key]
if remove_s0:
st = st[1:]
kwargs_copy = _saliva_plot_extract_style_params(key, linestyle, marker, palette, **kwargs)
_saliva_plot(
data=df,
saliva_type=key,
counter=i,
sample_times=st,
test_times=test_times,
sample_times_absolute=sample_times_absolute,
**kwargs_copy,
)
test_times = test_times or [0, 0]
test_title = kwargs.get("test_title", _saliva_plot_params.get("test_title"))
test_color = kwargs.get("test_color", _saliva_plot_params.get("test_color"))
test_alpha = kwargs.get("test_alpha", _saliva_plot_params.get("test_alpha"))
test_fontsize = kwargs.get("test_fontsize", _saliva_plot_params.get("test_fontsize"))
if sum(test_times) != 0:
ax.axvspan(*test_times, color=test_color, alpha=test_alpha, zorder=1, lw=0)
ax.text(
x=test_times[0] + 0.5 * (test_times[1] - test_times[0]),
y=0.95,
transform=ax.get_xaxis_transform(),
s=test_title,
fontsize=test_fontsize,
horizontalalignment="center",
verticalalignment="top",
)
if len(data) > 1:
saliva_plot_combine_legend(fig, saliva_types=list(data.keys()), **kwargs)
else:
fig.tight_layout()
return fig, ax
def _saliva_plot_extract_style_params(
key: str,
linestyle: Union[Dict[str, str], str],
marker: Union[Dict[str, str], str],
palette: Union[Dict[str, str], str],
**kwargs,
):
ls = _saliva_plot_get_plot_param(linestyle, key)
if linestyle is not None:
kwargs.setdefault("linestyle", ls)
m = _saliva_plot_get_plot_param(marker, key)
if marker is not None:
kwargs.setdefault("marker", m)
cmap = _saliva_plot_get_plot_param(palette, key)
if palette is not None:
kwargs.setdefault("palette", cmap)
return kwargs
def _saliva_plot_sanitize_dicts(
data: Union[Dict[str, pd.DataFrame], pd.DataFrame], ylabel: Union[Dict[str, str], str], saliva_type: str
):
if isinstance(ylabel, dict):
ylabel = ylabel[saliva_type]
if isinstance(data, dict):
# multiple saliva data were passed in a dict => get the selected saliva type
data = data[saliva_type]
return data, ylabel
def _saliva_plot(
data: Union[SalivaRawDataFrame, SalivaMeanSeDataFrame],
saliva_type: str,
counter: int,
sample_times: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None,
test_times: Optional[Sequence[int]] = None,
sample_times_absolute: Optional[bool] = False,
**kwargs,
):
ax: plt.Axes = kwargs.get("ax")
test_times = test_times or [0, 0]
xlabel = kwargs.get("xlabel", _saliva_plot_params.get("xlabel"))
ylabel = kwargs.get("ylabel", _saliva_plot_params.get("ylabel"))
xticks = kwargs.get("xticks")
xaxis_tick_locator = kwargs.get("xaxis_tick_locator")
data, ylabel = _saliva_plot_sanitize_dicts(data, ylabel, saliva_type)
_assert_saliva_data_input(data, saliva_type)
data = data.copy()
if sample_times is None and "time" not in data.reset_index().columns:
x = "sample"
xlabel = "Sample"
else:
sample_times = _get_sample_times(data, sample_times, test_times, sample_times_absolute)
if "time" in data.index.names:
data.index = data.index.droplevel("time")
data["time"] = sample_times * int(len(data) / len(sample_times))
x = "time"
kwargs.setdefault("hue", "condition" if "condition" in data.index.names else None)
kwargs.setdefault("style", kwargs.get("hue"))
kwargs.setdefault("marker", "o")
groups = kwargs.get("hue", None)
if groups is not None and groups in data.index.names:
num_groups = len(data.index.get_level_values(groups).unique())
else:
num_groups = 2
if counter == 0 and len(ax.lines) == 0:
kwargs.setdefault("palette", _get_palette("fau", num_groups))
else:
kwargs.setdefault("palette", _get_palette("tech", num_groups))
# the was already something drawn into the axis => we are using the same axis to add another feature
ax_twin = ax.twinx()
kwargs.update({"ax": ax_twin, "show_legend": False})
kwargs.update({"xlabel": xlabel, "ylabel": ylabel})
lineplot(data=data, x=x, y=saliva_type, **kwargs)
_saliva_plot_style_xaxis(xticks, xaxis_tick_locator, ax)
def _assert_saliva_data_input(data: pd.DataFrame, saliva_type: str):
ret = is_saliva_raw_dataframe(data, saliva_type, raise_exception=False)
ret = ret or is_saliva_mean_se_dataframe(data, raise_exception=False)
if not ret:
raise ValidationError("'data' is expected to be either a SalivaRawDataFrame or a SalivaMeanSeDataFrame!")
def _saliva_plot_get_plot_param(param: Union[Dict[str, str], str], key: str):
p = param[key] if isinstance(param, dict) else param
return p
def _saliva_plot_style_xaxis(xticks: Sequence[str], xaxis_tick_locator: mticks.Locator, ax: plt.Axes):
if xticks is not None and xaxis_tick_locator is not None:
ax.xaxis.set_major_locator(xaxis_tick_locator)
ax.xaxis.set_ticks(xticks)
[docs]def saliva_plot_combine_legend(fig: plt.Figure, ax: plt.Axes, saliva_types: Sequence[str], **kwargs):
"""Combine multiple legends of ``saliva_plot`` into one joint legend outside of plot.
If data from multiple saliva types are combined into one plot (e.g., by calling
:func:`~biopsykit.protocols.plotting.saliva_plot` on the same plot twice) then two separate legend are created.
This function can be used to combine the two legends into one.
Parameters
----------
fig : :class:`~matplotlib.figure.Figure`
figure object
ax : :class:`~matplotlib.axes.Axes`
axes object
saliva_types : list
list of saliva types in plot
**kwargs
additional arguments to customize plot, such as:
* ``legend_loc``: Location of legend. Default: ``upper center``
* ``legend_size``: Legend size. Default: ``small``
* ``rect``: Rectangle in normalized figure coordinates into which the whole subplots area
(including labels) will fit. Used to conveniently place legend outside of figure.
"""
legend_loc = kwargs.get("legend_loc", "upper center")
legend_size = kwargs.get("legend_size", "small")
rect = kwargs.get("rect", (0, 0, 1.0, 0.95))
labels = [ax.get_legend_handles_labels()[1] for ax in fig.get_axes()]
if all(len(label) == 1 for label in labels):
# only one group
handles = [ax.get_legend_handles_labels()[0] for ax in fig.get_axes()]
handles = [h[0] for handle in handles for h in handle]
labels = [_saliva_plot_params.get("legend_title")[b] for b in saliva_types]
ncol = len(handles)
fig.legend(
handles,
labels,
loc=legend_loc,
ncol=ncol,
prop={"size": legend_size},
)
else:
handles = [ax.get_legend_handles_labels()[0] for ax in fig.get_axes()]
handles = [h[0] for handle in handles for h in handle]
labels = [ax.get_legend_handles_labels()[1] for ax in fig.get_axes()]
labels = [
"{}: {}".format(_saliva_plot_params.get("legend_title")[b], " - ".join(label))
for b, label in zip(saliva_types, labels)
]
ncol = len(handles)
fig.legend(
list(zip(handles[::2], handles[1::2])),
labels,
loc=legend_loc,
ncol=ncol,
numpoints=1,
handler_map={tuple: HandlerTuple(ndivide=None)},
prop={"size": legend_size},
)
ax.legend().remove()
fig.tight_layout(pad=1.0, rect=rect)
[docs]def saliva_feature_boxplot(
data: SalivaFeatureDataFrame,
x: str,
saliva_type: str,
hue: Optional[str] = None,
feature: Optional[str] = None,
stats_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
"""Draw a boxplot with significance brackets, specifically designed for saliva features.
This is a wrapper of :func:`~biopsykit.plotting.feature_boxplot` that can be used to plot saliva features and
allows to easily add significance brackets that indicate statistical significance.
.. note::
The input data is assumed to be in long-format.
Parameters
----------
data : :obj:`~biopsykit.utils.datatype_helper.SalivaFeatureDataFrame`
data to plot
x : str
column of x axis in ``data``
saliva_type : str
type of saliva data to plot
hue : str, optional
column name of grouping variable. Default: ``None``
feature : str, optional
name of feature to plot or ``None``
stats_kwargs : dict, optional
dictionary with arguments for significance brackets
**kwargs
additional arguments that are passed to :func:`~biopsykit.plotting.feature_boxplot` and :func:`~seaborn.boxplot`
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
figure object
ax : :class:`~matplotlib.axes.Axes`
axes object
See Also
--------
:func:`~biopsykit.plotting.feature_boxplot`
plot features as boxplot
:class:`~biopsykit.stats.stats.StatsPipeline`
class to create statistical analysis pipelines and get parameter for plotting significance brackets
"""
is_saliva_feature_dataframe(data, saliva_type)
if feature is not None:
if isinstance(feature, str):
feature = [feature]
ylabel = _saliva_feature_boxplot_get_ylabels(saliva_type, feature)
ylabel = [ylabel[f] for f in feature]
if len(set(ylabel)) == 1:
kwargs.setdefault("ylabel", ylabel[0])
if hue is not None:
xticklabels = list(_saliva_feature_boxplot_get_xticklabels({f: f for f in feature}).values())
xticklabels = [x[0] for x in xticklabels]
kwargs.setdefault("xticklabels", xticklabels)
return feature_boxplot(data=data, x=x, y=saliva_type, stats_kwargs=stats_kwargs, **kwargs)
[docs]def saliva_multi_feature_boxplot(
data: SalivaFeatureDataFrame,
saliva_type: str,
features: Union[Sequence[str], Dict[str, Union[str, Sequence[str]]]],
hue: Optional[str] = None,
stats_kwargs: Optional[Dict] = None,
**kwargs,
) -> Tuple[plt.Figure, Iterable[plt.Axes]]:
"""Draw multiple features as boxplots with significance brackets, specifically designed for saliva features.
This is a wrapper of :func:`~biopsykit.plotting.multi_feature_boxplot` that can be used to plot saliva features and
allows to easily add significance brackets that indicate statistical significance.
.. note::
The input data is assumed to be in long-format.
Parameters
----------
data : :obj:`~biopsykit.utils.datatype_helper.SalivaFeatureDataFrame`
data to plot
saliva_type : str
type of saliva data to plot
hue : str, optional
column name of grouping variable. Default: ``None``
features : list of str or dict of str
features to plot. If ``features`` is a list, each entry must correspond to one feature category in the
index level specified by ``group``. A separate subplot will be created for each feature.
If similar features (i.e., different `slope` or `AUC` parameters) should be combined into one subplot,
``features`` can be provided as dictionary.
Then, the dict keys specify the feature category (a separate subplot will be created for each category)
and the dict values specify the feature (or list of features) that are combined into the subplots.
stats_kwargs : dict, optional
nested dictionary with arguments for significance brackets.
See :func:`~biopsykit.plotting.feature_boxplot` for further information
Returns
-------
fig : :class:`~matplotlib.figure.Figure`
figure object
axs : list of :class:`matplotlib.axes.Axes`
list of subplot axes objects
See Also
--------
:func:`~biopsykit.plotting.multi_feature_boxplot`
plot multiple features as boxplots
:class:`~biopsykit.stats.stats.StatsPipeline`
class to create statistical analysis pipelines and get parameter for plotting significance brackets
"""
x = kwargs.pop("x", "saliva_feature")
if isinstance(features, str):
# ensure list
features = [features]
if isinstance(features, list):
features = {f: [f] if isinstance(f, str) else f for f in features}
kwargs.setdefault("xticklabels", _saliva_feature_boxplot_get_xticklabels(features))
kwargs.setdefault("ylabels", _saliva_feature_boxplot_get_ylabels(saliva_type, features))
return multi_feature_boxplot(
data,
x=x,
y=saliva_type,
group="saliva_feature",
features=features,
hue=hue,
stats_kwargs=stats_kwargs,
**kwargs,
)
def _saliva_feature_boxplot_get_xticklabels(features: Dict[str, str]) -> Dict[str, Sequence[str]]:
xlabel_dict = {}
for feature in features:
cols = features[feature]
if isinstance(cols, str):
cols = [cols]
labels = []
for c in cols:
if "slope" in c:
label = _saliva_feature_params["xticklabels"]["slope"].replace("§", re.findall(r"slope(\w+)", c)[0])
else:
label = _saliva_feature_params["xticklabels"][c]
labels.append(label)
xlabel_dict[feature] = labels
return xlabel_dict
def _saliva_feature_boxplot_get_ylabels(saliva_type: str, features: Union[str, Sequence[str]]) -> Dict[str, str]:
ylabels = _saliva_feature_params["ylabel"][saliva_type]
if isinstance(features, str):
features = [features]
for feature in features:
if "slope" in feature:
ylabels[feature] = ylabels["slope"]
return ylabels
def _plot_get_fig_ax(**kwargs):
ax: plt.Axes = kwargs.get("ax", None)
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.get("figsize"))
else:
fig = ax.get_figure()
return fig, ax