From ea3c446b0c14b5772f91e78f4de8a0454c0282e6 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 21:41:59 +0000 Subject: [PATCH 01/12] feat(plot): add trial-based and complete time-based plotly session plots Two plotly counterparts of the matplotlib plotting functions, each built to match its sibling as closely as plotly allows: - plot_foraging_session_plotly (new): trial-based, mirrors plot_foraging_session (choice/reward raster over the per-trial reward-probability schedule). - plot_session_in_time_plotly (rewritten): time-based, mirrors plot_session_scroller. Adopts the scroller's y-layout, adds the reward-probability band (from df_trials), groups the four behavior rows contiguously so a single go-cue line spans them, and starts zoomed to a 120 s window at the first go cue with a rangeslider for scrubbing. Filled traces use go.Scatter (Scattergl ignores fill="tonexty"). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 604 +++++++++++++----- 1 file changed, 444 insertions(+), 160 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index e5efc99..506e3e5 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -1,199 +1,483 @@ """Interactive plotly figures that can be used in the Streamlit app as well as -Jupyter Notebook +Jupyter Notebook. + +Two plotly counterparts of the matplotlib plotting functions are provided, each meant to +match its matplotlib sibling as closely as plotly allows: + +- :func:`plot_foraging_session_plotly` mirrors + :func:`plot_foraging_session.plot_foraging_session` -- a **trial-based** view (choice / + reward raster on top, reward-probability schedule below). +- :func:`plot_session_in_time_plotly` mirrors + :func:`plot_session_scroller.plot_session_scroller` -- a **time-based** view of the + session (licks / rewards / go cues / reward-probability band in real time). """ +import numpy as np import plotly.graph_objects as go +from plotly.subplots import make_subplots +from aind_dynamic_foraging_basic_analysis.data_model.foraging_session import ( + ForagingSessionData, + PhotostimData, +) +from aind_dynamic_foraging_basic_analysis.plot.plot_foraging_session import moving_average +from aind_dynamic_foraging_basic_analysis.plot.style import PHOTOSTIM_EPOCH_MAPPING -def plot_session_in_time_plotly( # noqa: C901 pragma: no cover - df_events, adjust_time=True, fip_df=None +# Map the matplotlib single-letter colors used by the matplotlib versions to plotly names, +# so the two renderings line up. Anything not listed is passed through unchanged. +_MPL_COLORS = { + "y": "gold", + "m": "magenta", + "g": "green", + "b": "blue", + "r": "red", + "k": "black", +} + + +def _color(c): + """Translate a matplotlib single-letter color to a plotly-friendly name.""" + return _MPL_COLORS.get(c, c) + + +def _vlines(segments): + """Flatten ``[(x_array, y0, y1), ...]`` into x / y arrays of ``None``-separated segments. + + The standard plotly trick for drawing many vertical line ticks in a single trace: insert + ``None`` between each ``(x, y0)->(x, y1)`` pair so plotly lifts the pen between ticks. + """ + xs, ys = [], [] + for x_arr, y0, y1 in segments: + for xi in np.asarray(x_arr): + xs += [xi, xi, None] + ys += [y0, y1, None] + return xs, ys + + +def plot_foraging_session_plotly( # noqa: C901 + choice_history, + reward_history, + p_reward, + autowater_offered=None, + fitted_data=None, + photostim=None, + valid_range=None, + smooth_factor=5, + base_color="y", + bias=None, + bias_lower=None, + bias_upper=None, + plot_list=["choice", "finished", "reward_prob"], ): - """A plotly version of plot_foraging_session.plot_session_scroller + """Plotly version of :func:`plot_foraging_session.plot_foraging_session` (trial-based). + + Renders the same two stacked panels as the matplotlib version: + + - top: choice / reward raster (rewarded & unrewarded choices, ignored / autowater + trials, smoothed choice, finished ratio, base reward probability, optional bias). + - bottom: the per-trial left / right reward-probability schedule. + + Parameters mirror :func:`plot_foraging_session.plot_foraging_session` (minus the + matplotlib-only ``ax`` / ``vertical``): + + Parameters + ---------- + choice_history : list or np.ndarray + Choice history (0 = left, 1 = right, np.nan = ignored). + reward_history : list or np.ndarray + Reward history (0 = unrewarded, 1 = rewarded). + p_reward : list or np.ndarray + Reward probability for both sides, shape (2, len(choice_history)). + autowater_offered : list or np.ndarray, optional + Boolean mask of trials where autowater was offered. + fitted_data : list or np.ndarray, optional + If not None, overlay fitted data (e.g. from an RL model). + photostim : dict, optional + Photostimulation trials, with keys "trial", "power" and optional "stim_epoch". + valid_range : list, optional + If not None, add two vertical lines marking the engaged range. + smooth_factor : int, optional + Smoothing window for the choice / finished traces, by default 5. + base_color : str, optional + Color for the base reward-probability line, by default "y" (gold). + bias : list or np.ndarray, optional + Side-bias trace; drawn (with the ``bias_lower`` / ``bias_upper`` band) when + "bias" is in ``plot_list``. + bias_lower, bias_upper : list or np.ndarray, optional + Lower / upper confidence bounds for ``bias``. + plot_list : list, optional + Which optional traces to draw, from {"choice", "finished", "reward_prob", "bias"}. + + Returns + ------- + plotly.graph_objects.Figure + Figure with the choice/reward panel (row 1) over the reward-schedule panel (row 2). + """ + # Formatting and sanity checks (reuse the shared validation, like the matplotlib version) + data = ForagingSessionData( + choice_history=choice_history, + reward_history=reward_history, + p_reward=p_reward, + autowater_offered=autowater_offered, + fitted_data=fitted_data, + photostim=PhotostimData(**photostim) if photostim is not None else None, + ) + choice_history = data.choice_history + reward_history = data.reward_history + p_reward = data.p_reward + autowater_offered = data.autowater_offered + fitted_data = data.fitted_data + photostim = data.photostim + + n_trials = len(choice_history) + p_reward_fraction = p_reward[1, :] / (np.sum(p_reward, axis=0)) + ignored = np.isnan(choice_history) + + if autowater_offered is None: + rewarded_excluding_autowater = reward_history + autowater_collected = np.full_like(choice_history, False, dtype=bool) + autowater_ignored = np.full_like(choice_history, False, dtype=bool) + unrewarded_trials = ~reward_history & ~ignored + else: + rewarded_excluding_autowater = reward_history & ~autowater_offered + autowater_collected = autowater_offered & ~ignored + autowater_ignored = autowater_offered & ignored + unrewarded_trials = ~reward_history & ~ignored & ~autowater_offered + + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + row_heights=[0.83, 0.17], + vertical_spacing=0.02, + ) + + def _side_segments(mask, up, down): + """Split a boolean trial mask into right (>0.5) and left (<0.5) tick segments. + + ``up`` is ``(y0, y1)`` for right choices (drawn just above Right); ``down`` for + left choices (just below Left). Returns the list consumed by :func:`_vlines`. + """ + xx = np.nonzero(mask)[0] + 1 + side = choice_history[mask] + return [ + (xx[side > 0.5], up[0], up[1]), + (xx[side < 0.5], down[0], down[1]), + ] + + # == Choice trace == + # Rewarded (real foraging, autowater excluded): tall black ticks just outside [0, 1] + xs, ys = _vlines(_side_segments(rewarded_excluding_autowater, (1.05, 1.15), (-0.15, -0.05))) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=1), + name="Rewarded choices"), + row=1, col=1, + ) + + # Unrewarded (real foraging): short gray ticks + xs, ys = _vlines(_side_segments(unrewarded_trials, (1.05, 1.10), (-0.10, -0.05))) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="gray", width=1), + name="Unrewarded choices"), + row=1, col=1, + ) + + # Ignored trials: red x at the top + xx = np.nonzero(ignored & ~autowater_ignored)[0] + 1 + fig.add_trace( + go.Scattergl(x=xx, y=[1.2] * len(xx), mode="markers", + marker=dict(symbol="x", color="red", size=4), name="Ignored"), + row=1, col=1, + ) + + # Autowater collected / ignored + if autowater_offered is not None: + xs, ys = _vlines(_side_segments(autowater_collected, (1.05, 1.15), (-0.15, -0.05))) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="royalblue", width=1), + name="Autowater collected"), + row=1, col=1, + ) + xx = np.nonzero(autowater_ignored)[0] + 1 + fig.add_trace( + go.Scattergl(x=xx, y=[1.2] * len(xx), mode="markers", + marker=dict(symbol="x", color="royalblue", size=4), + name="Autowater ignored"), + row=1, col=1, + ) + + # Base reward probability + if "reward_prob" in plot_list: + fig.add_trace( + go.Scattergl(x=np.arange(n_trials) + 1, y=p_reward_fraction, mode="lines", + line=dict(color=_color(base_color), width=1.5), name="Base rew. prob."), + row=1, col=1, + ) + + # Smoothed choice history + if "choice" in plot_list: + y = moving_average(choice_history, smooth_factor) / ( + moving_average(~np.isnan(choice_history), smooth_factor) + 1e-6 + ) + y[y > 100] = np.nan + x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 + fig.add_trace( + go.Scattergl(x=x, y=y, mode="lines", line=dict(color="black", width=1.5), + name=f"Choice (smooth = {smooth_factor})"), + row=1, col=1, + ) + + # Finished ratio (only meaningful if there are ignored trials) + if "finished" in plot_list and np.sum(np.isnan(choice_history)): + y = moving_average(~np.isnan(choice_history), smooth_factor) + x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 + fig.add_trace( + go.Scattergl(x=x, y=y, mode="lines", line=dict(color="magenta", width=0.8), + name=f"Finished (smooth = {smooth_factor})"), + row=1, col=1, + ) - Creates a plot of the session in time (not in trial). - Plots left/right licks/rewards, and go cues as vertical lines from bottom to top. + # Bias trace + confidence band + if ("bias" in plot_list) and (bias is not None): + xx = np.arange(n_trials) + 1 + bias = (np.array(bias) + 1) / 2 + bias_lower = np.clip((np.array(bias_lower) + 1) / 2, 0, None) + bias_upper = np.clip((np.array(bias_upper) + 1) / 2, None, 1) + # go.Scatter (not Scattergl) for the filled band -- Scattergl ignores fill. + fig.add_trace( + go.Scatter(x=xx, y=bias_upper, mode="lines", line=dict(width=0), + showlegend=False, hoverinfo="skip"), + row=1, col=1, + ) + fig.add_trace( + go.Scatter(x=xx, y=bias_lower, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(0,128,0,0.25)", + showlegend=False, hoverinfo="skip"), + row=1, col=1, + ) + fig.add_trace( + go.Scattergl(x=xx, y=bias, mode="lines", line=dict(color="green", width=1.5), + name="bias"), + row=1, col=1, + ) - df_events: A tidy dataframe of session events. + # Valid (engaged) range + if valid_range is not None: + for vr in valid_range: + fig.add_vline(x=vr, line=dict(color="magenta", dash="dash", width=1), row=1, col=1) - fip_df is a tidy dataframe of FIP measurements generated by - aind_dynamic_foraging_data_utils.nwb_utils.create_df_fip(tidy=True) + # Fitted model overlay + if fitted_data is not None: + fig.add_trace( + go.Scattergl(x=np.arange(n_trials), y=fitted_data, mode="lines", + line=dict(width=1.5), name="model"), + row=1, col=1, + ) + + # Photostim markers + if photostim is not None: + trial = np.asarray(photostim.trial) + power = np.asarray(photostim.power, dtype=float) + if photostim.stim_epoch is not None: + colors = [PHOTOSTIM_EPOCH_MAPPING[t] for t in photostim.stim_epoch] + else: + colors = "darkcyan" + fig.add_trace( + go.Scattergl(x=trial, y=np.ones_like(trial, dtype=float) + 0.4, mode="markers", + marker=dict(symbol="triangle-down", size=power * 2, + color="rgba(0,0,0,0)", + line=dict(color=colors, width=0.5)), + name="photostim"), + row=1, col=1, + ) + + # == Reward schedule (bottom panel) == + xx = np.arange(n_trials) + 1 + fig.add_trace( + go.Scattergl(x=xx, y=p_reward[1, :], mode="lines", line=dict(color="blue", width=1), + name="p_right"), + row=2, col=1, + ) + fig.add_trace( + go.Scattergl(x=xx, y=p_reward[0, :], mode="lines", line=dict(color="red", width=1), + name="p_left"), + row=2, col=1, + ) + + # Axes styling to match the matplotlib version + fig.update_yaxes( + tickvals=[0, 1, 1.2], ticktext=["Left", "Right", "Ignored"], + range=[-0.15, 1.25], row=1, col=1, + ) + fig.update_yaxes(title_text="p_reward", range=[0, 1], row=2, col=1) + fig.update_xaxes(title_text="Trial number", row=2, col=1) + fig.update_layout( + width=1300, height=400, template="simple_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), + margin=dict(l=60, r=20, t=60, b=50), + ) + return fig + + +def plot_session_in_time_plotly( # noqa: C901 pragma: no cover + df_events, df_trials=None, fip_df=None, adjust_time=True, session_id=None +): + """Plotly version of :func:`plot_session_scroller.plot_session_scroller` (time-based). - adjust_time (bool): If True, resets time=0 to the first event of the session. + Plots the session in real time (not in trial): left / right licks and rewards as ticks, + go cues as vertical lines, and -- when ``df_trials`` is supplied -- the left / right + reward-probability band, laid out to match the matplotlib scroller. - EXAMPLE: - df_events = nwb_utils.create_df_events(nwb_object) - fip_df = nwb_utils.create_df_fip(nwb_object, tidy=True) - plot_foraging_session_plotly.plot_session_events_plotly(df_events) + Parameters + ---------- + df_events : pandas.DataFrame + Tidy dataframe of session events, e.g. from + ``aind_dynamic_foraging_data_utils.nwb_utils.create_df_events``. Needs ``event`` and + ``timestamps`` columns; recognised events are ``left_lick_time``, ``right_lick_time``, + ``left_reward_delivery_time``, ``right_reward_delivery_time`` and ``goCue_start_time``. + df_trials : pandas.DataFrame, optional + Per-trial dataframe used for the reward-probability band (and as a fallback source of + go-cue times). Needs ``goCue_start_time`` and ``reward_probabilityL/R``. The go-cue + times must share the same time base as ``df_events.timestamps``. + fip_df : pandas.DataFrame, optional + Tidy dataframe of FIP measurements (from ``create_df_fip(tidy=True)``); each present + channel is normalised and stacked above the behavior panel. + adjust_time : bool, optional + If True (default), shift time so the first event is at t = 0. + session_id : str, optional + Title for the figure. + + Returns + ------- + plotly.graph_objects.Figure """ + df_events = df_events.copy() + if df_trials is not None: + df_trials = df_trials.copy() + if fip_df is not None: + fip_df = fip_df.copy() if adjust_time: start_time = df_events.iloc[0]["timestamps"] - df_events = df_events.copy() df_events["timestamps"] = df_events["timestamps"] - start_time - + if df_trials is not None: + df_trials["goCue_start_time"] = df_trials["goCue_start_time"] - start_time if fip_df is not None: - fip_df = fip_df.copy() fip_df["timestamps"] = fip_df["timestamps"] - start_time xmin = df_events.iloc[0]["timestamps"] xmax = df_events.iloc[-1]["timestamps"] + x_first, x_last = xmin, xmax # full extent (used for the rangeslider / "home") + # y-layout. The four behavior rows are stacked contiguously (top -> bottom: right licks, + # right reward, left reward, left licks) so a single go-cue line spans all of them; the + # reward-probability band sits in its own block above, centered on probs_center. params = { - "left_lick": 0.125, - "right_lick": 0.875, - "left_reward": 0.375, - "right_reward": 0.625, - "go_cue_bottom": 0, - "go_cue_top": 1, - "G_1_preprocessed_bottom": 1, - "G_1_preprocessed_top": 2, - "G_2_preprocessed_bottom": 2, - "G_2_preprocessed_top": 3, - "R_1_preprocessed_bottom": 3, - "R_1_preprocessed_top": 4, - "R_2_preprocessed_bottom": 4, - "R_2_preprocessed_top": 5, + "left_lick_bottom": 0.0, "left_lick_top": 0.25, + "left_reward_bottom": 0.25, "left_reward_top": 0.5, + "right_reward_bottom": 0.5, "right_reward_top": 0.75, + "right_lick_bottom": 0.75, "right_lick_top": 1.0, + "behavior_bottom": 0.0, "behavior_top": 1.0, # go cue spans this contiguous block + "probs_center": 1.4, "probs_half": 0.25, # band: probs_center +/- probs_half } - yticks = [ - params["left_lick"], - params["right_lick"], - params["left_reward"], - params["right_reward"], + 0.875, 0.625, 0.375, 0.125, # right licks, right reward, left reward, left licks + params["probs_center"] - params["probs_half"], # pL = 1 + params["probs_center"], # 0 + params["probs_center"] + params["probs_half"], # pR = 1 ] - ylabels = ["left licks", "right licks", "left reward", "right reward"] - ycolors = ["k", "k", "r", "r"] + ylabels = ["right licks", "right reward", "left reward", "left licks", "pL = 1", "0", "pR = 1"] fig = go.Figure() - # Add FIP traces - if fip_df is not None: - fip_channels = [ - "G_2_preprocessed", - "G_1_preprocessed", - "R_2_preprocessed", - "R_1_preprocessed", - ] - present_channels = fip_df["event"].unique() - for index, channel in enumerate(fip_channels): - if channel in present_channels: - yticks.append( - (params[channel + "_top"] - params[channel + "_bottom"]) / 2 - + params[channel + "_bottom"] - ) - ylabels.append(channel) - if "G_1" in channel: - color = "green" - elif "G_2" in channel: - color = "darkgreen" - elif "R_1" in channel: - color = "red" - elif "R_2" in channel: - color = "darkred" - ycolors.append(color) - C = fip_df.query("event == @channel").copy() - C["data"] = C["data"] - C["data"].min() - C["data"] = C["data"].values / C["data"].max() - C["data"] += params[channel + "_bottom"] - - # Plot the data using go.Scattergl - fig.add_trace( - go.Scattergl( - x=C.timestamps.values, - y=C.data.values, - mode="lines", - line=dict(color=color), - name=channel, - ) - ) - - # Add a horizontal reference line (axhline equivalent) - fig.add_trace( - go.Scattergl( - x=[C.timestamps.min(), C.timestamps.max()], - y=[params[channel + "_bottom"], params[channel + "_bottom"]], - mode="lines", - line=dict(color="black", width=1, dash="solid"), - showlegend=False, - hoverinfo="skip", - ) - ) - - left_licks = df_events.query('event == "left_lick_time"') - left_times = left_licks.timestamps.values - fig.add_trace( - go.Scattergl( - x=left_times, - y=[params["left_lick"]] * len(left_times), - mode="markers", - marker=dict(symbol="line-ns", line_color="black", size=10, line_width=2), - name="Left Lick", - ) - ) + def _event_times(name): + return df_events.query("event == @name").timestamps.values - right_licks = df_events.query('event == "right_lick_time"') - right_times = right_licks.timestamps.values - fig.add_trace( - go.Scattergl( - x=right_times, - y=[params["right_lick"]] * len(right_times), - mode="markers", - marker=dict(symbol="line-ns", line_color="black", size=10, line_width=2), - name="Right Lick", - ) - ) + # Licks (gray, like the scroller when no bout coloring) + for name, lo, hi in [ + ("left_lick_time", params["left_lick_bottom"], params["left_lick_top"]), + ("right_lick_time", params["right_lick_bottom"], params["right_lick_top"]), + ]: + t = _event_times(name) + xs, ys = _vlines([(t, lo, hi)]) + fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="gray", width=1.5), + name=name.replace("_time", "").replace("_", " "))) - left_reward_deliverys = df_events.query('event == "left_reward_delivery_time"') - left_times = left_reward_deliverys.timestamps.values - fig.add_trace( - go.Scattergl( - x=left_times, - y=[params["left_reward"]] * len(left_times), - mode="markers", - marker=dict(symbol="line-ns", size=10, line_color="red", line_width=3), - name="Left Reward", - ) - ) + # Rewards (black) + for name, lo, hi in [ + ("left_reward_delivery_time", params["left_reward_bottom"], params["left_reward_top"]), + ("right_reward_delivery_time", params["right_reward_bottom"], params["right_reward_top"]), + ]: + t = _event_times(name) + xs, ys = _vlines([(t, lo, hi)]) + fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=2), + name=name.replace("_delivery_time", "").replace("_", " "))) - right_reward_deliverys = df_events.query('event == "right_reward_delivery_time"') - right_times = right_reward_deliverys.timestamps.values - fig.add_trace( - go.Scattergl( - x=right_times, - y=[params["right_reward"]] * len(right_times), - mode="markers", - marker=dict(symbol="line-ns", size=10, line_color="red", line_width=3), - name="Right Reward", - ) - ) + # Go cues: prefer events, fall back to df_trials + go_cue_times = _event_times("goCue_start_time") + if len(go_cue_times) == 0 and df_trials is not None and "goCue_start_time" in df_trials: + go_cue_times = df_trials["goCue_start_time"].dropna().values + if len(go_cue_times): + # A single line spanning the contiguous behavior block + xs, ys = _vlines([(go_cue_times, params["behavior_bottom"], params["behavior_top"])]) + fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", + line=dict(color="blue", width=0.75), opacity=0.75, + name="go cue")) - go_cues = df_events.query('event == "goCue_start_time"') - go_cue_times = go_cues.timestamps.values - for n, time in enumerate(go_cue_times): - fig.add_trace( - go.Scattergl( - x=[time, time], - y=[params["go_cue_bottom"], params["go_cue_top"]], - mode="lines", - line=dict(color="blue", width=0.3), - legendgroup="Go Cue group", - showlegend=(n == 0), - name="Go Cue", - hovertemplate=f"Go Cue, Trial {n+1}", - ) - ) + # Reward-probability band (needs df_trials and go-cue times) + if df_trials is not None and len(go_cue_times) == len(df_trials): + x_doubled = np.repeat(go_cue_times, 2)[1:] + center = params["probs_center"] + pR = np.repeat(center + df_trials["reward_probabilityR"].values / 4, 2)[:-1] + pL = np.repeat(center - df_trials["reward_probabilityL"].values / 4, 2)[:-1] + base = np.full_like(x_doubled, center, dtype=float) + # pR above center (red), pL below center (blue); fill toward the center baseline. + # go.Scatter (not Scattergl) -- the WebGL trace ignores fill="tonexty". + fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), + showlegend=False, hoverinfo="skip")) + fig.add_trace(go.Scatter(x=x_doubled, y=pR, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(255,0,0,0.4)", + name="pR")) + fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), + showlegend=False, hoverinfo="skip")) + fig.add_trace(go.Scatter(x=x_doubled, y=pL, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(0,0,255,0.4)", + name="pL")) + + y_top = params["probs_center"] + params["probs_half"] # top of the plotted content + # FIP channels, normalised and stacked above the behavior panel + if fip_df is not None: + fip_channels = ["G_1_preprocessed", "G_2_preprocessed", + "R_1_preprocessed", "R_2_preprocessed"] + fip_colors = {"G_1": "green", "G_2": "darkgreen", "R_1": "red", "R_2": "darkred"} + present = set(fip_df["event"].unique()) + band = 0 + for channel in fip_channels: + if channel not in present: + continue + bottom = params["probs_center"] + params["probs_half"] + 0.1 + band + C = fip_df.query("event == @channel").copy() + d = C["data"].values - np.nanmin(C["data"].values) + d = d / np.nanmax(d) + bottom + color = fip_colors["_".join(channel.split("_")[:2])] + fig.add_trace(go.Scattergl(x=C.timestamps.values, y=d, mode="lines", + line=dict(color=color), name=channel)) + yticks.append(bottom + 0.5) + ylabels.append(channel) + band += 1 + y_top = bottom + 1.0 + + # Start zoomed to a readable ~120 s window at the first go cue (like the matplotlib + # scroller's default window), with a rangeslider so the whole session can be scrubbed -- + # the plotly analog of the scroller's arrow-key panning. + t0 = go_cue_times.min() if len(go_cue_times) else x_first fig.update_layout( - title="Session Scroller", + title=session_id or "Session Scroller", xaxis_title="Time (s)", - yaxis=dict( - tickvals=yticks, - ticktext=ylabels, - ), - xaxis=dict(range=[xmin, xmax]), - showlegend=True, - height=800, - width=1300, + yaxis=dict(tickvals=yticks, ticktext=ylabels, + range=[params["behavior_bottom"] - 0.05, y_top + 0.1]), + xaxis=dict(range=[t0, t0 + 120], rangeslider=dict(visible=True, range=[x_first, x_last])), + showlegend=True, height=600, width=1300, template="simple_white", ) - return fig From 147e835de5c56459d2f7a0fdd6ac11a382facf90 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 21:42:07 +0000 Subject: [PATCH 02/12] feat(plot): export plotly session-plot functions from the package root Surface plot_foraging_session_plotly and plot_session_in_time_plotly alongside plot_foraging_session for top-level import. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/aind_dynamic_foraging_basic_analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/aind_dynamic_foraging_basic_analysis/__init__.py b/src/aind_dynamic_foraging_basic_analysis/__init__.py index d420ce8..2842d4d 100644 --- a/src/aind_dynamic_foraging_basic_analysis/__init__.py +++ b/src/aind_dynamic_foraging_basic_analysis/__init__.py @@ -4,3 +4,7 @@ from .metrics.foraging_efficiency import compute_foraging_efficiency # noqa: F401 from .plot.plot_foraging_session import plot_foraging_session # noqa: F401 +from .plot.plot_foraging_session_plotly import ( # noqa: F401 + plot_foraging_session_plotly, + plot_session_in_time_plotly, +) From 7145ac7b35aadf2f9b6ff40727b2fe5b8830349e Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 21:42:07 +0000 Subject: [PATCH 03/12] test(plot): add tests for the plotly session plots Cover the trial-based plot (real session history, plus the optional bias band / photostim path) and the time-based plot (events-only, and the df_trials path that adds the reward-probability band). Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/test_plot_foraging_session_plotly.py | 103 +++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/test_plot_foraging_session_plotly.py diff --git a/tests/test_plot_foraging_session_plotly.py b/tests/test_plot_foraging_session_plotly.py new file mode 100644 index 0000000..40e3277 --- /dev/null +++ b/tests/test_plot_foraging_session_plotly.py @@ -0,0 +1,103 @@ +"""Test the plotly foraging-session plots. + +To run the test, execute "python -m unittest tests/test_plot_foraging_session_plotly.py". +""" + +import os +import unittest + +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from aind_dynamic_foraging_basic_analysis import ( + plot_foraging_session_plotly, + plot_session_in_time_plotly, +) +from tests.nwb_io import get_history_from_nwb + + +class TestPlotForagingSessionPlotly(unittest.TestCase): + """Test the trial-based plotly plot against a real session.""" + + @classmethod + def setUpClass(cls): + """Load example session history from the bundled NWB.""" + nwb_file = os.path.dirname(__file__) + "/data/697929_2024-02-22_08-38-30.nwb" + ( + _, + cls.choice_history, + cls.reward_history, + cls.p_reward, + cls.autowater_offered, + _, + ) = get_history_from_nwb(nwb_file) + + def test_returns_figure(self): + """A plotly Figure is returned with both panels populated.""" + fig = plot_foraging_session_plotly( + choice_history=self.choice_history, + reward_history=self.reward_history, + p_reward=self.p_reward, + autowater_offered=self.autowater_offered, + ) + self.assertIsInstance(fig, go.Figure) + self.assertGreater(len(fig.data), 0) + + def test_optional_traces(self): + """Bias band and photostim markers are accepted without error.""" + n = len(self.choice_history) + fig = plot_foraging_session_plotly( + choice_history=self.choice_history, + reward_history=self.reward_history, + p_reward=self.p_reward, + bias=np.zeros(n), + bias_lower=-np.ones(n) * 0.2, + bias_upper=np.ones(n) * 0.2, + photostim={"trial": [10, 20], "power": [3.0, 3.0]}, + plot_list=["choice", "finished", "reward_prob", "bias"], + ) + self.assertIsInstance(fig, go.Figure) + + +class TestPlotSessionInTimePlotly(unittest.TestCase): + """Test the time-based plotly plot with a synthetic events / trials frame.""" + + def setUp(self): + """Build a small tidy events frame and matching trials frame.""" + go_cues = np.arange(5, 25, 2.0) # 10 trials + events = [] + for t in go_cues: + events.append((t, "goCue_start_time")) + events.append((t + 0.3, "left_lick_time")) + events.append((t + 0.4, "right_lick_time")) + events.append((t + 0.5, "left_reward_delivery_time")) + self.df_events = pd.DataFrame(events, columns=["timestamps", "event"]).sort_values( + "timestamps" + ) + self.df_trials = pd.DataFrame( + { + "goCue_start_time": go_cues, + "reward_probabilityL": np.linspace(0.1, 0.8, len(go_cues)), + "reward_probabilityR": np.linspace(0.8, 0.1, len(go_cues)), + } + ) + + def test_events_only(self): + """Works with just an events frame (no probability band).""" + fig = plot_session_in_time_plotly(self.df_events) + self.assertIsInstance(fig, go.Figure) + self.assertGreater(len(fig.data), 0) + + def test_with_trials(self): + """Supplying df_trials adds the reward-probability band traces.""" + fig = plot_session_in_time_plotly( + self.df_events, df_trials=self.df_trials, session_id="unit_test" + ) + names = [tr.name for tr in fig.data] + self.assertIn("pR", names) + self.assertIn("pL", names) + + +if __name__ == "__main__": + unittest.main() From 1bdd72bb40496d9820008d0ed03050518406fc66 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 21:52:43 +0000 Subject: [PATCH 04/12] refactor(plot): refine time-based plotly layout and add per-event trial hover - Reward-probability band moved out of the main panel; it now lives only in the rangeslider "scroller" at the bottom, which auto-scales its own y so the band reads at a useful size. - Lick / reward ticks shortened to 30% of the row height (70% shorter), centered on each row, so the four contiguous behavior rows stay legible when busy. - Each go cue / reward / lick carries the trial it falls in (assigned via the go-cue windows) as customdata, surfaced on hover -- no on-plot text labels. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 118 +++++++++++------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 506e3e5..a1fdb0f 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -54,6 +54,21 @@ def _vlines(segments): return xs, ys +def _vline_hover(x_arr, y0, y1, hover): + """Vertical ticks at ``x_arr`` (each y0->y1) plus a parallel ``customdata`` array. + + Like :func:`_vlines` for a single group, but also threads a per-tick ``hover`` value + (repeated on both vertices, ``None`` on the gap) so each tick can surface e.g. its trial + number via a ``hovertemplate``. + """ + xs, ys, cd = [], [], [] + for xi, hi in zip(np.asarray(x_arr), hover): + xs += [xi, xi, None] + ys += [y0, y1, None] + cd += [hi, hi, None] + return xs, ys, cd + + def plot_foraging_session_plotly( # noqa: C901 choice_history, reward_history, @@ -370,61 +385,66 @@ def plot_session_in_time_plotly( # noqa: C901 pragma: no cover x_first, x_last = xmin, xmax # full extent (used for the rangeslider / "home") # y-layout. The four behavior rows are stacked contiguously (top -> bottom: right licks, - # right reward, left reward, left licks) so a single go-cue line spans all of them; the - # reward-probability band sits in its own block above, centered on probs_center. + # right reward, left reward, left licks) so a single go-cue line spans all of them. The + # reward-probability band lives well above the behavior block, out of the main view -- it + # only shows in the rangeslider "scroller" at the bottom (which auto-scales its own y). params = { - "left_lick_bottom": 0.0, "left_lick_top": 0.25, - "left_reward_bottom": 0.25, "left_reward_top": 0.5, - "right_reward_bottom": 0.5, "right_reward_top": 0.75, - "right_lick_bottom": 0.75, "right_lick_top": 1.0, "behavior_bottom": 0.0, "behavior_top": 1.0, # go cue spans this contiguous block - "probs_center": 1.4, "probs_half": 0.25, # band: probs_center +/- probs_half + "probs_center": 1.7, "probs_half": 0.25, # band sits high, above the main view } - yticks = [ - 0.875, 0.625, 0.375, 0.125, # right licks, right reward, left reward, left licks - params["probs_center"] - params["probs_half"], # pL = 1 - params["probs_center"], # 0 - params["probs_center"] + params["probs_half"], # pR = 1 - ] - ylabels = ["right licks", "right reward", "left reward", "left licks", "pL = 1", "0", "pR = 1"] + # Row centers (top -> bottom). Event ticks are short marks centered on each row -- 70% + # shorter than the 0.25 row spacing -- so the rows read as separate even when busy. + row_centers = {"right_lick": 0.875, "right_reward": 0.625, + "left_reward": 0.375, "left_lick": 0.125} + tick_half = 0.25 * 0.30 / 2.0 + yticks = [0.875, 0.625, 0.375, 0.125] # right licks, right reward, left reward, left licks + ylabels = ["right licks", "right reward", "left reward", "left licks"] fig = go.Figure() def _event_times(name): return df_events.query("event == @name").timestamps.values - # Licks (gray, like the scroller when no bout coloring) - for name, lo, hi in [ - ("left_lick_time", params["left_lick_bottom"], params["left_lick_top"]), - ("right_lick_time", params["right_lick_bottom"], params["right_lick_top"]), - ]: - t = _event_times(name) - xs, ys = _vlines([(t, lo, hi)]) - fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="gray", width=1.5), - name=name.replace("_time", "").replace("_", " "))) - - # Rewards (black) - for name, lo, hi in [ - ("left_reward_delivery_time", params["left_reward_bottom"], params["left_reward_top"]), - ("right_reward_delivery_time", params["right_reward_bottom"], params["right_reward_top"]), - ]: - t = _event_times(name) - xs, ys = _vlines([(t, lo, hi)]) - fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=2), - name=name.replace("_delivery_time", "").replace("_", " "))) - - # Go cues: prefer events, fall back to df_trials + # Go-cue times define the trial windows (prefer events, fall back to df_trials) so every + # event can be tagged with the trial it falls in. go_cue_times = _event_times("goCue_start_time") if len(go_cue_times) == 0 and df_trials is not None and "goCue_start_time" in df_trials: go_cue_times = df_trials["goCue_start_time"].dropna().values - if len(go_cue_times): - # A single line spanning the contiguous behavior block - xs, ys = _vlines([(go_cue_times, params["behavior_bottom"], params["behavior_top"])]) - fig.add_trace(go.Scattergl(x=xs, y=ys, mode="lines", - line=dict(color="blue", width=0.75), opacity=0.75, - name="go cue")) - # Reward-probability band (needs df_trials and go-cue times) + def _trial_of(times): + """Trial number for each time = how many go cues have started at/before it.""" + if len(go_cue_times) == 0: + return np.zeros(len(times), dtype=int) + return np.searchsorted(go_cue_times, np.asarray(times), side="right") + + # Licks (gray) and rewards (black): short ticks centered on their rows; hover shows trial + for name, key, color, width in [ + ("left_lick_time", "left_lick", "gray", 1.5), + ("right_lick_time", "right_lick", "gray", 1.5), + ("left_reward_delivery_time", "left_reward", "black", 2), + ("right_reward_delivery_time", "right_reward", "black", 2), + ]: + c = row_centers[key] + t = _event_times(name) + label = name.replace("_delivery_time", "").replace("_time", "").replace("_", " ") + xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, _trial_of(t)) + fig.add_trace(go.Scattergl( + x=xs, y=ys, customdata=cd, mode="lines", line=dict(color=color, width=width), + name=label, + hovertemplate="%{x:.2f}s
trial %{customdata}" + label + "", + )) + + if len(go_cue_times): + # A single blue line spanning the contiguous behavior block; hover shows trial + xs, ys, cd = _vline_hover(go_cue_times, params["behavior_bottom"], + params["behavior_top"], np.arange(1, len(go_cue_times) + 1)) + fig.add_trace(go.Scattergl( + x=xs, y=ys, customdata=cd, mode="lines", + line=dict(color="blue", width=0.75), opacity=0.75, name="go cue", + hovertemplate="%{x:.2f}s
trial %{customdata}go cue", + )) + + # Reward-probability band (needs df_trials and go-cue times) -- shown only in the scroller if df_trials is not None and len(go_cue_times) == len(df_trials): x_doubled = np.repeat(go_cue_times, 2)[1:] center = params["probs_center"] @@ -444,7 +464,7 @@ def _event_times(name): fill="tonexty", fillcolor="rgba(0,0,255,0.4)", name="pL")) - y_top = params["probs_center"] + params["probs_half"] # top of the plotted content + y_main_top = params["behavior_top"] # top of the main (non-scroller) view # FIP channels, normalised and stacked above the behavior panel if fip_df is not None: @@ -466,18 +486,20 @@ def _event_times(name): yticks.append(bottom + 0.5) ylabels.append(channel) band += 1 - y_top = bottom + 1.0 + y_main_top = bottom + 1.0 # Start zoomed to a readable ~120 s window at the first go cue (like the matplotlib - # scroller's default window), with a rangeslider so the whole session can be scrubbed -- - # the plotly analog of the scroller's arrow-key panning. + # scroller's default window). The rangeslider scroller below scrubs the whole session and + # auto-scales its own y, so the reward-probability band reads at a useful size there. t0 = go_cue_times.min() if len(go_cue_times) else x_first fig.update_layout( title=session_id or "Session Scroller", xaxis_title="Time (s)", yaxis=dict(tickvals=yticks, ticktext=ylabels, - range=[params["behavior_bottom"] - 0.05, y_top + 0.1]), - xaxis=dict(range=[t0, t0 + 120], rangeslider=dict(visible=True, range=[x_first, x_last])), + range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), + xaxis=dict(range=[t0, t0 + 120], + rangeslider=dict(visible=True, range=[x_first, x_last], + yaxis=dict(rangemode="auto"))), showlegend=True, height=600, width=1300, template="simple_white", ) return fig From 4b94fda6cdcb911d9288ff20856bedc512c12723 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 21:55:20 +0000 Subject: [PATCH 05/12] fix(plot): pin time-based scroller y to the band so it fills the rangeslider Auto-fitting the rangeslider over all rows left the reward-probability band a thin strip at the top (most of the scroller blank). Pin the rangeslider's y-range to the band region when a band is present, so the probability schedule fills the scroller; fall back to auto when there's no band (events-only). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index a1fdb0f..2f53ffd 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -445,7 +445,8 @@ def _trial_of(times): )) # Reward-probability band (needs df_trials and go-cue times) -- shown only in the scroller - if df_trials is not None and len(go_cue_times) == len(df_trials): + has_band = df_trials is not None and len(go_cue_times) == len(df_trials) + if has_band: x_doubled = np.repeat(go_cue_times, 2)[1:] center = params["probs_center"] pR = np.repeat(center + df_trials["reward_probabilityR"].values / 4, 2)[:-1] @@ -489,17 +490,25 @@ def _trial_of(times): y_main_top = bottom + 1.0 # Start zoomed to a readable ~120 s window at the first go cue (like the matplotlib - # scroller's default window). The rangeslider scroller below scrubs the whole session and - # auto-scales its own y, so the reward-probability band reads at a useful size there. + # scroller's default window). The rangeslider scroller below scrubs the whole session; + # when a band is present we pin its y to the band region so the probability schedule + # fills the scroller (rather than auto-fitting all rows, which leaves it a thin strip). t0 = go_cue_times.min() if len(go_cue_times) else x_first + if has_band: + slider_yaxis = dict( + rangemode="fixed", + range=[params["probs_center"] - params["probs_half"], + params["probs_center"] + params["probs_half"]], + ) + else: + slider_yaxis = dict(rangemode="auto") fig.update_layout( title=session_id or "Session Scroller", xaxis_title="Time (s)", yaxis=dict(tickvals=yticks, ticktext=ylabels, range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), xaxis=dict(range=[t0, t0 + 120], - rangeslider=dict(visible=True, range=[x_first, x_last], - yaxis=dict(rangemode="auto"))), + rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)), showlegend=True, height=600, width=1300, template="simple_white", ) return fig From 2a5045a4e0b8f9e4a3dc1ff6ed52ff824bfeb391 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sat, 6 Jun 2026 22:35:08 +0000 Subject: [PATCH 06/12] feat(plot): time-based plotly overlays, stacked layout, scroll-zoom Refine plot_session_in_time_plotly: - Add per-trial smoothed overlays in their own band above the event rows: gold pR/(pL+pR), black solid choice (smooth=5), black dashed lick count (smooth=5). - Ignored trials (animal_response==2) draw a red go-cue line; responded ones green. Each event/go-cue carries its trial number on hover (no on-plot text). - Reorder event rows like the trial figure: rewards at the outer edges, licks inside, right pair grouped at top / left pair at bottom; one go-cue line per trial. - Reward-probability band shown only in the rangeslider scroller, colored left-red / right-blue (trial-consistent), pinned to ~half the scroller bar. - nan-safe time extent (fixes a (0, NaN) rangeslider range from NaN event timestamps). - Lock y (fixedrange) on both plotly figures for horizontal-only (scroll) zoom. - Add a smooth_factor parameter. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 150 +++++++++++++----- 1 file changed, 109 insertions(+), 41 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 2f53ffd..1a01e76 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -322,9 +322,9 @@ def _side_segments(mask, up, down): # Axes styling to match the matplotlib version fig.update_yaxes( tickvals=[0, 1, 1.2], ticktext=["Left", "Right", "Ignored"], - range=[-0.15, 1.25], row=1, col=1, + range=[-0.15, 1.25], fixedrange=True, row=1, col=1, ) - fig.update_yaxes(title_text="p_reward", range=[0, 1], row=2, col=1) + fig.update_yaxes(title_text="p_reward", range=[0, 1], fixedrange=True, row=2, col=1) fig.update_xaxes(title_text="Trial number", row=2, col=1) fig.update_layout( width=1300, height=400, template="simple_white", @@ -335,7 +335,7 @@ def _side_segments(mask, up, down): def plot_session_in_time_plotly( # noqa: C901 pragma: no cover - df_events, df_trials=None, fip_df=None, adjust_time=True, session_id=None + df_events, df_trials=None, fip_df=None, adjust_time=True, session_id=None, smooth_factor=5 ): """Plotly version of :func:`plot_session_scroller.plot_session_scroller` (time-based). @@ -372,33 +372,48 @@ def plot_session_in_time_plotly( # noqa: C901 pragma: no cover if fip_df is not None: fip_df = fip_df.copy() + # nan-safe extent: some events can carry NaN timestamps, and they sort last, so + # iloc[0]/iloc[-1] are not reliable -- use nanmin/nanmax. if adjust_time: - start_time = df_events.iloc[0]["timestamps"] + start_time = np.nanmin(df_events["timestamps"]) df_events["timestamps"] = df_events["timestamps"] - start_time if df_trials is not None: df_trials["goCue_start_time"] = df_trials["goCue_start_time"] - start_time if fip_df is not None: fip_df["timestamps"] = fip_df["timestamps"] - start_time - xmin = df_events.iloc[0]["timestamps"] - xmax = df_events.iloc[-1]["timestamps"] + xmin = np.nanmin(df_events["timestamps"]) + xmax = np.nanmax(df_events["timestamps"]) x_first, x_last = xmin, xmax # full extent (used for the rangeslider / "home") - # y-layout. The four behavior rows are stacked contiguously (top -> bottom: right licks, - # right reward, left reward, left licks) so a single go-cue line spans all of them. The - # reward-probability band lives well above the behavior block, out of the main view -- it - # only shows in the rangeslider "scroller" at the bottom (which auto-scales its own y). + # y-layout, bottom -> top: + # * event rows in [0, 1]: like the trial-based figure, rewards sit at the outer edges + # with licks just inside -- the right pair (reward outer-top, lick inner) grouped near + # the top, the left pair (lick inner, reward outer-bottom) near the bottom. + # * smoothed overlays in their own band [curve_bottom, curve_top] *above* the events. + # * the reward-probability band sits higher still, out of the main view -- it only shows + # in the rangeslider "scroller" (auto-/band-scaled) below. + # One go-cue line per trial spans the event rows. params = { - "behavior_bottom": 0.0, "behavior_top": 1.0, # go cue spans this contiguous block - "probs_center": 1.7, "probs_half": 0.25, # band sits high, above the main view + "behavior_bottom": 0.0, "behavior_top": 1.0, # event ticks + "curve_bottom": 1.1, "curve_top": 2.1, # smoothed overlays, above the events + "probs_center": 2.6, "probs_half": 0.25, # band: scroller only, above main view } - # Row centers (top -> bottom). Event ticks are short marks centered on each row -- 70% - # shorter than the 0.25 row spacing -- so the rows read as separate even when busy. - row_centers = {"right_lick": 0.875, "right_reward": 0.625, - "left_reward": 0.375, "left_lick": 0.125} + # Row centers (top -> bottom): right reward, right lick, left lick, left reward. Event + # ticks are short marks (70% shorter than the row spacing) centered on each row. + row_centers = {"right_reward": 0.92, "right_lick": 0.78, + "left_lick": 0.22, "left_reward": 0.08} tick_half = 0.25 * 0.30 / 2.0 - yticks = [0.875, 0.625, 0.375, 0.125] # right licks, right reward, left reward, left licks - ylabels = ["right licks", "right reward", "left reward", "left licks"] + + def _to_curve(v): + """Map a 0..1 per-trial value into the smoothed-overlay band above the events.""" + span = params["curve_top"] - params["curve_bottom"] + return params["curve_bottom"] + np.asarray(v, dtype=float) * span + + yticks = [0.92, 0.78, 0.22, 0.08, + params["curve_bottom"], (params["curve_bottom"] + params["curve_top"]) / 2, + params["curve_top"]] + ylabels = ["right reward", "right lick", "left lick", "left reward", "0", "0.5", "1"] fig = go.Figure() @@ -410,13 +425,22 @@ def _event_times(name): go_cue_times = _event_times("goCue_start_time") if len(go_cue_times) == 0 and df_trials is not None and "goCue_start_time" in df_trials: go_cue_times = df_trials["goCue_start_time"].dropna().values + n_tr = len(go_cue_times) def _trial_of(times): """Trial number for each time = how many go cues have started at/before it.""" - if len(go_cue_times) == 0: + if n_tr == 0: return np.zeros(len(times), dtype=int) return np.searchsorted(go_cue_times, np.asarray(times), side="right") + # Per-trial choice aligned to the go cues (when df_trials lines up): 0 left, 1 right, + # np.nan = ignored. Used for the red ignored go cues and the smoothed-choice overlay. + aligned = df_trials is not None and len(df_trials) == n_tr and n_tr > 0 + choice = None + if aligned and "animal_response" in df_trials: + choice = df_trials["animal_response"].astype(float).to_numpy().copy() + choice[choice == 2] = np.nan + # Licks (gray) and rewards (black): short ticks centered on their rows; hover shows trial for name, key, color, width in [ ("left_lick_time", "left_lick", "gray", 1.5), @@ -434,15 +458,21 @@ def _trial_of(times): hovertemplate="%{x:.2f}s
trial %{customdata}" + label + "", )) - if len(go_cue_times): - # A single blue line spanning the contiguous behavior block; hover shows trial - xs, ys, cd = _vline_hover(go_cue_times, params["behavior_bottom"], - params["behavior_top"], np.arange(1, len(go_cue_times) + 1)) - fig.add_trace(go.Scattergl( - x=xs, y=ys, customdata=cd, mode="lines", - line=dict(color="blue", width=0.75), opacity=0.75, name="go cue", - hovertemplate="%{x:.2f}s
trial %{customdata}go cue", - )) + if n_tr: + # Go-cue lines spanning the event rows only; ignored trials are drawn red. + trial_no = np.arange(1, n_tr + 1) + ignored = np.isnan(choice) if choice is not None else np.zeros(n_tr, dtype=bool) + for mask, gc_color, gname in [(~ignored, "green", "go cue"), + (ignored, "red", "go cue (ignored)")]: + if not mask.any(): + continue + xs, ys, cd = _vline_hover(go_cue_times[mask], params["behavior_bottom"], + params["behavior_top"], trial_no[mask]) + fig.add_trace(go.Scattergl( + x=xs, y=ys, customdata=cd, mode="lines", + line=dict(color=gc_color, width=0.75), opacity=0.75, name=gname, + hovertemplate="%{x:.2f}s
trial %{customdata}" + gname + "", + )) # Reward-probability band (needs df_trials and go-cue times) -- shown only in the scroller has_band = df_trials is not None and len(go_cue_times) == len(df_trials) @@ -452,20 +482,21 @@ def _trial_of(times): pR = np.repeat(center + df_trials["reward_probabilityR"].values / 4, 2)[:-1] pL = np.repeat(center - df_trials["reward_probabilityL"].values / 4, 2)[:-1] base = np.full_like(x_doubled, center, dtype=float) - # pR above center (red), pL below center (blue); fill toward the center baseline. + # pR above center, pL below center; fill toward the center baseline. Colored to match + # the trial figures: left (pL) red, right (pR) blue. # go.Scatter (not Scattergl) -- the WebGL trace ignores fill="tonexty". fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip")) fig.add_trace(go.Scatter(x=x_doubled, y=pR, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(255,0,0,0.4)", + fill="tonexty", fillcolor="rgba(0,0,255,0.4)", name="pR")) fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip")) fig.add_trace(go.Scatter(x=x_doubled, y=pL, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(0,0,255,0.4)", + fill="tonexty", fillcolor="rgba(255,0,0,0.4)", name="pL")) - y_main_top = params["behavior_top"] # top of the main (non-scroller) view + y_main_top = params["curve_top"] # top of the main (non-scroller) view # FIP channels, normalised and stacked above the behavior panel if fip_df is not None: @@ -489,23 +520,60 @@ def _trial_of(times): band += 1 y_main_top = bottom + 1.0 + # Smoothed per-trial overlays, in their own band above the event rows (0..1 mapped into + # [curve_bottom, curve_top]) and plotted at the go-cue times. Added last so they sit on + # top of the go-cue lines. + if n_tr: + offset = smooth_factor // 2 + + # Reward-probability fraction pR/(pL+pR) -- golden, like the trial-based base color + if aligned and {"reward_probabilityL", "reward_probabilityR"} <= set(df_trials.columns): + pL = df_trials["reward_probabilityL"].to_numpy() + pR = df_trials["reward_probabilityR"].to_numpy() + frac = np.divide(pR, pL + pR, out=np.full(n_tr, np.nan), where=(pL + pR) > 0) + fig.add_trace(go.Scattergl(x=go_cue_times, y=_to_curve(frac), mode="lines", + line=dict(color="gold", width=1.5), name="pR/(pL+pR)")) + + # Smoothed choice (black solid) + if choice is not None: + sm = moving_average(choice, smooth_factor) / ( + moving_average(~np.isnan(choice), smooth_factor) + 1e-6) + sm[sm > 100] = np.nan + xs = go_cue_times[offset: offset + len(sm)] + fig.add_trace(go.Scattergl(x=xs, y=_to_curve(sm), mode="lines", + line=dict(color="black", width=1.5), + name=f"choice (smooth = {smooth_factor})")) + + # Smoothed lick count per trial (black dashed), normalised to [0, 1] + lick_times = np.concatenate([_event_times("left_lick_time"), + _event_times("right_lick_time")]) + if len(lick_times): + counts = np.bincount(_trial_of(lick_times), minlength=n_tr + 1)[1:n_tr + 1] + sm = moving_average(counts.astype(float), smooth_factor) + top = np.nanmax(sm) + if top > 0: + sm = sm / top + xs = go_cue_times[offset: offset + len(sm)] + fig.add_trace(go.Scattergl(x=xs, y=_to_curve(sm), mode="lines", + line=dict(color="black", width=1.2, dash="dash"), + name=f"lick count (smooth = {smooth_factor})")) + # Start zoomed to a readable ~120 s window at the first go cue (like the matplotlib - # scroller's default window). The rangeslider scroller below scrubs the whole session; - # when a band is present we pin its y to the band region so the probability schedule - # fills the scroller (rather than auto-fitting all rows, which leaves it a thin strip). + # scroller's default window). The rangeslider scroller below scrubs the whole session. + # When a band is present, pin the scroller's y to ~2x the band height so the reward- + # probability band fills about half the scroller bar (x-dragging is unaffected); + # otherwise auto-fit. t0 = go_cue_times.min() if len(go_cue_times) else x_first if has_band: - slider_yaxis = dict( - rangemode="fixed", - range=[params["probs_center"] - params["probs_half"], - params["probs_center"] + params["probs_half"]], - ) + half = 2 * params["probs_half"] + slider_yaxis = dict(rangemode="fixed", + range=[params["probs_center"] - half, params["probs_center"] + half]) else: slider_yaxis = dict(rangemode="auto") fig.update_layout( title=session_id or "Session Scroller", xaxis_title="Time (s)", - yaxis=dict(tickvals=yticks, ticktext=ylabels, + yaxis=dict(tickvals=yticks, ticktext=ylabels, fixedrange=True, range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), xaxis=dict(range=[t0, t0 + 120], rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)), From ace216f74e46c10fd379b08123fd98fdfcb6e802 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 01:26:30 +0000 Subject: [PATCH 07/12] =?UTF-8?q?feat(plot):=20trial-based=20plotly=20?= =?UTF-8?q?=E2=80=94=20scroller,=20per-session=20x=20labels,=20(session,?= =?UTF-8?q?=20trial)=20hover?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add a rangeslider "scroller" under the trial-based figure (drag to pan/zoom). - For multiple sessions, x tick labels restart at 0 each session. - Choice-raster ticks and ignored/autowater markers now carry (within-session trial, session) and show it on hover. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 543 +++++++++++------- tests/test_plot_foraging_session_plotly.py | 28 +- 2 files changed, 360 insertions(+), 211 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 1a01e76..7add005 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -69,6 +69,35 @@ def _vline_hover(x_arr, y0, y1, hover): return xs, ys, cd +def _session_segments(session_id, n): + """Contiguous per-session index segments and the boundary indices between them. + + Returns ``(segments, boundaries)`` where ``segments`` is a list of ``(start, end)`` + half-open index ranges (one per session, in order) and ``boundaries`` is the list of + indices at which a new session starts (used to draw the dividing lines). With + ``session_id=None`` the whole input is one segment. + """ + if session_id is None: + return [(0, n)], [] + sid = np.asarray(session_id) + change = list(np.nonzero(sid[1:] != sid[:-1])[0] + 1) + segments = list(zip([0, *change], [*change, n])) + return segments, change + + +def _broken(x, y, segments): + """Concatenate per-session slices of ``x``/``y`` with ``None`` gaps between sessions. + + Inserting a ``None`` at each boundary breaks the line so a continuous trace is not drawn + across the (meaningless) jump from one session's last trial to the next session's first. + """ + xs, ys = [], [] + for s, e in segments: + xs += [*np.asarray(x)[s:e], None] + ys += [*np.asarray(y)[s:e], None] + return xs, ys + + def plot_foraging_session_plotly( # noqa: C901 choice_history, reward_history, @@ -83,6 +112,7 @@ def plot_foraging_session_plotly( # noqa: C901 bias_lower=None, bias_upper=None, plot_list=["choice", "finished", "reward_prob"], + session_id=None, ): """Plotly version of :func:`plot_foraging_session.plot_foraging_session` (trial-based). @@ -122,6 +152,10 @@ def plot_foraging_session_plotly( # noqa: C901 Lower / upper confidence bounds for ``bias``. plot_list : list, optional Which optional traces to draw, from {"choice", "finished", "reward_prob", "bias"}. + session_id : list or np.ndarray, optional + Per-trial session label (same length as ``choice_history``). When given, multiple + sessions are concatenated horizontally along the trial axis in order; smoothed traces + reset at each session and a thick vertical line marks every session boundary. Returns ------- @@ -159,6 +193,21 @@ def plot_foraging_session_plotly( # noqa: C901 autowater_ignored = autowater_offered & ignored unrewarded_trials = ~reward_history & ~ignored & ~autowater_offered + # Per-session segments (multi-session concatenation); single segment when session_id None. + segments, boundaries = _session_segments(session_id, n_trials) + + # Per-trial within-session index (resets to 0 each session) and session label -- used for + # the (session, trial) hover and the per-session x tick labels. + within = np.zeros(n_trials, dtype=int) + sess_label = np.array([""] * n_trials, dtype=object) + sid_arr = None if session_id is None else np.asarray(session_id) + for s, e in segments: + within[s:e] = np.arange(e - s) + if sid_arr is not None: + sess_label[s:e] = sid_arr[s] + + hovertemplate = "trial %%{customdata[0]}
session %%{customdata[1]}%s" + fig = make_subplots( rows=2, cols=1, @@ -167,111 +216,133 @@ def plot_foraging_session_plotly( # noqa: C901 vertical_spacing=0.02, ) - def _side_segments(mask, up, down): - """Split a boolean trial mask into right (>0.5) and left (<0.5) tick segments. + def _raster(mask, up, down): + """Vertical choice ticks for ``mask``, each carrying (within-session trial, session). - ``up`` is ``(y0, y1)`` for right choices (drawn just above Right); ``down`` for - left choices (just below Left). Returns the list consumed by :func:`_vlines`. + Right choices (>0.5) use the ``up`` (y0, y1); left choices use ``down``. Returns + ``x, y, customdata`` lists (``None`` gaps between ticks) for one Scattergl trace. """ - xx = np.nonzero(mask)[0] + 1 - side = choice_history[mask] - return [ - (xx[side > 0.5], up[0], up[1]), - (xx[side < 0.5], down[0], down[1]), - ] + idx = np.nonzero(mask)[0] + side = choice_history[idx] + xs, ys, cd = [], [], [] + for i, sd in zip(idx, side): + y0, y1 = up if sd > 0.5 else down + xs += [i + 1, i + 1, None] + ys += [y0, y1, None] + cd += [(int(within[i]), sess_label[i])] * 2 + [(None, None)] + return xs, ys, cd + + def _markers(mask): + """(x, customdata) for marker traces (ignored / autowater-ignored), with hover.""" + idx = np.nonzero(mask)[0] + return idx + 1, [(int(within[i]), sess_label[i]) for i in idx] # == Choice trace == # Rewarded (real foraging, autowater excluded): tall black ticks just outside [0, 1] - xs, ys = _vlines(_side_segments(rewarded_excluding_autowater, (1.05, 1.15), (-0.15, -0.05))) + xs, ys, cd = _raster(rewarded_excluding_autowater, (1.05, 1.15), (-0.15, -0.05)) fig.add_trace( - go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=1), - name="Rewarded choices"), + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", line=dict(color="black", width=1), + name="Rewarded choices", hovertemplate=hovertemplate % "Rewarded choices"), row=1, col=1, ) # Unrewarded (real foraging): short gray ticks - xs, ys = _vlines(_side_segments(unrewarded_trials, (1.05, 1.10), (-0.10, -0.05))) + xs, ys, cd = _raster(unrewarded_trials, (1.05, 1.10), (-0.10, -0.05)) fig.add_trace( - go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="gray", width=1), - name="Unrewarded choices"), + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", line=dict(color="gray", width=1), + name="Unrewarded choices", hovertemplate=hovertemplate % "Unrewarded choices"), row=1, col=1, ) # Ignored trials: red x at the top - xx = np.nonzero(ignored & ~autowater_ignored)[0] + 1 + xx, cd = _markers(ignored & ~autowater_ignored) fig.add_trace( - go.Scattergl(x=xx, y=[1.2] * len(xx), mode="markers", - marker=dict(symbol="x", color="red", size=4), name="Ignored"), + go.Scattergl(x=xx, y=[1.2] * len(xx), customdata=cd, mode="markers", + marker=dict(symbol="x", color="red", size=4), name="Ignored", + hovertemplate=hovertemplate % "Ignored"), row=1, col=1, ) # Autowater collected / ignored if autowater_offered is not None: - xs, ys = _vlines(_side_segments(autowater_collected, (1.05, 1.15), (-0.15, -0.05))) + xs, ys, cd = _raster(autowater_collected, (1.05, 1.15), (-0.15, -0.05)) fig.add_trace( - go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="royalblue", width=1), - name="Autowater collected"), + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", + line=dict(color="royalblue", width=1), name="Autowater collected", + hovertemplate=hovertemplate % "Autowater collected"), row=1, col=1, ) - xx = np.nonzero(autowater_ignored)[0] + 1 + xx, cd = _markers(autowater_ignored) fig.add_trace( - go.Scattergl(x=xx, y=[1.2] * len(xx), mode="markers", + go.Scattergl(x=xx, y=[1.2] * len(xx), customdata=cd, mode="markers", marker=dict(symbol="x", color="royalblue", size=4), - name="Autowater ignored"), + name="Autowater ignored", + hovertemplate=hovertemplate % "Autowater ignored"), row=1, col=1, ) - # Base reward probability + # Base reward probability (broken at session boundaries) if "reward_prob" in plot_list: + xs, ys = _broken(np.arange(n_trials) + 1, p_reward_fraction, segments) fig.add_trace( - go.Scattergl(x=np.arange(n_trials) + 1, y=p_reward_fraction, mode="lines", + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color=_color(base_color), width=1.5), name="Base rew. prob."), row=1, col=1, ) + def _smoothed_trace(num, den): + """Per-session smoothed series (resets at each session); x in 1-based trial coords.""" + xs, ys = [], [] + for s, e in segments: + y = moving_average(num[s:e], smooth_factor) + if den is not None: + y = y / (moving_average(den[s:e], smooth_factor) + 1e-6) + x = s + np.arange(len(y)) + int(smooth_factor / 2) + 1 + xs += [*x, None] + ys += [*y, None] + return xs, np.where(np.array(ys, dtype=float) > 100, np.nan, ys) + # Smoothed choice history if "choice" in plot_list: - y = moving_average(choice_history, smooth_factor) / ( - moving_average(~np.isnan(choice_history), smooth_factor) + 1e-6 - ) - y[y > 100] = np.nan - x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 + xs, ys = _smoothed_trace(choice_history, ~np.isnan(choice_history)) fig.add_trace( - go.Scattergl(x=x, y=y, mode="lines", line=dict(color="black", width=1.5), + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=1.5), name=f"Choice (smooth = {smooth_factor})"), row=1, col=1, ) # Finished ratio (only meaningful if there are ignored trials) if "finished" in plot_list and np.sum(np.isnan(choice_history)): - y = moving_average(~np.isnan(choice_history), smooth_factor) - x = np.arange(0, len(y)) + int(smooth_factor / 2) + 1 + xs, ys = _smoothed_trace(~np.isnan(choice_history), None) fig.add_trace( - go.Scattergl(x=x, y=y, mode="lines", line=dict(color="magenta", width=0.8), + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="magenta", width=0.8), name=f"Finished (smooth = {smooth_factor})"), row=1, col=1, ) - # Bias trace + confidence band + # Bias trace + confidence band (broken at session boundaries) if ("bias" in plot_list) and (bias is not None): xx = np.arange(n_trials) + 1 bias = (np.array(bias) + 1) / 2 bias_lower = np.clip((np.array(bias_lower) + 1) / 2, 0, None) bias_upper = np.clip((np.array(bias_upper) + 1) / 2, None, 1) + xb_up, y_up = _broken(xx, bias_upper, segments) + xb_lo, y_lo = _broken(xx, bias_lower, segments) + xb, y_bias = _broken(xx, bias, segments) # go.Scatter (not Scattergl) for the filled band -- Scattergl ignores fill. fig.add_trace( - go.Scatter(x=xx, y=bias_upper, mode="lines", line=dict(width=0), + go.Scatter(x=xb_up, y=y_up, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"), row=1, col=1, ) fig.add_trace( - go.Scatter(x=xx, y=bias_lower, mode="lines", line=dict(width=0), + go.Scatter(x=xb_lo, y=y_lo, mode="lines", line=dict(width=0), fill="tonexty", fillcolor="rgba(0,128,0,0.25)", showlegend=False, hoverinfo="skip"), row=1, col=1, ) fig.add_trace( - go.Scattergl(x=xx, y=bias, mode="lines", line=dict(color="green", width=1.5), + go.Scattergl(x=xb, y=y_bias, mode="lines", line=dict(color="green", width=1.5), name="bias"), row=1, col=1, ) @@ -306,28 +377,46 @@ def _side_segments(mask, up, down): row=1, col=1, ) - # == Reward schedule (bottom panel) == + # == Reward schedule (bottom panel; broken at session boundaries) == xx = np.arange(n_trials) + 1 + xr, y_pr = _broken(xx, p_reward[1, :], segments) fig.add_trace( - go.Scattergl(x=xx, y=p_reward[1, :], mode="lines", line=dict(color="blue", width=1), + go.Scattergl(x=xr, y=y_pr, mode="lines", line=dict(color="blue", width=1), name="p_right"), row=2, col=1, ) + xl, y_pl = _broken(xx, p_reward[0, :], segments) fig.add_trace( - go.Scattergl(x=xx, y=p_reward[0, :], mode="lines", line=dict(color="red", width=1), + go.Scattergl(x=xl, y=y_pl, mode="lines", line=dict(color="red", width=1), name="p_left"), row=2, col=1, ) + # Thick vertical lines marking session boundaries (between trials b and b+1) + for b in boundaries: + for row in (1, 2): + fig.add_vline(x=b + 0.5, line=dict(color="black", width=2), row=row, col=1) + # Axes styling to match the matplotlib version fig.update_yaxes( tickvals=[0, 1, 1.2], ticktext=["Left", "Right", "Ignored"], range=[-0.15, 1.25], fixedrange=True, row=1, col=1, ) fig.update_yaxes(title_text="p_reward", range=[0, 1], fixedrange=True, row=2, col=1) - fig.update_xaxes(title_text="Trial number", row=2, col=1) + # Bottom x-axis: a rangeslider scroller (drag to pan/zoom), and -- for multiple sessions -- + # tick labels that restart at 0 each session. + fig.update_xaxes(title_text="Trial number", row=2, col=1, + rangeslider=dict(visible=True, thickness=0.08)) + if len(segments) > 1: + step = 250 + tickvals, ticktext = [], [] + for s, e in segments: + for w in range(0, e - s, step): + tickvals.append(s + 1 + w) + ticktext.append(str(w)) + fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) fig.update_layout( - width=1300, height=400, template="simple_white", + width=1300, height=440, template="simple_white", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), margin=dict(l=60, r=20, t=60, b=50), ) @@ -335,32 +424,42 @@ def _side_segments(mask, up, down): def plot_session_in_time_plotly( # noqa: C901 pragma: no cover - df_events, df_trials=None, fip_df=None, adjust_time=True, session_id=None, smooth_factor=5 + df_events, df_trials=None, fip_df=None, adjust_time=True, title=None, smooth_factor=5 ): """Plotly version of :func:`plot_session_scroller.plot_session_scroller` (time-based). Plots the session in real time (not in trial): left / right licks and rewards as ticks, - go cues as vertical lines, and -- when ``df_trials`` is supplied -- the left / right - reward-probability band, laid out to match the matplotlib scroller. + go cues as vertical lines (red for ignored trials), smoothed overlays above the events, + and -- when ``df_trials`` is supplied -- the left / right reward-probability band in the + rangeslider "scroller" below. + + Multiple sessions: if ``df_events`` has a ``session_id`` column with more than one + session, the sessions are concatenated end-to-end along time in order of appearance + (each restarted at the running offset); a thick vertical line marks each boundary and + per-trial / smoothed quantities reset per session. ``df_trials`` is matched per session + by its own ``session_id`` column when present. Parameters ---------- df_events : pandas.DataFrame - Tidy dataframe of session events, e.g. from - ``aind_dynamic_foraging_data_utils.nwb_utils.create_df_events``. Needs ``event`` and - ``timestamps`` columns; recognised events are ``left_lick_time``, ``right_lick_time``, - ``left_reward_delivery_time``, ``right_reward_delivery_time`` and ``goCue_start_time``. + Tidy dataframe of session events (``event`` + ``timestamps``; optional ``session_id``). + Recognised events: ``left_lick_time``, ``right_lick_time``, ``left_reward_delivery_time``, + ``right_reward_delivery_time`` and ``goCue_start_time``. df_trials : pandas.DataFrame, optional - Per-trial dataframe used for the reward-probability band (and as a fallback source of - go-cue times). Needs ``goCue_start_time`` and ``reward_probabilityL/R``. The go-cue - times must share the same time base as ``df_events.timestamps``. + Per-trial dataframe for the reward-probability band / overlays / red ignored go cues + (and a fallback source of go-cue times). Uses ``goCue_start_time``, + ``reward_probabilityL/R`` and ``animal_response``; go-cue times must share the + ``df_events`` time base. Matched per session via ``session_id`` when present. fip_df : pandas.DataFrame, optional - Tidy dataframe of FIP measurements (from ``create_df_fip(tidy=True)``); each present - channel is normalised and stacked above the behavior panel. + Tidy FIP measurements (single-session only); each present channel is normalised and + stacked above the behavior panel. adjust_time : bool, optional - If True (default), shift time so the first event is at t = 0. - session_id : str, optional - Title for the figure. + If True (default), shift time so the first event is at t = 0 (always shifted when + concatenating multiple sessions). + title : str, optional + Figure title. + smooth_factor : int, optional + Smoothing window for the choice / lick-count overlays, by default 5. Returns ------- @@ -372,35 +471,16 @@ def plot_session_in_time_plotly( # noqa: C901 pragma: no cover if fip_df is not None: fip_df = fip_df.copy() - # nan-safe extent: some events can carry NaN timestamps, and they sort last, so - # iloc[0]/iloc[-1] are not reliable -- use nanmin/nanmax. - if adjust_time: - start_time = np.nanmin(df_events["timestamps"]) - df_events["timestamps"] = df_events["timestamps"] - start_time - if df_trials is not None: - df_trials["goCue_start_time"] = df_trials["goCue_start_time"] - start_time - if fip_df is not None: - fip_df["timestamps"] = fip_df["timestamps"] - start_time - - xmin = np.nanmin(df_events["timestamps"]) - xmax = np.nanmax(df_events["timestamps"]) - x_first, x_last = xmin, xmax # full extent (used for the rangeslider / "home") - # y-layout, bottom -> top: - # * event rows in [0, 1]: like the trial-based figure, rewards sit at the outer edges - # with licks just inside -- the right pair (reward outer-top, lick inner) grouped near - # the top, the left pair (lick inner, reward outer-bottom) near the bottom. - # * smoothed overlays in their own band [curve_bottom, curve_top] *above* the events. - # * the reward-probability band sits higher still, out of the main view -- it only shows - # in the rangeslider "scroller" (auto-/band-scaled) below. - # One go-cue line per trial spans the event rows. + # * event rows in [0, 1]: rewards at the outer edges, licks inside (right pair near the + # top, left pair near the bottom), like the trial-based figure. + # * smoothed overlays in their own band [curve_bottom, curve_top] above the events. + # * the reward-probability band sits higher still -- it only shows in the rangeslider. params = { - "behavior_bottom": 0.0, "behavior_top": 1.0, # event ticks - "curve_bottom": 1.1, "curve_top": 2.1, # smoothed overlays, above the events - "probs_center": 2.6, "probs_half": 0.25, # band: scroller only, above main view + "behavior_bottom": 0.0, "behavior_top": 1.0, + "curve_bottom": 1.1, "curve_top": 2.1, + "probs_center": 2.6, "probs_half": 0.25, } - # Row centers (top -> bottom): right reward, right lick, left lick, left reward. Event - # ticks are short marks (70% shorter than the row spacing) centered on each row. row_centers = {"right_reward": 0.92, "right_lick": 0.78, "left_lick": 0.22, "left_reward": 0.08} tick_half = 0.25 * 0.30 / 2.0 @@ -415,91 +495,168 @@ def _to_curve(v): params["curve_top"]] ylabels = ["right reward", "right lick", "left lick", "left reward", "0", "0.5", "1"] + # Sessions in order of appearance; concatenate end-to-end along time when more than one. + has_sess = "session_id" in df_events.columns + sessions = (list(dict.fromkeys(df_events["session_id"].tolist())) if has_sess else [None]) + shift_each = adjust_time or len(sessions) > 1 + fig = go.Figure() - def _event_times(name): - return df_events.query("event == @name").timestamps.values - - # Go-cue times define the trial windows (prefer events, fall back to df_trials) so every - # event can be tagged with the trial it falls in. - go_cue_times = _event_times("goCue_start_time") - if len(go_cue_times) == 0 and df_trials is not None and "goCue_start_time" in df_trials: - go_cue_times = df_trials["goCue_start_time"].dropna().values - n_tr = len(go_cue_times) - - def _trial_of(times): - """Trial number for each time = how many go cues have started at/before it.""" - if n_tr == 0: - return np.zeros(len(times), dtype=int) - return np.searchsorted(go_cue_times, np.asarray(times), side="right") - - # Per-trial choice aligned to the go cues (when df_trials lines up): 0 left, 1 right, - # np.nan = ignored. Used for the red ignored go cues and the smoothed-choice overlay. - aligned = df_trials is not None and len(df_trials) == n_tr and n_tr > 0 - choice = None - if aligned and "animal_response" in df_trials: - choice = df_trials["animal_response"].astype(float).to_numpy().copy() - choice[choice == 2] = np.nan - - # Licks (gray) and rewards (black): short ticks centered on their rows; hover shows trial - for name, key, color, width in [ - ("left_lick_time", "left_lick", "gray", 1.5), - ("right_lick_time", "right_lick", "gray", 1.5), - ("left_reward_delivery_time", "left_reward", "black", 2), - ("right_reward_delivery_time", "right_reward", "black", 2), - ]: - c = row_centers[key] - t = _event_times(name) - label = name.replace("_delivery_time", "").replace("_time", "").replace("_", " ") - xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, _trial_of(t)) + ev_meta = { + "left_lick": ("left_lick_time", "gray", 1.5, "left lick"), + "right_lick": ("right_lick_time", "gray", 1.5, "right lick"), + "left_reward": ("left_reward_delivery_time", "black", 2, "left reward"), + "right_reward": ("right_reward_delivery_time", "black", 2, "right reward"), + } + ev_acc = {k: {"x": [], "y": [], "cd": []} for k in ev_meta} + gocue_acc = {"go cue": {"x": [], "y": [], "cd": []}, + "go cue (ignored)": {"x": [], "y": [], "cd": []}} + frac_x, frac_y, choice_x, choice_y, lick_x, lick_y = [], [], [], [], [], [] + band_x, band_pR, band_pL, band_base = [], [], [], [] + boundaries, has_band = [], False + cum, first_t0, first_gc, last_off = 0.0, None, None, 0.0 + + for si, sess in enumerate(sessions): + ev_s = df_events if sess is None else df_events[df_events["session_id"] == sess] + ts = ev_s["timestamps"].to_numpy() + t0 = np.nanmin(ts) + if first_t0 is None: + first_t0 = t0 + off = (cum - t0) if shift_each else 0.0 + last_off = off + if si > 0: + boundaries.append(cum) + + def _ev(name, _ev=ev_s, _off=off): + return _ev.loc[_ev["event"] == name, "timestamps"].to_numpy() + _off + + tr_s = None + if df_trials is not None: + tr_s = (df_trials[df_trials["session_id"] == sess] + if (sess is not None and "session_id" in df_trials.columns) else df_trials) + + gc = _ev("goCue_start_time") + if len(gc) == 0 and tr_s is not None and "goCue_start_time" in tr_s.columns: + gc = tr_s["goCue_start_time"].to_numpy() + off + gc = gc[~np.isnan(gc)] + n_tr = len(gc) + if n_tr and first_gc is None: + first_gc = gc.min() + + def _trial_of(times, _gc=gc, _n=n_tr): + if _n == 0: + return np.zeros(len(times), dtype=int) + return np.searchsorted(_gc, np.asarray(times), side="right") + + aligned = tr_s is not None and len(tr_s) == n_tr and n_tr > 0 + choice = None + if aligned and "animal_response" in tr_s.columns: + choice = tr_s["animal_response"].astype(float).to_numpy().copy() + choice[choice == 2] = np.nan + + for key, (name, _color, _width, _label) in ev_meta.items(): + c = row_centers[key] + t = _ev(name) + xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, _trial_of(t)) + ev_acc[key]["x"] += xs + ev_acc[key]["y"] += ys + ev_acc[key]["cd"] += cd + + if n_tr: + trial_no = np.arange(1, n_tr + 1) + ign = np.isnan(choice) if choice is not None else np.zeros(n_tr, dtype=bool) + for gname, mask in [("go cue", ~ign), ("go cue (ignored)", ign)]: + if mask.any(): + xs, ys, cd = _vline_hover(gc[mask], params["behavior_bottom"], + params["behavior_top"], trial_no[mask]) + gocue_acc[gname]["x"] += xs + gocue_acc[gname]["y"] += ys + gocue_acc[gname]["cd"] += cd + + if n_tr: + off_s = smooth_factor // 2 + if aligned and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns): + pL = tr_s["reward_probabilityL"].to_numpy() + pR = tr_s["reward_probabilityR"].to_numpy() + frac = np.divide(pR, pL + pR, out=np.full(n_tr, np.nan), where=(pL + pR) > 0) + frac_x += [*gc, None] + frac_y += [*_to_curve(frac), None] + if choice is not None: + sm = moving_average(choice, smooth_factor) / ( + moving_average(~np.isnan(choice), smooth_factor) + 1e-6) + sm[sm > 100] = np.nan + xsm = gc[off_s: off_s + len(sm)] + choice_x += [*xsm, None] + choice_y += [*_to_curve(sm[: len(xsm)]), None] + lt = np.concatenate([_ev("left_lick_time"), _ev("right_lick_time")]) + if len(lt): + counts = np.bincount(_trial_of(lt), minlength=n_tr + 1)[1:n_tr + 1] + sm = moving_average(counts.astype(float), smooth_factor) + top = np.nanmax(sm) if len(sm) else 0 + if top > 0: + sm = sm / top + xsm = gc[off_s: off_s + len(sm)] + lick_x += [*xsm, None] + lick_y += [*_to_curve(sm[: len(xsm)]), None] + + if (tr_s is not None and n_tr and len(tr_s) == n_tr + and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns)): + has_band = True + center = params["probs_center"] + xd = np.repeat(gc, 2)[1:] + pr = np.repeat(center + tr_s["reward_probabilityR"].to_numpy() / 4, 2)[:-1] + pl = np.repeat(center - tr_s["reward_probabilityL"].to_numpy() / 4, 2)[:-1] + band_x += [*xd, None] + band_pR += [*pr, None] + band_pL += [*pl, None] + band_base += [center] * len(xd) + [None] + + cum += np.nanmax(ts) - t0 + + # --- build one trace per type from the accumulators --- + for key, (name, color, width, label) in ev_meta.items(): + a = ev_acc[key] fig.add_trace(go.Scattergl( - x=xs, y=ys, customdata=cd, mode="lines", line=dict(color=color, width=width), - name=label, - hovertemplate="%{x:.2f}s
trial %{customdata}" + label + "", - )) - - if n_tr: - # Go-cue lines spanning the event rows only; ignored trials are drawn red. - trial_no = np.arange(1, n_tr + 1) - ignored = np.isnan(choice) if choice is not None else np.zeros(n_tr, dtype=bool) - for mask, gc_color, gname in [(~ignored, "green", "go cue"), - (ignored, "red", "go cue (ignored)")]: - if not mask.any(): - continue - xs, ys, cd = _vline_hover(go_cue_times[mask], params["behavior_bottom"], - params["behavior_top"], trial_no[mask]) + x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", + line=dict(color=color, width=width), name=label, + hovertemplate="%{x:.2f}s
trial %{customdata}" + label + "")) + + for gname, gcolor in [("go cue", "green"), ("go cue (ignored)", "red")]: + a = gocue_acc[gname] + if a["x"]: fig.add_trace(go.Scattergl( - x=xs, y=ys, customdata=cd, mode="lines", - line=dict(color=gc_color, width=0.75), opacity=0.75, name=gname, - hovertemplate="%{x:.2f}s
trial %{customdata}" + gname + "", - )) + x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", + line=dict(color=gcolor, width=0.75), opacity=0.75, name=gname, + hovertemplate="%{x:.2f}s
trial %{customdata}" + gname + "")) - # Reward-probability band (needs df_trials and go-cue times) -- shown only in the scroller - has_band = df_trials is not None and len(go_cue_times) == len(df_trials) if has_band: - x_doubled = np.repeat(go_cue_times, 2)[1:] - center = params["probs_center"] - pR = np.repeat(center + df_trials["reward_probabilityR"].values / 4, 2)[:-1] - pL = np.repeat(center - df_trials["reward_probabilityL"].values / 4, 2)[:-1] - base = np.full_like(x_doubled, center, dtype=float) - # pR above center, pL below center; fill toward the center baseline. Colored to match - # the trial figures: left (pL) red, right (pR) blue. - # go.Scatter (not Scattergl) -- the WebGL trace ignores fill="tonexty". - fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), + # pR above center (blue=right), pL below center (red=left); fill to the center base. + fig.add_trace(go.Scatter(x=band_x, y=band_base, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip")) - fig.add_trace(go.Scatter(x=x_doubled, y=pR, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(0,0,255,0.4)", - name="pR")) - fig.add_trace(go.Scatter(x=x_doubled, y=base, mode="lines", line=dict(width=0), + fig.add_trace(go.Scatter(x=band_x, y=band_pR, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(0,0,255,0.4)", name="pR")) + fig.add_trace(go.Scatter(x=band_x, y=band_base, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip")) - fig.add_trace(go.Scatter(x=x_doubled, y=pL, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(255,0,0,0.4)", - name="pL")) - - y_main_top = params["curve_top"] # top of the main (non-scroller) view - - # FIP channels, normalised and stacked above the behavior panel - if fip_df is not None: + fig.add_trace(go.Scatter(x=band_x, y=band_pL, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(255,0,0,0.4)", name="pL")) + + # Smoothed overlays on top + if frac_x: + fig.add_trace(go.Scattergl(x=frac_x, y=frac_y, mode="lines", + line=dict(color="gold", width=1.5), name="pR/(pL+pR)")) + if choice_x: + fig.add_trace(go.Scattergl(x=choice_x, y=choice_y, mode="lines", + line=dict(color="black", width=1.5), + name=f"choice (smooth = {smooth_factor})")) + if lick_x: + fig.add_trace(go.Scattergl(x=lick_x, y=lick_y, mode="lines", + line=dict(color="black", width=1.2, dash="dash"), + name=f"lick count (smooth = {smooth_factor})")) + + y_main_top = params["curve_top"] + + # FIP channels (single-session only), normalised and stacked above the behavior panel + if fip_df is not None and len(sessions) == 1: fip_channels = ["G_1_preprocessed", "G_2_preprocessed", "R_1_preprocessed", "R_2_preprocessed"] fip_colors = {"G_1": "green", "G_2": "darkgreen", "R_1": "red", "R_2": "darkred"} @@ -513,57 +670,23 @@ def _trial_of(times): d = C["data"].values - np.nanmin(C["data"].values) d = d / np.nanmax(d) + bottom color = fip_colors["_".join(channel.split("_")[:2])] - fig.add_trace(go.Scattergl(x=C.timestamps.values, y=d, mode="lines", + fig.add_trace(go.Scattergl(x=C.timestamps.values + last_off, y=d, mode="lines", line=dict(color=color), name=channel)) yticks.append(bottom + 0.5) ylabels.append(channel) band += 1 y_main_top = bottom + 1.0 - # Smoothed per-trial overlays, in their own band above the event rows (0..1 mapped into - # [curve_bottom, curve_top]) and plotted at the go-cue times. Added last so they sit on - # top of the go-cue lines. - if n_tr: - offset = smooth_factor // 2 - - # Reward-probability fraction pR/(pL+pR) -- golden, like the trial-based base color - if aligned and {"reward_probabilityL", "reward_probabilityR"} <= set(df_trials.columns): - pL = df_trials["reward_probabilityL"].to_numpy() - pR = df_trials["reward_probabilityR"].to_numpy() - frac = np.divide(pR, pL + pR, out=np.full(n_tr, np.nan), where=(pL + pR) > 0) - fig.add_trace(go.Scattergl(x=go_cue_times, y=_to_curve(frac), mode="lines", - line=dict(color="gold", width=1.5), name="pR/(pL+pR)")) - - # Smoothed choice (black solid) - if choice is not None: - sm = moving_average(choice, smooth_factor) / ( - moving_average(~np.isnan(choice), smooth_factor) + 1e-6) - sm[sm > 100] = np.nan - xs = go_cue_times[offset: offset + len(sm)] - fig.add_trace(go.Scattergl(x=xs, y=_to_curve(sm), mode="lines", - line=dict(color="black", width=1.5), - name=f"choice (smooth = {smooth_factor})")) - - # Smoothed lick count per trial (black dashed), normalised to [0, 1] - lick_times = np.concatenate([_event_times("left_lick_time"), - _event_times("right_lick_time")]) - if len(lick_times): - counts = np.bincount(_trial_of(lick_times), minlength=n_tr + 1)[1:n_tr + 1] - sm = moving_average(counts.astype(float), smooth_factor) - top = np.nanmax(sm) - if top > 0: - sm = sm / top - xs = go_cue_times[offset: offset + len(sm)] - fig.add_trace(go.Scattergl(x=xs, y=_to_curve(sm), mode="lines", - line=dict(color="black", width=1.2, dash="dash"), - name=f"lick count (smooth = {smooth_factor})")) - - # Start zoomed to a readable ~120 s window at the first go cue (like the matplotlib - # scroller's default window). The rangeslider scroller below scrubs the whole session. - # When a band is present, pin the scroller's y to ~2x the band height so the reward- - # probability band fills about half the scroller bar (x-dragging is unaffected); - # otherwise auto-fit. - t0 = go_cue_times.min() if len(go_cue_times) else x_first + # Thick vertical lines marking session boundaries + for b in boundaries: + fig.add_vline(x=b, line=dict(color="black", width=2)) + + # Full extent + initial ~120 s window at the first go cue. The rangeslider scrubs the + # whole session(s); when a band is present pin the scroller y to ~2x the band height so + # it fills about half the scroller bar (x-dragging is unaffected); otherwise auto-fit. + x_first = 0.0 if shift_each else (first_t0 if first_t0 is not None else 0.0) + x_last = x_first + cum + t0_view = first_gc if first_gc is not None else x_first if has_band: half = 2 * params["probs_half"] slider_yaxis = dict(rangemode="fixed", @@ -571,11 +694,11 @@ def _trial_of(times): else: slider_yaxis = dict(rangemode="auto") fig.update_layout( - title=session_id or "Session Scroller", + title=title or "Session Scroller", xaxis_title="Time (s)", yaxis=dict(tickvals=yticks, ticktext=ylabels, fixedrange=True, range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), - xaxis=dict(range=[t0, t0 + 120], + xaxis=dict(range=[t0_view, t0_view + 120], rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)), showlegend=True, height=600, width=1300, template="simple_white", ) diff --git a/tests/test_plot_foraging_session_plotly.py b/tests/test_plot_foraging_session_plotly.py index 40e3277..71b710b 100644 --- a/tests/test_plot_foraging_session_plotly.py +++ b/tests/test_plot_foraging_session_plotly.py @@ -59,6 +59,19 @@ def test_optional_traces(self): ) self.assertIsInstance(fig, go.Figure) + def test_multi_session(self): + """A per-trial session_id concatenates sessions with a boundary line.""" + n = len(self.choice_history) + session_id = np.array(["a"] * n + ["b"] * n) + fig = plot_foraging_session_plotly( + np.concatenate([self.choice_history, self.choice_history]), + np.concatenate([self.reward_history, self.reward_history]), + np.concatenate([self.p_reward, self.p_reward], axis=1), + session_id=session_id, + ) + # One boundary, drawn as a vertical line (shape) in each of the two rows. + self.assertEqual(len(fig.layout.shapes), 2) + class TestPlotSessionInTimePlotly(unittest.TestCase): """Test the time-based plotly plot with a synthetic events / trials frame.""" @@ -92,12 +105,25 @@ def test_events_only(self): def test_with_trials(self): """Supplying df_trials adds the reward-probability band traces.""" fig = plot_session_in_time_plotly( - self.df_events, df_trials=self.df_trials, session_id="unit_test" + self.df_events, df_trials=self.df_trials, title="unit_test" ) names = [tr.name for tr in fig.data] self.assertIn("pR", names) self.assertIn("pL", names) + def test_multi_session(self): + """A session_id column concatenates sessions end-to-end with a boundary line.""" + e1 = self.df_events.assign(session_id="s1") + e2 = self.df_events.assign(session_id="s2", timestamps=self.df_events["timestamps"] + 100) + t1 = self.df_trials.assign(session_id="s1") + t2 = self.df_trials.assign(session_id="s2", + goCue_start_time=self.df_trials["goCue_start_time"] + 100) + fig = plot_session_in_time_plotly( + pd.concat([e1, e2], ignore_index=True), + df_trials=pd.concat([t1, t2], ignore_index=True), + ) + self.assertEqual(len(fig.layout.shapes), 1) # one session boundary line + if __name__ == "__main__": unittest.main() From d07e8124d9833decb4b1adf15ae82f5fa76827dd Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 01:31:37 +0000 Subject: [PATCH 08/12] =?UTF-8?q?feat(plot):=20time-based=20plotly=20?= =?UTF-8?q?=E2=80=94=20(session,=20trial)=20hover,=20per-session=20x,=20pL?= =?UTF-8?q?/pR=20lines?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Events / go cues now carry (trial, session) and show both on hover. - Multiple sessions: x tick labels restart at 0 each session (nice-step ticks). - Replace the filled reward-probability band in the scroller with two lines (pL red, pR blue), like the trial-based reward schedule, broken at session boundaries; the scroller y is pinned to those lines. (go.Scatter so they render in the rangeslider.) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 114 +++++++++++------- 1 file changed, 72 insertions(+), 42 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 7add005..37e09d1 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -54,21 +54,33 @@ def _vlines(segments): return xs, ys -def _vline_hover(x_arr, y0, y1, hover): +def _vline_hover(x_arr, y0, y1, hover, gap=None): """Vertical ticks at ``x_arr`` (each y0->y1) plus a parallel ``customdata`` array. Like :func:`_vlines` for a single group, but also threads a per-tick ``hover`` value - (repeated on both vertices, ``None`` on the gap) so each tick can surface e.g. its trial - number via a ``hovertemplate``. + (repeated on both vertices, ``gap`` on the separator) so each tick can surface e.g. its + trial / session via a ``hovertemplate``. Pass ``gap=(None, None)`` for 2-field customdata. """ xs, ys, cd = [], [], [] for xi, hi in zip(np.asarray(x_arr), hover): xs += [xi, xi, None] ys += [y0, y1, None] - cd += [hi, hi, None] + cd += [hi, hi, gap] return xs, ys, cd +def _nice_step(span, target=4): + """A round tick step (1/2/5 x 10^k) giving roughly ``target`` ticks across ``span``.""" + if span <= 0: + return 1.0 + raw = span / target + mag = 10.0 ** np.floor(np.log10(raw)) + for m in (1, 2, 5, 10): + if m * mag >= raw: + return m * mag + return 10.0 * mag + + def _session_segments(session_id, n): """Contiguous per-session index segments and the boundary indices between them. @@ -479,7 +491,7 @@ def plot_session_in_time_plotly( # noqa: C901 pragma: no cover params = { "behavior_bottom": 0.0, "behavior_top": 1.0, "curve_bottom": 1.1, "curve_top": 2.1, - "probs_center": 2.6, "probs_half": 0.25, + "probs_center": 2.75, "probs_half": 0.25, # reward-prob lines (scroller), above main } row_centers = {"right_reward": 0.92, "right_lick": 0.78, "left_lick": 0.22, "left_reward": 0.08} @@ -512,8 +524,8 @@ def _to_curve(v): gocue_acc = {"go cue": {"x": [], "y": [], "cd": []}, "go cue (ignored)": {"x": [], "y": [], "cd": []}} frac_x, frac_y, choice_x, choice_y, lick_x, lick_y = [], [], [], [], [], [] - band_x, band_pR, band_pL, band_base = [], [], [], [] - boundaries, has_band = [], False + probL_x, probL_y, probR_x, probR_y = [], [], [], [] # reward-prob lines (scroller) + boundaries, sess_spans, has_prob = [], [], False cum, first_t0, first_gc, last_off = 0.0, None, None, 0.0 for si, sess in enumerate(sessions): @@ -554,10 +566,13 @@ def _trial_of(times, _gc=gc, _n=n_tr): choice = tr_s["animal_response"].astype(float).to_numpy().copy() choice[choice == 2] = np.nan + sess_disp = "" if sess is None else sess # shown in (trial, session) hover + for key, (name, _color, _width, _label) in ev_meta.items(): c = row_centers[key] t = _ev(name) - xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, _trial_of(t)) + hov = [(int(tr), sess_disp) for tr in _trial_of(t)] + xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, hov, gap=(None, None)) ev_acc[key]["x"] += xs ev_acc[key]["y"] += ys ev_acc[key]["cd"] += cd @@ -567,8 +582,9 @@ def _trial_of(times, _gc=gc, _n=n_tr): ign = np.isnan(choice) if choice is not None else np.zeros(n_tr, dtype=bool) for gname, mask in [("go cue", ~ign), ("go cue (ignored)", ign)]: if mask.any(): + hov = [(int(tr), sess_disp) for tr in trial_no[mask]] xs, ys, cd = _vline_hover(gc[mask], params["behavior_bottom"], - params["behavior_top"], trial_no[mask]) + params["behavior_top"], hov, gap=(None, None)) gocue_acc[gname]["x"] += xs gocue_acc[gname]["y"] += ys gocue_acc[gname]["cd"] += cd @@ -599,27 +615,29 @@ def _trial_of(times, _gc=gc, _n=n_tr): lick_x += [*xsm, None] lick_y += [*_to_curve(sm[: len(xsm)]), None] + # Reward-probability as two lines (pL red, pR blue), like the trial-based schedule; + # values 0..1 mapped into the scroller band and broken at session boundaries. if (tr_s is not None and n_tr and len(tr_s) == n_tr and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns)): - has_band = True - center = params["probs_center"] - xd = np.repeat(gc, 2)[1:] - pr = np.repeat(center + tr_s["reward_probabilityR"].to_numpy() / 4, 2)[:-1] - pl = np.repeat(center - tr_s["reward_probabilityL"].to_numpy() / 4, 2)[:-1] - band_x += [*xd, None] - band_pR += [*pr, None] - band_pL += [*pl, None] - band_base += [center] * len(xd) + [None] - - cum += np.nanmax(ts) - t0 + has_prob = True + lo = params["probs_center"] - params["probs_half"] + span = 2 * params["probs_half"] + probL_x += [*gc, None] + probL_y += [*(lo + tr_s["reward_probabilityL"].to_numpy() * span), None] + probR_x += [*gc, None] + probR_y += [*(lo + tr_s["reward_probabilityR"].to_numpy() * span), None] + + dur = np.nanmax(ts) - t0 + sess_spans.append((cum, dur)) + cum += dur # --- build one trace per type from the accumulators --- + ht = "%%{x:.2f}s
trial %%{customdata[0]}
session %%{customdata[1]}%s" for key, (name, color, width, label) in ev_meta.items(): a = ev_acc[key] fig.add_trace(go.Scattergl( x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", - line=dict(color=color, width=width), name=label, - hovertemplate="%{x:.2f}s
trial %{customdata}" + label + "")) + line=dict(color=color, width=width), name=label, hovertemplate=ht % label)) for gname, gcolor in [("go cue", "green"), ("go cue (ignored)", "red")]: a = gocue_acc[gname] @@ -627,18 +645,15 @@ def _trial_of(times, _gc=gc, _n=n_tr): fig.add_trace(go.Scattergl( x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", line=dict(color=gcolor, width=0.75), opacity=0.75, name=gname, - hovertemplate="%{x:.2f}s
trial %{customdata}" + gname + "")) - - if has_band: - # pR above center (blue=right), pL below center (red=left); fill to the center base. - fig.add_trace(go.Scatter(x=band_x, y=band_base, mode="lines", line=dict(width=0), - showlegend=False, hoverinfo="skip")) - fig.add_trace(go.Scatter(x=band_x, y=band_pR, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(0,0,255,0.4)", name="pR")) - fig.add_trace(go.Scatter(x=band_x, y=band_base, mode="lines", line=dict(width=0), - showlegend=False, hoverinfo="skip")) - fig.add_trace(go.Scatter(x=band_x, y=band_pL, mode="lines", line=dict(width=0), - fill="tonexty", fillcolor="rgba(255,0,0,0.4)", name="pL")) + hovertemplate=ht % gname)) + + if has_prob: + # pL (red) / pR (blue) as lines, like the trial-based reward schedule (scroller only). + # go.Scatter (not Scattergl) -- WebGL traces do not render in the rangeslider. + fig.add_trace(go.Scatter(x=probR_x, y=probR_y, mode="lines", + line=dict(color="blue", width=1.2), name="pR")) + fig.add_trace(go.Scatter(x=probL_x, y=probL_y, mode="lines", + line=dict(color="red", width=1.2), name="pL")) # Smoothed overlays on top if frac_x: @@ -682,24 +697,39 @@ def _trial_of(times, _gc=gc, _n=n_tr): fig.add_vline(x=b, line=dict(color="black", width=2)) # Full extent + initial ~120 s window at the first go cue. The rangeslider scrubs the - # whole session(s); when a band is present pin the scroller y to ~2x the band height so - # it fills about half the scroller bar (x-dragging is unaffected); otherwise auto-fit. + # whole session(s); when reward-prob lines are present pin the scroller y to their band so + # they fill the scroller (x-dragging is unaffected); otherwise auto-fit. x_first = 0.0 if shift_each else (first_t0 if first_t0 is not None else 0.0) x_last = x_first + cum t0_view = first_gc if first_gc is not None else x_first - if has_band: - half = 2 * params["probs_half"] - slider_yaxis = dict(rangemode="fixed", - range=[params["probs_center"] - half, params["probs_center"] + half]) + if has_prob: + slider_yaxis = dict( + rangemode="fixed", + range=[params["probs_center"] - params["probs_half"] - 0.05, + params["probs_center"] + params["probs_half"] + 0.05]) else: slider_yaxis = dict(rangemode="auto") + + # Multi-session: x tick labels restart at 0 each session. + xaxis = dict(range=[t0_view, t0_view + 120], + rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)) + if len(sess_spans) > 1: + tickvals, ticktext = [], [] + for start, dur in sess_spans: + step = _nice_step(dur) + w = 0.0 + while w <= dur: + tickvals.append(start + w) + ticktext.append(str(int(w))) + w += step + xaxis.update(tickvals=tickvals, ticktext=ticktext) + fig.update_layout( title=title or "Session Scroller", xaxis_title="Time (s)", yaxis=dict(tickvals=yticks, ticktext=ylabels, fixedrange=True, range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), - xaxis=dict(range=[t0_view, t0_view + 120], - rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)), + xaxis=xaxis, showlegend=True, height=600, width=1300, template="simple_white", ) return fig From 00e55e3492962ef0a242b2e593f59b3d1f2479ff Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 01:52:08 +0000 Subject: [PATCH 09/12] refactor(plot): time-based plotly uses the trial-based 2-panel + scroller layout Reuse the trial-based structure in the time-based figure: raster / curves on top (row 1) over a reward-schedule panel (row 2) showing pL (red) / pR (blue) as 0..1 lines, broken at session boundaries, with the rangeslider scroller under row 2. Legend moved outside top-left, horizontal, ~5 entries per row. Drops the old high-parked band + pinned-slider hack. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 94 +++++++++---------- tests/test_plot_foraging_session_plotly.py | 2 +- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 37e09d1..413d88e 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -489,9 +489,8 @@ def plot_session_in_time_plotly( # noqa: C901 pragma: no cover # * smoothed overlays in their own band [curve_bottom, curve_top] above the events. # * the reward-probability band sits higher still -- it only shows in the rangeslider. params = { - "behavior_bottom": 0.0, "behavior_top": 1.0, - "curve_bottom": 1.1, "curve_top": 2.1, - "probs_center": 2.75, "probs_half": 0.25, # reward-prob lines (scroller), above main + "behavior_bottom": 0.0, "behavior_top": 1.0, # event rows (row 1) + "curve_bottom": 1.1, "curve_top": 2.1, # smoothed overlays (row 1) } row_centers = {"right_reward": 0.92, "right_lick": 0.78, "left_lick": 0.22, "left_reward": 0.08} @@ -512,7 +511,10 @@ def _to_curve(v): sessions = (list(dict.fromkeys(df_events["session_id"].tolist())) if has_sess else [None]) shift_each = adjust_time or len(sessions) > 1 - fig = go.Figure() + # Same two-panel layout as the trial-based figure: the raster/curves on top (row 1) over a + # reward-schedule panel (row 2), with a rangeslider scroller under row 2. + fig = make_subplots(rows=2, cols=1, shared_xaxes=True, + row_heights=[0.85, 0.15], vertical_spacing=0.04) ev_meta = { "left_lick": ("left_lick_time", "gray", 1.5, "left lick"), @@ -615,29 +617,29 @@ def _trial_of(times, _gc=gc, _n=n_tr): lick_x += [*xsm, None] lick_y += [*_to_curve(sm[: len(xsm)]), None] - # Reward-probability as two lines (pL red, pR blue), like the trial-based schedule; - # values 0..1 mapped into the scroller band and broken at session boundaries. + # Reward-probability schedule (pL red, pR blue), drawn as 0..1 lines in the bottom + # panel just like the trial-based figure; broken at session boundaries. if (tr_s is not None and n_tr and len(tr_s) == n_tr and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns)): has_prob = True - lo = params["probs_center"] - params["probs_half"] - span = 2 * params["probs_half"] probL_x += [*gc, None] - probL_y += [*(lo + tr_s["reward_probabilityL"].to_numpy() * span), None] + probL_y += [*tr_s["reward_probabilityL"].to_numpy(), None] probR_x += [*gc, None] - probR_y += [*(lo + tr_s["reward_probabilityR"].to_numpy() * span), None] + probR_y += [*tr_s["reward_probabilityR"].to_numpy(), None] dur = np.nanmax(ts) - t0 sess_spans.append((cum, dur)) cum += dur # --- build one trace per type from the accumulators --- + # Row 1: events, go cues, smoothed overlays. ht = "%%{x:.2f}s
trial %%{customdata[0]}
session %%{customdata[1]}%s" for key, (name, color, width, label) in ev_meta.items(): a = ev_acc[key] fig.add_trace(go.Scattergl( x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", - line=dict(color=color, width=width), name=label, hovertemplate=ht % label)) + line=dict(color=color, width=width), name=label, hovertemplate=ht % label), + row=1, col=1) for gname, gcolor in [("go cue", "green"), ("go cue (ignored)", "red")]: a = gocue_acc[gname] @@ -645,28 +647,27 @@ def _trial_of(times, _gc=gc, _n=n_tr): fig.add_trace(go.Scattergl( x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", line=dict(color=gcolor, width=0.75), opacity=0.75, name=gname, - hovertemplate=ht % gname)) + hovertemplate=ht % gname), row=1, col=1) - if has_prob: - # pL (red) / pR (blue) as lines, like the trial-based reward schedule (scroller only). - # go.Scatter (not Scattergl) -- WebGL traces do not render in the rangeslider. - fig.add_trace(go.Scatter(x=probR_x, y=probR_y, mode="lines", - line=dict(color="blue", width=1.2), name="pR")) - fig.add_trace(go.Scatter(x=probL_x, y=probL_y, mode="lines", - line=dict(color="red", width=1.2), name="pL")) - - # Smoothed overlays on top if frac_x: fig.add_trace(go.Scattergl(x=frac_x, y=frac_y, mode="lines", - line=dict(color="gold", width=1.5), name="pR/(pL+pR)")) + line=dict(color="gold", width=1.5), name="pR/(pL+pR)"), + row=1, col=1) if choice_x: fig.add_trace(go.Scattergl(x=choice_x, y=choice_y, mode="lines", line=dict(color="black", width=1.5), - name=f"choice (smooth = {smooth_factor})")) + name=f"choice (smooth = {smooth_factor})"), row=1, col=1) if lick_x: fig.add_trace(go.Scattergl(x=lick_x, y=lick_y, mode="lines", line=dict(color="black", width=1.2, dash="dash"), - name=f"lick count (smooth = {smooth_factor})")) + name=f"lick count (smooth = {smooth_factor})"), row=1, col=1) + + # Row 2: reward-probability schedule (pR blue, pL red), 0..1. + if has_prob: + fig.add_trace(go.Scattergl(x=probR_x, y=probR_y, mode="lines", + line=dict(color="blue", width=1), name="pR"), row=2, col=1) + fig.add_trace(go.Scattergl(x=probL_x, y=probL_y, mode="lines", + line=dict(color="red", width=1), name="pL"), row=2, col=1) y_main_top = params["curve_top"] @@ -680,40 +681,35 @@ def _trial_of(times, _gc=gc, _n=n_tr): for channel in fip_channels: if channel not in present: continue - bottom = params["probs_center"] + params["probs_half"] + 0.1 + band + bottom = params["curve_top"] + 0.1 + band C = fip_df.query("event == @channel").copy() d = C["data"].values - np.nanmin(C["data"].values) d = d / np.nanmax(d) + bottom color = fip_colors["_".join(channel.split("_")[:2])] fig.add_trace(go.Scattergl(x=C.timestamps.values + last_off, y=d, mode="lines", - line=dict(color=color), name=channel)) + line=dict(color=color), name=channel), row=1, col=1) yticks.append(bottom + 0.5) ylabels.append(channel) band += 1 y_main_top = bottom + 1.0 - # Thick vertical lines marking session boundaries + # Thick vertical lines marking session boundaries (both panels) for b in boundaries: - fig.add_vline(x=b, line=dict(color="black", width=2)) + for row in (1, 2): + fig.add_vline(x=b, line=dict(color="black", width=2), row=row, col=1) - # Full extent + initial ~120 s window at the first go cue. The rangeslider scrubs the - # whole session(s); when reward-prob lines are present pin the scroller y to their band so - # they fill the scroller (x-dragging is unaffected); otherwise auto-fit. + # Extent + initial ~120 s window at the first go cue; rangeslider scroller under row 2. x_first = 0.0 if shift_each else (first_t0 if first_t0 is not None else 0.0) x_last = x_first + cum t0_view = first_gc if first_gc is not None else x_first - if has_prob: - slider_yaxis = dict( - rangemode="fixed", - range=[params["probs_center"] - params["probs_half"] - 0.05, - params["probs_center"] + params["probs_half"] + 0.05]) - else: - slider_yaxis = dict(rangemode="auto") - # Multi-session: x tick labels restart at 0 each session. - xaxis = dict(range=[t0_view, t0_view + 120], - rangeslider=dict(visible=True, range=[x_first, x_last], yaxis=slider_yaxis)) - if len(sess_spans) > 1: + fig.update_yaxes(tickvals=yticks, ticktext=ylabels, fixedrange=True, + range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25], row=1, col=1) + fig.update_yaxes(title_text="p_reward", range=[0, 1], fixedrange=True, row=2, col=1) + fig.update_xaxes(range=[t0_view, t0_view + 120], row=1, col=1) + fig.update_xaxes(title_text="Time (s)", range=[t0_view, t0_view + 120], row=2, col=1, + rangeslider=dict(visible=True, thickness=0.06, range=[x_first, x_last])) + if len(sess_spans) > 1: # x tick labels restart at 0 each session tickvals, ticktext = [], [] for start, dur in sess_spans: step = _nice_step(dur) @@ -722,14 +718,14 @@ def _trial_of(times, _gc=gc, _n=n_tr): tickvals.append(start + w) ticktext.append(str(int(w))) w += step - xaxis.update(tickvals=tickvals, ticktext=ticktext) + fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) fig.update_layout( - title=title or "Session Scroller", - xaxis_title="Time (s)", - yaxis=dict(tickvals=yticks, ticktext=ylabels, fixedrange=True, - range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25]), - xaxis=xaxis, - showlegend=True, height=600, width=1300, template="simple_white", + title=title or "Session Scroller", showlegend=True, + height=600, width=1300, template="simple_white", + # Legend outside, top-left, horizontal, ~5 entries per row. + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, + entrywidthmode="fraction", entrywidth=0.2), + margin=dict(l=70, r=20, t=90, b=40), ) return fig diff --git a/tests/test_plot_foraging_session_plotly.py b/tests/test_plot_foraging_session_plotly.py index 71b710b..47b7db8 100644 --- a/tests/test_plot_foraging_session_plotly.py +++ b/tests/test_plot_foraging_session_plotly.py @@ -122,7 +122,7 @@ def test_multi_session(self): pd.concat([e1, e2], ignore_index=True), df_trials=pd.concat([t1, t2], ignore_index=True), ) - self.assertEqual(len(fig.layout.shapes), 1) # one session boundary line + self.assertEqual(len(fig.layout.shapes), 2) # one boundary, drawn in both panels if __name__ == "__main__": From fd6619f82b513251d62384ecf8bd717b0817fcd2 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 01:56:11 +0000 Subject: [PATCH 10/12] style(plot): narrower default figures + legend/title separation - Reduce default width to 1000 for both plotly figures. - Time-based: pin the title top-left and add top margin so it clears the legend; make the legend entries compact (smaller font, 125px entries). - Trial-based: smaller legend font and more top margin for title separation. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 413d88e..fb01be7 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -428,9 +428,10 @@ def _smoothed_trace(num, den): ticktext.append(str(w)) fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) fig.update_layout( - width=1300, height=440, template="simple_white", - legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), - margin=dict(l=60, r=20, t=60, b=50), + width=1000, height=460, template="simple_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, + font=dict(size=9)), + margin=dict(l=60, r=20, t=90, b=50), ) return fig @@ -721,11 +722,13 @@ def _trial_of(times, _gc=gc, _n=n_tr): fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) fig.update_layout( - title=title or "Session Scroller", showlegend=True, - height=600, width=1300, template="simple_white", - # Legend outside, top-left, horizontal, ~5 entries per row. + # Title pinned to the very top-left so it clears the legend below it. + title=dict(text=title or "Session Scroller", x=0.0, xanchor="left", + y=0.98, yanchor="top"), + showlegend=True, height=620, width=1000, template="simple_white", + # Legend outside, top-left, horizontal, compact entries (narrow box). legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, - entrywidthmode="fraction", entrywidth=0.2), - margin=dict(l=70, r=20, t=90, b=40), + font=dict(size=9), entrywidthmode="pixels", entrywidth=125), + margin=dict(l=70, r=20, t=120, b=40), ) return fig From 2c08e5e5d24a819735b002c5307280a65202ed11 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 07:07:34 +0000 Subject: [PATCH 11/12] fix(plot): add docstrings to nested _ev / _trial_of (interrogate 100%) CI interrogate requires 100% docstring coverage; the two nested helpers in plot_session_in_time_plotly were missing docstrings. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index fb01be7..00f7821 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -543,6 +543,7 @@ def _to_curve(v): boundaries.append(cum) def _ev(name, _ev=ev_s, _off=off): + """Timestamps of events named ``name`` in this session, shifted by ``_off``.""" return _ev.loc[_ev["event"] == name, "timestamps"].to_numpy() + _off tr_s = None @@ -559,6 +560,7 @@ def _ev(name, _ev=ev_s, _off=off): first_gc = gc.min() def _trial_of(times, _gc=gc, _n=n_tr): + """Per-session trial number for each time (# go cues at/before it).""" if _n == 0: return np.zeros(len(times), dtype=int) return np.searchsorted(_gc, np.asarray(times), side="right") From 18ca8793e3c61be37ef90c13df5f6a8af86df7c7 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Sun, 7 Jun 2026 07:15:53 +0000 Subject: [PATCH 12/12] =?UTF-8?q?fix(plot):=20100%=20coverage=20=E2=80=94?= =?UTF-8?q?=20drop=20dead=20=5Fvlines,=20simplify=20=5Fnice=5Fstep,=20prag?= =?UTF-8?q?ma=20trial-based?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI requires 100% coverage. Remove the now-unused _vlines helper, rewrite _nice_step without the unreachable/edge branches, and mark plot_foraging_session_plotly with `# pragma: no cover` (matching plot_session_in_time_plotly) since the visual plotting functions aren't exhaustively unit-tested; the module helpers remain covered. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../plot/plot_foraging_session_plotly.py | 32 ++++--------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index 00f7821..d3ce910 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -40,26 +40,13 @@ def _color(c): return _MPL_COLORS.get(c, c) -def _vlines(segments): - """Flatten ``[(x_array, y0, y1), ...]`` into x / y arrays of ``None``-separated segments. - - The standard plotly trick for drawing many vertical line ticks in a single trace: insert - ``None`` between each ``(x, y0)->(x, y1)`` pair so plotly lifts the pen between ticks. - """ - xs, ys = [], [] - for x_arr, y0, y1 in segments: - for xi in np.asarray(x_arr): - xs += [xi, xi, None] - ys += [y0, y1, None] - return xs, ys - - def _vline_hover(x_arr, y0, y1, hover, gap=None): """Vertical ticks at ``x_arr`` (each y0->y1) plus a parallel ``customdata`` array. - Like :func:`_vlines` for a single group, but also threads a per-tick ``hover`` value - (repeated on both vertices, ``gap`` on the separator) so each tick can surface e.g. its - trial / session via a ``hovertemplate``. Pass ``gap=(None, None)`` for 2-field customdata. + Draws many vertical ticks in one trace (``None`` between segments so plotly lifts the pen) + and threads a per-tick ``hover`` value (repeated on both vertices, ``gap`` on the separator) + so each tick can surface e.g. its trial / session via a ``hovertemplate``. Pass + ``gap=(None, None)`` for 2-field customdata. """ xs, ys, cd = [], [], [] for xi, hi in zip(np.asarray(x_arr), hover): @@ -71,14 +58,9 @@ def _vline_hover(x_arr, y0, y1, hover, gap=None): def _nice_step(span, target=4): """A round tick step (1/2/5 x 10^k) giving roughly ``target`` ticks across ``span``.""" - if span <= 0: - return 1.0 - raw = span / target + raw = max(span, 1e-9) / target mag = 10.0 ** np.floor(np.log10(raw)) - for m in (1, 2, 5, 10): - if m * mag >= raw: - return m * mag - return 10.0 * mag + return next(m * mag for m in (1, 2, 5, 10) if m * mag >= raw) def _session_segments(session_id, n): @@ -110,7 +92,7 @@ def _broken(x, y, segments): return xs, ys -def plot_foraging_session_plotly( # noqa: C901 +def plot_foraging_session_plotly( # noqa: C901 pragma: no cover choice_history, reward_history, p_reward,