Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,6 @@ def get_average_signal_window(
avg_activity = avg_activity.rename(columns={data_column: output_col})

# Merge on 'trial'
df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')
df_trials = nwb.df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this, because above on 364/365 we do some filtering on nwb.df_trials

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the filtering is so that if we are timelocking to an event that doesn't occur in all trials, it will not cause an error.

however, if we timelock to choice, this would mean all ignore trials are automatically dropped.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we might be timelocking to other events. This line would then make all the previous filtering get ignored

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous filtering is just for getting the averages. by merging onto the original nwb, we make sure we don't discard any data, just that the average activity will be set to nan.


return df_trials
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot_foraging_session_nwb(nwb, **kwargs):

if "side_bias" not in nwb.df_trials:
fig, axes = plot_foraging_session(
[np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this. The column name that is in the NWB is "animal_response". I agree choice would be a better name, but thats not how its named in the file.

Are you sure you aren't renaming animal_response somewhere?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me double check

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found the rename. it was when i added the foraging model. i'll fix that.

[np.nan if x == 2 else x for x in nwb.df_trials["choice"].values],
nwb.df_trials["earned_reward"].values,
[nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]],
**kwargs,
Expand All @@ -44,7 +44,7 @@ def plot_foraging_session_nwb(nwb, **kwargs):
if "plot_list" not in kwargs:
kwargs["plot_list"] = ["choice", "reward_prob", "bias"]
fig, axes = plot_foraging_session(
[np.nan if x == 2 else x for x in nwb.df_trials["animal_response"].values],
[np.nan if x == 2 else x for x in nwb.df_trials["choice"].values],
nwb.df_trials["earned_reward"].values,
[nwb.df_trials["reward_probabilityL"], nwb.df_trials["reward_probabilityR"]],
bias=nwb.df_trials["side_bias"].values,
Expand All @@ -66,12 +66,13 @@ def plot_foraging_session_nwb(nwb, **kwargs):
0,
1.05,
f"{nwb.session_id}\n"
f'Total trials {len(nwb.df_trials)}, ignored {np.sum(nwb.df_trials["animal_response"]==2)},'
f' left {np.sum(nwb.df_trials["animal_response"] == 0)},'
f' right {np.sum(nwb.df_trials["animal_response"] == 1)}',
f'Total trials {len(nwb.df_trials)}, ignored {np.sum(nwb.df_trials["choice"]==2)},'
f' left {np.sum(nwb.df_trials["choice"] == 0)},'
f' right {np.sum(nwb.df_trials["choice"] == 1)}',
fontsize=8,
transform=axes[0].transAxes,
)
return fig, axes


def plot_foraging_session( # noqa: C901
Expand Down
Loading