From 7d184825d863b409c2f6dfe5fe5bb6d38a4c6656 Mon Sep 17 00:00:00 2001 From: rachelstephlee Date: Mon, 30 Mar 2026 22:19:34 +0000 Subject: [PATCH] fixed bug that accidentally filtered out ignore trials --- .../metrics/trial_metrics.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py index 5e04497..e4cd53b 100644 --- a/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py +++ b/src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py @@ -1,7 +1,7 @@ """ - Tools for computing trial by trial metrics - df_trials = compute_trial_metrics(nwb) - df_trials = compute_bias(nwb) +Tools for computing trial by trial metrics +df_trials = compute_trial_metrics(nwb) +df_trials = compute_bias(nwb) """ @@ -239,13 +239,7 @@ def add_intertrial_licking(df_trials, df_licks): def get_average_signal_window_multi( - nwbs, - alignment_event, - offsets, - channel, - data_column='data_z', - censor=True, - output_col=None + nwbs, alignment_event, offsets, channel, data_column="data_z", censor=True, output_col=None ): """ Wrapper for get_average_signal_window to process a @@ -280,7 +274,7 @@ def get_average_signal_window_multi( channel=channel, data_column=data_column, censor=censor, - output_col=output_col + output_col=output_col, ) nwb.df_trials = df_trials return nwbs @@ -291,7 +285,7 @@ def get_average_signal_window( alignment_event, offsets, channel, - data_column='data_z', + data_column="data_z", censor=True, output_col=None, ): @@ -331,7 +325,7 @@ def get_average_signal_window( """ # Check alignment_event ends with 'in_session' - if not alignment_event.endswith('in_session'): + if not alignment_event.endswith("in_session"): raise ValueError(f"alignment_event '{alignment_event}' must end with 'in_session'.") if not hasattr(nwb, "df_trials"): @@ -380,10 +374,10 @@ def get_average_signal_window( ) avg_activity = etr.groupby("event_number").mean() - avg_activity['trial'] = df_trials.trial.values + avg_activity["trial"] = df_trials.trial.values avg_activity = avg_activity.rename(columns={data_column: output_col}) # Merge on 'trial' - df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left') + df_trials = nwb.df_trials.merge(avg_activity[["trial", output_col]], on="trial", how="left") return df_trials