"""The dual threshold algorithm for detecting oscillatory bursts in a neural signal."""
import numpy as np
from neurodsp.utils.core import get_avg_func
from neurodsp.utils.checks import check_param_options
from neurodsp.utils.decorators import multidim
from neurodsp.timefrequency.hilbert import amp_by_time
###################################################################################################
###################################################################################################
[docs]@multidim()
def detect_bursts_dual_threshold(sig, fs, dual_thresh, f_range=None,
min_n_cycles=3, min_burst_duration=None,
avg_type='median', magnitude_type='amplitude',
**filter_kwargs):
"""Detect bursts in a signal using the dual threshold algorithm.
Parameters
----------
sig : 1d array
Time series.
fs : float
Sampling rate, in Hz.
dual_thresh : tuple of (float, float)
Low and high threshold values for burst detection.
Units are normalized by the average signal magnitude.
f_range : tuple of (float, float), optional
Frequency range, to filter signal to, before running burst detection.
If f_range is None, then no filtering is applied prior to running burst detection.
min_n_cycles : float, optional, default: 3
Minimum burst duration in to keep.
Only used if `f_range` is defined, and is used as the number of cycles at f_range[0].
min_burst_duration : float, optional, default: None
Minimum length of a burst, in seconds. Must be defined if not filtering.
Only used if `f_range` is not defined, or if `min_n_cycles` is set to None.
avg_type : {'median', 'mean'}, optional
Averaging method to use to normalize the magnitude that is used for thresholding.
magnitude_type : {'amplitude', 'power'}, optional
Metric of magnitude used for thresholding.
**filter_kwargs
Keyword parameters to pass to `filter_signal`.
Returns
-------
is_burst : 1d array
Boolean indication of where bursts are present in the input signal.
True indicates that a burst was detected at that sample, otherwise False.
Notes
-----
The dual-threshold burst detection algorithm was originally proposed in [1]_.
References
----------
.. [1] Feingold, J., Gibson, D. J., DePasquale, B., & Graybiel, A. M. (2015).
Bursts of beta oscillation differentiate postperformance activity in
the striatum and motor cortex of monkeys performing movement tasks.
Proceedings of the National Academy of Sciences, 112(44), 13687–13692.
DOI: https://doi.org/10.1073/pnas.1517629112
Examples
--------
Detect bursts using the dual threshold algorithm:
>>> from neurodsp.sim import sim_combined
>>> sig = sim_combined(n_seconds=10, fs=500,
... components={'sim_synaptic_current': {},
... 'sim_bursty_oscillation' : {'freq': 10}},
... component_variances=[0.1, 0.9])
>>> is_burst = detect_bursts_dual_threshold(sig, fs=500, dual_thresh=(1, 2), f_range=(8, 12))
"""
if len(dual_thresh) != 2:
raise ValueError("Invalid number of elements in 'dual_thresh' parameter")
# Compute amplitude time series
sig_magnitude = amp_by_time(sig, fs, f_range, remove_edges=False, **filter_kwargs)
# Set magnitude as power or amplitude: square if power, leave as is if amplitude
check_param_options(magnitude_type, 'magnitude_type', ['amplitude', 'power'])
if magnitude_type == 'power':
sig_magnitude = sig_magnitude**2
# Calculate normalized magnitude
sig_magnitude = sig_magnitude / get_avg_func(avg_type)(sig_magnitude)
# Identify time periods of bursting using the 2 thresholds
is_burst = _dual_threshold_split(sig_magnitude, dual_thresh[1], dual_thresh[0])
# Remove bursts detected that are too short
# Use a number of cycles defined on the frequency range, if available
if f_range is not None and min_n_cycles is not None:
min_burst_samples = int(np.ceil(min_n_cycles * fs / f_range[0]))
# Otherwise, make sure minimum duration is set, and use that
else:
if min_burst_duration is None:
raise ValueError("Minimum burst duration must be defined if not filtering "
"and using a number of cycles threshold.")
min_burst_samples = int(np.ceil(min_burst_duration * fs))
is_burst = _rmv_short_periods(is_burst, min_burst_samples)
return is_burst.astype(bool)
def _dual_threshold_split(sig, thresh_hi, thresh_lo):
"""Identify periods that are above thresh_lo and have at least one value above thresh_hi."""
# Find all values above thresh_hi
# To avoid bug in later loop, do not allow first or last index to start off as 1
sig[[0, -1]] = 0
idx_over_hi = np.where(sig >= thresh_hi)[0]
# Initialize values in identified period
positive = np.zeros(len(sig))
positive[idx_over_hi] = 1
# Iteratively test if a value is above thresh_lo if it is not currently in an identified period
sig_len = len(sig)
for ind in idx_over_hi:
j_down = ind - 1
if positive[j_down] == 0:
j_down_done = False
while j_down_done is False:
if sig[j_down] >= thresh_lo:
positive[j_down] = 1
j_down -= 1
if j_down < 0:
j_down_done = True
else:
j_down_done = True
j_up = ind + 1
if positive[j_up] == 0:
j_up_done = False
while j_up_done is False:
if sig[j_up] >= thresh_lo:
positive[j_up] = 1
j_up += 1
if j_up >= sig_len:
j_up_done = True
else:
j_up_done = True
return positive
def _rmv_short_periods(sig, n_samples):
"""Remove periods that are equal to 1 for less than n_samples."""
if np.sum(sig) == 0:
return sig
osc_changes = np.diff(1 * sig)
osc_starts = np.where(osc_changes == 1)[0]
osc_ends = np.where(osc_changes == -1)[0]
if len(osc_starts) == 0:
osc_starts = [0]
if len(osc_ends) == 0:
osc_ends = [len(osc_changes)]
if osc_ends[0] < osc_starts[0]:
osc_starts = np.insert(osc_starts, 0, 0)
if osc_ends[-1] < osc_starts[-1]:
osc_ends = np.append(osc_ends, len(osc_changes))
osc_length = osc_ends - osc_starts
osc_starts_long = osc_starts[osc_length >= n_samples]
osc_ends_long = osc_ends[osc_length >= n_samples]
is_osc = np.zeros(len(sig))
for ind in range(len(osc_starts_long)):
is_osc[osc_starts_long[ind]:osc_ends_long[ind]] = 1
return is_osc