"""Bycycle class objects."""
import warnings
import numpy as np
from neurodsp.plts.utils import savefig
import pandas as pd
from bycycle.features import compute_features
from bycycle.group import compute_features_2d, compute_features_3d
from bycycle.plts import plot_burst_detect_summary
from bycycle.burst.utils import recompute_edges as rc_edges
###################################################################################################
###################################################################################################
class BycycleBase:
"""Shared base sub-class."""
def __init__(self, center_extrema='peak', burst_method='cycles', burst_kwargs=None,
thresholds=None, find_extrema_kwargs=None, return_samples=True):
# Compute features settings
self.center_extrema = center_extrema
self.burst_method = burst_method
self.burst_kwargs = {} if burst_kwargs is None else burst_kwargs
# Thresholds
if not isinstance(thresholds, dict):
warnings.warn("""
No burst detection thresholds are provided. This is not recommended. Please
inspect your data and choose appropriate parameters for 'thresholds'.
Default burst detection parameters are likely not well suited for your
desired application.
""")
if thresholds is None and burst_method == 'cycles':
self.thresholds = {
'amp_fraction_threshold': 0.,
'amp_consistency_threshold': .5,
'period_consistency_threshold': .5,
'monotonicity_threshold': .8,
'min_n_cycles': 3
}
elif thresholds is None and burst_method == 'amp':
self.thresholds = {
'burst_fraction_threshold': 1,
'min_n_cycles': 3
}
else:
self.thresholds = thresholds
# Allow shorthand (e.g. monotonicicy instead of monotonicity_threshold)
if isinstance(self.thresholds, dict):
for k in list(self.thresholds.keys()):
if not k.endswith('_threshold') and k != 'min_n_cycles':
self.thresholds[k + '_threshold'] = self.thresholds.pop(k)
if find_extrema_kwargs is None:
self.find_extrema_kwargs = {'filter_kwargs': {'n_cycles': 3}}
else:
self.find_extrema_kwargs = find_extrema_kwargs
self.return_samples = return_samples
# Compute features args
self.sig = None
self.fs = None
self.f_range = None
# Results
self.df_features = None
def reduce_thresholds(self, reduction):
"""Adjust thresholds by a given amount.
Parameters
----------
reduction : float, optional, default: None
Reduces all float thresholds by given amount.
Returns
-------
reduced_thresholds : dict
Copy of thresholds with reduction applied.
"""
reduction = 0 if reduction is None else reduction
reduced_thresholds = {}
for k, v in self.thresholds.items():
if k.endswith('threshold'):
reduced_thresholds[k] = v - reduction
else:
reduced_thresholds[k] = v
return reduced_thresholds
[docs]class Bycycle(BycycleBase):
"""Compute bycycle features from a signal.
Attributes
----------
df_features : pandas.DataFrame
A dataframe containing shape and burst features for each cycle.
sig : 1d array
Voltage time series.
fs : float
Sampling rate, in Hz.
f_range : tuple of (float, float)
Frequency range for narrowband signal of interest (Hz).
center_extrema : {'peak', 'trough'}
The center extrema in the cycle.
- 'peak' : cycles are defined trough-to-trough
- 'trough' : cycles are defined peak-to-peak
burst_method : {'cycles', 'amp'}
Method for detecting bursts.
- 'cycles': detect bursts based on the consistency of consecutive periods & amplitudes
- 'amp': detect bursts using an amplitude threshold
burst_kwargs : dict, optional, default: None
Additional keyword arguments defined in :func:`~.compute_burst_fraction` for dual
amplitude threshold burst detection (i.e. when burst_method='amp').
threshold_kwargs : dict, optional, default: None
Feature thresholds for cycles to be considered bursts, matching keyword arguments for:
- :func:`~.detect_bursts_cycles` for consistency burst detection
(i.e. when burst_method='cycles')
- :func:`~.detect_bursts_amp` for amplitude threshold burst detection
(i.e. when burst_method='amp').
find_extrema_kwargs : dict, optional, default: None
Keyword arguments for function to find peaks an troughs (:func:`~.find_extrema`)
to change filter parameters or boundary. By default, the filter length is set to three
cycles of the low cutoff frequency (``f_range[0]``).
return_samples : bool, optional, default: True
Returns samples indices of cyclepoints used for determining features if True.
"""
[docs] def __init__(self, center_extrema='peak', burst_method='cycles', burst_kwargs=None,
thresholds=None, find_extrema_kwargs=None, return_samples=True):
"""Initialize object settings."""
super().__init__(center_extrema, burst_method, burst_kwargs, thresholds,
find_extrema_kwargs, return_samples)
def __getattr__(self, key):
"""Access df_features columns as class attributes.
Parameters
----------
key : str
Column name.
Returns
-------
1d-array
Column values.
"""
if key in {'__getstate__', '__setstate__'}:
return object.__getattr__(self, key)
elif (self.df_features is not None and key in self.df_features.keys()):
return self.df_features[key].values
else:
raise AttributeError(f'\'{self.__class__.__name__}\' object has no attribute \'{key}\'')
[docs] def fit(self, sig, fs, f_range):
"""Run the bycycle algorithm on a signal.
Parameters
----------
sig : 1d array
Time series.
fs : float
Sampling rate, in Hz.
f_range : tuple of (float, float)
Frequency range for narrowband signal of interest (Hz).
"""
if sig.ndim != 1:
raise ValueError('Signal must be 1-dimensional.')
# Add settings as attributes
self.sig = sig
self.fs = fs
self.f_range = f_range
self.df_features = compute_features(
self.sig, self.fs, self.f_range, self.center_extrema,
self.burst_method, self.burst_kwargs, self.thresholds,
self.find_extrema_kwargs, self.return_samples
)
[docs] def recompute_edges(self, reduction=None):
"""Recomputes features for cycles on the edge of bursts.
Parameters
----------
reduction : float, optional, default: None
Reduces all float thresholds by given amount.
"""
reduced_thresholds = self.reduce_thresholds(reduction)
self.df_features = rc_edges(self.df_features, reduced_thresholds)
[docs] @savefig
def plot(self, xlim=None, figsize=(15, 3), plot_only_results=False, interp=True):
"""Plot burst detection results.
Parameters
----------
xlim : tuple of (float, float), optional, default: None
Start and stop times for plot.
figsize : tuple of (float, float), optional, default: (15, 3)
Size of each plot.
plot_only_result : bool, optional, default: False
Plot only the signal and bursts, excluding burst parameter plots.
interp : bool, optional, default: True
If True, interpolates between given values. Otherwise, plots in a step-wise fashion.
"""
if self.df_features is None or self.sig is None or self.fs is None:
raise ValueError('The fit method must be successfully called prior to plotting.')
plot_burst_detect_summary(self.df_features, self.sig, self.fs, self.thresholds,
xlim, figsize, plot_only_results, interp)
[docs] def load(self, df_features, sig, fs, f_range):
"""Load external results."""
self.sig = sig
self.fs = fs
self.f_range = f_range
self.df_features = df_features
[docs]class BycycleGroup(BycycleBase):
"""Compute bycycle features for a 2d or 3d signal.
Attributes
----------
models : list of Bycycle
Fit Bycycle objects.
sigs : 2d or 3d array
Voltage time series.
fs : float
Sampling rate, in Hz.
f_range : tuple of (float, float)
Frequency range for narrowband signal of interest (Hz).
center_extrema : {'peak', 'trough'}
The center extrema in the cycle.
- 'peak' : cycles are defined trough-to-trough
- 'trough' : cycles are defined peak-to-peak
burst_method : {'cycles', 'amp'}
Method for detecting bursts.
- 'cycles': detect bursts based on the consistency of consecutive periods & amplitudes
- 'amp': detect bursts using an amplitude threshold
burst_kwargs : dict, optional, default: None
Additional keyword arguments defined in :func:`~.compute_burst_fraction` for dual
amplitude threshold burst detection (i.e. when burst_method='amp').
threshold_kwargs : dict, optional, default: None
Feature thresholds for cycles to be considered bursts, matching keyword arguments for:
- :func:`~.detect_bursts_cycles` for consistency burst detection
(i.e. when burst_method='cycles')
- :func:`~.detect_bursts_amp` for amplitude threshold burst detection
(i.e. when burst_method='amp').
find_extrema_kwargs : dict, optional, default: None
Keyword arguments for function to find peaks an troughs (:func:`~.find_extrema`)
to change filter parameters or boundary. By default, the filter length is set to three
cycles of the low cutoff frequency (``f_range[0]``).
axis : {0, 1, (0, 1), None}
For 2d arrays:
- ``axis=0`` : Iterates over each row/signal in an array independently (i.e. for each
channel in (n_channels, n_timepoints)).
- ``axis=None`` : Flattens rows/signals prior to computing features (i.e. across flatten
epochs in (n_epochs, n_timepoints)).
For 3d arrays:
- ``axis=0`` : Iterates over 2D slices along the zeroth dimension, (i.e. for each
channel in (n_channels, n_epochs, n_timepoints)).
- ``axis=1`` : Iterates over 2D slices along the first dimension (i.e. across flatten
epochs in (n_epochs, n_channels, n_timepoints)).
- ``axis=(0, 1)`` : Iterates over 1D slices along the zeroth and first dimensions (i.e
across each signal independently in (n_participants, n_channels, n_timepoints)).
return_samples : bool, optional, default: True
Returns samples indices of cyclepoints used for determining features if True.
"""
[docs] def __init__(self, center_extrema='peak', burst_method='cycles', burst_kwargs=None,
thresholds=None, find_extrema_kwargs=None, return_samples=True):
"""Initialize object settings."""
super().__init__(center_extrema, burst_method, burst_kwargs, thresholds,
find_extrema_kwargs, return_samples)
# 2d settings
self.axis = None
self.n_jobs = None
self.n_dims = None
# Results
self.models = []
def __len__(self):
"""Define the length of the object."""
return len(self.models)
def __iter__(self):
"""Allow for iterating across the object."""
for result in self.models:
yield result
def __getitem__(self, index):
"""Allow for indexing into the object."""
return self.models[index]
[docs] def fit(self, sigs, fs, f_range, axis=0, n_jobs=-1, progress=None):
"""Run the bycycle algorithm on a 2D or 3D array of signals.
Parameters
----------
sigs : 3d array
Voltage time series, with 2d or 3d shape.
fs : float
Sampling rate, in Hz.
f_range : tuple of (float, float)
Frequency range for narrowband signal of interest, in Hz.
recompute_edges : bool, optional, default: False
Recomputes features for cycles on the edge of bursts.
axis : {0, 1, (0, 1), None}
For 2d arrays:
- ``axis=0`` : Iterates over each row/signal in an array independently (i.e. for each
channel in (n_channels, n_timepoints)).
- ``axis=None`` : Flattens rows/signals prior to computing features (i.e. across flatten
epochs in (n_epochs, n_timepoints)).
For 3d arrays:
- ``axis=0`` : Iterates over 2D slices along the zeroth dimension, (i.e. for each
channel in (n_channels, n_epochs, n_timepoints)).
- ``axis=1`` : Iterates over 2D slices along the first dimension (i.e. across flatten
epochs in (n_epochs, n_channels, n_timepoints)).
- ``axis=(0, 1)`` : Iterates over 1D slices along the zeroth and first dimensions (i.e
across each signal independently in (n_participants, n_channels, n_timepoints)).
n_jobs : int, optional, default: -1
The number of jobs to compute features in parallel.
progress : {None, 'tqdm', 'tqdm.notebook'}
Specify whether to display a progress bar. Uses 'tqdm', if installed.
"""
if sigs.ndim not in (2, 3):
raise ValueError('Signal must be 2 or 3-dimensional.')
self.sigs = sigs
self.fs = fs
self.f_range = f_range
self.axis = axis
self.n_jobs = n_jobs
compute_features_kwargs = {
'center_extrema': self.center_extrema,
'burst_method': self.burst_method,
'burst_kwargs': self.burst_kwargs,
'threshold_kwargs': self.thresholds,
'find_extrema_kwargs': self.find_extrema_kwargs
}
compute_func = compute_features_2d if self.sigs.ndim == 2 else compute_features_3d
self.df_features = compute_func(
self.sigs, self.fs, self.f_range, compute_features_kwargs,
self.axis, self.return_samples, self.n_jobs, progress
)
# Initialize lists
if self.sigs.ndim == 3:
self.models = np.zeros((len(self.df_features), len(self.df_features[0]))).tolist()
else:
self.models = np.zeros(len(self.df_features)).tolist()
# Convert dataframes to Bycycle objects
self.n_dims = self.sigs.ndim
for dim0, sig in enumerate(self.sigs):
if self.n_dims == 3:
for dim1, sig_ in enumerate(sig):
# Intialize
bm = Bycycle(self.center_extrema, self.burst_method, self.burst_kwargs,
self.thresholds, self.find_extrema_kwargs, self.return_samples)
# Load
bm.load(self.df_features[dim0][dim1], sig_, self.fs, self.f_range)
# Set
self.models[dim0][dim1] = bm
else:
# Intialize
bm = Bycycle(self.center_extrema, self.burst_method, self.burst_kwargs,
self.thresholds, self.find_extrema_kwargs, self.return_samples)
# Load
bm.load(self.df_features[dim0], sig, self.fs, self.f_range)
# Set
self.models[dim0] = bm
[docs] def recompute_edges(self, reduction=None):
"""Recomputes features for cycles on the edge of bursts.
Parameters
----------
reduction : float, optional, default: None
Reduces all float thresholds by given amount.
"""
for dim0, sig in enumerate(self.sigs):
if self.n_dims == 3:
for dim1 in range(len(sig)):
self.models[dim0][dim1].recompute_edges(reduction)
else:
self.models[dim0].recompute_edges(reduction)