diff --git a/src/aind_dynamic_foraging_basic_analysis/__init__.py b/src/aind_dynamic_foraging_basic_analysis/__init__.py index d420ce8..2842d4d 100644 --- a/src/aind_dynamic_foraging_basic_analysis/__init__.py +++ b/src/aind_dynamic_foraging_basic_analysis/__init__.py @@ -4,3 +4,7 @@ from .metrics.foraging_efficiency import compute_foraging_efficiency # noqa: F401 from .plot.plot_foraging_session import plot_foraging_session # noqa: F401 +from .plot.plot_foraging_session_plotly import ( # noqa: F401 + plot_foraging_session_plotly, + plot_session_in_time_plotly, +) diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py index e5efc99..d3ce910 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_foraging_session_plotly.py @@ -1,199 +1,718 @@ """Interactive plotly figures that can be used in the Streamlit app as well as -Jupyter Notebook +Jupyter Notebook. + +Two plotly counterparts of the matplotlib plotting functions are provided, each meant to +match its matplotlib sibling as closely as plotly allows: + +- :func:`plot_foraging_session_plotly` mirrors + :func:`plot_foraging_session.plot_foraging_session` -- a **trial-based** view (choice / + reward raster on top, reward-probability schedule below). +- :func:`plot_session_in_time_plotly` mirrors + :func:`plot_session_scroller.plot_session_scroller` -- a **time-based** view of the + session (licks / rewards / go cues / reward-probability band in real time). """ +import numpy as np import plotly.graph_objects as go +from plotly.subplots import make_subplots +from aind_dynamic_foraging_basic_analysis.data_model.foraging_session import ( + ForagingSessionData, + PhotostimData, +) +from aind_dynamic_foraging_basic_analysis.plot.plot_foraging_session import moving_average +from aind_dynamic_foraging_basic_analysis.plot.style import PHOTOSTIM_EPOCH_MAPPING -def plot_session_in_time_plotly( # noqa: C901 pragma: no cover - df_events, adjust_time=True, fip_df=None +# Map the matplotlib single-letter colors used by the matplotlib versions to plotly names, +# so the two renderings line up. Anything not listed is passed through unchanged. +_MPL_COLORS = { + "y": "gold", + "m": "magenta", + "g": "green", + "b": "blue", + "r": "red", + "k": "black", +} + + +def _color(c): + """Translate a matplotlib single-letter color to a plotly-friendly name.""" + return _MPL_COLORS.get(c, c) + + +def _vline_hover(x_arr, y0, y1, hover, gap=None): + """Vertical ticks at ``x_arr`` (each y0->y1) plus a parallel ``customdata`` array. + + Draws many vertical ticks in one trace (``None`` between segments so plotly lifts the pen) + and threads a per-tick ``hover`` value (repeated on both vertices, ``gap`` on the separator) + so each tick can surface e.g. its trial / session via a ``hovertemplate``. Pass + ``gap=(None, None)`` for 2-field customdata. + """ + xs, ys, cd = [], [], [] + for xi, hi in zip(np.asarray(x_arr), hover): + xs += [xi, xi, None] + ys += [y0, y1, None] + cd += [hi, hi, gap] + return xs, ys, cd + + +def _nice_step(span, target=4): + """A round tick step (1/2/5 x 10^k) giving roughly ``target`` ticks across ``span``.""" + raw = max(span, 1e-9) / target + mag = 10.0 ** np.floor(np.log10(raw)) + return next(m * mag for m in (1, 2, 5, 10) if m * mag >= raw) + + +def _session_segments(session_id, n): + """Contiguous per-session index segments and the boundary indices between them. + + Returns ``(segments, boundaries)`` where ``segments`` is a list of ``(start, end)`` + half-open index ranges (one per session, in order) and ``boundaries`` is the list of + indices at which a new session starts (used to draw the dividing lines). With + ``session_id=None`` the whole input is one segment. + """ + if session_id is None: + return [(0, n)], [] + sid = np.asarray(session_id) + change = list(np.nonzero(sid[1:] != sid[:-1])[0] + 1) + segments = list(zip([0, *change], [*change, n])) + return segments, change + + +def _broken(x, y, segments): + """Concatenate per-session slices of ``x``/``y`` with ``None`` gaps between sessions. + + Inserting a ``None`` at each boundary breaks the line so a continuous trace is not drawn + across the (meaningless) jump from one session's last trial to the next session's first. + """ + xs, ys = [], [] + for s, e in segments: + xs += [*np.asarray(x)[s:e], None] + ys += [*np.asarray(y)[s:e], None] + return xs, ys + + +def plot_foraging_session_plotly( # noqa: C901 pragma: no cover + choice_history, + reward_history, + p_reward, + autowater_offered=None, + fitted_data=None, + photostim=None, + valid_range=None, + smooth_factor=5, + base_color="y", + bias=None, + bias_lower=None, + bias_upper=None, + plot_list=["choice", "finished", "reward_prob"], + session_id=None, ): - """A plotly version of plot_foraging_session.plot_session_scroller + """Plotly version of :func:`plot_foraging_session.plot_foraging_session` (trial-based). - Creates a plot of the session in time (not in trial). - Plots left/right licks/rewards, and go cues as vertical lines from bottom to top. + Renders the same two stacked panels as the matplotlib version: - df_events: A tidy dataframe of session events. + - top: choice / reward raster (rewarded & unrewarded choices, ignored / autowater + trials, smoothed choice, finished ratio, base reward probability, optional bias). + - bottom: the per-trial left / right reward-probability schedule. - fip_df is a tidy dataframe of FIP measurements generated by - aind_dynamic_foraging_data_utils.nwb_utils.create_df_fip(tidy=True) + Parameters mirror :func:`plot_foraging_session.plot_foraging_session` (minus the + matplotlib-only ``ax`` / ``vertical``): - adjust_time (bool): If True, resets time=0 to the first event of the session. + Parameters + ---------- + choice_history : list or np.ndarray + Choice history (0 = left, 1 = right, np.nan = ignored). + reward_history : list or np.ndarray + Reward history (0 = unrewarded, 1 = rewarded). + p_reward : list or np.ndarray + Reward probability for both sides, shape (2, len(choice_history)). + autowater_offered : list or np.ndarray, optional + Boolean mask of trials where autowater was offered. + fitted_data : list or np.ndarray, optional + If not None, overlay fitted data (e.g. from an RL model). + photostim : dict, optional + Photostimulation trials, with keys "trial", "power" and optional "stim_epoch". + valid_range : list, optional + If not None, add two vertical lines marking the engaged range. + smooth_factor : int, optional + Smoothing window for the choice / finished traces, by default 5. + base_color : str, optional + Color for the base reward-probability line, by default "y" (gold). + bias : list or np.ndarray, optional + Side-bias trace; drawn (with the ``bias_lower`` / ``bias_upper`` band) when + "bias" is in ``plot_list``. + bias_lower, bias_upper : list or np.ndarray, optional + Lower / upper confidence bounds for ``bias``. + plot_list : list, optional + Which optional traces to draw, from {"choice", "finished", "reward_prob", "bias"}. + session_id : list or np.ndarray, optional + Per-trial session label (same length as ``choice_history``). When given, multiple + sessions are concatenated horizontally along the trial axis in order; smoothed traces + reset at each session and a thick vertical line marks every session boundary. - EXAMPLE: - df_events = nwb_utils.create_df_events(nwb_object) - fip_df = nwb_utils.create_df_fip(nwb_object, tidy=True) - plot_foraging_session_plotly.plot_session_events_plotly(df_events) + Returns + ------- + plotly.graph_objects.Figure + Figure with the choice/reward panel (row 1) over the reward-schedule panel (row 2). """ + # Formatting and sanity checks (reuse the shared validation, like the matplotlib version) + data = ForagingSessionData( + choice_history=choice_history, + reward_history=reward_history, + p_reward=p_reward, + autowater_offered=autowater_offered, + fitted_data=fitted_data, + photostim=PhotostimData(**photostim) if photostim is not None else None, + ) + choice_history = data.choice_history + reward_history = data.reward_history + p_reward = data.p_reward + autowater_offered = data.autowater_offered + fitted_data = data.fitted_data + photostim = data.photostim - if adjust_time: - start_time = df_events.iloc[0]["timestamps"] - df_events = df_events.copy() - df_events["timestamps"] = df_events["timestamps"] - start_time + n_trials = len(choice_history) + p_reward_fraction = p_reward[1, :] / (np.sum(p_reward, axis=0)) + ignored = np.isnan(choice_history) - if fip_df is not None: - fip_df = fip_df.copy() - fip_df["timestamps"] = fip_df["timestamps"] - start_time + if autowater_offered is None: + rewarded_excluding_autowater = reward_history + autowater_collected = np.full_like(choice_history, False, dtype=bool) + autowater_ignored = np.full_like(choice_history, False, dtype=bool) + unrewarded_trials = ~reward_history & ~ignored + else: + rewarded_excluding_autowater = reward_history & ~autowater_offered + autowater_collected = autowater_offered & ~ignored + autowater_ignored = autowater_offered & ignored + unrewarded_trials = ~reward_history & ~ignored & ~autowater_offered - xmin = df_events.iloc[0]["timestamps"] - xmax = df_events.iloc[-1]["timestamps"] + # Per-session segments (multi-session concatenation); single segment when session_id None. + segments, boundaries = _session_segments(session_id, n_trials) - params = { - "left_lick": 0.125, - "right_lick": 0.875, - "left_reward": 0.375, - "right_reward": 0.625, - "go_cue_bottom": 0, - "go_cue_top": 1, - "G_1_preprocessed_bottom": 1, - "G_1_preprocessed_top": 2, - "G_2_preprocessed_bottom": 2, - "G_2_preprocessed_top": 3, - "R_1_preprocessed_bottom": 3, - "R_1_preprocessed_top": 4, - "R_2_preprocessed_bottom": 4, - "R_2_preprocessed_top": 5, - } + # Per-trial within-session index (resets to 0 each session) and session label -- used for + # the (session, trial) hover and the per-session x tick labels. + within = np.zeros(n_trials, dtype=int) + sess_label = np.array([""] * n_trials, dtype=object) + sid_arr = None if session_id is None else np.asarray(session_id) + for s, e in segments: + within[s:e] = np.arange(e - s) + if sid_arr is not None: + sess_label[s:e] = sid_arr[s] - yticks = [ - params["left_lick"], - params["right_lick"], - params["left_reward"], - params["right_reward"], - ] - ylabels = ["left licks", "right licks", "left reward", "right reward"] - ycolors = ["k", "k", "r", "r"] + hovertemplate = "trial %%{customdata[0]}
session %%{customdata[1]}%s" - fig = go.Figure() + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + row_heights=[0.83, 0.17], + vertical_spacing=0.02, + ) - # Add FIP traces - if fip_df is not None: - fip_channels = [ - "G_2_preprocessed", - "G_1_preprocessed", - "R_2_preprocessed", - "R_1_preprocessed", - ] - present_channels = fip_df["event"].unique() - for index, channel in enumerate(fip_channels): - if channel in present_channels: - yticks.append( - (params[channel + "_top"] - params[channel + "_bottom"]) / 2 - + params[channel + "_bottom"] - ) - ylabels.append(channel) - if "G_1" in channel: - color = "green" - elif "G_2" in channel: - color = "darkgreen" - elif "R_1" in channel: - color = "red" - elif "R_2" in channel: - color = "darkred" - ycolors.append(color) - C = fip_df.query("event == @channel").copy() - C["data"] = C["data"] - C["data"].min() - C["data"] = C["data"].values / C["data"].max() - C["data"] += params[channel + "_bottom"] - - # Plot the data using go.Scattergl - fig.add_trace( - go.Scattergl( - x=C.timestamps.values, - y=C.data.values, - mode="lines", - line=dict(color=color), - name=channel, - ) - ) - - # Add a horizontal reference line (axhline equivalent) - fig.add_trace( - go.Scattergl( - x=[C.timestamps.min(), C.timestamps.max()], - y=[params[channel + "_bottom"], params[channel + "_bottom"]], - mode="lines", - line=dict(color="black", width=1, dash="solid"), - showlegend=False, - hoverinfo="skip", - ) - ) - - left_licks = df_events.query('event == "left_lick_time"') - left_times = left_licks.timestamps.values + def _raster(mask, up, down): + """Vertical choice ticks for ``mask``, each carrying (within-session trial, session). + + Right choices (>0.5) use the ``up`` (y0, y1); left choices use ``down``. Returns + ``x, y, customdata`` lists (``None`` gaps between ticks) for one Scattergl trace. + """ + idx = np.nonzero(mask)[0] + side = choice_history[idx] + xs, ys, cd = [], [], [] + for i, sd in zip(idx, side): + y0, y1 = up if sd > 0.5 else down + xs += [i + 1, i + 1, None] + ys += [y0, y1, None] + cd += [(int(within[i]), sess_label[i])] * 2 + [(None, None)] + return xs, ys, cd + + def _markers(mask): + """(x, customdata) for marker traces (ignored / autowater-ignored), with hover.""" + idx = np.nonzero(mask)[0] + return idx + 1, [(int(within[i]), sess_label[i]) for i in idx] + + # == Choice trace == + # Rewarded (real foraging, autowater excluded): tall black ticks just outside [0, 1] + xs, ys, cd = _raster(rewarded_excluding_autowater, (1.05, 1.15), (-0.15, -0.05)) fig.add_trace( - go.Scattergl( - x=left_times, - y=[params["left_lick"]] * len(left_times), - mode="markers", - marker=dict(symbol="line-ns", line_color="black", size=10, line_width=2), - name="Left Lick", - ) + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", line=dict(color="black", width=1), + name="Rewarded choices", hovertemplate=hovertemplate % "Rewarded choices"), + row=1, col=1, ) - right_licks = df_events.query('event == "right_lick_time"') - right_times = right_licks.timestamps.values + # Unrewarded (real foraging): short gray ticks + xs, ys, cd = _raster(unrewarded_trials, (1.05, 1.10), (-0.10, -0.05)) fig.add_trace( - go.Scattergl( - x=right_times, - y=[params["right_lick"]] * len(right_times), - mode="markers", - marker=dict(symbol="line-ns", line_color="black", size=10, line_width=2), - name="Right Lick", - ) + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", line=dict(color="gray", width=1), + name="Unrewarded choices", hovertemplate=hovertemplate % "Unrewarded choices"), + row=1, col=1, ) - left_reward_deliverys = df_events.query('event == "left_reward_delivery_time"') - left_times = left_reward_deliverys.timestamps.values + # Ignored trials: red x at the top + xx, cd = _markers(ignored & ~autowater_ignored) fig.add_trace( - go.Scattergl( - x=left_times, - y=[params["left_reward"]] * len(left_times), - mode="markers", - marker=dict(symbol="line-ns", size=10, line_color="red", line_width=3), - name="Left Reward", - ) + go.Scattergl(x=xx, y=[1.2] * len(xx), customdata=cd, mode="markers", + marker=dict(symbol="x", color="red", size=4), name="Ignored", + hovertemplate=hovertemplate % "Ignored"), + row=1, col=1, ) - right_reward_deliverys = df_events.query('event == "right_reward_delivery_time"') - right_times = right_reward_deliverys.timestamps.values - fig.add_trace( - go.Scattergl( - x=right_times, - y=[params["right_reward"]] * len(right_times), - mode="markers", - marker=dict(symbol="line-ns", size=10, line_color="red", line_width=3), - name="Right Reward", + # Autowater collected / ignored + if autowater_offered is not None: + xs, ys, cd = _raster(autowater_collected, (1.05, 1.15), (-0.15, -0.05)) + fig.add_trace( + go.Scattergl(x=xs, y=ys, customdata=cd, mode="lines", + line=dict(color="royalblue", width=1), name="Autowater collected", + hovertemplate=hovertemplate % "Autowater collected"), + row=1, col=1, + ) + xx, cd = _markers(autowater_ignored) + fig.add_trace( + go.Scattergl(x=xx, y=[1.2] * len(xx), customdata=cd, mode="markers", + marker=dict(symbol="x", color="royalblue", size=4), + name="Autowater ignored", + hovertemplate=hovertemplate % "Autowater ignored"), + row=1, col=1, + ) + + # Base reward probability (broken at session boundaries) + if "reward_prob" in plot_list: + xs, ys = _broken(np.arange(n_trials) + 1, p_reward_fraction, segments) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", + line=dict(color=_color(base_color), width=1.5), name="Base rew. prob."), + row=1, col=1, + ) + + def _smoothed_trace(num, den): + """Per-session smoothed series (resets at each session); x in 1-based trial coords.""" + xs, ys = [], [] + for s, e in segments: + y = moving_average(num[s:e], smooth_factor) + if den is not None: + y = y / (moving_average(den[s:e], smooth_factor) + 1e-6) + x = s + np.arange(len(y)) + int(smooth_factor / 2) + 1 + xs += [*x, None] + ys += [*y, None] + return xs, np.where(np.array(ys, dtype=float) > 100, np.nan, ys) + + # Smoothed choice history + if "choice" in plot_list: + xs, ys = _smoothed_trace(choice_history, ~np.isnan(choice_history)) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="black", width=1.5), + name=f"Choice (smooth = {smooth_factor})"), + row=1, col=1, + ) + + # Finished ratio (only meaningful if there are ignored trials) + if "finished" in plot_list and np.sum(np.isnan(choice_history)): + xs, ys = _smoothed_trace(~np.isnan(choice_history), None) + fig.add_trace( + go.Scattergl(x=xs, y=ys, mode="lines", line=dict(color="magenta", width=0.8), + name=f"Finished (smooth = {smooth_factor})"), + row=1, col=1, + ) + + # Bias trace + confidence band (broken at session boundaries) + if ("bias" in plot_list) and (bias is not None): + xx = np.arange(n_trials) + 1 + bias = (np.array(bias) + 1) / 2 + bias_lower = np.clip((np.array(bias_lower) + 1) / 2, 0, None) + bias_upper = np.clip((np.array(bias_upper) + 1) / 2, None, 1) + xb_up, y_up = _broken(xx, bias_upper, segments) + xb_lo, y_lo = _broken(xx, bias_lower, segments) + xb, y_bias = _broken(xx, bias, segments) + # go.Scatter (not Scattergl) for the filled band -- Scattergl ignores fill. + fig.add_trace( + go.Scatter(x=xb_up, y=y_up, mode="lines", line=dict(width=0), + showlegend=False, hoverinfo="skip"), + row=1, col=1, + ) + fig.add_trace( + go.Scatter(x=xb_lo, y=y_lo, mode="lines", line=dict(width=0), + fill="tonexty", fillcolor="rgba(0,128,0,0.25)", + showlegend=False, hoverinfo="skip"), + row=1, col=1, + ) + fig.add_trace( + go.Scattergl(x=xb, y=y_bias, mode="lines", line=dict(color="green", width=1.5), + name="bias"), + row=1, col=1, ) - ) - go_cues = df_events.query('event == "goCue_start_time"') - go_cue_times = go_cues.timestamps.values - for n, time in enumerate(go_cue_times): + # Valid (engaged) range + if valid_range is not None: + for vr in valid_range: + fig.add_vline(x=vr, line=dict(color="magenta", dash="dash", width=1), row=1, col=1) + + # Fitted model overlay + if fitted_data is not None: + fig.add_trace( + go.Scattergl(x=np.arange(n_trials), y=fitted_data, mode="lines", + line=dict(width=1.5), name="model"), + row=1, col=1, + ) + + # Photostim markers + if photostim is not None: + trial = np.asarray(photostim.trial) + power = np.asarray(photostim.power, dtype=float) + if photostim.stim_epoch is not None: + colors = [PHOTOSTIM_EPOCH_MAPPING[t] for t in photostim.stim_epoch] + else: + colors = "darkcyan" fig.add_trace( - go.Scattergl( - x=[time, time], - y=[params["go_cue_bottom"], params["go_cue_top"]], - mode="lines", - line=dict(color="blue", width=0.3), - legendgroup="Go Cue group", - showlegend=(n == 0), - name="Go Cue", - hovertemplate=f"Go Cue, Trial {n+1}", - ) + go.Scattergl(x=trial, y=np.ones_like(trial, dtype=float) + 0.4, mode="markers", + marker=dict(symbol="triangle-down", size=power * 2, + color="rgba(0,0,0,0)", + line=dict(color=colors, width=0.5)), + name="photostim"), + row=1, col=1, ) + # == Reward schedule (bottom panel; broken at session boundaries) == + xx = np.arange(n_trials) + 1 + xr, y_pr = _broken(xx, p_reward[1, :], segments) + fig.add_trace( + go.Scattergl(x=xr, y=y_pr, mode="lines", line=dict(color="blue", width=1), + name="p_right"), + row=2, col=1, + ) + xl, y_pl = _broken(xx, p_reward[0, :], segments) + fig.add_trace( + go.Scattergl(x=xl, y=y_pl, mode="lines", line=dict(color="red", width=1), + name="p_left"), + row=2, col=1, + ) + + # Thick vertical lines marking session boundaries (between trials b and b+1) + for b in boundaries: + for row in (1, 2): + fig.add_vline(x=b + 0.5, line=dict(color="black", width=2), row=row, col=1) + + # Axes styling to match the matplotlib version + fig.update_yaxes( + tickvals=[0, 1, 1.2], ticktext=["Left", "Right", "Ignored"], + range=[-0.15, 1.25], fixedrange=True, row=1, col=1, + ) + fig.update_yaxes(title_text="p_reward", range=[0, 1], fixedrange=True, row=2, col=1) + # Bottom x-axis: a rangeslider scroller (drag to pan/zoom), and -- for multiple sessions -- + # tick labels that restart at 0 each session. + fig.update_xaxes(title_text="Trial number", row=2, col=1, + rangeslider=dict(visible=True, thickness=0.08)) + if len(segments) > 1: + step = 250 + tickvals, ticktext = [], [] + for s, e in segments: + for w in range(0, e - s, step): + tickvals.append(s + 1 + w) + ticktext.append(str(w)) + fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) fig.update_layout( - title="Session Scroller", - xaxis_title="Time (s)", - yaxis=dict( - tickvals=yticks, - ticktext=ylabels, - ), - xaxis=dict(range=[xmin, xmax]), - showlegend=True, - height=800, - width=1300, + width=1000, height=460, template="simple_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, + font=dict(size=9)), + margin=dict(l=60, r=20, t=90, b=50), ) + return fig + + +def plot_session_in_time_plotly( # noqa: C901 pragma: no cover + df_events, df_trials=None, fip_df=None, adjust_time=True, title=None, smooth_factor=5 +): + """Plotly version of :func:`plot_session_scroller.plot_session_scroller` (time-based). + + Plots the session in real time (not in trial): left / right licks and rewards as ticks, + go cues as vertical lines (red for ignored trials), smoothed overlays above the events, + and -- when ``df_trials`` is supplied -- the left / right reward-probability band in the + rangeslider "scroller" below. + + Multiple sessions: if ``df_events`` has a ``session_id`` column with more than one + session, the sessions are concatenated end-to-end along time in order of appearance + (each restarted at the running offset); a thick vertical line marks each boundary and + per-trial / smoothed quantities reset per session. ``df_trials`` is matched per session + by its own ``session_id`` column when present. + + Parameters + ---------- + df_events : pandas.DataFrame + Tidy dataframe of session events (``event`` + ``timestamps``; optional ``session_id``). + Recognised events: ``left_lick_time``, ``right_lick_time``, ``left_reward_delivery_time``, + ``right_reward_delivery_time`` and ``goCue_start_time``. + df_trials : pandas.DataFrame, optional + Per-trial dataframe for the reward-probability band / overlays / red ignored go cues + (and a fallback source of go-cue times). Uses ``goCue_start_time``, + ``reward_probabilityL/R`` and ``animal_response``; go-cue times must share the + ``df_events`` time base. Matched per session via ``session_id`` when present. + fip_df : pandas.DataFrame, optional + Tidy FIP measurements (single-session only); each present channel is normalised and + stacked above the behavior panel. + adjust_time : bool, optional + If True (default), shift time so the first event is at t = 0 (always shifted when + concatenating multiple sessions). + title : str, optional + Figure title. + smooth_factor : int, optional + Smoothing window for the choice / lick-count overlays, by default 5. + + Returns + ------- + plotly.graph_objects.Figure + """ + df_events = df_events.copy() + if df_trials is not None: + df_trials = df_trials.copy() + if fip_df is not None: + fip_df = fip_df.copy() + + # y-layout, bottom -> top: + # * event rows in [0, 1]: rewards at the outer edges, licks inside (right pair near the + # top, left pair near the bottom), like the trial-based figure. + # * smoothed overlays in their own band [curve_bottom, curve_top] above the events. + # * the reward-probability band sits higher still -- it only shows in the rangeslider. + params = { + "behavior_bottom": 0.0, "behavior_top": 1.0, # event rows (row 1) + "curve_bottom": 1.1, "curve_top": 2.1, # smoothed overlays (row 1) + } + row_centers = {"right_reward": 0.92, "right_lick": 0.78, + "left_lick": 0.22, "left_reward": 0.08} + tick_half = 0.25 * 0.30 / 2.0 + + def _to_curve(v): + """Map a 0..1 per-trial value into the smoothed-overlay band above the events.""" + span = params["curve_top"] - params["curve_bottom"] + return params["curve_bottom"] + np.asarray(v, dtype=float) * span + + yticks = [0.92, 0.78, 0.22, 0.08, + params["curve_bottom"], (params["curve_bottom"] + params["curve_top"]) / 2, + params["curve_top"]] + ylabels = ["right reward", "right lick", "left lick", "left reward", "0", "0.5", "1"] + # Sessions in order of appearance; concatenate end-to-end along time when more than one. + has_sess = "session_id" in df_events.columns + sessions = (list(dict.fromkeys(df_events["session_id"].tolist())) if has_sess else [None]) + shift_each = adjust_time or len(sessions) > 1 + + # Same two-panel layout as the trial-based figure: the raster/curves on top (row 1) over a + # reward-schedule panel (row 2), with a rangeslider scroller under row 2. + fig = make_subplots(rows=2, cols=1, shared_xaxes=True, + row_heights=[0.85, 0.15], vertical_spacing=0.04) + + ev_meta = { + "left_lick": ("left_lick_time", "gray", 1.5, "left lick"), + "right_lick": ("right_lick_time", "gray", 1.5, "right lick"), + "left_reward": ("left_reward_delivery_time", "black", 2, "left reward"), + "right_reward": ("right_reward_delivery_time", "black", 2, "right reward"), + } + ev_acc = {k: {"x": [], "y": [], "cd": []} for k in ev_meta} + gocue_acc = {"go cue": {"x": [], "y": [], "cd": []}, + "go cue (ignored)": {"x": [], "y": [], "cd": []}} + frac_x, frac_y, choice_x, choice_y, lick_x, lick_y = [], [], [], [], [], [] + probL_x, probL_y, probR_x, probR_y = [], [], [], [] # reward-prob lines (scroller) + boundaries, sess_spans, has_prob = [], [], False + cum, first_t0, first_gc, last_off = 0.0, None, None, 0.0 + + for si, sess in enumerate(sessions): + ev_s = df_events if sess is None else df_events[df_events["session_id"] == sess] + ts = ev_s["timestamps"].to_numpy() + t0 = np.nanmin(ts) + if first_t0 is None: + first_t0 = t0 + off = (cum - t0) if shift_each else 0.0 + last_off = off + if si > 0: + boundaries.append(cum) + + def _ev(name, _ev=ev_s, _off=off): + """Timestamps of events named ``name`` in this session, shifted by ``_off``.""" + return _ev.loc[_ev["event"] == name, "timestamps"].to_numpy() + _off + + tr_s = None + if df_trials is not None: + tr_s = (df_trials[df_trials["session_id"] == sess] + if (sess is not None and "session_id" in df_trials.columns) else df_trials) + + gc = _ev("goCue_start_time") + if len(gc) == 0 and tr_s is not None and "goCue_start_time" in tr_s.columns: + gc = tr_s["goCue_start_time"].to_numpy() + off + gc = gc[~np.isnan(gc)] + n_tr = len(gc) + if n_tr and first_gc is None: + first_gc = gc.min() + + def _trial_of(times, _gc=gc, _n=n_tr): + """Per-session trial number for each time (# go cues at/before it).""" + if _n == 0: + return np.zeros(len(times), dtype=int) + return np.searchsorted(_gc, np.asarray(times), side="right") + + aligned = tr_s is not None and len(tr_s) == n_tr and n_tr > 0 + choice = None + if aligned and "animal_response" in tr_s.columns: + choice = tr_s["animal_response"].astype(float).to_numpy().copy() + choice[choice == 2] = np.nan + + sess_disp = "" if sess is None else sess # shown in (trial, session) hover + + for key, (name, _color, _width, _label) in ev_meta.items(): + c = row_centers[key] + t = _ev(name) + hov = [(int(tr), sess_disp) for tr in _trial_of(t)] + xs, ys, cd = _vline_hover(t, c - tick_half, c + tick_half, hov, gap=(None, None)) + ev_acc[key]["x"] += xs + ev_acc[key]["y"] += ys + ev_acc[key]["cd"] += cd + + if n_tr: + trial_no = np.arange(1, n_tr + 1) + ign = np.isnan(choice) if choice is not None else np.zeros(n_tr, dtype=bool) + for gname, mask in [("go cue", ~ign), ("go cue (ignored)", ign)]: + if mask.any(): + hov = [(int(tr), sess_disp) for tr in trial_no[mask]] + xs, ys, cd = _vline_hover(gc[mask], params["behavior_bottom"], + params["behavior_top"], hov, gap=(None, None)) + gocue_acc[gname]["x"] += xs + gocue_acc[gname]["y"] += ys + gocue_acc[gname]["cd"] += cd + + if n_tr: + off_s = smooth_factor // 2 + if aligned and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns): + pL = tr_s["reward_probabilityL"].to_numpy() + pR = tr_s["reward_probabilityR"].to_numpy() + frac = np.divide(pR, pL + pR, out=np.full(n_tr, np.nan), where=(pL + pR) > 0) + frac_x += [*gc, None] + frac_y += [*_to_curve(frac), None] + if choice is not None: + sm = moving_average(choice, smooth_factor) / ( + moving_average(~np.isnan(choice), smooth_factor) + 1e-6) + sm[sm > 100] = np.nan + xsm = gc[off_s: off_s + len(sm)] + choice_x += [*xsm, None] + choice_y += [*_to_curve(sm[: len(xsm)]), None] + lt = np.concatenate([_ev("left_lick_time"), _ev("right_lick_time")]) + if len(lt): + counts = np.bincount(_trial_of(lt), minlength=n_tr + 1)[1:n_tr + 1] + sm = moving_average(counts.astype(float), smooth_factor) + top = np.nanmax(sm) if len(sm) else 0 + if top > 0: + sm = sm / top + xsm = gc[off_s: off_s + len(sm)] + lick_x += [*xsm, None] + lick_y += [*_to_curve(sm[: len(xsm)]), None] + + # Reward-probability schedule (pL red, pR blue), drawn as 0..1 lines in the bottom + # panel just like the trial-based figure; broken at session boundaries. + if (tr_s is not None and n_tr and len(tr_s) == n_tr + and {"reward_probabilityL", "reward_probabilityR"} <= set(tr_s.columns)): + has_prob = True + probL_x += [*gc, None] + probL_y += [*tr_s["reward_probabilityL"].to_numpy(), None] + probR_x += [*gc, None] + probR_y += [*tr_s["reward_probabilityR"].to_numpy(), None] + + dur = np.nanmax(ts) - t0 + sess_spans.append((cum, dur)) + cum += dur + + # --- build one trace per type from the accumulators --- + # Row 1: events, go cues, smoothed overlays. + ht = "%%{x:.2f}s
trial %%{customdata[0]}
session %%{customdata[1]}%s" + for key, (name, color, width, label) in ev_meta.items(): + a = ev_acc[key] + fig.add_trace(go.Scattergl( + x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", + line=dict(color=color, width=width), name=label, hovertemplate=ht % label), + row=1, col=1) + + for gname, gcolor in [("go cue", "green"), ("go cue (ignored)", "red")]: + a = gocue_acc[gname] + if a["x"]: + fig.add_trace(go.Scattergl( + x=a["x"], y=a["y"], customdata=a["cd"], mode="lines", + line=dict(color=gcolor, width=0.75), opacity=0.75, name=gname, + hovertemplate=ht % gname), row=1, col=1) + + if frac_x: + fig.add_trace(go.Scattergl(x=frac_x, y=frac_y, mode="lines", + line=dict(color="gold", width=1.5), name="pR/(pL+pR)"), + row=1, col=1) + if choice_x: + fig.add_trace(go.Scattergl(x=choice_x, y=choice_y, mode="lines", + line=dict(color="black", width=1.5), + name=f"choice (smooth = {smooth_factor})"), row=1, col=1) + if lick_x: + fig.add_trace(go.Scattergl(x=lick_x, y=lick_y, mode="lines", + line=dict(color="black", width=1.2, dash="dash"), + name=f"lick count (smooth = {smooth_factor})"), row=1, col=1) + + # Row 2: reward-probability schedule (pR blue, pL red), 0..1. + if has_prob: + fig.add_trace(go.Scattergl(x=probR_x, y=probR_y, mode="lines", + line=dict(color="blue", width=1), name="pR"), row=2, col=1) + fig.add_trace(go.Scattergl(x=probL_x, y=probL_y, mode="lines", + line=dict(color="red", width=1), name="pL"), row=2, col=1) + + y_main_top = params["curve_top"] + + # FIP channels (single-session only), normalised and stacked above the behavior panel + if fip_df is not None and len(sessions) == 1: + fip_channels = ["G_1_preprocessed", "G_2_preprocessed", + "R_1_preprocessed", "R_2_preprocessed"] + fip_colors = {"G_1": "green", "G_2": "darkgreen", "R_1": "red", "R_2": "darkred"} + present = set(fip_df["event"].unique()) + band = 0 + for channel in fip_channels: + if channel not in present: + continue + bottom = params["curve_top"] + 0.1 + band + C = fip_df.query("event == @channel").copy() + d = C["data"].values - np.nanmin(C["data"].values) + d = d / np.nanmax(d) + bottom + color = fip_colors["_".join(channel.split("_")[:2])] + fig.add_trace(go.Scattergl(x=C.timestamps.values + last_off, y=d, mode="lines", + line=dict(color=color), name=channel), row=1, col=1) + yticks.append(bottom + 0.5) + ylabels.append(channel) + band += 1 + y_main_top = bottom + 1.0 + + # Thick vertical lines marking session boundaries (both panels) + for b in boundaries: + for row in (1, 2): + fig.add_vline(x=b, line=dict(color="black", width=2), row=row, col=1) + + # Extent + initial ~120 s window at the first go cue; rangeslider scroller under row 2. + x_first = 0.0 if shift_each else (first_t0 if first_t0 is not None else 0.0) + x_last = x_first + cum + t0_view = first_gc if first_gc is not None else x_first + + fig.update_yaxes(tickvals=yticks, ticktext=ylabels, fixedrange=True, + range=[params["behavior_bottom"] - 0.05, y_main_top + 0.25], row=1, col=1) + fig.update_yaxes(title_text="p_reward", range=[0, 1], fixedrange=True, row=2, col=1) + fig.update_xaxes(range=[t0_view, t0_view + 120], row=1, col=1) + fig.update_xaxes(title_text="Time (s)", range=[t0_view, t0_view + 120], row=2, col=1, + rangeslider=dict(visible=True, thickness=0.06, range=[x_first, x_last])) + if len(sess_spans) > 1: # x tick labels restart at 0 each session + tickvals, ticktext = [], [] + for start, dur in sess_spans: + step = _nice_step(dur) + w = 0.0 + while w <= dur: + tickvals.append(start + w) + ticktext.append(str(int(w))) + w += step + fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, row=2, col=1) + + fig.update_layout( + # Title pinned to the very top-left so it clears the legend below it. + title=dict(text=title or "Session Scroller", x=0.0, xanchor="left", + y=0.98, yanchor="top"), + showlegend=True, height=620, width=1000, template="simple_white", + # Legend outside, top-left, horizontal, compact entries (narrow box). + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, + font=dict(size=9), entrywidthmode="pixels", entrywidth=125), + margin=dict(l=70, r=20, t=120, b=40), + ) return fig diff --git a/tests/test_plot_foraging_session_plotly.py b/tests/test_plot_foraging_session_plotly.py new file mode 100644 index 0000000..47b7db8 --- /dev/null +++ b/tests/test_plot_foraging_session_plotly.py @@ -0,0 +1,129 @@ +"""Test the plotly foraging-session plots. + +To run the test, execute "python -m unittest tests/test_plot_foraging_session_plotly.py". +""" + +import os +import unittest + +import numpy as np +import pandas as pd +import plotly.graph_objects as go + +from aind_dynamic_foraging_basic_analysis import ( + plot_foraging_session_plotly, + plot_session_in_time_plotly, +) +from tests.nwb_io import get_history_from_nwb + + +class TestPlotForagingSessionPlotly(unittest.TestCase): + """Test the trial-based plotly plot against a real session.""" + + @classmethod + def setUpClass(cls): + """Load example session history from the bundled NWB.""" + nwb_file = os.path.dirname(__file__) + "/data/697929_2024-02-22_08-38-30.nwb" + ( + _, + cls.choice_history, + cls.reward_history, + cls.p_reward, + cls.autowater_offered, + _, + ) = get_history_from_nwb(nwb_file) + + def test_returns_figure(self): + """A plotly Figure is returned with both panels populated.""" + fig = plot_foraging_session_plotly( + choice_history=self.choice_history, + reward_history=self.reward_history, + p_reward=self.p_reward, + autowater_offered=self.autowater_offered, + ) + self.assertIsInstance(fig, go.Figure) + self.assertGreater(len(fig.data), 0) + + def test_optional_traces(self): + """Bias band and photostim markers are accepted without error.""" + n = len(self.choice_history) + fig = plot_foraging_session_plotly( + choice_history=self.choice_history, + reward_history=self.reward_history, + p_reward=self.p_reward, + bias=np.zeros(n), + bias_lower=-np.ones(n) * 0.2, + bias_upper=np.ones(n) * 0.2, + photostim={"trial": [10, 20], "power": [3.0, 3.0]}, + plot_list=["choice", "finished", "reward_prob", "bias"], + ) + self.assertIsInstance(fig, go.Figure) + + def test_multi_session(self): + """A per-trial session_id concatenates sessions with a boundary line.""" + n = len(self.choice_history) + session_id = np.array(["a"] * n + ["b"] * n) + fig = plot_foraging_session_plotly( + np.concatenate([self.choice_history, self.choice_history]), + np.concatenate([self.reward_history, self.reward_history]), + np.concatenate([self.p_reward, self.p_reward], axis=1), + session_id=session_id, + ) + # One boundary, drawn as a vertical line (shape) in each of the two rows. + self.assertEqual(len(fig.layout.shapes), 2) + + +class TestPlotSessionInTimePlotly(unittest.TestCase): + """Test the time-based plotly plot with a synthetic events / trials frame.""" + + def setUp(self): + """Build a small tidy events frame and matching trials frame.""" + go_cues = np.arange(5, 25, 2.0) # 10 trials + events = [] + for t in go_cues: + events.append((t, "goCue_start_time")) + events.append((t + 0.3, "left_lick_time")) + events.append((t + 0.4, "right_lick_time")) + events.append((t + 0.5, "left_reward_delivery_time")) + self.df_events = pd.DataFrame(events, columns=["timestamps", "event"]).sort_values( + "timestamps" + ) + self.df_trials = pd.DataFrame( + { + "goCue_start_time": go_cues, + "reward_probabilityL": np.linspace(0.1, 0.8, len(go_cues)), + "reward_probabilityR": np.linspace(0.8, 0.1, len(go_cues)), + } + ) + + def test_events_only(self): + """Works with just an events frame (no probability band).""" + fig = plot_session_in_time_plotly(self.df_events) + self.assertIsInstance(fig, go.Figure) + self.assertGreater(len(fig.data), 0) + + def test_with_trials(self): + """Supplying df_trials adds the reward-probability band traces.""" + fig = plot_session_in_time_plotly( + self.df_events, df_trials=self.df_trials, title="unit_test" + ) + names = [tr.name for tr in fig.data] + self.assertIn("pR", names) + self.assertIn("pL", names) + + def test_multi_session(self): + """A session_id column concatenates sessions end-to-end with a boundary line.""" + e1 = self.df_events.assign(session_id="s1") + e2 = self.df_events.assign(session_id="s2", timestamps=self.df_events["timestamps"] + 100) + t1 = self.df_trials.assign(session_id="s1") + t2 = self.df_trials.assign(session_id="s2", + goCue_start_time=self.df_trials["goCue_start_time"] + 100) + fig = plot_session_in_time_plotly( + pd.concat([e1, e2], ignore_index=True), + df_trials=pd.concat([t1, t2], ignore_index=True), + ) + self.assertEqual(len(fig.layout.shapes), 2) # one boundary, drawn in both panels + + +if __name__ == "__main__": + unittest.main()