"""finds bad channels."""
from copy import copy
import mne
import numpy as np
from mne.utils import check_random_state, logger
from scipy import signal
from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend
from pyprep.utils import _filter_design, _mad, _mat_iqr, _mat_quantile
[docs]
class NoisyChannels:
"""Detect bad channels in an EEG recording using a range of methods.
This class provides a number of methods for detecting bad channels across a
full-session EEG recording. Specifically, this class implements all of the
noisy channel detection methods used in the PREP pipeline, as described in [1]_.
The detection methods in this class can be run independently, or can be run
all at once using the :meth:`~.find_all_bads` method.
At present, only EEG channels are supported and any non-EEG channels in the
provided data will be ignored.
Parameters
----------
raw : mne.io.Raw
An MNE Raw object to check for bad EEG channels.
do_detrend : bool, optional
Whether or not low-frequency (<1.0 Hz) trends should be removed from the
EEG signal prior to bad channel detection. This should always be set to
``True`` unless the signal has already had low-frequency trends removed.
Defaults to ``True``.
random_state : {int, None, np.random.RandomState}, optional
The seed to use for random number generation within RANSAC. This can be
``None``, an integer, or a :class:`~numpy.random.RandomState` object.
If ``None``, a random seed will be obtained from the operating system.
Defaults to ``None``.
matlab_strict : bool, optional
Whether or not PyPREP should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code
(see :ref:`matlab-diffs` for more details). Defaults to ``False``.
References
----------
.. [1] Bigdely-Shamlo, N., Mullen, T., Kothe, C., Su, K. M., Robbins, K. A.
(2015). The PREP pipeline: standardized preprocessing for large-scale
EEG analysis. Frontiers in Neuroinformatics, 9, 16.
"""
def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False):
# Make sure that we got an MNE object
assert isinstance(raw, mne.io.BaseRaw)
raw.load_data()
self.raw_mne = raw.copy()
self.raw_mne.pick_types(eeg=True)
self.sample_rate = raw.info["sfreq"]
if do_detrend:
self.raw_mne._data = removeTrend(
self.raw_mne.get_data(), self.sample_rate, matlab_strict=matlab_strict
)
self.matlab_strict = matlab_strict
# Extra data for debugging
self._extra_info = {
"bad_by_deviation": {},
"bad_by_hf_noise": {},
"bad_by_correlation": {},
"bad_by_dropout": {},
"bad_by_ransac": {},
}
# random_state
self.random_state = check_random_state(random_state)
# The identified bad channels
self.bad_by_nan = []
self.bad_by_flat = []
self.bad_by_deviation = []
self.bad_by_hf_noise = []
self.bad_by_correlation = []
self.bad_by_SNR = []
self.bad_by_dropout = []
self.bad_by_ransac = []
# Get original EEG channel names, channel count & samples
ch_names = np.asarray(self.raw_mne.info["ch_names"])
self.ch_names_original = ch_names
self.n_chans_original = len(ch_names)
self.n_samples = raw.get_data().shape[1]
# Before anything else, flag bad-by-NaNs and bad-by-flats
self.find_bad_by_nan_flat()
bads_by_nan_flat = self.bad_by_nan + self.bad_by_flat
# Make a subset of the data containing only usable EEG channels
self.usable_idx = np.isin(ch_names, bads_by_nan_flat, invert=True)
self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx])
self.EEGFiltered = None
# Get usable EEG channel names & channel counts
self.ch_names_new = np.asarray(ch_names[self.usable_idx])
self.n_chans_new = len(self.ch_names_new)
def _get_filtered_data(self):
"""Apply a [1 Hz - 50 Hz] bandpass filter to the EEG signal.
Only applied if the sample rate is above 100 Hz to avoid violating the
Nyquist theorem.
"""
if self.sample_rate <= 100:
return self.EEGData.copy()
bandpass_filter = _filter_design(
N_order=100,
amp=np.array([1, 1, 0, 0]),
freq=np.array([0, 90 / self.sample_rate, 100 / self.sample_rate, 1]),
)
EEG_filt = np.zeros_like(self.EEGData)
for i in range(EEG_filt.shape[0]):
EEG_filt[i, :] = signal.filtfilt(bandpass_filter, 1, self.EEGData[i, :])
return EEG_filt
[docs]
def get_bads(self, verbose=False, as_dict=False):
"""Get the names of all channels currently flagged as bad.
Note that this method does not perform any bad channel detection itself,
and only reports channels already detected as bad by other methods.
Parameters
----------
verbose : bool, optional
If ``True``, a summary of the channels currently flagged as by bad per
category is printed. Defaults to ``False``.
as_dict: bool, optional
If ``True``, this method will return a dict of the channels currently
flagged as bad by each individual bad channel type. If ``False``, this
method will return a list of all unique bad channels detected so far.
Defaults to ``False``.
Returns
-------
bads : list or dict
The names of all bad channels detected so far, either as a combined
list or a dict indicating the channels flagged bad by each type.
"""
bads = {
"bad_by_nan": self.bad_by_nan,
"bad_by_flat": self.bad_by_flat,
"bad_by_deviation": self.bad_by_deviation,
"bad_by_hf_noise": self.bad_by_hf_noise,
"bad_by_correlation": self.bad_by_correlation,
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
}
all_bads = set()
for bad_chs in bads.values():
all_bads.update(bad_chs)
name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
if verbose:
out = f"Found {len(all_bads)} uniquely bad channels:\n"
for bad_type, bad_chs in bads.items():
bad_type = bad_type.replace("bad_by_", "")
if bad_type in name_map.keys():
bad_type = name_map[bad_type]
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
logger.info(out)
if as_dict:
bads["bad_all"] = list(all_bads)
else:
bads = list(all_bads)
return bads
[docs]
def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.
This function calls all the bad-channel detecting functions.
Parameters
----------
ransac : bool, optional
Whether RANSAC should be used for bad channel detection, in addition
to the other methods. RANSAC can detect bad channels that other
methods are unable to catch, but also slows down noisy channel
detection considerably. Defaults to ``True``.
channel_wise : bool, optional
Whether RANSAC should predict signals for chunks of channels over the
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
parameter). If ``False``, RANSAC will instead predict signals for all
channels at once but over a number of smaller time windows instead of
over the entire signal length ("window-wise RANSAC"). Channel-wise
RANSAC generally has higher RAM demands than window-wise RANSAC
(especially if `max_chunk_size` is ``None``), but can be faster on
systems with lots of RAM to spare. Has no effect if not using RANSAC.
Defaults to ``False``.
max_chunk_size : {int, None}, optional
The maximum number of channels to predict at once during
channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
size that will fit into the available RAM, which may slow down
other programs on the host system. If using window-wise RANSAC
(the default) or not using RANSAC at all, this parameter has no
effect. Defaults to ``None``.
"""
# NOTE: Bad-by-NaN/flat is already run during init, no need to re-run here
self.find_bad_by_deviation()
self.find_bad_by_hfnoise()
self.find_bad_by_correlation()
self.find_bad_by_SNR()
if ransac:
self.find_bad_by_ransac(
channel_wise=channel_wise, max_chunk_size=max_chunk_size
)
[docs]
def find_bad_by_nan_flat(self, flat_threshold=1e-15):
"""Detect channels than contain NaN values or have near-flat signals.
A channel is considered flat if its standard deviation or its median
absolute deviation from the median (MAD) are below the provided flat
threshold (default: ``1e-15`` volts).
This method is run automatically when a ``NoisyChannels`` object is
initialized, preventing flat or NaN-containing channels from interfering
with the detection of other types of bad channels.
Parameters
----------
flat_threshold : float, optional
The lowest standard deviation or MAD value for a channel to be
considered bad-by-flat. Defaults to ``1e-15`` volts (corresponds to
10e-10 µV in MATLAB PREP).
"""
# Get all EEG channels from original copy of data
EEGData = self.raw_mne.get_data()
# Detect channels containing any NaN values
nan_channel_mask = np.isnan(np.sum(EEGData, axis=1))
nan_channels = self.ch_names_original[nan_channel_mask]
# Detect channels with flat or extremely weak signals
flat_by_mad = _mad(EEGData, axis=1) < flat_threshold
flat_by_stdev = np.std(EEGData, axis=1) < flat_threshold
flat_channel_mask = flat_by_mad | flat_by_stdev
flat_channels = self.ch_names_original[flat_channel_mask]
# Update names of bad channels by NaN or flat signal
self.bad_by_nan = nan_channels.tolist()
self.bad_by_flat = flat_channels.tolist()
[docs]
def find_bad_by_deviation(self, deviation_threshold=5.0):
"""Detect channels with abnormally high or low overall amplitudes.
A channel is considered "bad-by-deviation" if its amplitude deviates
considerably from the median channel amplitude, as calculated using a
robust Z-scoring method and the given deviation threshold.
Amplitude Z-scores are calculated using the formula
``(channel_amplitude - median_amplitude) / amplitude_sd``, where
channel amplitudes are calculated using a robust outlier-resistant estimate
of the signals' standard deviations (IQR scaled to units of SD), and the
amplitude SD is the IQR-based SD of those amplitudes.
Parameters
----------
deviation_threshold : float, optional
The minimum absolute z-score of a channel for it to be considered
bad-by-deviation. Defaults to ``5.0``.
"""
IQR_TO_SD = 0.7413 # Scales units of IQR to units of SD, assuming normality
# Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/IQR.html
# Get channel amplitudes and the median / robust SD of those amplitudes
chan_amplitudes = _mat_iqr(self.EEGData, axis=1) * IQR_TO_SD
amp_sd = _mat_iqr(chan_amplitudes) * IQR_TO_SD
amp_median = np.nanmedian(chan_amplitudes)
# Calculate robust Z-scores for the channel amplitudes
amplitude_zscore = np.zeros(self.n_chans_original)
amplitude_zscore[self.usable_idx] = (chan_amplitudes - amp_median) / amp_sd
# Flag channels with amplitudes that deviate excessively from the median
abnormal_amplitude = np.abs(amplitude_zscore) > deviation_threshold
deviation_channel_mask = np.isnan(amplitude_zscore) | abnormal_amplitude
# Update names of bad channels by excessive deviation & save additional info
deviation_channels = self.ch_names_original[deviation_channel_mask]
self.bad_by_deviation = deviation_channels.tolist()
self._extra_info["bad_by_deviation"].update(
{
"median_channel_amplitude": amp_median,
"channel_amplitude_sd": amp_sd,
"robust_channel_deviations": amplitude_zscore,
}
)
[docs]
def find_bad_by_hfnoise(self, HF_zscore_threshold=5.0):
"""Detect channels with abnormally high amounts of high-frequency noise.
The noisiness of a channel is defined as the amplitude of its
high-frequency (>50 Hz) components divided by its overall amplitude.
A channel is considered "bad-by-high-frequency-noise" if its noisiness
is considerably higher than the median channel noisiness, as determined
by a robust Z-scoring method and the given Z-score threshold.
Due to the Nyquist theorem, this method will only attempt bad channel
detection if the sample rate of the given signal is above 100 Hz.
Parameters
----------
HF_zscore_threshold : float, optional
The minimum noisiness z-score of a channel for it to be considered
bad-by-high-frequency-noise. Defaults to ``5.0``.
"""
MAD_TO_SD = 1.4826 # Scales units of MAD to units of SD, assuming normality
# Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/mad.html
if self.EEGFiltered is None:
self.EEGFiltered = self._get_filtered_data()
# Set default values for noise parameters
noise_median, noise_sd = (0, 1)
noise_zscore = np.zeros(self.n_chans_original)
# If sample rate is high enough, calculate ratio of > 50 Hz amplitude to
# < 50 Hz amplitude for each channel and get robust z-scores of values
if self.sample_rate > 100:
noisiness = np.divide(
_mad(self.EEGData - self.EEGFiltered, axis=1),
_mad(self.EEGFiltered, axis=1),
)
noise_median = np.nanmedian(noisiness)
noise_sd = np.median(np.abs(noisiness - noise_median)) * MAD_TO_SD
noise_zscore[self.usable_idx] = (noisiness - noise_median) / noise_sd
# Flag channels with much more high-frequency noise than the median channel
hf_mask = np.isnan(noise_zscore) | (noise_zscore > HF_zscore_threshold)
hf_noise_channels = self.ch_names_original[hf_mask]
# Update names of high-frequency noise channels & save additional info
self.bad_by_hf_noise = hf_noise_channels.tolist()
self._extra_info["bad_by_hf_noise"].update(
{
"median_channel_noisiness": noise_median,
"channel_noisiness_sd": noise_sd,
"hf_noise_zscores": noise_zscore,
}
)
[docs]
def find_bad_by_correlation(
self, correlation_secs=1.0, correlation_threshold=0.4, frac_bad=0.01
):
"""Detect channels that sometimes don't correlate with any other channels.
Channel correlations are calculated by splitting the recording into
non-overlapping windows of time (default: 1 second), getting the absolute
correlations of each usable channel with every other usable channel for
each window, and then finding the highest correlation each channel has
with another channel for each window (by taking the 98th percentile of
the absolute correlations).
A correlation window is considered "bad" for a channel if its maximum
correlation with another channel is below the provided correlation
threshold (default: ``0.4``). A channel is considered "bad-by-correlation"
if its fraction of bad correlation windows is above the bad fraction
threshold (default: ``0.01``).
This method also detects channels with intermittent dropouts (i.e.,
regions of flat signal). A channel is considered "bad-by-dropout" if
its fraction of correlation windows with a completely flat signal is
above the bad fraction threshold (default: ``0.01``).
Parameters
----------
correlation_secs : float, optional
The length (in seconds) of each correlation window. Defaults to ``1.0``.
correlation_threshold : float, optional
The lowest maximum inter-channel correlation for a channel to be
considered "bad" within a given window. Defaults to ``0.4``.
frac_bad : float, optional
The minimum proportion of bad windows for a channel to be considered
"bad-by-correlation" or "bad-by-dropout". Defaults to ``0.01`` (1% of
all windows).
"""
IQR_TO_SD = 0.7413 # Scales units of IQR to units of SD, assuming normality
# Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/IQR.html
if self.EEGFiltered is None:
self.EEGFiltered = self._get_filtered_data()
# Determine the number and size (in frames) of correlation windows
win_size = int(correlation_secs * self.sample_rate)
win_offsets = np.arange(1, (self.n_samples - win_size), win_size)
win_count = len(win_offsets)
# Initialize per-window arrays for each type of noise info calculated below
max_correlations = np.ones((win_count, self.n_chans_original))
dropout = np.zeros((win_count, self.n_chans_original), dtype=bool)
noiselevels = np.zeros((win_count, self.n_chans_original))
channel_amplitudes = np.zeros((win_count, self.n_chans_original))
for w in range(win_count):
# Get both filtered and unfiltered data for the current window
start, end = (w * win_size, (w + 1) * win_size)
eeg_filtered = self.EEGFiltered[:, start:end]
eeg_raw = self.EEGData[:, start:end]
# Get channel amplitude info for the window
usable = self.usable_idx.copy()
channel_amplitudes[w, usable] = _mat_iqr(eeg_raw, axis=1) * IQR_TO_SD
# Check for any channel dropouts (flat signal) within the window
eeg_amplitude = _mad(eeg_filtered, axis=1)
dropout[w, usable] = eeg_amplitude == 0
# Exclude any dropout chans from further calculations (avoids div-by-zero)
usable[usable] = eeg_amplitude > 0
eeg_raw = eeg_raw[eeg_amplitude > 0, :]
eeg_filtered = eeg_filtered[eeg_amplitude > 0, :]
eeg_amplitude = eeg_amplitude[eeg_amplitude > 0]
# Get high-frequency noise ratios for the window
high_freq_amplitude = _mad(eeg_raw - eeg_filtered, axis=1)
noiselevels[w, usable] = high_freq_amplitude / eeg_amplitude
# Get inter-channel correlations for the window
win_correlations = np.corrcoef(eeg_filtered)
abs_corr = np.abs(win_correlations - np.diag(np.diag(win_correlations)))
max_correlations[w, usable] = _mat_quantile(abs_corr, 0.98, axis=0)
max_correlations[w, dropout[w, :]] = 0 # Set dropout correlations to 0
# Flag channels with above-threshold fractions of bad correlation windows
thresholded_correlations = max_correlations < correlation_threshold
fraction_bad_corr_windows = np.mean(thresholded_correlations, axis=0)
bad_correlation_mask = fraction_bad_corr_windows > frac_bad
bad_correlation_channels = self.ch_names_original[bad_correlation_mask]
# Flag channels with above-threshold fractions of drop-out windows
fraction_dropout_windows = np.mean(dropout, axis=0)
dropout_mask = fraction_dropout_windows > frac_bad
dropout_channels = self.ch_names_original[dropout_mask]
# Update names of low-correlation/dropout channels & save additional info
self.bad_by_correlation = bad_correlation_channels.tolist()
self.bad_by_dropout = dropout_channels.tolist()
self._extra_info["bad_by_correlation"] = {
"max_correlations": np.transpose(max_correlations),
"median_max_correlations": np.median(max_correlations, axis=0),
"bad_window_fractions": fraction_bad_corr_windows,
}
self._extra_info["bad_by_dropout"] = {
"dropouts": np.transpose(dropout.astype(np.int8)),
"bad_window_fractions": fraction_dropout_windows,
}
self._extra_info["bad_by_deviation"]["channel_amplitudes"] = channel_amplitudes
self._extra_info["bad_by_hf_noise"]["noise_levels"] = noiselevels
[docs]
def find_bad_by_SNR(self):
"""Detect channels that have a low signal-to-noise ratio.
Channels are considered "bad-by-SNR" if they are bad by both high-frequency
noise and bad by low correlation.
"""
# Get names of bad-by-HF-noise and bad-by-correlation channels
if not len(self._extra_info["bad_by_hf_noise"]) > 1:
self.find_bad_by_hfnoise()
if not len(self._extra_info["bad_by_correlation"]):
self.find_bad_by_correlation()
bad_by_hf = set(self.bad_by_hf_noise)
bad_by_corr = set(self.bad_by_correlation)
# Flag channels bad by both HF noise and low correlation as bad by low SNR
self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf))
[docs]
def find_bad_by_ransac(
self,
n_samples=50,
sample_prop=0.25,
corr_thresh=0.75,
frac_bad=0.4,
corr_window_secs=5.0,
channel_wise=False,
max_chunk_size=None,
):
"""Detect channels that are predicted poorly by other channels.
This method uses a random sample consensus approach (RANSAC, see [1]_,
and a short discussion in [2]_) to try and predict what the signal should
be for each channel based on the signals and spatial locations of other
currently-good channels. RANSAC correlations are calculated by splitting
the recording into non-overlapping windows of time (default: 5 seconds)
and correlating each channel's RANSAC-predicted signal with its actual
signal within each window.
A RANSAC window is considered "bad" for a channel if its predicted signal
vs. actual signal correlation falls below the given correlation threshold
(default: ``0.75``). A channel is considered "bad-by-RANSAC" if its fraction
of bad RANSAC windows is above the given threshold (default: ``0.4``).
Due to its random sampling component, the channels identified as
"bad-by-RANSAC" may vary slightly between calls of this method.
Additionally, bad channels may vary between different montages given that
RANSAC's signal predictions are based on the spatial coordinates of each
electrode.
This method is a wrapper for the :func:`~ransac.find_bad_by_ransac`
function.
.. warning:: For optimal performance, RANSAC requires that channels bad by
deviation, correlation, and/or dropout have already been
flagged. Otherwise RANSAC will attempt to use those channels
when making signal predictions, decreasing accuracy and thus
increasing the likelihood of false positives.
Parameters
----------
n_samples : int, optional
Number of random channel samples to use for RANSAC. Defaults
to ``50``.
sample_prop : float, optional
Proportion of total channels to use for signal prediction per RANSAC
sample. This needs to be in the range [0, 1], where 0 would mean no
channels would be used and 1 would mean all channels would be used
(neither of which would be useful values). Defaults to ``0.25``
(e.g., 16 channels per sample for a 64-channel dataset).
corr_thresh : float, optional
The minimum predicted vs. actual signal correlation for a channel to
be considered good within a given RANSAC window. Defaults
to ``0.75``.
frac_bad : float, optional
The minimum fraction of bad (i.e., below-threshold) RANSAC windows
for a channel to be considered bad-by-RANSAC. Defaults to ``0.4``.
corr_window_secs : float, optional
The duration (in seconds) of each RANSAC correlation window. Defaults
to 5 seconds.
channel_wise : bool, optional
Whether RANSAC should predict signals for chunks of channels over the
entire signal length ("channel-wise RANSAC", see `max_chunk_size`
parameter). If ``False``, RANSAC will instead predict signals for all
channels at once but over a number of smaller time windows instead of
over the entire signal length ("window-wise RANSAC"). Channel-wise
RANSAC generally has higher RAM demands than window-wise RANSAC
(especially if `max_chunk_size` is ``None``), but can be faster on
systems with lots of RAM to spare. Defaults to ``False``.
max_chunk_size : {int, None}, optional
The maximum number of channels to predict at once during
channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk
size that will fit into the available RAM, which may slow down
other programs on the host system. If using window-wise RANSAC
(the default), this parameter has no effect. Defaults to ``None``.
References
----------
.. [1] Fischler, M.A., Bolles, R.C. (1981). Random sample consensus: A
Paradigm for Model Fitting with Applications to Image Analysis and
Automated Cartography. Communications of the ACM, 24, 381-395
.. [2] Jas, M., Engemann, D.A., Bekhti, Y., Raimondo, F., Gramfort, A.
(2017). Autoreject: Automated Artifact Rejection for MEG and EEG
Data. NeuroImage, 159, 417-429
"""
if self.EEGFiltered is None:
self.EEGFiltered = self._get_filtered_data()
exclude_from_ransac = (
self.bad_by_correlation + self.bad_by_deviation + self.bad_by_dropout
)
rng = copy(self.random_state) if self.matlab_strict else self.random_state
self.bad_by_ransac, ch_correlations_usable = find_bad_by_ransac(
self.EEGFiltered,
self.sample_rate,
self.ch_names_new,
self.raw_mne._get_channel_positions()[self.usable_idx, :],
exclude_from_ransac,
n_samples,
sample_prop,
corr_thresh,
frac_bad,
corr_window_secs,
channel_wise,
max_chunk_size,
rng,
self.matlab_strict,
)
# Reshape correlation matrix to match original channel count
n_ransac_windows = ch_correlations_usable.shape[0]
ch_correlations = np.ones((n_ransac_windows, self.n_chans_original))
ch_correlations[:, self.usable_idx] = ch_correlations_usable
self._extra_info["bad_by_ransac"] = {
"ransac_correlations": ch_correlations,
"bad_window_fractions": np.mean(ch_correlations < corr_thresh, axis=0),
}