Source code for bycycle.plts.cyclepoints

"""Plot extrema and zero-crossings."""

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts import plot_time_series
from neurodsp.plts.utils import savefig

from bycycle.utils.checks import check_param_range
from bycycle.utils import limit_signal, get_extrema_df

###################################################################################################
###################################################################################################

[docs]@savefig def plot_cyclepoints_df(df_samples, sig, fs, plot_sig=True, plot_extrema=True, plot_zerox=True, xlim=None, ax=None, **kwargs): """Plot extrema and/or zero-crossings from a DataFrame. Parameters ---------- df_samples : pandas.DataFrame Dataframe output of :func:`~.compute_cyclepoints`. sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. plot_sig : bool, optional, default: True Whether to also plot the raw signal. plot_extrema : bool, optional, default: True Whether to plots the peaks and troughs. plot_zerox : bool, optional, default: True Whether to plots the zero-crossings. xlim : tuple of (float, float), optional Start and stop times. ax : matplotlib.Axes, optional Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`. Notes ----- Default keyword arguments include: - ``figsize``: tuple of (float, float), default: (15, 3) - ``xlabel``: str, default: 'Time (s)' - ``ylabel``: str, default: 'Voltage (uV) Examples -------- Plot cyclepoints using a dataframe from :func:`~.compute_cyclepoints`: >>> from bycycle.features import compute_cyclepoints >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> df_samples = compute_cyclepoints(sig, fs, f_range=(8, 12)) >>> plot_cyclepoints_df(df_samples, sig, fs) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Determine extrema/zero-crossings from dataframe center_e, side_e = get_extrema_df(df_samples) peaks, troughs, rises, decays = [None]*4 if plot_extrema: peaks = df_samples['sample_' + center_e].values troughs = np.append(df_samples['sample_last_' + side_e].values, df_samples['sample_next_' + side_e].values) troughs = np.unique(troughs) if plot_zerox: rises = df_samples['sample_zerox_rise'].values decays = df_samples['sample_zerox_decay'].values plot_cyclepoints_array(sig, fs, peaks=peaks, troughs=troughs, rises=rises, decays=decays, plot_sig=plot_sig, xlim=xlim, ax=ax, **kwargs)
[docs]@savefig def plot_cyclepoints_array(sig, fs, peaks=None, troughs=None, rises=None, decays=None, plot_sig=True, xlim=None, ax=None, **kwargs): """Plot extrema and/or zero-crossings from arrays. Parameters ---------- sig : 1d array Time series to plot. fs : float Sampling rate, in Hz. peaks : 1d array, optional Peak signal indices from :func:`.find_extrema`. troughs : 1d array, optional Trough signal indices from :func:`.find_extrema`. rises : 1d array, optional Zero-crossing rise indices from :func:`~.find_zerox`. decays : 1d array, optional Zero-crossing decay indices from :func:`~.find_zerox`. plot_sig : bool, optional, default: True Whether to also plot the raw signal. xlim : tuple of (float, float), optional Start and stop times. ax : matplotlib.Axes, optional, default: None Figure axes upon which to plot. **kwargs Keyword arguments to pass into `plot_time_series`. Notes ----- Default keyword arguments include: - ``figsize``: tuple of (float, float), default: (15, 3) - ``xlabel``: str, default: 'Time (s)' - ``ylabel``: str, default: 'Voltage (uV) - ``colors``: list, default: ['k', 'b', 'r', 'g', 'm'] Examples -------- Plot cyclepoints using arrays from :func:`.find_extrema` and :func:`~.find_zerox`: >>> from bycycle.cyclepoints import find_extrema, find_zerox >>> from neurodsp.sim import sim_bursty_oscillation >>> fs = 500 >>> sig = sim_bursty_oscillation(10, fs, freq=10) >>> peaks, troughs = find_extrema(sig, fs, f_range=(8, 12), boundary=0) >>> rises, decays = find_zerox(sig, peaks, troughs) >>> plot_cyclepoints_array(sig, fs, peaks=peaks, troughs=troughs, rises=rises, decays=decays) """ # Ensure arguments are within valid range check_param_range(fs, 'fs', (0, np.inf)) # Set times and limits times = np.arange(0, len(sig) / fs, 1 / fs) # Restrict sig and times to xlim if xlim is not None: sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1]) # Set default kwargs figsize = kwargs.pop('figsize', (15, 3)) xlabel = kwargs.pop('xlabel', 'Time (s)') ylabel = kwargs.pop('ylabel', 'Voltage (uV)') default_colors = ['b', 'r', 'g', 'm'] # Extend plotting based on given arguments x_values = [] y_values = [] colors = ['k'] for idx, points in enumerate([peaks, troughs, rises, decays]): if points is not None: # Limit times and shift indices of cyclepoints (cps) cps = points[(points >= times[0]*fs) & (points < times[-1]*fs)] cps = cps - int(times[0]*fs) y_values.append(sig[cps]) x_values.append(times[cps]) colors.append(default_colors[idx]) # Allow custom colors to overwrite default colors = kwargs.pop('colors', colors) if ax is None: fig, ax = plt.subplots(figsize=figsize) if plot_sig: plot_time_series(times, sig, colors=colors[0], ax=ax) colors = colors[1:] plot_time_series(x_values, y_values, ax=ax, xlabel=xlabel, ylabel=ylabel, colors=colors, marker='o', ls='', **kwargs)