Source code for ransac

"""RANSAC bad channel identification."""
import mne
import numpy as np
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import ProgressBar, check_random_state, logger

from pyprep.utils import (
    _correlate_arrays,
    _get_random_subset,
    _mat_round,
    _split_list,
    _verify_free_ram,
)


[docs] def find_bad_by_ransac( data, sample_rate, complete_chn_labs, chn_pos, exclude, n_samples=50, sample_prop=0.25, corr_thresh=0.75, frac_bad=0.4, corr_window_secs=5.0, channel_wise=False, max_chunk_size=None, random_state=None, matlab_strict=False, ): """Detect channels that are not predicted well by other channels. Here, a RANSAC approach (see [1]_, and a short discussion in [2]_) is adopted to predict a "clean EEG" dataset. After identifying clean EEG channels through the other methods, the clean EEG dataset is constructed by repeatedly sampling a small subset of clean EEG channels and interpolation the complete data. The median of all those repetitions forms the clean EEG dataset. In a second step, the original and the RANSAC-predicted data are correlated and channels, which do not correlate well with themselves across the two datasets are considered `bad_by_ransac`. Parameters ---------- data : np.ndarray A 2-D array of detrended EEG data, with bad-by-flat and bad-by-NaN channels removed. sample_rate : float The sample rate (in Hz) of the EEG data. complete_chn_labs : array_like Labels for all channels in `data`, in the same order as they appear in `data`. chn_pos : np.ndarray 3-D electrode coordinates for all channels in `data`, in the same order as they appear in `data`. exclude : list Labels of channels to exclude as signal predictors during RANSAC (i.e., channels already flagged as bad by metrics other than HF noise). n_samples : int, optional Number of random channel samples to use for RANSAC. Defaults to ``50``. sample_prop : float, optional Proportion of total channels to use for signal prediction per RANSAC sample. This needs to be in the range [0, 1], where 0 would mean no channels would be used and 1 would mean all channels would be used (neither of which would be useful values). Defaults to ``0.25`` (e.g., 16 channels per sample for a 64-channel dataset). corr_thresh : float, optional The minimum predicted vs. actual signal correlation for a channel to be considered good within a given RANSAC window. Defaults to ``0.75``. frac_bad : float, optional The minimum fraction of bad (i.e., below-threshold) RANSAC windows for a channel to be considered bad-by-RANSAC. Defaults to ``0.4``. corr_window_secs : float, optional The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. channel_wise : bool, optional Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all channels at once but over a number of smaller time windows instead of over the entire signal length ("window-wise RANSAC"). Channel-wise RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Defaults to ``False``. max_chunk_size : {int, None}, optional The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down other programs on the host system. If using window-wise RANSAC (the default), this parameter has no effect. Defaults to ``None``. random_state : {int, None, np.random.RandomState}, optional The random seed with which to generate random samples of channels during RANSAC. If random_state is an int, it will be used as a seed for RandomState. If ``None``, the seed will be obtained from the operating system (see RandomState for details). Defaults to ``None``. matlab_strict : bool, optional Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code (see :ref:`matlab-diffs` for more details). Defaults to ``False``. Returns ------- bad_by_ransac : list List containing the labels of all channels flagged as bad by RANSAC. channel_correlations : np.ndarray Array of shape (windows, channels) containing the correlations of the channels with their predicted RANSAC values for each window. References ---------- .. [1] Fischler, M.A., Bolles, R.C. (1981). Random sample consensus: A Paradigm for Model Fitting with Applications to Image Analysis and Automated Cartography. Communications of the ACM, 24, 381-395 .. [2] Jas, M., Engemann, D.A., Bekhti, Y., Raimondo, F., Gramfort, A. (2017). Autoreject: Automated Artifact Rejection for MEG and EEG Data. NeuroImage, 159, 417-429 """ # Check we find channel positions if np.isnan(chn_pos).any(): raise ValueError( "Found NaN in channel positions. Did you supply a montage for the raw data?" ) # First, check that the argument types are valid if not isinstance(n_samples, int): err = "Argument 'n_samples' must be an int (got {0})" raise TypeError(err.format(type(n_samples).__name__)) complete_chn_labs = np.asarray(complete_chn_labs) # Get all channel positions and the position subset of "clean channels" # Exclude should be the bad channels from other methods # That is, identify all bad channels by other means good_idx = mne.pick_channels(list(complete_chn_labs), include=[], exclude=exclude) n_chans_good = good_idx.shape[0] chn_pos_good = chn_pos[good_idx, :] # Check if we have enough remaining channels # after exclusion of bad channels n_chans = data.shape[0] n_pred_chns = int(np.around(sample_prop * n_chans)) if n_pred_chns <= 3: sample_pct = int(sample_prop * 100) e = "Too few channels in the original data to reliably perform RANSAC " e += "(minimum {0} for a sample size of {1}%)." raise IOError(e.format(int(np.floor(4.0 / sample_prop)), sample_pct)) elif n_chans_good < (n_pred_chns + 1): e = "Too many noisy channels in the data to reliably perform RANSAC " e += "(only {0} good channels remaining, need at least {1})." raise IOError(e.format(n_chans_good, n_pred_chns + 1)) # Before running, make sure we have enough memory when using the # smallest possible chunk size if channel_wise: _verify_free_ram(data, n_samples, 1) else: window_size = int(sample_rate * corr_window_secs) _verify_free_ram(data[:, :window_size], n_samples, n_chans_good) # Generate random channel picks for each RANSAC sample random_ch_picks = [] good_chans = np.arange(chn_pos_good.shape[0]) rng = check_random_state(random_state) for i in range(n_samples): # Pick a random subset of clean channels to use for interpolation picks = _get_random_subset(good_chans, n_pred_chns, rng) random_ch_picks.append(picks) # Generate interpolation matrix for each RANSAC sample interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good) # Calculate the size (in frames) and count of correlation windows correlation_frames = corr_window_secs * sample_rate signal_frames = data.shape[1] correlation_offsets = np.arange( 0, (signal_frames - correlation_frames), correlation_frames ) win_size = int(correlation_frames) win_count = correlation_offsets.shape[0] # Preallocate RANSAC correlation matrix n_chans_complete = len(complete_chn_labs) channel_correlations = np.ones((win_count, n_chans_complete)) # Notice self.EEGData.shape[0] = self.n_chans_new # Is now data.shape[0] = n_chans_complete # They came from the same drop of channels logger.info("Executing RANSAC\nThis may take a while, so be patient...") # If enabled, run window-wise RANSAC if not channel_wise: # Get correlations between actual vs predicted signals for each RANSAC window channel_correlations[:, good_idx] = _ransac_by_window( data[good_idx, :], interp_mats, win_size, win_count, matlab_strict ) # Calculate smallest chunk size for each possible chunk count chunk_sizes = [] chunk_count = 0 for i in range(1, n_chans_good + 1): n_chunks = int(np.ceil(n_chans_good / i)) if n_chunks != chunk_count: chunk_count = n_chunks if not max_chunk_size or i <= max_chunk_size: chunk_sizes.append(i) chunk_size = chunk_sizes.pop() mem_error = True job = list(range(n_chans_good)) # If not using window-wise RANSAC, do channel-wise RANSAC while mem_error and channel_wise: try: channel_chunks = _split_list(job, chunk_size) total_chunks = len(channel_chunks) current = 1 for chunk in channel_chunks: interp_mats_for_chunk = [mat[chunk, :] for mat in interp_mats] channel_correlations[:, good_idx[chunk]] = _ransac_by_channel( data[good_idx, :], interp_mats_for_chunk, win_size, win_count, chunk, random_ch_picks, matlab_strict, ) if chunk == channel_chunks[0]: # If it gets here, it means it is the optimal logger.info("Finding optimal chunk size : %s", chunk_size) logger.info("Total # of chunks: %s", total_chunks) logger.info("Current chunk:") logger.info(current) current = current + 1 mem_error = False # All chunks processed, hurray! del current except MemoryError: if len(chunk_sizes): chunk_size = chunk_sizes.pop() else: # pragma: no cover raise MemoryError( "Not even doing 1 channel at a time the data fits in ram..." "You could downsample the data or reduce the number of requ" "ested samples." ) # Calculate fractions of bad RANSAC windows for each channel thresholded_correlations = channel_correlations < corr_thresh frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0) # find the corresponding channel names and return bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > frac_bad) bad_ransac_channels_name = complete_chn_labs[bad_ransac_channels_idx.astype(int)] bad_by_ransac = [i[0] for i in bad_ransac_channels_name] logger.info("\nRANSAC done!") return bad_by_ransac, channel_correlations
def _make_interpolation_matrices(random_ch_picks, chn_pos_good): """Create an interpolation matrix for each RANSAC sample of channels. This function takes the spatial coordinates of random subsets of currently-good channels and uses them to predict what the signal will be at the spatial coordinates of all other currently-good channels. The results of this process are returned as matrices that can be multiplied with EEG data to generate predicted signals. Parameters ---------- random_ch_picks : list of list of int A list containing multiple random subsets of currently-good channels. chn_pos_good : np.ndarray 3-D spatial coordinates of all currently-good channels. Returns ------- interpolation_mats : list of np.ndarray A list of interpolation matrices, one for each random subset of channels. Each matrix has the shape `[num_good_channels, num_good_channels]`, with the number of good channels being inferred from the size of `ch_pos_good`. Notes ----- This function currently makes use of a private MNE function, ``mne.channels.interpolation._make_interpolation_matrix``, to generate matrices. """ n_chans_good = chn_pos_good.shape[0] interpolation_mats = [] for sample in random_ch_picks: mat = np.zeros((n_chans_good, n_chans_good)) subset_pos = chn_pos_good[sample, :] mat[:, sample] = _make_interpolation_matrix(subset_pos, chn_pos_good) interpolation_mats.append(mat) return interpolation_mats def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_strict): """Calculate correlations of channels with their RANSAC-predicted values. This function calculates RANSAC correlations for each RANSAC window individually, requiring RAM equivalent to [channels * sample rate * seconds per RANSAC window] to run. Generally, this method will use less RAM than :func:`_ransac_by_channel`, with the exception of short recordings with high electrode counts. Parameters ---------- data : np.ndarray A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A list of interpolation matrices, one for each RANSAC sample of channels. win_size : int Number of frames/samples of EEG data in each RANSAC correlation window. win_count: int Number of RANSAC correlation windows. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- correlations : np.ndarray Correlations of the given channels to their predicted values within each RANSAC window. """ ch_count = data.shape[0] correlations = np.ones((win_count, ch_count)) pb = ProgressBar(range(win_count)) for window in pb: # Get the current window of EEG data start = window * win_size end = (window + 1) * win_size actual = data[:, start:end] # Get the median RANSAC-predicted signal for each channel predicted = _predict_median_signals(actual, interpolation_mats, matlab_strict) # Calculate the actual vs predicted signal correlation for each channel correlations[window, :] = _correlate_arrays(actual, predicted, matlab_strict) return correlations def _predict_median_signals(window, interpolation_mats, matlab_strict=False): """Calculate the median RANSAC-predicted signal for a given window of data. Parameters ---------- window : np.ndarray A 2-D window of EEG data with the shape `[channels, samples]`. interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- predicted : np.ndarray The median RANSAC-predicted EEG signal for the given window of data. Notes ----- In MATLAB PREP, the median signal is calculated by sorting the different predictions for each EEG sample/channel from low to high and then taking the value at the middle index (as calculated by ``int(n_ransac_samples / 2.0)``) for each. Because this logic only returns the correct result for odd numbers of samples, the current function will instead return the true median signal across predictions unless strict MATLAB equivalence is requested. """ ransac_samples = len(interpolation_mats) merged_mats = np.concatenate(interpolation_mats, axis=0) predictions_per_sample = np.reshape( np.matmul(merged_mats, window), (ransac_samples, window.shape[0], window.shape[1]), ) if matlab_strict: # Match MATLAB's rounding logic (.5 always rounded up) median_idx = int(_mat_round(ransac_samples / 2.0) - 1) predictions_per_sample.sort(axis=0) return predictions_per_sample[median_idx, :, :] else: return np.median(predictions_per_sample, axis=0) def _ransac_by_channel( data, interpolation_mats, win_size, win_count, chans_to_predict, random_ch_picks, matlab_strict, ): """Calculate correlations of channels with their RANSAC-predicted values. This function calculates RANSAC correlations on one (or more) full channels at once, requiring RAM equivalent to [channels per chunk * sample rate * length of recording in seconds] to run. Generally, this method will use more RAM than :func:`_ransac_by_window`, but may be faster for systems with large amounts of RAM. Parameters ---------- data : np.ndarray A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. win_size : int Number of frames/samples of EEG data in each RANSAC correlation window. win_count: int Number of RANSAC correlation windows. chans_to_predict : list of int Indices of the channels to predict (as they appear in `data`) within the current chunk. random_ch_picks : list of list of int A list containing multiple random subsets of currently-good channels. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- correlations : np.ndarray Correlations of the given channels to their predicted values within each RANSAC window. """ # Preallocate RANSAC correlation matrix for current chunk chunk_size = len(chans_to_predict) correlations = np.ones((win_count, chunk_size)) # Get median RANSAC predictions for each channel in the current chunk predicted_chans = _predict_median_signals_channelwise( data=data, interpolation_mats=interpolation_mats, random_ch_picks=random_ch_picks, chunk_size=len(chans_to_predict), matlab_strict=matlab_strict, ) # Correlate ransac prediction and eeg data # For the actual data data_window = data[chans_to_predict, : win_size * win_count] data_window = data_window.reshape(chunk_size, win_count, win_size) data_window = data_window.swapaxes(1, 0) # For the ransac predicted eeg pred_window = predicted_chans[:chunk_size, : win_size * win_count] pred_window = pred_window.reshape(chunk_size, win_count, win_size) pred_window = pred_window.swapaxes(1, 0) # Perform correlations for k in range(win_count): data_portion = data_window[k, :, :] pred_portion = pred_window[k, :, :] R = _correlate_arrays(data_portion, pred_portion, matlab_strict) correlations[k, :] = R return correlations def _predict_median_signals_channelwise( data, interpolation_mats, random_ch_picks, chunk_size, matlab_strict, ): """Calculate the median RANSAC-predicted signal for a given chunk of channels. Parameters ---------- data : np.ndarray A 2-D array containing the EEG signals from all currently-good channels. interpolation_mats : list of np.ndarray A set of channel interpolation matrices, one for each RANSAC sample of channels. random_ch_picks : list of list of int A list containing multiple random subsets of currently-good channels. chunk_size : int The number of channels to predict in the current chunk. matlab_strict : bool Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Returns ------- predicted_chans : np.ndarray The median RANSAC-predicted EEG signals for the given chunk of channels. """ # n_chns, n_timepts = data.shape # 2 next lines should be equivalent but support single channel processing ransac_samples = len(interpolation_mats) n_timepts = data.shape[1] # Before running, make sure we have enough memory _verify_free_ram(data, ransac_samples, chunk_size) # Memory seems to be fine ... # Make the predictions eeg_predictions = np.zeros((chunk_size, n_timepts, ransac_samples)) for sample in range(ransac_samples): # Get the random channels & interpolation matrix for the current sample reconstr_idx = random_ch_picks[sample] interp_mat = interpolation_mats[sample][:, reconstr_idx] # Predict the EEG signals for the current RANSAC sample / channel chunk eeg_predictions[..., sample] = np.matmul(interp_mat, data[reconstr_idx, :]) # Form median from all predictions if matlab_strict: # Match MATLAB's rounding logic (.5 always rounded up) median_idx = int(_mat_round(ransac_samples / 2.0) - 1) eeg_predictions.sort(axis=-1) return eeg_predictions[:, :, median_idx] else: return np.median(eeg_predictions, axis=-1, overwrite_input=True)