diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py index 737039cf..130ee80a 100644 --- a/src/ndv/controllers/_array_viewer.py +++ b/src/ndv/controllers/_array_viewer.py @@ -31,13 +31,35 @@ import numpy.typing as npt from typing_extensions import Unpack - from ndv._types import AxisKey, ChannelKey, KeyPressEvent, MouseMoveEvent + from ndv._types import ( + AxisKey, + ChannelKey, + KeyPressEvent, + MouseMoveEvent, + MousePressEvent, + ) from ndv.models._array_display_model import ArrayDisplayModelKwargs from ndv.models._viewer_model import ArrayViewerModelKwargs from ndv.views.bases import HistogramCanvas from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle +def _apply_white_balance( + data: np.ndarray, gains: tuple[float, float, float] +) -> np.ndarray: + """Apply per-channel RGB gains to an array with a trailing color dimension.""" + result = data.astype(np.float32, copy=True) + n_channels = min(3, result.shape[-1]) + for i in range(n_channels): + result[..., i] *= gains[i] + if data.dtype.kind in ("u", "i"): + info = np.iinfo(data.dtype) + np.clip(result, info.min, info.max, out=result) + return result.astype(data.dtype) + np.clip(result, 0, None, out=result) + return result + + class ArrayViewer: """Viewer dedicated to displaying a single n-dimensional array. @@ -133,10 +155,12 @@ def __init__( self._view.histogramRequested.connect(self._add_histogram) self._view.channelModeChanged.connect(self._on_view_channel_mode_changed) self._view.ndimToggleRequested.connect(self._on_view_ndim_toggle_requested) + self._view.resetWhiteBalance.connect(self._on_reset_white_balance) self._highlight_pos: tuple[float, float] | None = None self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved) self._canvas.mouseLeft.connect(self._on_canvas_mouse_left) + self._canvas.mousePressed.connect(self._on_canvas_mouse_pressed) self._focused_slider_axis: AxisKey | None = None self._disconnect_key_events = _app.filter_key_events( @@ -336,6 +360,7 @@ def _set_model_connected( (model.events.channel_mode, self._re_resolve), (model.scales.value_changed, self._re_resolve), (model.luts.value_changed, self._re_resolve), + (model.events.white_balance_gains, self._re_resolve), ]: getattr(obj, _connect)(callback) @@ -394,6 +419,7 @@ def _apply_changes( or old.channel_axis != new.channel_axis or old.channel_mode != new.channel_mode or old.current_index != new.current_index + or old.white_balance_gains != new.white_balance_gains ) if needs_data: self._request_data() @@ -407,6 +433,12 @@ def _apply_changes( if old.summary_info != new.summary_info: self._view.set_data_info(new.summary_info) + if old.channel_mode != new.channel_mode: + is_rgba = new.channel_mode == ChannelMode.RGBA + self._viewer_model.show_white_balance_button = is_rgba + if not is_rgba: + self._display_model.white_balance_gains = None + def _fallback_channel_name(self, key: ChannelKey) -> str: """Compute the data-derived fallback name for a channel key.""" if self._data_wrapper is not None and isinstance(key, int): @@ -570,12 +602,60 @@ def _on_canvas_mouse_left(self) -> None: self._highlight_pos = None self._highlight_values({}, self._highlight_pos) + def _on_canvas_mouse_pressed(self, event: MousePressEvent) -> None: + """Handle mouse press for eyedropper white balance pick.""" + if self._viewer_model.interaction_mode != InteractionMode.PICK_COLOR: + return + + x, y, _z = self._canvas.canvas_to_world((event.x, event.y)) + pixel = self._sample_original_rgb(x, y) + if pixel is not None: + max_val = pixel.max() + if max_val > 1e-10: + gains = tuple(float(max_val / max(ch, 1e-10)) for ch in pixel) + self._display_model.white_balance_gains = gains # type: ignore[assignment] + + self._viewer_model.interaction_mode = InteractionMode.PAN_ZOOM + + def _sample_original_rgb(self, x: float, y: float) -> np.ndarray | None: + """Sample the original (un-gained) RGB pixel at world coordinates.""" + wrapper = self._data_wrapper + if wrapper is None: + return None + + row, col = self._world_to_data(x, y) + vis = self._resolved.visible_axes + if len(vis) < 2: + return None + + # Build index: current slice position + the clicked pixel coords + idx: dict[int, int | slice] = dict(self._resolved.current_index) + idx[vis[-2]] = row + idx[vis[-1]] = col + # Channel axis gets all channels + ch_ax = self._resolved.channel_axis + if ch_ax is not None: + idx[ch_ax] = slice(None) + + try: + data = wrapper.isel(idx) + except (IndexError, KeyError): + return None + + pixel = np.asarray(data, dtype=np.float64).ravel() + if pixel.size < 3: + return None + return pixel[:3] + def _on_key_pressed(self, event: KeyPressEvent) -> None: handle_key_press(event, self) def _on_view_channel_mode_changed(self, mode: ChannelMode) -> None: self._display_model.channel_mode = mode + def _on_reset_white_balance(self) -> None: + self._display_model.white_balance_gains = None + # ------------------ Helper methods ------------------ def _highlight_values( @@ -670,9 +750,12 @@ def _on_data_response_ready(self, future: Future[DataResponse]) -> None: warnings.warn(f"Error fetching data: {e}", stacklevel=1) return + wb_gains = self._resolved.white_balance_gains for key, data in response.data.items(): if data.size == 0: continue + if wb_gains is not None and data.ndim >= 3 and data.shape[-1] >= 3: + data = _apply_white_balance(data, wb_gains) if (lut_ctrl := self._lut_controllers.get(key)) is None: if key is None: model = self._display_model.default_lut diff --git a/src/ndv/models/_array_display_model.py b/src/ndv/models/_array_display_model.py index 3f713796..cf171e71 100644 --- a/src/ndv/models/_array_display_model.py +++ b/src/ndv/models/_array_display_model.py @@ -207,6 +207,9 @@ class ArrayDisplayModel(NDVModel): # per-axis scale factors (e.g. physical pixel size) scales: ScalesMap = Field(default_factory=ScalesMap, frozen=True) + # per-channel (R, G, B) multiplicative gains for white balance correction + white_balance_gains: tuple[float, float, float] | None = None + @computed_field # type: ignore [prop-decorator] @property def n_visible_axes(self) -> Literal[2, 3]: diff --git a/src/ndv/models/_resolve.py b/src/ndv/models/_resolve.py index 9cf6b4e1..babe6650 100644 --- a/src/ndv/models/_resolve.py +++ b/src/ndv/models/_resolve.py @@ -79,6 +79,7 @@ class ResolvedDisplayState: current_index: dict[int, int | slice] data_coords: dict[int, tuple] hidden_sliders: frozenset[Hashable] + white_balance_gains: tuple[float, float, float] | None summary_info: str visible_scales: tuple[float, ...] @@ -93,6 +94,7 @@ def __eq__(self, other: object) -> bool: and self.data_coords == other.data_coords and self.hidden_sliders == other.hidden_sliders and self.visible_scales == other.visible_scales + and self.white_balance_gains == other.white_balance_gains # summary_info excluded: metadata-only, should not trigger data fetch ) @@ -121,6 +123,7 @@ def __rich_repr__(self) -> Iterable[tuple[str, object]]: current_index={}, data_coords={}, hidden_sliders=frozenset(), + white_balance_gains=None, summary_info="", visible_scales=(), ) @@ -164,11 +167,6 @@ def _norm_channel_axis(model: ArrayDisplayModel, wrapper: DataWrapper) -> int | except Exception: return None - # don't use a visible axis as the channel axis - normed_vis = _norm_visible_axes(model, wrapper) - if normed_guess in normed_vis: - return None - return normed_guess @@ -289,6 +287,22 @@ def resolve(model: ArrayDisplayModel, wrapper: DataWrapper) -> ResolvedDisplaySt """ visible_axes = _norm_visible_axes(model, wrapper) channel_axis = _norm_channel_axis(model, wrapper) + + # If the guessed channel axis collides with a visible axis (e.g. a 3D + # array where the last dim is size 3 and visible_axes defaults to (-2, -1)): + # - In RGBA mode, shift visible axes to make room so the channel axis is + # usable (otherwise explicit channel_mode="rgba" breaks on (M,N,3) data). + # - In other modes, discard the guess: treat the ambiguous axis as spatial + # so the user gets a normal Z slider instead of unwanted channel splitting. + if channel_axis is not None and channel_axis in visible_axes: + if model.channel_mode == ChannelMode.RGBA: + ndim = len(wrapper.dims) + n_vis = len(visible_axes) + candidates = [i for i in range(ndim) if i != channel_axis] + visible_axes = tuple(candidates[-n_vis:]) + else: + channel_axis = None + current_index = _norm_current_index(model, wrapper) data_coords = _norm_data_coords(wrapper) @@ -297,6 +311,11 @@ def resolve(model: ArrayDisplayModel, wrapper: DataWrapper) -> ResolvedDisplaySt ) visible_scales = _resolve_visible_scales(model, wrapper, visible_axes) + # white balance only applies in RGBA mode + wb_gains = ( + model.white_balance_gains if model.channel_mode == ChannelMode.RGBA else None + ) + return ResolvedDisplayState( visible_axes=visible_axes, channel_axis=channel_axis, @@ -304,6 +323,7 @@ def resolve(model: ArrayDisplayModel, wrapper: DataWrapper) -> ResolvedDisplaySt current_index=current_index, data_coords=data_coords, hidden_sliders=hidden_sliders, + white_balance_gains=wb_gains, summary_info=wrapper.summary_info(), visible_scales=visible_scales, ) diff --git a/src/ndv/models/_viewer_model.py b/src/ndv/models/_viewer_model.py index 04134cd0..94b4b4bc 100644 --- a/src/ndv/models/_viewer_model.py +++ b/src/ndv/models/_viewer_model.py @@ -34,6 +34,7 @@ class InteractionMode(str, Enum): PAN_ZOOM = "pan_zoom" CREATE_ROI = "create_roi" + PICK_COLOR = "pick_color" def __str__(self) -> str: """Return the string representation of the enum value.""" @@ -86,6 +87,7 @@ class ArrayViewerModel(NDVModel): show_histogram_button: bool = True show_reset_zoom_button: bool = True show_roi_button: bool = False + show_white_balance_button: bool = False show_channel_mode_selector: bool = True show_play_button: bool = True show_data_info: bool = True diff --git a/src/ndv/views/_jupyter/_array_view.py b/src/ndv/views/_jupyter/_array_view.py index 9d3020fb..4bfd9dde 100644 --- a/src/ndv/views/_jupyter/_array_view.py +++ b/src/ndv/views/_jupyter/_array_view.py @@ -448,6 +448,17 @@ def __init__( if not viewer_model.show_roi_button: self._add_roi_btn.layout.display = "none" + # White balance eyedropper button + self._wb_btn = widgets.ToggleButton( + value=False, + description="WB", + tooltip="Pick white balance point", + icon="eyedropper", + ) + self._wb_btn.observe(self._on_wb_button_toggle, names="value") + if not viewer_model.show_white_balance_button: + self._wb_btn.layout.display = "none" + # LAYOUT top_row = widgets.HBox( @@ -470,6 +481,7 @@ def __init__( self._channel_mode_combo, self._ndims_btn, self._add_roi_btn, + self._wb_btn, self._reset_zoom_btn, ], layout=widgets.Layout(justify_content="flex-end"), @@ -595,11 +607,15 @@ def add_histogram(self, channel: ChannelKey, histogram: HistogramCanvas) -> None lut.add_histogram(histogram) def _on_add_roi_button_toggle(self, change: dict[str, Any]) -> None: - """Emit signal when the channel mode changes.""" self._viewer_model.interaction_mode = ( InteractionMode.CREATE_ROI if change["new"] else InteractionMode.PAN_ZOOM ) + def _on_wb_button_toggle(self, change: dict[str, Any]) -> None: + self._viewer_model.interaction_mode = ( + InteractionMode.PICK_COLOR if change["new"] else InteractionMode.PAN_ZOOM + ) + def remove_histogram(self, widget: Any) -> None: """Remove a histogram widget from the viewer.""" @@ -638,16 +654,19 @@ def _on_viewer_model_event(self, info: EmissionInfo) -> None: if sig_name == "show_progress_spinner": self._progress_spinner.layout.display = "flex" if value else "none" elif sig_name == "interaction_mode": - # If leaving CanvasMode.CREATE_ROI, uncheck the ROI button _new, old = info.args if old == InteractionMode.CREATE_ROI: self._add_roi_btn.value = False + if old == InteractionMode.PICK_COLOR: + self._wb_btn.value = False elif sig_name == "show_histogram_button": # Note that "block" displays the icon better than "flex" for lut in self._luts.values(): lut._histogram_btn.layout.display = "block" if value else "none" elif sig_name == "show_roi_button": self._add_roi_btn.layout.display = "flex" if value else "none" + elif sig_name == "show_white_balance_button": + self._wb_btn.layout.display = "flex" if value else "none" elif sig_name == "show_channel_mode_selector": self._channel_mode_combo.layout.display = "flex" if value else "none" elif sig_name == "show_reset_zoom_button": diff --git a/src/ndv/views/_pygfx/_array_canvas.py b/src/ndv/views/_pygfx/_array_canvas.py index 7d727293..1c1bd5d1 100755 --- a/src/ndv/views/_pygfx/_array_canvas.py +++ b/src/ndv/views/_pygfx/_array_canvas.py @@ -689,6 +689,11 @@ def on_mouse_press(self, event: MousePressEvent) -> bool: canvas_pos = (event.x, event.y) world_pos = self.canvas_to_world(canvas_pos)[:2] + # In PICK_COLOR mode, consume the press to prevent camera pan. + # The mousePressed signal still fires so the controller can handle it. + if self._viewer.interaction_mode == InteractionMode.PICK_COLOR: + return True + # If in CREATE_ROI mode, the new ROI should "start" here. if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: if self._last_roi_created is None: @@ -733,7 +738,10 @@ def on_mouse_release(self, event: MouseReleaseEvent) -> bool: return False def get_cursor(self, event: MouseMoveEvent) -> CursorType: - if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + if self._viewer.interaction_mode in ( + InteractionMode.CREATE_ROI, + InteractionMode.PICK_COLOR, + ): return CursorType.CROSS for vis in self.elements_at((event.x, event.y)): if cursor := vis.get_cursor(event): diff --git a/src/ndv/views/_qt/_array_view.py b/src/ndv/views/_qt/_array_view.py index 877ce52f..1732a617 100644 --- a/src/ndv/views/_qt/_array_view.py +++ b/src/ndv/views/_qt/_array_view.py @@ -438,6 +438,15 @@ def __init__(self, parent: QWidget | None = None): self.setIcon(QIconifyIcon("mdi:vector-rectangle")) +class WhiteBalanceButton(QPushButton): + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self.setCheckable(True) + self.setToolTip("Pick white balance point (right-click to reset)") + self.setIcon(QIconifyIcon("mdi:eyedropper")) + self.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + + class DimRow(QObject): def __init__( self, axis: AxisKey, _coords: Sequence, parent: QObject | None @@ -713,6 +722,7 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): self._roi_handle: RectangularROIHandle | None = None self._selection: CanvasElement | None = None self.add_roi_btn = ROIButton() + self.wb_btn = WhiteBalanceButton() self.luts = _UpCollapsible( "LUTs", @@ -730,6 +740,7 @@ def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): self._btn_layout.addWidget(self.channel_mode_combo) self._btn_layout.addWidget(self.ndims_btn) self._btn_layout.addWidget(self.add_roi_btn) + self._btn_layout.addWidget(self.wb_btn) self._btn_layout.addWidget(self.set_range_btn) self._btns = QWidget() @@ -797,9 +808,12 @@ def __init__( # Mapping of channel key to LUTViews self._luts: dict[ChannelKey, QLUTView] = {} qwdg.add_roi_btn.toggled.connect(self._on_add_roi_clicked) + qwdg.wb_btn.toggled.connect(self._on_wb_clicked) + qwdg.wb_btn.customContextMenuRequested.connect(self._on_wb_context_menu) self._viewer_model.events.connect(self._on_viewer_model_event) qwdg.add_roi_btn.setVisible(viewer_model.show_roi_button) + qwdg.wb_btn.setVisible(viewer_model.show_white_balance_button) # TODO: use emit_fast qwdg.dims_sliders.currentIndexChanged.connect(self.currentIndexChanged.emit) @@ -897,21 +911,38 @@ def _on_add_roi_clicked(self, checked: bool) -> None: InteractionMode.CREATE_ROI if checked else InteractionMode.PAN_ZOOM ) + def _on_wb_clicked(self, checked: bool) -> None: + self._viewer_model.interaction_mode = ( + InteractionMode.PICK_COLOR if checked else InteractionMode.PAN_ZOOM + ) + + def _on_wb_context_menu(self) -> None: + from qtpy.QtWidgets import QMenu + + menu = QMenu(self._qwidget.wb_btn) + menu.addAction("Reset White Balance", self.resetWhiteBalance.emit) + menu.exec( + self._qwidget.wb_btn.mapToGlobal(self._qwidget.wb_btn.rect().center()) + ) + def _on_viewer_model_event(self, info: EmissionInfo) -> None: sig_name = info.signal.name value = info.args[0] if sig_name == "show_progress_spinner": self._qwidget._progress_spinner.setVisible(value) if sig_name == "interaction_mode": - # If leaving CanvasMode.CREATE_ROI, uncheck the ROI button _new, old = info.args if old == InteractionMode.CREATE_ROI: self._qwidget.add_roi_btn.setChecked(False) + if old == InteractionMode.PICK_COLOR: + self._qwidget.wb_btn.setChecked(False) elif sig_name == "show_histogram_button": for lut in self._luts.values(): lut._qwidget.histogram_btn.setVisible(value) elif sig_name == "show_roi_button": self._qwidget.add_roi_btn.setVisible(value) + elif sig_name == "show_white_balance_button": + self._qwidget.wb_btn.setVisible(value) elif sig_name == "show_channel_mode_selector": self._qwidget.channel_mode_combo.setVisible(value) elif sig_name == "show_reset_zoom_button": diff --git a/src/ndv/views/_vispy/_array_canvas.py b/src/ndv/views/_vispy/_array_canvas.py index c47c1da6..a6e56f15 100755 --- a/src/ndv/views/_vispy/_array_canvas.py +++ b/src/ndv/views/_vispy/_array_canvas.py @@ -534,6 +534,11 @@ def on_mouse_press(self, event: MousePressEvent) -> bool: canvas_pos = (event.x, event.y) world_pos = self.canvas_to_world(canvas_pos)[:2] + # In PICK_COLOR mode, consume the press to prevent camera pan. + # The mousePressed signal still fires so the controller can handle it. + if self._viewer.interaction_mode == InteractionMode.PICK_COLOR: + return True + # If in CREATE_ROI mode, the new ROI should "start" here. if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: if self._last_roi_created is None: @@ -579,7 +584,10 @@ def on_mouse_release(self, event: MouseReleaseEvent) -> bool: return False def get_cursor(self, event: MouseMoveEvent) -> CursorType: - if self._viewer.interaction_mode == InteractionMode.CREATE_ROI: + if self._viewer.interaction_mode in ( + InteractionMode.CREATE_ROI, + InteractionMode.PICK_COLOR, + ): return CursorType.CROSS for vis in self.elements_at((event.x, event.y)): if cursor := vis.get_cursor(event): diff --git a/src/ndv/views/_wx/_array_view.py b/src/ndv/views/_wx/_array_view.py index c6cba03a..e98f390b 100644 --- a/src/ndv/views/_wx/_array_view.py +++ b/src/ndv/views/_wx/_array_view.py @@ -623,6 +623,10 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window | None = None): self.add_roi_btn = wx.ToggleButton(self, label="ROI", size=(40, -1)) _add_icon(self.add_roi_btn, "mdi:vector-rectangle") + # White balance eyedropper button + self.wb_btn = wx.ToggleButton(self, label="WB", size=(40, -1)) + _add_icon(self.wb_btn, "mdi:eyedropper") + # how many luts need to be present before lut toolbar appears self._toolbar_display_thresh = 7 @@ -654,6 +658,7 @@ def __init__(self, canvas_widget: wx.Window, parent: wx.Window | None = None): self._btns.Add(self.set_range_btn, 0, wx.ALL, 4) self._btns.Add(self.ndims_btn, 0, wx.ALL, 4) self._btns.Add(self.add_roi_btn, 0, wx.ALL, 4) + self._btns.Add(self.wb_btn, 0, wx.ALL, 4) self._top_info = top_info = wx.BoxSizer(wx.HORIZONTAL) top_info.Add(self._data_info_label, 0, wx.EXPAND | wx.BOTTOM, 0) @@ -720,6 +725,9 @@ def __init__( wdg.ndims_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_ndims_toggled) wdg.add_roi_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_add_roi_toggled) wdg.add_roi_btn.Show(viewer_model.show_roi_button) + wdg.wb_btn.Bind(wx.EVT_TOGGLEBUTTON, self._on_wb_toggled) + wdg.wb_btn.Bind(wx.EVT_RIGHT_DOWN, self._on_wb_right_click) + wdg.wb_btn.Show(viewer_model.show_white_balance_button) def _on_channel_mode_changed(self, event: wx.CommandEvent) -> None: mode = self._wxwidget.channel_mode_combo.GetValue() @@ -738,6 +746,19 @@ def _on_add_roi_toggled(self, event: wx.CommandEvent) -> None: InteractionMode.CREATE_ROI if create_roi else InteractionMode.PAN_ZOOM ) + def _on_wb_toggled(self, event: wx.CommandEvent) -> None: + pick = self._wxwidget.wb_btn.GetValue() + self._viewer_model.interaction_mode = ( + InteractionMode.PICK_COLOR if pick else InteractionMode.PAN_ZOOM + ) + + def _on_wb_right_click(self, event: wx.MouseEvent) -> None: + menu = wx.Menu() + item = menu.Append(wx.ID_ANY, "Reset White Balance") + self._wxwidget.Bind(wx.EVT_MENU, lambda _: self.resetWhiteBalance.emit(), item) + self._wxwidget.PopupMenu(menu) + menu.Destroy() + def visible_axes(self) -> Sequence[AxisKey]: return self._visible_axes # no widget to control this yet @@ -853,10 +874,11 @@ def _on_viewer_model_event(self, info: EmissionInfo) -> None: self._wxwidget._progress_spinner.Show(value) self._wxwidget._top_info.Layout() elif sig_name == "interaction_mode": - # If leaving CanvasMode.CREATE_ROI, uncheck the ROI button _new, old = info.args if old == InteractionMode.CREATE_ROI: self._wxwidget.add_roi_btn.SetValue(False) + if old == InteractionMode.PICK_COLOR: + self._wxwidget.wb_btn.SetValue(False) elif sig_name == "show_histogram_button": for lut in self._luts.values(): lut._wxwidget.histogram_btn.Show(value) @@ -864,6 +886,9 @@ def _on_viewer_model_event(self, info: EmissionInfo) -> None: elif sig_name == "show_roi_button": self._wxwidget.add_roi_btn.Show(value) self._wxwidget._btns.Layout() + elif sig_name == "show_white_balance_button": + self._wxwidget.wb_btn.Show(value) + self._wxwidget._btns.Layout() elif sig_name == "show_channel_mode_selector": self._wxwidget.channel_mode_combo.Show(value) self._wxwidget._btns.Layout() diff --git a/src/ndv/views/bases/_array_view.py b/src/ndv/views/bases/_array_view.py index 7800dad3..26d99561 100644 --- a/src/ndv/views/bases/_array_view.py +++ b/src/ndv/views/bases/_array_view.py @@ -33,6 +33,7 @@ class ArrayView(Viewable): ndimToggleRequested = Signal(bool) channelModeChanged = Signal(ChannelMode) keyPressed = Signal(KeyPressEvent) + resetWhiteBalance = Signal() @abstractmethod def __init__( diff --git a/tests/test_resolve.py b/tests/test_resolve.py index 6a55adf3..8c9bb679 100644 --- a/tests/test_resolve.py +++ b/tests/test_resolve.py @@ -66,6 +66,46 @@ def test_negative_index_produces_nonempty_data() -> None: ) +@pytest.mark.parametrize( + "shape, model_kwargs, expect_channel_axis, expect_visible_axes", + [ + pytest.param( + (100, 200, 3), + {"channel_mode": "rgba"}, + 2, + (0, 1), + id="MN3_rgba_gets_channel_axis", + ), + pytest.param( + (100, 200, 3), + {}, + None, + (1, 2), + id="MN3_default_mode_no_channel_split", + ), + pytest.param( + (100, 200, 3), + {"visible_axes": (0, 1)}, + None, + (0, 1), + id="MN3_explicit_visible_axes_respected", + ), + ], +) +def test_channel_axis_visible_axis_collision( + shape: tuple[int, ...], + model_kwargs: dict, + expect_channel_axis: int | None, + expect_visible_axes: tuple[int, ...], +) -> None: + """Channel/visible axis collision: RGBA shifts visible, others discard guess.""" + wrapper = DataWrapper.create(np.ones(shape, dtype=np.uint8)) + model = ArrayDisplayModel(**model_kwargs) + resolved = resolve(model, wrapper) + assert resolved.channel_axis == expect_channel_axis + assert resolved.visible_axes == expect_visible_axes + + def test_resolved_index_clamps_negative_to_valid_range() -> None: """Negative index values should be clamped to [0, max_val].""" data = np.ones((5, 8, 10), dtype=np.uint8)