Source code for bycycle.plts.burst

"""Plot burst detection parameters."""

from itertools import cycle

import numpy as np
from scipy.stats import zscore
import matplotlib.pyplot as plt

from neurodsp.plts import plot_time_series, plot_bursts
from neurodsp.plts.utils import savefig

from bycycle.plts.cyclepoints import plot_cyclepoints_df
from bycycle.utils import limit_df, limit_signal, get_extrema_df
from bycycle.utils.checks import check_param_range

###################################################################################################
###################################################################################################

[docs]@savefig def plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs, xlim=None, figsize=(15, 3), plot_only_result=False, interp=True): """Plot the cycle-by-cycle burst detection parameters and burst detection summary. Parameters ---------- df_features : pandas.DataFrame Dataframe output of :func:`~.compute_features`. The df must contain sample indices (i.e. when ``return_samples = True``). sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. threshold_kwargs : dict Burst parameter keys and threshold value pairs, as defined in the 'threshold_kwargs' argument of :func:`.compute_features`. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. figsize : tuple of (float, float), optional, default: (15, 3) Size of each plot. plot_only_result : bool, optional, default: False Plot only the signal and bursts, excluding burst parameter plots. interp : bool, optional, default: True If True, interpolates between given values. Otherwise, plots in a step-wise fashion. Notes ----- - If plot_only_result = True: return a plot of the burst detection in which periods with bursts are denoted in red. - If plot_only_result = False: return a list of the fig handle followed by the 5 axes. - In the top plot, the raw signal is plotted in black, and the red line indicates periods defined as oscillatory bursts. The highlighted regions indicate when each burst requirement was violated, color-coded consistently with the plots below. - blue: amp_fraction_threshold - red: amp_consistency_threshold - yellow: period_consistency_threshold - green: monotonicity_threshold Examples -------- Plot the burst detection summary of a bursting signal: >>> from bycycle.features import compute_features >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5, ... 'period_consistency_threshold': .5, 'monotonicity_threshold': .8} >>> df_features = compute_features(sig, fs, f_range=(8, 12), threshold_kwargs=threshold_kwargs) >>> plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Normalize signal sig_full = zscore(sig) times_full = np.arange(0, len(sig_full) / fs, 1 / fs) # Limit arrays and dataframe if xlim is not None: sig, times = limit_signal(times_full, sig_full, start=xlim[0], stop=xlim[1]) df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1], reset_indices=False) else: sig, times, = sig_full, times_full # Determine if peak of troughs are the sides of an oscillation _, side_e = get_extrema_df(df_features) # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot) thresholds = threshold_kwargs.copy() if 'min_n_cycles' in thresholds.keys(): del thresholds['min_n_cycles'] n_kwargs = len(thresholds.keys()) # Create figure and subplots if plot_only_result: fig, axes = plt.subplots(figsize=figsize, nrows=1) axes = [axes] else: fig, axes = plt.subplots(figsize=(figsize[0], figsize[1]*(n_kwargs+1)), nrows=n_kwargs+1, sharex=True) # Determine which samples are defined as bursting is_osc = np.zeros(len(sig), dtype=bool) df_osc = df_features.loc[df_features['is_burst']] start = 0 if xlim is None else xlim[0] for cyc in df_osc.to_dict('records'): samp_start_burst = int(cyc['sample_last_' + side_e]) - int(fs * start) samp_end_burst = int(cyc['sample_next_' + side_e] + 1) - int(fs * start) is_osc[samp_start_burst:samp_end_burst] = True # Plot bursts with extrema points xlabel = 'Time (s)' if len(axes) == 1 else '' plot_bursts(times, sig, is_osc, ax=axes[0], lw=2, labels=['Signal', 'Bursts'], xlabel='', ylabel='') plot_cyclepoints_df(df_features, sig_full, fs, xlim=xlim, ax=axes[0], plot_zerox=False, plot_sig=False, xlabel=xlabel, ylabel='Voltage\n(normalized)', colors=['m', 'c']) # Plot each burst param colors = cycle(['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange']) for idx, osc_key in enumerate(thresholds.keys()): column = osc_key.replace('_threshold', '') color = next(colors) # Highlight where a burst param falls below threshold for cyc in df_features.to_dict('records'): last_cyc = int(cyc['sample_last_' + side_e]) - int(fs * start) next_cyc = int(cyc['sample_next_' + side_e]) - int(fs * start) if cyc[column] < threshold_kwargs[osc_key] and last_cyc > 0: axes[0].axvspan(times[last_cyc], times[next_cyc], alpha=0.5, color=color, lw=0) # Plot each burst param on separate axes if not plot_only_result: ylabel = column.replace('_', ' ').capitalize() xlabel = 'Time (s)' if idx == n_kwargs-1 else '' plot_burst_detect_param(df_features, sig_full, fs, column, thresholds[osc_key], xlim=xlim, figsize=figsize, ax=axes[idx+1], xlabel=xlabel, ylabel=ylabel, color=color, interp=interp)
[docs]@savefig def plot_burst_detect_param(df_features, sig, fs, burst_param, thresh, xlim=None, interp=True, ax=None, **kwargs): """Plot a burst detection parameter and threshold. Parameters ---------- df_features : pandas.DataFrame Dataframe output of :func:`~.compute_features`. sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. burst_param : str Column name of the parameter of interest in ``df``. thresh : float The burst parameter threshold. Parameter values greater than ``thresh`` are considered bursts. xlim : tuple of (float, float), optional, default: None Start and stop times for plot. interp : bool, optional, default: True Interpolates points if true. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`. Notes ----- Default keyword arguments include: - ``figsize``: tuple of (float, float), default: (15, 3) - ``xlabel``: str, default: 'Time (s)' - ``ylabel``: str, default: 'Voltage (uV) - ``color``: str, default: 'r'. - Note: ``color`` here is the fill color, rather than line color. Examples -------- Plot the monotonicity of a bursting signal: >>> from bycycle.features import compute_features >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5, ... 'period_consistency_threshold': .5, 'monotonicity_threshold': .8} >>> df_features = compute_features(sig, fs, f_range=(8, 12), ... threshold_kwargs=threshold_kwargs) >>> plot_burst_detect_param(df_features, sig, fs, 'monotonicity', .8) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Set default kwargs figsize = kwargs.pop('figsize', (15, 3)) xlabel = kwargs.pop('xlabel', 'Time (s)') ylabel = kwargs.pop('ylabel', burst_param) color = kwargs.pop('color', 'r') # Determine time array and limits times = np.arange(0, len(sig) / fs, 1 / fs) if ax is None: fig, ax = plt.subplots(figsize=figsize) # Determine extrema strings center_e, side_e = get_extrema_df(df_features) if xlim is not None: # Limit dataframe, sig and times df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1]) sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) # Remove start / end cycles that tlims falls between df_features = df_features[(df_features['sample_last_' + side_e] >= 0) & \ (df_features['sample_next_' + side_e] < xlim[1]*fs)] # Plot burst param if interp: plot_time_series([times[df_features['sample_' + center_e]], (times[0], times[-1])], [df_features[burst_param], [thresh]*2], ax=ax, colors=['k', 'k'], ls=['-', '--'], marker=["o", None], xlabel=xlabel, ylabel="{0:s}\nthreshold={1:.2f}".format(ylabel, thresh), **kwargs) else: # Create steps, from side to side of each cycle, and set the y-value # to the burst parameter value for that cycle side_times = np.array([]) side_param = np.array([]) for cyc in df_features.to_dict('records'): # Get the times for the last and next side of a cycle side_times = np.append(side_times, [times[int(cyc['sample_last_' + side_e])], times[int(cyc['sample_next_' + side_e])]]) # Set the y-value, from side to side, to the burst param for each cycle side_param = np.append(side_param, [cyc[burst_param]] * 2) plot_time_series([side_times, (times[0], times[-1])], [side_param, [thresh]*2], ax=ax, colors=['k', 'k'], ls=['-', '--'], marker=["o", None], xlabel=xlabel, ylabel="{0:s}\nthreshold={1:.2f}".format(ylabel, thresh), **kwargs) # Highlight where param falls below threshold for cyc in df_features.to_dict('records'): if cyc[burst_param] <= thresh: ax.axvspan(times[int(cyc['sample_last_' + side_e])], times[int(cyc['sample_next_' + side_e])], alpha=0.5, color=color, lw=0)