Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/ndv/controllers/_array_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,35 @@
import numpy.typing as npt
from typing_extensions import Unpack

from ndv._types import AxisKey, ChannelKey, KeyPressEvent, MouseMoveEvent
from ndv._types import (
AxisKey,
ChannelKey,
KeyPressEvent,
MouseMoveEvent,
MousePressEvent,
)
from ndv.models._array_display_model import ArrayDisplayModelKwargs
from ndv.models._viewer_model import ArrayViewerModelKwargs
from ndv.views.bases import HistogramCanvas
from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle


def _apply_white_balance(
data: np.ndarray, gains: tuple[float, float, float]
) -> np.ndarray:
"""Apply per-channel RGB gains to an array with a trailing color dimension."""
result = data.astype(np.float32, copy=True)
n_channels = min(3, result.shape[-1])
for i in range(n_channels):
result[..., i] *= gains[i]
if data.dtype.kind in ("u", "i"):
info = np.iinfo(data.dtype)
np.clip(result, info.min, info.max, out=result)
return result.astype(data.dtype)
np.clip(result, 0, None, out=result)
return result


class ArrayViewer:
"""Viewer dedicated to displaying a single n-dimensional array.

Expand Down Expand Up @@ -133,10 +155,12 @@ def __init__(
self._view.histogramRequested.connect(self._add_histogram)
self._view.channelModeChanged.connect(self._on_view_channel_mode_changed)
self._view.ndimToggleRequested.connect(self._on_view_ndim_toggle_requested)
self._view.resetWhiteBalance.connect(self._on_reset_white_balance)

self._highlight_pos: tuple[float, float] | None = None
self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved)
self._canvas.mouseLeft.connect(self._on_canvas_mouse_left)
self._canvas.mousePressed.connect(self._on_canvas_mouse_pressed)

self._focused_slider_axis: AxisKey | None = None
self._disconnect_key_events = _app.filter_key_events(
Expand Down Expand Up @@ -336,6 +360,7 @@ def _set_model_connected(
(model.events.channel_mode, self._re_resolve),
(model.scales.value_changed, self._re_resolve),
(model.luts.value_changed, self._re_resolve),
(model.events.white_balance_gains, self._re_resolve),
]:
getattr(obj, _connect)(callback)

Expand Down Expand Up @@ -394,6 +419,7 @@ def _apply_changes(
or old.channel_axis != new.channel_axis
or old.channel_mode != new.channel_mode
or old.current_index != new.current_index
or old.white_balance_gains != new.white_balance_gains
)
if needs_data:
self._request_data()
Expand All @@ -407,6 +433,12 @@ def _apply_changes(
if old.summary_info != new.summary_info:
self._view.set_data_info(new.summary_info)

if old.channel_mode != new.channel_mode:
is_rgba = new.channel_mode == ChannelMode.RGBA
self._viewer_model.show_white_balance_button = is_rgba
if not is_rgba:
self._display_model.white_balance_gains = None

def _fallback_channel_name(self, key: ChannelKey) -> str:
"""Compute the data-derived fallback name for a channel key."""
if self._data_wrapper is not None and isinstance(key, int):
Expand Down Expand Up @@ -570,12 +602,60 @@ def _on_canvas_mouse_left(self) -> None:
self._highlight_pos = None
self._highlight_values({}, self._highlight_pos)

def _on_canvas_mouse_pressed(self, event: MousePressEvent) -> None:
"""Handle mouse press for eyedropper white balance pick."""
if self._viewer_model.interaction_mode != InteractionMode.PICK_COLOR:
return

x, y, _z = self._canvas.canvas_to_world((event.x, event.y))
pixel = self._sample_original_rgb(x, y)
if pixel is not None:
max_val = pixel.max()
if max_val > 1e-10:
gains = tuple(float(max_val / max(ch, 1e-10)) for ch in pixel)
self._display_model.white_balance_gains = gains # type: ignore[assignment]

self._viewer_model.interaction_mode = InteractionMode.PAN_ZOOM

def _sample_original_rgb(self, x: float, y: float) -> np.ndarray | None:
"""Sample the original (un-gained) RGB pixel at world coordinates."""
wrapper = self._data_wrapper
if wrapper is None:
return None

row, col = self._world_to_data(x, y)
vis = self._resolved.visible_axes
if len(vis) < 2:
return None

# Build index: current slice position + the clicked pixel coords
idx: dict[int, int | slice] = dict(self._resolved.current_index)
idx[vis[-2]] = row
idx[vis[-1]] = col
# Channel axis gets all channels
ch_ax = self._resolved.channel_axis
if ch_ax is not None:
idx[ch_ax] = slice(None)

try:
data = wrapper.isel(idx)
except (IndexError, KeyError):
return None

pixel = np.asarray(data, dtype=np.float64).ravel()
if pixel.size < 3:
return None
return pixel[:3]

def _on_key_pressed(self, event: KeyPressEvent) -> None:
handle_key_press(event, self)

def _on_view_channel_mode_changed(self, mode: ChannelMode) -> None:
self._display_model.channel_mode = mode

def _on_reset_white_balance(self) -> None:
self._display_model.white_balance_gains = None

# ------------------ Helper methods ------------------

def _highlight_values(
Expand Down Expand Up @@ -670,9 +750,12 @@ def _on_data_response_ready(self, future: Future[DataResponse]) -> None:
warnings.warn(f"Error fetching data: {e}", stacklevel=1)
return

wb_gains = self._resolved.white_balance_gains
for key, data in response.data.items():
if data.size == 0:
continue
if wb_gains is not None and data.ndim >= 3 and data.shape[-1] >= 3:
data = _apply_white_balance(data, wb_gains)
if (lut_ctrl := self._lut_controllers.get(key)) is None:
if key is None:
model = self._display_model.default_lut
Expand Down
3 changes: 3 additions & 0 deletions src/ndv/models/_array_display_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ class ArrayDisplayModel(NDVModel):
# per-axis scale factors (e.g. physical pixel size)
scales: ScalesMap = Field(default_factory=ScalesMap, frozen=True)

# per-channel (R, G, B) multiplicative gains for white balance correction
white_balance_gains: tuple[float, float, float] | None = None

@computed_field # type: ignore [prop-decorator]
@property
def n_visible_axes(self) -> Literal[2, 3]:
Expand Down
30 changes: 25 additions & 5 deletions src/ndv/models/_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ResolvedDisplayState:
current_index: dict[int, int | slice]
data_coords: dict[int, tuple]
hidden_sliders: frozenset[Hashable]
white_balance_gains: tuple[float, float, float] | None
summary_info: str
visible_scales: tuple[float, ...]

Expand All @@ -93,6 +94,7 @@ def __eq__(self, other: object) -> bool:
and self.data_coords == other.data_coords
and self.hidden_sliders == other.hidden_sliders
and self.visible_scales == other.visible_scales
and self.white_balance_gains == other.white_balance_gains
# summary_info excluded: metadata-only, should not trigger data fetch
)

Expand Down Expand Up @@ -121,6 +123,7 @@ def __rich_repr__(self) -> Iterable[tuple[str, object]]:
current_index={},
data_coords={},
hidden_sliders=frozenset(),
white_balance_gains=None,
summary_info="",
visible_scales=(),
)
Expand Down Expand Up @@ -164,11 +167,6 @@ def _norm_channel_axis(model: ArrayDisplayModel, wrapper: DataWrapper) -> int |
except Exception:
return None

# don't use a visible axis as the channel axis
normed_vis = _norm_visible_axes(model, wrapper)
if normed_guess in normed_vis:
return None

return normed_guess


Expand Down Expand Up @@ -289,6 +287,22 @@ def resolve(model: ArrayDisplayModel, wrapper: DataWrapper) -> ResolvedDisplaySt
"""
visible_axes = _norm_visible_axes(model, wrapper)
channel_axis = _norm_channel_axis(model, wrapper)

# If the guessed channel axis collides with a visible axis (e.g. a 3D
# array where the last dim is size 3 and visible_axes defaults to (-2, -1)):
# - In RGBA mode, shift visible axes to make room so the channel axis is
# usable (otherwise explicit channel_mode="rgba" breaks on (M,N,3) data).
# - In other modes, discard the guess: treat the ambiguous axis as spatial
# so the user gets a normal Z slider instead of unwanted channel splitting.
if channel_axis is not None and channel_axis in visible_axes:
if model.channel_mode == ChannelMode.RGBA:
ndim = len(wrapper.dims)
n_vis = len(visible_axes)
candidates = [i for i in range(ndim) if i != channel_axis]
visible_axes = tuple(candidates[-n_vis:])
else:
channel_axis = None

current_index = _norm_current_index(model, wrapper)
data_coords = _norm_data_coords(wrapper)

Expand All @@ -297,13 +311,19 @@ def resolve(model: ArrayDisplayModel, wrapper: DataWrapper) -> ResolvedDisplaySt
)
visible_scales = _resolve_visible_scales(model, wrapper, visible_axes)

# white balance only applies in RGBA mode
wb_gains = (
model.white_balance_gains if model.channel_mode == ChannelMode.RGBA else None
)

return ResolvedDisplayState(
visible_axes=visible_axes,
channel_axis=channel_axis,
channel_mode=model.channel_mode,
current_index=current_index,
data_coords=data_coords,
hidden_sliders=hidden_sliders,
white_balance_gains=wb_gains,
summary_info=wrapper.summary_info(),
visible_scales=visible_scales,
)
Expand Down
2 changes: 2 additions & 0 deletions src/ndv/models/_viewer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class InteractionMode(str, Enum):

PAN_ZOOM = "pan_zoom"
CREATE_ROI = "create_roi"
PICK_COLOR = "pick_color"

def __str__(self) -> str:
"""Return the string representation of the enum value."""
Expand Down Expand Up @@ -86,6 +87,7 @@ class ArrayViewerModel(NDVModel):
show_histogram_button: bool = True
show_reset_zoom_button: bool = True
show_roi_button: bool = False
show_white_balance_button: bool = False
show_channel_mode_selector: bool = True
show_play_button: bool = True
show_data_info: bool = True
Expand Down
23 changes: 21 additions & 2 deletions src/ndv/views/_jupyter/_array_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,17 @@ def __init__(
if not viewer_model.show_roi_button:
self._add_roi_btn.layout.display = "none"

# White balance eyedropper button
self._wb_btn = widgets.ToggleButton(
value=False,
description="WB",
tooltip="Pick white balance point",
icon="eyedropper",
)
self._wb_btn.observe(self._on_wb_button_toggle, names="value")
if not viewer_model.show_white_balance_button:
self._wb_btn.layout.display = "none"

# LAYOUT

top_row = widgets.HBox(
Expand All @@ -470,6 +481,7 @@ def __init__(
self._channel_mode_combo,
self._ndims_btn,
self._add_roi_btn,
self._wb_btn,
self._reset_zoom_btn,
],
layout=widgets.Layout(justify_content="flex-end"),
Expand Down Expand Up @@ -595,11 +607,15 @@ def add_histogram(self, channel: ChannelKey, histogram: HistogramCanvas) -> None
lut.add_histogram(histogram)

def _on_add_roi_button_toggle(self, change: dict[str, Any]) -> None:
"""Emit signal when the channel mode changes."""
self._viewer_model.interaction_mode = (
InteractionMode.CREATE_ROI if change["new"] else InteractionMode.PAN_ZOOM
)

def _on_wb_button_toggle(self, change: dict[str, Any]) -> None:
self._viewer_model.interaction_mode = (
InteractionMode.PICK_COLOR if change["new"] else InteractionMode.PAN_ZOOM
)

def remove_histogram(self, widget: Any) -> None:
"""Remove a histogram widget from the viewer."""

Expand Down Expand Up @@ -638,16 +654,19 @@ def _on_viewer_model_event(self, info: EmissionInfo) -> None:
if sig_name == "show_progress_spinner":
self._progress_spinner.layout.display = "flex" if value else "none"
elif sig_name == "interaction_mode":
# If leaving CanvasMode.CREATE_ROI, uncheck the ROI button
_new, old = info.args
if old == InteractionMode.CREATE_ROI:
self._add_roi_btn.value = False
if old == InteractionMode.PICK_COLOR:
self._wb_btn.value = False
elif sig_name == "show_histogram_button":
# Note that "block" displays the icon better than "flex"
for lut in self._luts.values():
lut._histogram_btn.layout.display = "block" if value else "none"
elif sig_name == "show_roi_button":
self._add_roi_btn.layout.display = "flex" if value else "none"
elif sig_name == "show_white_balance_button":
self._wb_btn.layout.display = "flex" if value else "none"
elif sig_name == "show_channel_mode_selector":
self._channel_mode_combo.layout.display = "flex" if value else "none"
elif sig_name == "show_reset_zoom_button":
Expand Down
10 changes: 9 additions & 1 deletion src/ndv/views/_pygfx/_array_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,11 @@ def on_mouse_press(self, event: MousePressEvent) -> bool:
canvas_pos = (event.x, event.y)
world_pos = self.canvas_to_world(canvas_pos)[:2]

# In PICK_COLOR mode, consume the press to prevent camera pan.
# The mousePressed signal still fires so the controller can handle it.
if self._viewer.interaction_mode == InteractionMode.PICK_COLOR:
return True

# If in CREATE_ROI mode, the new ROI should "start" here.
if self._viewer.interaction_mode == InteractionMode.CREATE_ROI:
if self._last_roi_created is None:
Expand Down Expand Up @@ -733,7 +738,10 @@ def on_mouse_release(self, event: MouseReleaseEvent) -> bool:
return False

def get_cursor(self, event: MouseMoveEvent) -> CursorType:
if self._viewer.interaction_mode == InteractionMode.CREATE_ROI:
if self._viewer.interaction_mode in (
InteractionMode.CREATE_ROI,
InteractionMode.PICK_COLOR,
):
return CursorType.CROSS
for vis in self.elements_at((event.x, event.y)):
if cursor := vis.get_cursor(event):
Expand Down
Loading
Loading