diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 549f43ca..e32056d7 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -1,3 +1,11 @@ +"""Widgets for displaying and interacting with SpatialData objects in napari. + +This module provides a set of Qt widgets for visualizing and interacting with +SpatialData objects within the napari viewer. It includes a ListWidget for selecting +coordinate systems, browsing elements within SpatialData objects, and handling +channel selection for multidimensional image data. +""" + from __future__ import annotations import platform @@ -5,7 +13,7 @@ from importlib.metadata import version from operator import itemgetter from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import shapely @@ -30,7 +38,11 @@ from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants import config from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD -from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping +from napari_spatialdata.utils._utils import ( + _get_sdata_key, + get_duplicate_element_names, + get_elements_meta_mapping, +) if TYPE_CHECKING: from napari import Viewer @@ -55,21 +67,76 @@ PROBLEMATIC_NUMPY_MACOS = False -class ElementWidget(QListWidget): - def __init__(self, sdata: EventedList): +class ListWidget(QListWidget): + """Widget for displaying and selecting coordinate systems or elements from SpatialData objects. + + This widget can show a list of coordinate systems or available elements (images, labels, points, shapes) + from the SpatialData objects, with warnings for elements that might be slow to render. + + Attributes + ---------- + _icon : QIcon + Icon used for warning indicators for elements that might be slow to render. + _sdata : EventedList + List of SpatialData objects. + _duplicate_element_names : dict + Dictionary of duplicate element names across SpatialData objects. + _elements : dict or None + Dictionary with metadata of the currently visible elements. + _system : str or None + Currently selected coordinate system. + """ + + def __init__(self, sdata: EventedList, coordinate_system: bool = False): + """Initialize the Widget. + + Parameters + ---------- + sdata : EventedList + List of SpatialData objects to display elements from. + coordinate_system : bool + If True, populate the widget with coordinate systems instead of elements. + """ super().__init__() self._icon = QIcon(str(icon_path)) self._sdata = sdata self._duplicate_element_names, _ = get_duplicate_element_names(self._sdata) - self._elements: None | dict[str, dict[str, str | int]] = None + self._elements: dict[str, dict[str, str | int]] | None = None + self._system: None | str = None - def _onItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + if coordinate_system: + # Sort alphabetically, but keep default "global" at the top. + coordinate_systems = sorted({cs for sdata in self._sdata for cs in sdata.coordinate_systems}) + if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: + coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) + coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) + self.addItems(coordinate_systems) + + def _onCsItemChange(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Update the element list of an element widget when the coordinate system selection changes. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The newly selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self.clear() elements, _ = get_elements_meta_mapping(self._sdata, selected_coordinate_system, self._duplicate_element_names) self._set_element_widget_items(elements) self._elements = elements def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) -> None: + """Populate an element widget with element items. + + Adds each element as an item in the list widget, with warning icons for elements + that might be slow to render (e.g., many circles or shapes). + + Parameters + ---------- + elements : dict[str, dict[str, str | int]] + Dictionary mapping element names to their metadata. + """ for key, dict_val in sorted(elements.items(), key=itemgetter(0)): sdata = self._sdata[dict_val["sdata_index"]] element_type = dict_val["element_type"] @@ -98,29 +165,59 @@ def _set_element_widget_items(self, elements: dict[str, dict[str, str | int]]) - ) self.addItem(item) - -class CoordinateSystemWidget(QListWidget): - def __init__(self, sdata: EventedList): - super().__init__() - - self._sdata = sdata - self._system: None | str = None - - # Sort alphabetically, but keep default "global" at the top. - coordinate_systems = sorted(cs for sdata in self._sdata for cs in sdata.coordinate_systems) - if DEFAULT_COORDINATE_SYSTEM in coordinate_systems: - coordinate_systems.remove(DEFAULT_COORDINATE_SYSTEM) - coordinate_systems.insert(0, DEFAULT_COORDINATE_SYSTEM) - self.addItems(coordinate_systems) - def _select_coord_sys(self, selected_coordinate_system: QListWidgetItem | int | Iterable[str]) -> None: + """Store the currently selected coordinate system. + + Parameters + ---------- + selected_coordinate_system : QListWidgetItem or int or Iterable[str] + The selected coordinate system. + Can be a QListWidgetItem, an index, or an iterable of strings. + """ self._system = str(selected_coordinate_system) class DataLoadThread(QThread): + """Thread for asynchronously loading SpatialData elements. + + This thread handles loading different types of data (images, labels, points, shapes) + from SpatialData objects without blocking the UI. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + + Attributes + ---------- + returned : Signal + Signal emitted when data loading is complete, carrying the created layer. + sdata_widget : SdataWidget + Parent SdataWidget that owns this thread. + _data_type : str + Type of data to load (images, labels, points, shapes). + _text : str + Name of the element to load. + _sdata : SpatialData + SpatialData object containing the element. + _selected_cs : str + Selected coordinate system. + _multi : bool + Boolean indicating if multiple SpatialData objects are present. + _channel_name : str, optional + Optional channel name for image data. + """ + returned = Signal(object) def __init__(self, parent: SdataWidget): + """Initialize the DataLoadThread. + + Parameters + ---------- + parent : SdataWidget + Parent SdataWidget that owns this thread. + """ super().__init__(parent=parent) self.sdata_widget = parent self._data_type = "" @@ -129,21 +226,58 @@ def __init__(self, parent: SdataWidget): self._selected_cs: str = "" self._multi: bool = False - def load_data(self, data_type: str, text: str, sdata: SpatialData, selected_cs: str, multi: bool) -> None: + def load_data( + self, + data_type: str, + text: str, + sdata: SpatialData, + selected_cs: str, + multi: bool, + channel_name: str | None = None, + ) -> None: + """Set up data loading parameters and start the thread. + + Parameters + ---------- + data_type : str + Type of data to load (images, labels, points, shapes). + text : str + Name of the element to load. + sdata : SpatialData + SpatialData object containing the element. + selected_cs : str + Selected coordinate system. + multi : bool + Boolean indicating if multiple SpatialData objects are present. + channel_name : str, optional + Optional channel name for image data. + + Raises + ------ + RuntimeError + If the thread is already running. + """ if self.isRunning(): raise RuntimeError("Thread is already running.") self._data_type = data_type self._text = text + self._channel_name = channel_name self._sdata = sdata self._selected_cs = selected_cs self._multi = multi if PROBLEMATIC_NUMPY_MACOS: self.run() + self.finished.emit() else: self.start() def run(self) -> None: + """Execute the data loading operation. + + Loads the specified data element based on its type and emits the + returned layer through the 'returned' signal. + """ if not self._data_type: return if self._data_type == "labels": @@ -152,7 +286,7 @@ def run(self) -> None: ) elif self._data_type == "images": layer = self.sdata_widget.viewer_model.get_sdata_image( - self._sdata, self._text, self._selected_cs, self._multi + self._sdata, self._text, self._selected_cs, self._multi, self._channel_name ) elif self._data_type == "points": layer = self.sdata_widget.viewer_model.get_sdata_points( @@ -167,18 +301,51 @@ def run(self) -> None: class SdataWidget(QWidget): + """Main widget for interacting with SpatialData objects in napari. + + This widget combines coordinate system selection and element browsing into a + unified interface for visualizing SpatialData objects in napari. + It manages the loading and display of different data types and handles coordinate + system transformations. + + Attributes + ---------- + _sdata + List of SpatialData objects. + viewer_model + SpatialDataViewer instance for interacting with napari. + worker_thread + Thread for asynchronous data loading. + coordinate_system_widget + Widget for selecting coordinate systems. + elements_widget + Widget for browsing and selecting elements. + slider + Progress bar shown during data loading. + """ + def __init__(self, viewer: Viewer, sdata: EventedList): + """Initialize the SdataWidget. + + Parameters + ---------- + viewer : Viewer + napari Viewer instance. + sdata : EventedList + List of SpatialData objects to visualize. + """ super().__init__() self._sdata = sdata self.viewer_model = SpatialDataViewer(viewer, self._sdata) + self._load_queue: list[tuple[Any, ...]] = [] self.worker_thread = DataLoadThread(parent=self) self.worker_thread.returned.connect(self.viewer_model.viewer.add_layer) - self.worker_thread.finished.connect(self._hide_slider) + self.worker_thread.finished.connect(self._on_load_finished) self.setLayout(QVBoxLayout()) - self.coordinate_system_widget = CoordinateSystemWidget(self._sdata) - self.elements_widget = ElementWidget(self._sdata) + self.coordinate_system_widget = ListWidget(self._sdata, coordinate_system=True) + self.elements_widget = ListWidget(self._sdata) self.slider = QProgressBar(self) self.slider.setRange(0, 0) self.slider.setVisible(False) @@ -216,14 +383,14 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.layout().addWidget(self._three_d_settings_label) self.layout().addWidget(self.discard_z_points) self.layout().addWidget(self.discard_z_shapes) - self.elements_widget.itemDoubleClicked.connect(self._on_click_item) + self.elements_widget.itemDoubleClicked.connect(self._on_doubleclick_element_item) self.coordinate_system_widget.currentItemChanged.connect( - lambda item: self.elements_widget._onItemChange(item.text()) + lambda item: self.elements_widget._onCsItemChange(item.text()) ) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.coordinate_system_widget._select_coord_sys(item.text()) ) - self.viewer_model.layer_saved.connect(self.elements_widget._onItemChange) + self.viewer_model.layer_saved.connect(self.elements_widget._onCsItemChange) self.coordinate_system_widget.currentItemChanged.connect(self._update_layers_visibility) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.viewer_model._affine_transform_layers(item.text()) @@ -231,24 +398,59 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.viewer_model.viewer.layers.events.inserted.connect(self._on_insert_layer) def _on_insert_layer(self, event: Event) -> None: + """Connect visibility events for newly inserted layers. + + Parameters + ---------- + event : Event + Event containing the newly inserted layer. + """ layer = event.value layer.events.visible.connect(self._update_visible_in_coordinate_system) - def _on_click_item(self, item: QListWidgetItem) -> None: + def _on_doubleclick_element_item(self, item: QListWidgetItem) -> None: + """Handle double-click events on element items in the element widget. + + Loads and displays the selected element. + + Parameters + ---------- + item : QListWidgetItem + The double-clicked element item. + """ self._onClick(item.text()) def _hide_slider(self) -> None: + """Hide the progress slider when data loading is complete.""" self.slider.setVisible(False) - def _onClick(self, text: str) -> None: + def _on_load_finished(self) -> None: + self._hide_slider() + self._process_queue() + + def _process_queue(self) -> None: + if self._load_queue: + task = self._load_queue.pop(0) + self._start_load(*task) + + def _onClick(self, element_name: str, channel_name: str | None = None) -> None: + """Handle click events to load and display data elements. + + Parameters + ---------- + element_name : str + Name of the element to load. + channel_name : str, optional + Name of the channel to load for image elements. + """ selected_cs = self.coordinate_system_widget._system if self.worker_thread.isRunning(): show_info("Please wait for the current operation to finish.") return if selected_cs and self.elements_widget._elements: - sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, text) - if (type_ := self.elements_widget._elements[text]["element_type"]) not in { + sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, element_name) + if (type_ := self.elements_widget._elements[element_name]["element_type"]) not in { "labels", "images", "shapes", @@ -256,12 +458,40 @@ def _onClick(self, text: str) -> None: }: return - self.worker_thread.load_data(type_, text, sdata, selected_cs, multi) - if not PROBLEMATIC_NUMPY_MACOS: - self.slider.setVisible(True) + self._start_load(type_, element_name, sdata, selected_cs, multi, channel_name) + + def _enqueue_channel(self, element_name: str, channel_name: str) -> None: + """Queue a channel load, starting immediately if the thread is idle.""" + selected_cs = self.coordinate_system_widget._system + if not selected_cs or not self.elements_widget._elements: + return + if element_name not in self.elements_widget._elements: + return + sdata, multi = _get_sdata_key(self._sdata, self.elements_widget._elements, element_name) + type_ = self.elements_widget._elements[element_name]["element_type"] + if type_ != "images": + return + task = (type_, element_name, sdata, selected_cs, multi, channel_name) + if self.worker_thread.isRunning() or self._load_queue: + self._load_queue.append(task) + else: + self._start_load(*task) + + def _start_load( + self, type_: str, element_name: str, sdata: SpatialData, selected_cs: str, multi: bool, channel_name: str | None + ) -> None: + if not PROBLEMATIC_NUMPY_MACOS: + self.slider.setVisible(True) + self.worker_thread.load_data(type_, element_name, sdata, selected_cs, multi, channel_name) def _update_visible_in_coordinate_system(self, event: Event) -> None: - """Toggle active in the coordinate system metadata when changing visibility of layer.""" + """Toggle active status in the coordinate system metadata when changing layer visibility. + + Parameters + ---------- + event : Event + Event triggered by changing layer visibility. + """ metadata = event.source.metadata layer_active = metadata.get("_active_in_cs") selected_coordinate_system = self.coordinate_system_widget._system @@ -275,7 +505,12 @@ def _update_visible_in_coordinate_system(self, event: Event) -> None: layer_active.remove(selected_coordinate_system) def _update_layers_visibility(self) -> None: - """Toggle layer visibility dependent on presence in currently selected coordinate system.""" + """Toggle layer visibility based on presence in the currently selected coordinate system. + + Updates the visibility of all layers based on whether they are active in the + currently selected coordinate system. Also updates layer metadata to track + coordinate system information. + """ elements = self.elements_widget._elements coordinate_system = self.coordinate_system_widget._system # No layer selected on first time coordinate system selection @@ -312,6 +547,33 @@ def _sdatas_have_z_axis(sdatas: EventedList) -> bool: return False def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: + """Load and create appropriate layer for shape data. + + Determines the geometry type of the shapes element and calls the appropriate + method to create either a Points layer (for Point geometries) or a Shapes + layer (for Polygon or MultiPolygon geometries). + + Parameters + ---------- + sdata : SpatialData + SpatialData object containing the shapes element. + key : str + Name of the shapes element to load. + selected_cs : str + Selected coordinate system. + multi : bool + Whether multiple SpatialData objects are present. + + Returns + ------- + Shapes or Points + The created napari layer. + + Raises + ------ + TypeError + If the geometry type is not Point, Polygon, or MultiPolygon. + """ original_name = key[: key.rfind("_")] if multi else key if type(sdata.shapes[original_name].iloc[0].geometry) is shapely.geometry.point.Point: diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index 9b2de745..cce03991 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -21,9 +21,11 @@ from qtpy import QtWidgets from qtpy.QtCore import Qt from qtpy.QtWidgets import ( + QCheckBox, QComboBox, QDialog, QGridLayout, + QHBoxLayout, QInputDialog, QLabel, QLineEdit, @@ -41,9 +43,7 @@ from napari_spatialdata._widgets import ( AListWidget, AnnDataSaveDialog, - CBarWidget, ComponentWidget, - RangeSliderWidget, SaveDialog, ScatterAnnotationDialog, ) @@ -453,8 +453,19 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non # Vars var_label = QLabel("Vars:") var_label.setToolTip("Names from `adata.var_names` or `adata.raw.var_names`.") + self.add_in_new_layer_checkbox = QCheckBox("add in new layer") + self.add_in_new_layer_checkbox.setChecked(False) + var_header = QWidget() + var_header_layout = QHBoxLayout(var_header) + var_header_layout.setContentsMargins(0, 0, 0, 0) + var_header_layout.addWidget(var_label) + var_header_layout.addWidget(self.add_in_new_layer_checkbox) + var_header_layout.addStretch() + self.var_widget = AListWidget(self.viewer, self.model, attr="var") self.var_widget.setAdataLayer("X") + self.add_in_new_layer_checkbox.toggled.connect(self._on_add_in_new_layer_toggled) + self.var_widget.load_channels.connect(self._load_channels_in_new_layer) self.viewer.dims.events.current_step.connect(self._channel_changed) @@ -470,7 +481,7 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non self.layout().addWidget(adata_layer_label) self.layout().addWidget(self.adata_layer_widget) - self.layout().addWidget(var_label) + self.layout().addWidget(var_header) self.layout().addWidget(self.var_widget) # obsm @@ -497,19 +508,25 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non self.color_by = QLabel("Colored by:") self.layout().addWidget(self.color_by) - # scalebar - colorbar = CBarWidget(model=self.model) - self.slider = RangeSliderWidget(self.viewer, self.model, colorbar=colorbar) - self._viewer.window.add_dock_widget(self.slider, area="left", name="slider") - self._viewer.window.add_dock_widget(colorbar, area="left", name="colorbar") - self.viewer.layers.selection.events.active.connect(self.slider._onLayerChange) - if (layer := self.viewer.layers.selection.active) is not None and layer.metadata.get("adata") is not None: self._on_layer_update() self.model.events.adata.connect(self._on_layer_update) self.model.events.color_by.connect(self._change_color_by) + def _on_add_in_new_layer_toggled(self, checked: bool) -> None: + self.var_widget.add_in_new_layer = checked + + def _load_channels_in_new_layer(self, channel_names: tuple[str, ...]) -> None: + layer = self.model.layer + if layer is None: + return + layer_name = layer.name + element_name = layer_name[: layer_name.rfind("_ch:")] if "_ch:" in layer_name else layer_name + sdata_widget = self._viewer.window._dock_widgets["SpatialData"].widget() + for channel_name in channel_names: + sdata_widget._enqueue_channel(element_name, channel_name) + def _channel_changed(self, event: Event) -> None: layer = self.model.layer is_image = isinstance(layer, Image) @@ -522,6 +539,14 @@ def _channel_changed(self, event: Event) -> None: return current_point = list(event.value) + data = layer.data[-1] if isinstance(layer.data, MultiScaleData) else layer.data + # after loading a channel into a new layer, the new layer is selected, but one can + # still load additional channels as new layers. If deselecting the checkbox for adding + # channels as new layers, without the next if block one would get a crash due to + # going out of index + if data.ndim < len(current_point): + return + displayed = self._viewer.dims.displayed if layer.multiscale: for i, (lo_size, hi_size, cord) in enumerate( diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index cca44a92..0e28f3b4 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -28,6 +28,7 @@ _get_ellipses_from_circles, _get_init_metadata_adata, _get_transform, + _obtain_channel_image, _transform_coordinates, get_duplicate_element_names, get_napari_version, @@ -451,10 +452,14 @@ def clean_worker(self) -> None: """Clean the worker.""" self.worker = None - def add_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> None: - self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi)) + def add_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> None: + self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi, channel_name)) - def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Image: + def get_sdata_image( + self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None + ) -> Image: """ Add an image in a spatial data object to the viewer. @@ -474,7 +479,12 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: original_name = original_name[: original_name.rfind("_")] affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True) - rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) + if channel_name: + image = _obtain_channel_image(element=sdata.images[original_name], channel_name=channel_name) + rgb = False + key = key + f"_ch:{channel_name}" + else: + image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name]) @@ -482,7 +492,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: # TODO: type check return Image( - rgb_image, + image, rgb=rgb, name=key, affine=affine, diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index 7cab1d64..e27e5acd 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -18,12 +18,7 @@ from napari.viewer import Viewer from qtpy import QtCore, QtWidgets from qtpy.QtCore import Qt, Signal -from sklearn.preprocessing import MinMaxScaler from spatialdata._types import ArrayLike -from superqt import QRangeSlider -from vispy import scene -from vispy.color.colormap import Colormap, MatplotlibColormap -from vispy.scene.widgets import ColorBarWidget # See https://github.com/scverse/squidpy/issues/1061 for more details. # Scanpy 0.11.x-0.12.x renamed set_default_colors_for_categorical_obs to _set_default_colors_for_categorical_obs @@ -36,7 +31,7 @@ from napari_spatialdata._model import DataModel from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version -__all__ = ["AListWidget", "CBarWidget", "RangeSliderWidget", "ComponentWidget"] +__all__ = ["AListWidget", "ComponentWidget"] # label string: attribute name # TODO(giovp): remove since layer controls private? @@ -63,8 +58,9 @@ def __init__(self, viewer: napari.Viewer | None, unique: bool = True, multiselec self._index: int | str = 0 self._unique = unique self._viewer = viewer + self._pre_click_selection: tuple[str, ...] = () - self.itemDoubleClicked.connect(lambda item: self._onAction((item.text(),))) + self.itemDoubleClicked.connect(self._on_item_double_clicked) self.enterPressed.connect(self._onAction) self.indexChanged.connect(self._onAction) @@ -93,6 +89,17 @@ def addItems(self, labels: str | Iterable[str] | None) -> None: super().addItems(labels) # self.sortItems(QtCore.Qt.AscendingOrder) + def mousePressEvent(self, event: QtCore.QEvent) -> None: + self._pre_click_selection = tuple(s.text() for s in self.selectedItems()) + super().mousePressEvent(event) + + def _on_item_double_clicked(self, item: QtWidgets.QListWidgetItem) -> None: + pre = self._pre_click_selection + if len(pre) > 1 and item.text() in pre: + self._onAction(pre) + else: + self._onAction((item.text(),)) + def keyPressEvent(self, event: QtCore.QEvent) -> None: if event.key() == QtCore.Qt.Key_Return: event.accept() @@ -103,6 +110,7 @@ def keyPressEvent(self, event: QtCore.QEvent) -> None: class AListWidget(ListWidget): layerChanged = Signal() + load_channels = Signal(object) # emits a tuple[str, ...] of channel names def __init__(self, viewer: Viewer | None, model: DataModel, attr: str, **kwargs: Any): if attr != "None" and attr not in DataModel.VALID_ATTRIBUTES: @@ -111,8 +119,8 @@ def __init__(self, viewer: Viewer | None, model: DataModel, attr: str, **kwargs: self._viewer = viewer self._model = model - self._attr = attr + self.add_in_new_layer = False if attr == "None": self._getter: Callable[..., Any] = lambda: None @@ -128,10 +136,17 @@ def _onChange(self) -> None: self.addItems(self.model.get_items(self._attr)) def _onAction(self, items: Iterable[str]) -> None: + channels_to_load: list[str] = [] for item in sorted(set(items)): if isinstance(self.model.layer, (Image)): - i = self.model.layer.metadata["adata"].var.index.get_loc(item) - self.viewer.dims.set_point(0, i) + if self.add_in_new_layer: + channels_to_load.append(item) + else: + layer = self.model.layer + data_ndim = layer.data[-1].ndim if layer.multiscale else layer.data.ndim + if data_ndim > 2: + i = layer.metadata["adata"].var.index.get_loc(item) + self.viewer.dims.set_point(0, i) else: vec, name, index = self._getter(item, index=self.getIndex()) @@ -168,6 +183,9 @@ def _onAction(self, items: Iterable[str]) -> None: # TODO(giovp): make layer editable? # self.viewer.layers[layer_name].editable = False + if channels_to_load: + self.load_channels.emit(tuple(channels_to_load)) + def setAdataLayer(self, layer: str | None) -> None: if layer in ("default", "None", "X"): layer = None @@ -416,193 +434,6 @@ def attr(self, field: str | None) -> None: self._attr = field -class CBarWidget(QtWidgets.QWidget): - FORMAT = "{0:0.2f}" - - cmapChanged = Signal(str) - climChanged = Signal((float, float)) - - def __init__( - self, - model: DataModel, - cmap: str = "viridis", - label: str | None = None, - width: int | None = 250, - height: int | None = 50, - **kwargs: Any, - ): - super().__init__(**kwargs) - - self._model = model - - self._clim = (0.0, 1.0) - self._oclim = self._clim - - self._width = width - self._height = height - self._label = label - - self.__init_UI() - - def __init_UI(self) -> None: - self.setFixedWidth(self._width) - self.setFixedHeight(self._height) - - # use napari's BG color for dark mode - self._canvas = scene.SceneCanvas( - size=(self._width, self._height), bgcolor="#262930", parent=self, decorate=False, resizable=False, dpi=150 - ) - self._colorbar = ColorBarWidget( - self._create_colormap(self.cmap), - orientation="top", - label=self._label, - label_color="white", - clim=self.getClim(), - border_width=1.0, - border_color="black", - padding=(0.3, 0.167), - axis_ratio=0.05, - ) - - self._canvas.central_widget.add_widget(self._colorbar) - - self.climChanged.connect(self.onClimChanged) - self.cmapChanged.connect(self.onCmapChanged) - - def _create_colormap(self, cmap: str) -> Colormap: - ominn, omaxx = self.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - - assert 0 <= minn <= 1, f"Expected `min` to be in `[0, 1]`, found `{minn}`" - assert 0 <= maxx <= 1, f"Expected `maxx` to be in `[0, 1]`, found `{maxx}`" - - cm = MatplotlibColormap(cmap) - - return Colormap(cm[np.linspace(minn, maxx, len(cm.colors))], interpolation="linear") - - def getCmap(self) -> str: - return self.cmap - - def onCmapChanged(self, value: str) -> None: - # this does not trigger update for some reason... - self._colorbar.cmap = self._create_colormap(value) - self._colorbar._colorbar._update() - - def setClim(self, value: tuple[float, float]) -> None: - if value == self._clim: - return - - self._clim = value - self.climChanged.emit(*value) - - def getClim(self) -> tuple[float, float]: - return self._clim - - def getOclim(self) -> tuple[float, float]: - return self._oclim - - def setOclim(self, value: tuple[float, float]) -> None: - # original color limit used for 0-1 normalization - self._oclim = value - - def onClimChanged(self, minn: float, maxx: float) -> None: - # ticks are not working with vispy's colorbar - self._colorbar.cmap = self._create_colormap(self.cmap) - self._colorbar.clim = (self.FORMAT.format(minn), self.FORMAT.format(maxx)) - - def getCanvas(self) -> scene.SceneCanvas: - return self._canvas - - def getColorBar(self) -> ColorBarWidget: - return self._colorbar - - def setLayout(self, layout: QtWidgets.QLayout) -> None: - layout.addWidget(self.getCanvas().native) - super().setLayout(layout) - - def update_color(self) -> None: - # when changing selected layers that have the same limit - # could also trigger it as self._colorbar.clim = self.getClim() - # but the above option also updates geometry - # cbarwidget->cbar->cbarvisual - self._colorbar._colorbar._colorbar._update() - - @property - def cmap(self) -> str: - return self._model.cmap - - -class RangeSliderWidget(QRangeSlider): - def __init__(self, viewer: Viewer, model: DataModel, colorbar: CBarWidget, **kwargs: Any): - super().__init__(**kwargs) - - self._viewer = viewer - self._model = model - self._colorbar = colorbar - self._cmap = plt.get_cmap(self._colorbar.cmap) - self.setValue((0, 100)) - self.setSliderPosition((0, 100)) - self.setSingleStep(0.01) - self.setOrientation(Qt.Horizontal) - self.valueChanged.connect(self._onValueChange) - - def _onLayerChange(self) -> None: - layer = self.viewer.layers.selection.active - if layer is not None: - self._onValueChange((0, 100)) - - def _onValueChange(self, percentile: tuple[float, float]) -> None: - layer = self.viewer.layers.selection.active - # TODO(michalk8): use constants - if "data" not in layer.metadata: - return None # noqa: RET501 - v = layer.metadata["data"] - # this code is currently not used since the slider is not enabled; so I silenced the mypy error; 2. there is a - # mismatch for this error with the mypy in the CI, so I silenced the unused-ignore from the local mypy. - # when this code is re-enabled, let's fix mypy - clipped = np.clip(v, *np.percentile(v, percentile)) # type: ignore[misc,unused-ignore] - - if isinstance(layer, Points): - layer.metadata = {**layer.metadata, "perc": percentile} - layer.face_color = "value" - layer.properties = {"value": clipped} - layer.refresh_colors() - elif isinstance(layer, Labels): - norm_vec = self._scale_vec(clipped) - color_vec = self._cmap(norm_vec) - layer.color = dict(zip(layer.color.keys(), color_vec, strict=False)) - layer.properties = {"value": clipped} - layer.refresh() - - self._colorbar.setOclim(layer.metadata["minmax"]) - self._colorbar.setClim((np.min(layer.properties["value"]), np.max(layer.properties["value"]))) - self._colorbar.update_color() - - def _scale_vec(self, vec: ArrayLike) -> ArrayLike: - ominn, omaxx = self._colorbar.getOclim() - delta = omaxx - ominn + 1e-12 - - minn, maxx = self._colorbar.getClim() - minn = (minn - ominn) / delta - maxx = (maxx - ominn) / delta - scaler = MinMaxScaler(feature_range=(minn, maxx)) - return scaler.fit_transform(vec.reshape(-1, 1)) - - @property - def viewer(self) -> napari.Viewer: - """:mod:`napari` viewer.""" - return self._viewer - - @property - def model(self) -> DataModel: - """:mod:`napari` viewer.""" - return self._model - - class SaveDialog(QtWidgets.QDialog): def __init__(self, layer: Layer, table_name: str) -> None: super().__init__() diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index f46d8cef..0472ae39 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from functools import wraps from random import randint -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, TypeVar import numpy as np import packaging.version @@ -44,7 +44,7 @@ from napari.utils.events import EventedList from qtpy.QtWidgets import QListWidgetItem - from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget + from napari_spatialdata._sdata_widgets import ListWidget from spatialdata._types import ArrayLike @@ -252,6 +252,31 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi return out +def _datatree_to_dataarray_list(new_raster: DataArray | DataTree) -> DataArray | list[DataArray]: + if isinstance(new_raster, DataTree): + list_of_xdata = [] + for k in new_raster: + v = new_raster[k].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + list_of_xdata.append(xdata) + return list_of_xdata + return new_raster + + +def _obtain_channel_image(element: DataArray | DataTree, channel_name: str | int) -> DataArray | list[DataArray]: + is_multiscale_int_ch = isinstance(element, DataTree) and np.issubdtype( + element["scale0"].c.to_numpy().dtype, np.integer + ) + is_int_ch = isinstance(element, DataArray) and np.issubdtype(element.c.to_numpy().dtype, np.integer) + if isinstance(channel_name, str) and (is_multiscale_int_ch or is_int_ch): + channel_name = int(channel_name) + + # works for both DataArray and DataTree + new_raster = element.sel(c=channel_name) + return _datatree_to_dataarray_list(new_raster) + + def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]: """Swap the axes to y, x, c and check if an image supports rgb(a) visualization. @@ -295,14 +320,7 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l rgb = False new_raster = element - if isinstance(new_raster, DataTree): - list_of_xdata = [] - for k in new_raster: - v = new_raster[k].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - list_of_xdata.append(xdata) - new_raster = list_of_xdata + new_raster = _datatree_to_dataarray_list(new_raster) return new_raster, rgb @@ -418,9 +436,7 @@ def _get_init_metadata_adata(sdata: SpatialData, table_name: str | None, element return adata -def get_itemindex_by_text( - list_widget: CoordinateSystemWidget | ElementWidget, item_text: str -) -> None | QListWidgetItem: +def get_itemindex_by_text(list_widget: ListWidget, item_text: str) -> None | QListWidgetItem: """ Get the item in a listwidget based on its text. @@ -543,3 +559,7 @@ def block_signals(widget: QObject) -> Generator[None]: yield finally: widget.blockSignals(False) + + +WidgetType = Literal["coordinate_system", "element", "channel"] +F = TypeVar("F", bound=Callable[..., Any]) diff --git a/tests/conftest.py b/tests/conftest.py index 553f1cea..178f2dcc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,7 +118,7 @@ def _safe_get_max_texture_sizes(): # type: ignore[no-untyped-def] from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs -from spatialdata.models import PointsModel, ShapesModel, TableModel +from spatialdata.models import Image2DModel, PointsModel, ShapesModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation from napari_spatialdata.utils._test_utils import export_figure, save_image @@ -231,6 +231,18 @@ def sdata_blobs() -> SpatialData: return blobs() +@pytest.fixture() +def sdata_channel_images() -> SpatialData: + sdata = blobs() + sdata["blobs_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"] + ) + sdata["blobs_multiscale_image_str_ch"] = Image2DModel.parse( + sdata["blobs_image"], c_coords=["channel1", "channel2", "channel3"], scale_factors=[2, 2] + ) + return sdata + + @pytest.fixture def image(): _, image = _get_blobs_galaxy() diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index 504f08dd..4c2b5009 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -41,7 +41,7 @@ def test_3d_points_visualization( widget = SdataWidget(viewer, EventedList([sdata_3d_points])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("points_3d") viewer.dims.ndisplay = 3 @@ -76,7 +76,7 @@ def test_2_5d_shapes_visualization( widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("shapes_2.5d") assert len(viewer.layers) == 1 @@ -111,7 +111,7 @@ def test_2_5d_circles_visualization( widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("circles_2.5d") assert len(viewer.layers) == 1 @@ -145,7 +145,7 @@ def test_affine_transform_preserves_dimensionality( widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("points_3d") assert len(viewer.layers) == 1 @@ -196,7 +196,7 @@ def test_save_points_z_handling( widget = SdataWidget(viewer, EventedList([sdata_3d_points])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("points_3d") layer = viewer.layers[0] @@ -245,7 +245,7 @@ def test_save_shapes_z_handling( widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("shapes_2.5d") layer = viewer.layers[0] @@ -287,7 +287,7 @@ def test_toggle_affects_loaded_points( widget = SdataWidget(viewer, EventedList([sdata_3d_points])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("points_3d") assert viewer.layers[0].data.shape[1] == 2 @@ -336,7 +336,7 @@ def test_mixed_dimension_visualization( widget = SdataWidget(viewer, EventedList([combined_sdata])) widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("points_3d") assert viewer.layers[0].data.shape[1] == points_dim diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index 44320efc..e3da69cc 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -22,7 +22,7 @@ from xarray import DataArray, DataTree from napari_spatialdata import QtAdataViewWidget -from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget +from napari_spatialdata._sdata_widgets import ListWidget, SdataWidget from napari_spatialdata.constants import config from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem from tests.conftest import OFFSCREEN @@ -32,10 +32,10 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = ElementWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "element") assert widget._sdata is not None assert not widget._elements - widget._onItemChange("global") + widget._onCsItemChange("global") assert widget._elements for name in blobs_extra_cs.images: assert widget._elements[name]["element_type"] == "images" @@ -50,7 +50,7 @@ def test_elementwidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): def test_coordinatewidget(make_napari_viewer: Any, blobs_extra_cs: SpatialData): _ = make_napari_viewer() - widget = CoordinateSystemWidget(EventedList([blobs_extra_cs])) + widget = ListWidget(EventedList([blobs_extra_cs]), "coordinate_system") items = [widget.item(x).text() for x in range(widget.count())] assert len(items) == len(blobs_extra_cs.coordinate_systems) for item in items: @@ -62,13 +62,13 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.images.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert isinstance(widget.viewer_model.viewer.layers[0], Image) assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.images.keys())[0] blobs_extra_cs.images["image"] = to_multiscale(blobs_extra_cs.images["blobs_image"], [2, 4]) - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick("image") assert len(widget.viewer_model.viewer.layers) == 2 @@ -76,12 +76,79 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData del blobs_extra_cs.images["image"] +@pytest.mark.parametrize( + "images", [["blobs_image", "blobs_image_str_ch"], ["blobs_multiscale_image", "blobs_multiscale_image_str_ch"]] +) +def test_channel_selection(qtbot, make_napari_viewer, sdata_channel_images, images): + """Test selecting a channel from an image via _onClick (the underlying mechanism used by 'add in new layer').""" + # Create a viewer + viewer = make_napari_viewer() + + # Create the SdataWidget + widget = SdataWidget(viewer, EventedList([sdata_channel_images])) + + # Click on 'global' coordinate system + center_pos = get_center_pos_listitem(widget.coordinate_system_widget, "global") + click_list_widget_item(qtbot, widget.coordinate_system_widget, center_pos, "currentItemChanged") + + # Click on the image element + center_pos = get_center_pos_listitem(widget.elements_widget, images[0]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + # Load a specific channel via _onClick (triggered by "add in new layer" checkbox in the View widget) + widget._onClick(images[0], "1") + + # Verify that the layer has been added with the correct name and data + assert len(viewer.layers) == 1 + assert viewer.layers[0].name == f"{images[0]}_ch:1" + + # Verify that the layer contains only the selected channel + assert viewer.layers[0].data.shape == (512, 512) + + center_pos = get_center_pos_listitem(widget.elements_widget, images[1]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + widget._onClick(images[1], "channel2") + + assert len(viewer.layers) == 2 + assert viewer.layers[1].name == f"{images[1]}_ch:channel2" + + +@pytest.mark.parametrize( + "images", [["blobs_image", "blobs_image_str_ch"], ["blobs_multiscale_image", "blobs_multiscale_image_str_ch"]] +) +def test_add_in_new_layer_multiple_channels(qtbot, make_napari_viewer, sdata_channel_images, images): + """Test that enqueueing multiple channels loads each as a separate layer in order.""" + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_channel_images])) + + center_pos = get_center_pos_listitem(widget.coordinate_system_widget, "global") + click_list_widget_item(qtbot, widget.coordinate_system_widget, center_pos, "currentItemChanged") + center_pos = get_center_pos_listitem(widget.elements_widget, images[0]) + click_list_widget_item(qtbot, widget.elements_widget, center_pos, "currentItemChanged") + + # Enqueue all three channels in one go — mirrors what _load_channels_in_new_layer does + # when the user has multiple vars selected with "add in new layer" checked. + for ch in ["0", "1", "2"]: + widget._enqueue_channel(images[0], ch) + + qtbot.waitUntil(lambda: len(viewer.layers) == 3, timeout=5000) + + layer_names = {layer.name for layer in viewer.layers} + for ch in ["0", "1", "2"]: + assert f"{images[0]}_ch:{ch}" in layer_names + + for layer in viewer.layers: + assert isinstance(layer, Image) + assert layer.data.shape == (512, 512) + + def test_sdatawidget_labels(qtbot, make_napari_viewer: Any, blobs_extra_cs: SpatialData): viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.labels.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.labels.keys())[0] @@ -123,7 +190,7 @@ def test_sdatawidget_points(caplog, make_napari_viewer: Any, blobs_extra_cs: Spa widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") + widget.elements_widget._onCsItemChange("global") widget._onClick(list(blobs_extra_cs.points.keys())[0]) assert len(widget.viewer_model.viewer.layers) == 1 assert widget.viewer_model.viewer.layers[0].name == list(blobs_extra_cs.points.keys())[0]