diff --git a/src/cedalion/sigproc/epochs.py b/src/cedalion/sigproc/epochs.py index c4bed40a..d55475b7 100644 --- a/src/cedalion/sigproc/epochs.py +++ b/src/cedalion/sigproc/epochs.py @@ -23,6 +23,7 @@ def to_epochs( trial_types: list[str], before: cdt.QTime, after: cdt.QTime, + exclude_trial_types: list[str] | None = None, ): """Extract epochs from the time series based on stimulus events. @@ -32,6 +33,7 @@ def to_epochs( trial_types: List of trial types to include in the epochs. before: Time before stimulus event to include in epoch. after: Time after stimulus event to include in epoch. + exclude_trial_types: Exclude epochs containing any of the events in this list. Returns: xarray.DataArray: Array containing the extracted epochs. @@ -48,6 +50,14 @@ def to_epochs( if trial_type not in available_trial_types: raise ValueError(f"df_stim does not contain trial_type '{trial_type}'") + before = before.to("s").magnitude.item() + after = after.to("s").magnitude.item() + fs = sampling_rate(ts).to("Hz") + + # exclude events if necessary + if exclude_trial_types: + df_stim = exclude_events(df_stim, exclude_trial_types, before, after) + # reduce df_stim to only the selected trial types df_stim = df_stim[df_stim.trial_type.isin(trial_types)] @@ -58,10 +68,6 @@ def to_epochs( # assume time coords are already in seconds time = ts.time.values - before = before.to("s").magnitude.item() - after = after.to("s").magnitude.item() - fs = sampling_rate(ts).to("Hz") - # the time stamps of the sampled time series and the events can have different # precision. Be explicit about how timestamps are assigned to samples in ts. # For samples i-1, i , i+1 in ts with timestamps t[i-1], t[i], t[i+1] we say @@ -152,3 +158,39 @@ def to_epochs( epochs = epochs.pint.quantify(units) return epochs + + +def exclude_events( + df_stim: pd.DataFrame, exclude: list[str], before: float, after: float +) -> pd.DataFrame: + """Exclude marked events or events that contain marked events within their epoch. + + An event is excluded if: + 1. It's 'trial_type' is in the `exclude` list. + 2. Contains an event inside its time window that is marked for exclusion. + + Args: + df_stim: DataFrame containing stimulus events. + exclude: List of trial type labels to mark for exclusion. + before: Time duration before the stimulus onset to include in the window. + after: Time duration after the stimulus onset to include in the window. + + Returns: + Updated dataframe with only included events. + """ + exc_idx = [] + for idx, onset, *_, trial_type in df_stim.itertuples(): + # if event is marked for exclusion, add to list and go to next iteration + if trial_type in exclude: + exc_idx.append(idx) + continue + + # get events whose onset is included in the event's time span + times = onset - before, onset + after + next_events = df_stim[df_stim.onset.between(*times)] + + # if any of next_events is marked for exclusion, mark this even for exclusion + if any(ne in exclude for ne in next_events.trial_type): + exc_idx.append(idx) + + return df_stim[~df_stim.index.isin(exc_idx)] diff --git a/tests/test_sigproc_epochs.py b/tests/test_sigproc_epochs.py index fe49de01..340d6ed0 100644 --- a/tests/test_sigproc_epochs.py +++ b/tests/test_sigproc_epochs.py @@ -6,7 +6,7 @@ import cedalion.dataclasses as cdc import cedalion.datasets -from cedalion.sigproc.epochs import to_epochs +from cedalion.sigproc.epochs import to_epochs, exclude_events from cedalion import units @@ -341,3 +341,99 @@ def test_to_epochs_dimension_independence(): ts_vertex = ts_vertex_chromo[:, :, 0] to_epochs(ts_vertex, **kwargs) + + +def test_exclude_events(): + df_stim = pd.DataFrame( + { + "onset": [0.5, 1.3, 2.6, 4.1, 4.5, 4.7], + "duration": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "value": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "trial_type": ["A", "E1", "A", "E2", "B", "E1"], + } + ) + + df_new = exclude_events(df_stim, ["E1"], 0.3, 1.0) + assert df_new.shape[0] == 1 + assert all(df_new.trial_type == "A") + + df_new = exclude_events(df_stim, ["E2"], 0.3, 1.0) + assert df_new.shape[0] == 5 + assert all(df_new.trial_type == ["A", "E1", "A", "B", "E1"]) + + df_new = exclude_events(df_stim, ["E1", "E2"], 0.3, 1.0) + assert df_new.shape[0] == 1 + assert all(df_new.trial_type == ["A"]) + + +def test_to_epochs_exclusion(timeseries): + """Trial types marked for exclusion are excluded in stimulus dataframe.""" + + df_stim = pd.DataFrame( + { + "onset": [0.5, 1.3, 2.6, 4.1, 4.5, 4.7], + "duration": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "value": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "trial_type": ["A", "E1", "A", "E2", "B", "E1"], + } + ) + + # no exclusion + epochs = to_epochs( + timeseries, df_stim, ["A", "B"], before=0.3 * units.s, after=1 * units.s + ) + + assert epochs.sizes["epoch"] == 3 + assert all(epochs.trial_type == ["A", "A", "B"]) + + # empty exclusion + epochs = to_epochs( + timeseries, + df_stim, + ["A", "B"], + before=0.3 * units.s, + after=1 * units.s, + exclude_trial_types=[], + ) + + assert epochs.sizes["epoch"] == 3 + assert all(epochs.trial_type == ["A", "A", "B"]) + + # exclude E1 + epochs = to_epochs( + timeseries, + df_stim, + ["A", "B"], + before=0.3 * units.s, + after=1 * units.s, + exclude_trial_types=["E1"], + ) + + assert epochs.sizes["epoch"] == 1 + assert all(epochs.trial_type == ["A"]) + + # exclude E2 + epochs = to_epochs( + timeseries, + df_stim, + ["A", "B"], + before=0.3 * units.s, + after=1 * units.s, + exclude_trial_types=["E2"], + ) + + assert epochs.sizes["epoch"] == 3 + assert all(epochs.trial_type == ["A", "A", "B"]) + + # exclude E1 and E2 + epochs = to_epochs( + timeseries, + df_stim, + ["A", "B"], + before=0.3 * units.s, + after=1 * units.s, + exclude_trial_types=["E1", "E2"], + ) + + assert epochs.sizes["epoch"] == 1 + assert all(epochs.trial_type == ["A"])