From f13e9436ccc3e71c3b3ad6abb0f90b320aa20666 Mon Sep 17 00:00:00 2001 From: asarigun Date: Tue, 14 Apr 2026 17:17:41 +0200 Subject: [PATCH 1/5] fix the Issue #31 --- src/napari_spatialdata/_viewer.py | 54 +++-- src/napari_spatialdata/utils/_utils.py | 31 ++- tests/conftest.py | 58 +++++- tests/test_3d_visualization.py | 265 ++++++++++++++++++++++++- 4 files changed, 386 insertions(+), 22 deletions(-) diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 8d5e3d71..ba279085 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -189,9 +189,6 @@ def _save_points_to_sdata( raise ValueError("Cannot export a points element with no points") transformed_data = np.array([layer_to_save.data_to_world(xy) for xy in layer_to_save.data]) swap_data = np.fliplr(transformed_data) - # ignore z axis if present - if swap_data.shape[1] == 3: - swap_data = swap_data[:, :2] parsed = PointsModel.parse(swap_data, transformations=transformation) # saving to disk of points temporarily disabled until the interface update that will unify the view widget, @@ -261,14 +258,21 @@ def _save_shapes_to_sdata( for shape in layer_to_save._data_view.shapes ] - def _fix_coords(coords: ArrayLike) -> ArrayLike: - remove_z = coords.shape[1] == 3 - first_index = 1 if remove_z else 0 - coords = coords[:, first_index::] - return np.fliplr(coords) + has_z = coords[0].shape[1] == 3 - polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords] - gdf = GeoDataFrame({"geometry": polygons}) + def _fix_coords(coords: ArrayLike) -> tuple[ArrayLike, float | None]: + if coords.shape[1] == 3: + z_val = float(coords[0, 0]) + yx = coords[:, 1:] + return np.fliplr(yx), z_val + return np.fliplr(coords), None + + fixed = [_fix_coords(p) for p in coords] + polygons: list[Polygon] = [Polygon(xy) for xy, _ in fixed] + gdf_dict: dict[str, Any] = {"geometry": polygons} + if has_z: + gdf_dict["z"] = [z_val for _, z_val in fixed] + gdf = GeoDataFrame(gdf_dict) force_2d(gdf) parsed = ShapesModel.parse(gdf, transformations=transformation) @@ -514,11 +518,15 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult original_name = original_name[: original_name.rfind("_")] df = sdata.shapes[original_name] - affine = _get_transform(sdata.shapes[original_name], selected_cs) + axes = get_axes_names(df) + include_z = "z" in axes and not config.PROJECT_2_5D_SHAPES_TO_2D + affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z) - # 2.5D circles not supported yet xy = np.array([df.geometry.x, df.geometry.y]).T yx = np.fliplr(xy) + if include_z: + z_vals = df["z"].to_numpy() + yx = np.column_stack([z_vals, yx]) radii = df.radius.to_numpy() adata, table_name, table_names = self._get_table_data(sdata, original_name) @@ -804,8 +812,28 @@ def _affine_transform_layers(self, coordinate_system: str) -> None: sdata = metadata["sdata"] element_name = metadata["name"] element_data = sdata[element_name] - affine = _get_transform(element_data, coordinate_system) + include_z = self._should_include_z(element_data) + affine = _get_transform(element_data, coordinate_system, include_z=include_z) if affine is not None: layer.affine = affine if layer._type_string == "points": self._adjust_radii_of_points_layer(layer, affine) + + @staticmethod + def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool: + """Determine whether to include the z axis for a given spatial element. + + For raster data (images, labels) z is always included when present. + For vector data (points, shapes) z inclusion depends on the user-facing + projection config flags. + """ + from xarray import DataArray, DataTree + + if isinstance(element, DataArray | DataTree): + return True + axes = get_axes_names(element) + if "z" not in axes: + return False + if isinstance(element, DaskDataFrame): + return not config.PROJECT_3D_POINTS_TO_2D + return not config.PROJECT_2_5D_SHAPES_TO_2D diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index cfe37e23..e44328b4 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -462,10 +462,13 @@ def generate_random_color_hex() -> str: def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: """Convert circles to ellipses. + Supports both 2D (y, x) and 2.5D (z, y, x) centroids. For 2.5D input the radius is + applied only to y and x while z is kept constant across the four corner vertices. + Parameters ---------- yx - Centroids of the circles. + Centroids of the circles with shape ``(N, 2)`` or ``(N, 3)``. radii Radii of the circles. @@ -475,13 +478,27 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: Ellipses. """ ndim = yx.shape[1] - assert ndim == 2 - r = np.stack([radii] * ndim, axis=1) - lower_left = yx - r - upper_right = yx + r + assert ndim in (2, 3) + + if ndim == 3: + z = yx[:, :1] + yx_2d = yx[:, 1:] + else: + yx_2d = yx + + r = np.stack([radii, radii], axis=1) + lower_left = yx_2d - r + upper_right = yx_2d + r r[:, 0] = -r[:, 0] - lower_right = yx - r - upper_left = yx + r + lower_right = yx_2d - r + upper_left = yx_2d + r + + if ndim == 3: + lower_left = np.column_stack([z, lower_left]) + lower_right = np.column_stack([z, lower_right]) + upper_right = np.column_stack([z, upper_right]) + upper_left = np.column_stack([z, upper_left]) + ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1) assert isinstance(ellipses, np.ndarray) return ellipses diff --git a/tests/conftest.py b/tests/conftest.py index 5089971a..553f1cea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,7 @@ def _safe_get_max_texture_sizes(): # type: ignore[no-untyped-def] from spatialdata._types import ArrayLike from spatialdata.datasets import blobs from spatialdata.models import PointsModel, ShapesModel, TableModel -from spatialdata.transformations import Identity, set_transformation +from spatialdata.transformations import Affine, Identity, set_transformation from napari_spatialdata.utils._test_utils import export_figure, save_image @@ -415,3 +415,59 @@ def sdata_2_5d_shapes() -> SpatialData: shapes["shapes_2.5d"] = shape_element return SpatialData(shapes=shapes) + + +@pytest.fixture +def sdata_2_5d_circles() -> SpatialData: + """Create a SpatialData object with 2.5D circles (circles at different z levels).""" + n_circles = 10 + rng = np.random.default_rng(SEED) + gdf = gpd.GeoDataFrame( + { + "geometry": gpd.points_from_xy( + rng.uniform(0, 100, n_circles), + rng.uniform(0, 100, n_circles), + ), + "radius": rng.uniform(5, 15, n_circles), + "z": rng.uniform(0, 50, n_circles), + } + ) + circles = ShapesModel.parse(gdf) + set_transformation(circles, {"global": Identity()}, set_all=True) + + return SpatialData(shapes={"circles_2.5d": circles}) + + +@pytest.fixture +def sdata_3d_points_two_cs() -> SpatialData: + """Create a SpatialData with 3D points registered to two coordinate systems. + + The element lives in ``global`` (identity) and in ``scaled`` (2x scale + with a 10-unit z-translation). This is useful for testing that + ``_affine_transform_layers`` produces a correctly-sized affine matrix + when switching between coordinate systems. + """ + n_points = 5 + rng = np.random.default_rng(SEED) + df = pd.DataFrame( + { + "x": rng.uniform(0, 100, n_points), + "y": rng.uniform(0, 100, n_points), + "z": rng.uniform(0, 50, n_points), + } + ) + dask_df = from_pandas(df, npartitions=1) + points = PointsModel.parse(dask_df) + + affine_matrix = np.array( + [ + [2.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 0.0, 2.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + scaled_affine = Affine(affine_matrix, input_axes=("x", "y", "z"), output_axes=("x", "y", "z")) + set_transformation(points, {"global": Identity(), "scaled": scaled_affine}, set_all=True) + + return SpatialData(points={"points_3d": points}) diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index 41ee43f3..c8658c0a 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -1,14 +1,17 @@ -"""Tests for 3D points and 2.5D shapes visualization. +"""Tests for 3D points, 2.5D shapes and 2.5D circles visualization. For debugging tips on how to visually inspect tests, see docs/contributing.md. """ +from pathlib import Path from typing import Any +import numpy as np import pytest from napari.layers import Points, Shapes from napari.utils.events import EventedList from spatialdata import SpatialData +from spatialdata.models import get_axes_names from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata.constants import config @@ -111,6 +114,266 @@ def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: S config.PROJECT_2_5D_SHAPES_TO_2D = original_value +class Test2_5DCirclesVisualization: + """Test 2.5D circles visualization in napari.""" + + def test_2_5d_circles_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_circles: SpatialData): + """Test that 2.5D circles are projected to 2D when config flag is True.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("circles_2.5d") + + assert len(viewer.layers) == 1 + # 2D projection: coordinates should have 2 values (y, x) + assert viewer.layers[0].data.shape[1] == 2 + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_2_5d_circles_full_3d(self, make_napari_viewer: Any, sdata_2_5d_circles: SpatialData): + """Test that 2.5D circles are visualized in 3D when config flag is False.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("circles_2.5d") + + assert len(viewer.layers) == 1 + # Full 3D: coordinates should have 3 values (z, y, x) + assert viewer.layers[0].data.shape[1] == 3 + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + +class TestAffineTransformLayers: + """Test that _affine_transform_layers propagates include_z correctly.""" + + def test_affine_transform_preserves_3d_for_points( + self, + make_napari_viewer: Any, + sdata_3d_points_two_cs: SpatialData, + ): + """Switching coordinate system must produce a 4x4 affine for 3D points.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + assert len(viewer.layers) == 1 + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == 3 + + # Identity in "global" -> affine should be 4x4 identity + np.testing.assert_array_almost_equal(layer.affine.affine_matrix, np.eye(4)) + + # Switch to the "scaled" coordinate system + widget.coordinate_system_widget._select_coord_sys("scaled") + widget.viewer_model._affine_transform_layers("scaled") + + # After switching, the affine must still be 4x4 (not 3x3) + assert layer.affine.affine_matrix.shape == (4, 4) + assert not np.allclose(layer.affine.affine_matrix, np.eye(4)) + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_affine_transform_projects_to_2d_when_configured( + self, + make_napari_viewer: Any, + sdata_3d_points_two_cs: SpatialData, + ): + """When projection is enabled the affine must be 3x3 (2D).""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + assert len(viewer.layers) == 1 + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == 2 + + widget.coordinate_system_widget._select_coord_sys("scaled") + widget.viewer_model._affine_transform_layers("scaled") + + # Projected to 2D -> affine stays 3x3 + assert layer.affine.affine_matrix.shape == (3, 3) + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class TestSavePointsPreservesZ: + """Test that saving 3D points preserves the z coordinate.""" + + def test_save_3d_points_preserves_z( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """Saving a 3D points layer must retain the z column in the stored element.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + tmpdir = tmp_path / "sdata.zarr" + sdata_3d_points.write(tmpdir) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == 3 + + original_z = sdata_3d_points.points["points_3d"].compute()["z"].values.copy() + + parsed, _ = widget.viewer_model._save_points_to_sdata(layer, "points_3d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert "z" in saved_axes, "z axis must be preserved after save" + + saved_z = parsed.compute()["z"].values + np.testing.assert_array_almost_equal(saved_z, original_z) + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_save_2d_points_no_z( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """When projected to 2D, saved points must not contain a z column.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + tmpdir = tmp_path / "sdata.zarr" + sdata_3d_points.write(tmpdir) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == 2 + + parsed, _ = widget.viewer_model._save_points_to_sdata(layer, "points_3d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert "z" not in saved_axes + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class TestSaveShapesPreservesZ: + """Test that saving 2.5D shapes preserves the z coordinate.""" + + def test_save_2_5d_shapes_preserves_z( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + ): + """Saving a 2.5D shapes layer must retain the z column in the stored element.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = False + + tmpdir = tmp_path / "sdata.zarr" + sdata_2_5d_shapes.write(tmpdir) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("shapes_2.5d") + + layer = viewer.layers[0] + assert isinstance(layer, Shapes) + for shape_data in layer.data: + assert shape_data.shape[1] == 3 + + parsed, _ = widget.viewer_model._save_shapes_to_sdata(layer, "shapes_2.5d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert "z" in saved_axes, "z axis must be preserved after save" + + saved_z = parsed["z"].values + original_unique_z = np.unique(sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values) + np.testing.assert_array_almost_equal( + np.unique(saved_z), original_unique_z + ) + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_save_2d_shapes_no_z( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + ): + """When projected to 2D, saved shapes must not contain a z column.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = True + + tmpdir = tmp_path / "sdata.zarr" + sdata_2_5d_shapes.write(tmpdir) + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("shapes_2.5d") + + layer = viewer.layers[0] + assert isinstance(layer, Shapes) + for shape_data in layer.data: + assert shape_data.shape[1] == 2 + + parsed, _ = widget.viewer_model._save_shapes_to_sdata(layer, "shapes_2.5d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert "z" not in saved_axes + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + class TestMixed2D3DVisualization: """Test mixed 2D and 3D visualization scenarios.""" From 1bd1d90ce299de50473a0d5f91cf4165e18beeac Mon Sep 17 00:00:00 2001 From: asarigun Date: Tue, 14 Apr 2026 18:00:44 +0200 Subject: [PATCH 2/5] fix Issue #31 --- src/napari_spatialdata/_sdata_widgets.py | 81 ++++++++- src/napari_spatialdata/_viewer.py | 73 +++++++++ tests/test_3d_visualization.py | 200 +++++++++++++++++++++++ 3 files changed, 353 insertions(+), 1 deletion(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 9915706a..206dedd2 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -15,11 +15,22 @@ from packaging.version import parse as parse_version from qtpy.QtCore import QThread, Signal from qtpy.QtGui import QIcon -from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget +from qtpy.QtWidgets import ( + QCheckBox, + QHBoxLayout, + QLabel, + QListWidget, + QListWidgetItem, + QProgressBar, + QVBoxLayout, + QWidget, +) +from superqt import QDoubleRangeSlider from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM from napari_spatialdata._viewer import SpatialDataViewer +from napari_spatialdata.constants import config from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping @@ -174,11 +185,45 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.slider.setRange(0, 0) self.slider.setVisible(False) + self.enable_3d_points = QCheckBox("Enable 3D points") + self.enable_3d_points.setChecked(not config.PROJECT_3D_POINTS_TO_2D) + self.enable_3d_points.setToolTip("When checked, points with a z coordinate are displayed in 3D.") + self.enable_3d_points.toggled.connect(self._on_3d_points_toggled) + + self.enable_2_5d_shapes = QCheckBox("Enable 2.5D shapes") + self.enable_2_5d_shapes.setChecked(not config.PROJECT_2_5D_SHAPES_TO_2D) + self.enable_2_5d_shapes.setToolTip("When checked, shapes with a z coordinate are displayed in 2.5D.") + self.enable_2_5d_shapes.toggled.connect(self._on_2_5d_shapes_toggled) + + self.z_range_label = QLabel("Z range:") + self.z_range_value_label = QLabel("") + z_range_header = QHBoxLayout() + z_range_header.addWidget(self.z_range_label) + z_range_header.addWidget(self.z_range_value_label) + z_range_header.addStretch() + self.z_range_header_widget = QWidget() + self.z_range_header_widget.setLayout(z_range_header) + + self.z_range_slider = QDoubleRangeSlider() + self.z_range_slider.setRange(0.0, 1.0) + self.z_range_slider.setValue((0.0, 1.0)) + self.z_range_slider.setToolTip("Filter visible points and shapes by z coordinate range.") + self.z_range_slider.valueChanged.connect(self._on_z_range_changed) + + self._z_slider_visible = False + self.z_range_header_widget.setVisible(False) + self.z_range_slider.setVisible(False) + self.layout().addWidget(self.slider) self.layout().addWidget(QLabel("Coordinate System:")) self.layout().addWidget(self.coordinate_system_widget) self.layout().addWidget(QLabel("Elements:")) self.layout().addWidget(self.elements_widget) + self.layout().addWidget(QLabel("3D Settings:")) + self.layout().addWidget(self.enable_3d_points) + self.layout().addWidget(self.enable_2_5d_shapes) + self.layout().addWidget(self.z_range_header_widget) + self.layout().addWidget(self.z_range_slider) self.elements_widget.itemDoubleClicked.connect(self._on_click_item) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.elements_widget._onItemChange(item.text()) @@ -196,12 +241,14 @@ def __init__(self, viewer: Viewer, sdata: EventedList): def _on_insert_layer(self, event: Event) -> None: layer = event.value layer.events.visible.connect(self._update_visible_in_coordinate_system) + self._update_z_slider() def _on_click_item(self, item: QListWidgetItem) -> None: self._onClick(item.text()) def _hide_slider(self) -> None: self.slider.setVisible(False) + self._update_z_slider() def _onClick(self, text: str) -> None: selected_cs = self.coordinate_system_widget._system @@ -258,6 +305,38 @@ def _update_layers_visibility(self) -> None: layer.metadata["_active_in_cs"].add(coordinate_system) layer.metadata["_current_cs"] = coordinate_system + def _on_3d_points_toggled(self, checked: bool) -> None: + config.PROJECT_3D_POINTS_TO_2D = not checked + + def _on_2_5d_shapes_toggled(self, checked: bool) -> None: + config.PROJECT_2_5D_SHAPES_TO_2D = not checked + + def _update_z_slider(self) -> None: + """Show the z-range slider when layers with z data are present and update its range.""" + z_range = self.viewer_model.get_z_range() + if z_range is None: + self.z_range_header_widget.setVisible(False) + self.z_range_slider.setVisible(False) + self._z_slider_visible = False + return + + z_min, z_max = z_range + if z_min == z_max: + z_max = z_min + 1.0 + + self.z_range_slider.setRange(z_min, z_max) + if not self._z_slider_visible: + self.z_range_slider.setValue((z_min, z_max)) + self._z_slider_visible = True + self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]") + self.z_range_header_widget.setVisible(True) + self.z_range_slider.setVisible(True) + + def _on_z_range_changed(self, value: tuple[float, float]) -> None: + z_min, z_max = value + self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]") + self.viewer_model.filter_layers_by_z_range(z_min, z_max) + def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: original_name = key[: key.rfind("_")] if multi else key diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index ba279085..7e1e0628 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -837,3 +837,76 @@ def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool: if isinstance(element, DaskDataFrame): return not config.PROJECT_3D_POINTS_TO_2D return not config.PROJECT_2_5D_SHAPES_TO_2D + + def get_z_range(self) -> tuple[float, float] | None: + """Return the global (min, max) z range across all visible layers, or ``None`` if no z data exists.""" + z_min, z_max = float("inf"), float("-inf") + found = False + for layer in self.viewer.layers: + metadata = layer.metadata + if not metadata.get("sdata"): + continue + sdata = metadata["sdata"] + element_name = metadata["name"] + element_data = sdata[element_name] + axes = get_axes_names(element_data) + if "z" not in axes: + continue + if isinstance(element_data, DaskDataFrame): + z_vals = element_data["z"].compute().values + elif isinstance(element_data, GeoDataFrame): + if "z" not in element_data.columns: + continue + z_vals = element_data["z"].values + else: + continue + if len(z_vals) == 0: + continue + found = True + z_min = min(z_min, float(z_vals.min())) + z_max = max(z_max, float(z_vals.max())) + if not found: + return None + return z_min, z_max + + def filter_layers_by_z_range(self, z_min: float, z_max: float) -> None: + """Hide points/shapes outside the given z range. + + For :class:`~napari.layers.Points` layers the ``shown`` property is + used. For :class:`~napari.layers.Shapes` layers the face and edge + color alpha channels are set to 0 for shapes outside the range. + """ + for layer in self.viewer.layers: + metadata = layer.metadata + if not metadata.get("sdata"): + continue + sdata = metadata["sdata"] + element_name = metadata["name"] + element_data = sdata[element_name] + axes = get_axes_names(element_data) + if "z" not in axes: + continue + + if isinstance(layer, Points): + if layer.data.shape[1] == 3: + z_vals = layer.data[:, 0] + else: + continue + mask = (z_vals >= z_min) & (z_vals <= z_max) + layer.shown = mask + + elif isinstance(layer, Shapes): + if isinstance(element_data, GeoDataFrame) and "z" in element_data.columns: + z_vals = element_data["z"].values + n_shapes = len(layer.data) + if len(z_vals) != n_shapes: + continue + mask = (z_vals >= z_min) & (z_vals <= z_max) + face_colors = layer.face_color.copy() + edge_colors = layer.edge_color.copy() + face_colors[~mask, 3] = 0.0 + face_colors[mask, 3] = 1.0 + edge_colors[~mask, 3] = 0.0 + edge_colors[mask, 3] = 1.0 + layer.face_color = face_colors + layer.edge_color = edge_colors diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index c8658c0a..d3f67035 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -374,6 +374,206 @@ def test_save_2d_shapes_no_z( config.PROJECT_2_5D_SHAPES_TO_2D = original_value +class TestUIToggle: + """Test the 3D settings checkboxes in SdataWidget.""" + + def test_toggle_3d_points_checkbox(self, make_napari_viewer: Any): + """Toggling the 3D points checkbox must update the config flag.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([])) + + assert not widget.enable_3d_points.isChecked() + assert config.PROJECT_3D_POINTS_TO_2D is True + + widget.enable_3d_points.setChecked(True) + assert config.PROJECT_3D_POINTS_TO_2D is False + + widget.enable_3d_points.setChecked(False) + assert config.PROJECT_3D_POINTS_TO_2D is True + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_toggle_2_5d_shapes_checkbox(self, make_napari_viewer: Any): + """Toggling the 2.5D shapes checkbox must update the config flag.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([])) + + assert not widget.enable_2_5d_shapes.isChecked() + assert config.PROJECT_2_5D_SHAPES_TO_2D is True + + widget.enable_2_5d_shapes.setChecked(True) + assert config.PROJECT_2_5D_SHAPES_TO_2D is False + + widget.enable_2_5d_shapes.setChecked(False) + assert config.PROJECT_2_5D_SHAPES_TO_2D is True + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_toggle_affects_loaded_points( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """Loading points after toggling 3D on must produce 3D coordinates.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == 2 + + viewer.layers.clear() + + widget.enable_3d_points.setChecked(True) + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == 3 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class TestZBinning: + """Test z-range slider filtering for points and shapes.""" + + def test_z_slider_hidden_without_z_data(self, make_napari_viewer: Any): + """The z-range slider must be hidden when no layer has z data.""" + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([])) + assert not widget._z_slider_visible + + def test_z_slider_appears_with_3d_points( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """Adding a 3D points layer must activate the z-range slider.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + assert not widget._z_slider_visible + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + widget._update_z_slider() + assert widget._z_slider_visible + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_filter_points_by_z_range( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """Narrowing the z range must hide points outside the range via the shown mask.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == 3 + assert np.all(layer.shown) + + z_vals = layer.data[:, 0] + z_mid = (z_vals.min() + z_vals.max()) / 2.0 + widget.viewer_model.filter_layers_by_z_range(z_mid, z_vals.max()) + + expected_mask = (z_vals >= z_mid) & (z_vals <= z_vals.max()) + np.testing.assert_array_equal(layer.shown, expected_mask) + assert not np.all(layer.shown) + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + def test_filter_shapes_by_z_range( + self, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + ): + """Narrowing the z range must set alpha=0 for shapes outside the range.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("shapes_2.5d") + + layer = viewer.layers[0] + assert isinstance(layer, Shapes) + + widget.viewer_model.filter_layers_by_z_range(15.0, 25.0) + + z_vals = sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values + n_shapes = len(layer.data) + if len(z_vals) == n_shapes: + expected_visible = (z_vals >= 15.0) & (z_vals <= 25.0) + for i, visible in enumerate(expected_visible): + if visible: + assert layer.edge_color[i, 3] == 1.0 + else: + assert layer.edge_color[i, 3] == 0.0 + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + def test_get_z_range( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """get_z_range must return the global min/max z across all layers.""" + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = False + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + assert widget.viewer_model.get_z_range() is None + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") + + z_range = widget.viewer_model.get_z_range() + assert z_range is not None + z_min, z_max = z_range + + actual_z = sdata_3d_points.points["points_3d"].compute()["z"].values + np.testing.assert_almost_equal(z_min, actual_z.min()) + np.testing.assert_almost_equal(z_max, actual_z.max()) + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + class TestMixed2D3DVisualization: """Test mixed 2D and 3D visualization scenarios.""" From bc31dc5b028873577068f0f02947821fe6bacd5f Mon Sep 17 00:00:00 2001 From: asarigun Date: Wed, 15 Apr 2026 09:00:17 +0200 Subject: [PATCH 3/5] fix Issue #31 --- src/napari_spatialdata/_viewer.py | 40 ++++++++++++++++++++----------- tests/test_3d_visualization.py | 18 +++++++------- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 7e1e0628..70de58ef 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -874,7 +874,8 @@ def filter_layers_by_z_range(self, z_min: float, z_max: float) -> None: For :class:`~napari.layers.Points` layers the ``shown`` property is used. For :class:`~napari.layers.Shapes` layers the face and edge - color alpha channels are set to 0 for shapes outside the range. + color alpha channels are set to 0 for shapes outside the range while + preserving the original alpha for visible shapes. """ for layer in self.viewer.layers: metadata = layer.metadata @@ -896,17 +897,28 @@ def filter_layers_by_z_range(self, z_min: float, z_max: float) -> None: layer.shown = mask elif isinstance(layer, Shapes): - if isinstance(element_data, GeoDataFrame) and "z" in element_data.columns: - z_vals = element_data["z"].values - n_shapes = len(layer.data) - if len(z_vals) != n_shapes: + n_shapes = len(layer.data) + if n_shapes == 0: + continue + + if layer.data[0].shape[1] == 3: + z_vals = np.array([float(s[0, 0]) for s in layer.data]) + elif isinstance(element_data, GeoDataFrame) and "z" in element_data.columns: + z_raw = element_data["z"].values + if len(z_raw) != n_shapes: continue - mask = (z_vals >= z_min) & (z_vals <= z_max) - face_colors = layer.face_color.copy() - edge_colors = layer.edge_color.copy() - face_colors[~mask, 3] = 0.0 - face_colors[mask, 3] = 1.0 - edge_colors[~mask, 3] = 0.0 - edge_colors[mask, 3] = 1.0 - layer.face_color = face_colors - layer.edge_color = edge_colors + z_vals = z_raw + else: + continue + + if "_original_face_color" not in metadata: + metadata["_original_face_color"] = layer.face_color.copy() + metadata["_original_edge_color"] = layer.edge_color.copy() + + mask = (z_vals >= z_min) & (z_vals <= z_max) + face_colors = metadata["_original_face_color"].copy() + edge_colors = metadata["_original_edge_color"].copy() + face_colors[~mask, 3] = 0.0 + edge_colors[~mask, 3] = 0.0 + layer.face_color = face_colors + layer.edge_color = edge_colors diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index d3f67035..fb636230 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -530,17 +530,17 @@ def test_filter_shapes_by_z_range( layer = viewer.layers[0] assert isinstance(layer, Shapes) + z_vals = np.array([float(s[0, 0]) for s in layer.data]) + assert len(z_vals) == len(layer.data) + widget.viewer_model.filter_layers_by_z_range(15.0, 25.0) - z_vals = sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values - n_shapes = len(layer.data) - if len(z_vals) == n_shapes: - expected_visible = (z_vals >= 15.0) & (z_vals <= 25.0) - for i, visible in enumerate(expected_visible): - if visible: - assert layer.edge_color[i, 3] == 1.0 - else: - assert layer.edge_color[i, 3] == 0.0 + expected_visible = (z_vals >= 15.0) & (z_vals <= 25.0) + for i, visible in enumerate(expected_visible): + if visible: + assert layer.edge_color[i, 3] > 0.0 + else: + assert layer.edge_color[i, 3] == 0.0 finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value From 8e8a83b1ba8baa9eb39cb3dba2adb65fb233661b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 07:07:32 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/napari_spatialdata/_sdata_widgets.py | 2 +- tests/test_3d_visualization.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 206dedd2..15d4dbea 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -25,9 +25,9 @@ QVBoxLayout, QWidget, ) -from superqt import QDoubleRangeSlider from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM +from superqt import QDoubleRangeSlider from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants import config diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index fb636230..49adfed2 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -334,9 +334,7 @@ def test_save_2_5d_shapes_preserves_z( saved_z = parsed["z"].values original_unique_z = np.unique(sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values) - np.testing.assert_array_almost_equal( - np.unique(saved_z), original_unique_z - ) + np.testing.assert_array_almost_equal(np.unique(saved_z), original_unique_z) finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value From a2c765dff9ea9eb65f39001e12dd9e5748005984 Mon Sep 17 00:00:00 2001 From: asarigun Date: Thu, 14 May 2026 23:32:32 +0200 Subject: [PATCH 5/5] address review feedback PR#393 --- src/napari_spatialdata/_sdata_widgets.py | 106 ++--- src/napari_spatialdata/_viewer.py | 102 +---- src/napari_spatialdata/utils/_utils.py | 48 ++- tests/test_3d_visualization.py | 484 +++++------------------ 4 files changed, 201 insertions(+), 539 deletions(-) diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 034f0b84..549f43ca 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -17,7 +17,6 @@ from qtpy.QtGui import QIcon from qtpy.QtWidgets import ( QCheckBox, - QHBoxLayout, QLabel, QListWidget, QListWidgetItem, @@ -27,7 +26,6 @@ ) from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM -from superqt import QDoubleRangeSlider from napari_spatialdata._viewer import SpatialDataViewer from napari_spatialdata.constants import config @@ -185,45 +183,39 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.slider.setRange(0, 0) self.slider.setVisible(False) - self.enable_3d_points = QCheckBox("Enable 3D points") - self.enable_3d_points.setChecked(not config.PROJECT_3D_POINTS_TO_2D) - self.enable_3d_points.setToolTip("When checked, points with a z coordinate are displayed in 3D.") - self.enable_3d_points.toggled.connect(self._on_3d_points_toggled) - - self.enable_2_5d_shapes = QCheckBox("Enable 2.5D shapes") - self.enable_2_5d_shapes.setChecked(not config.PROJECT_2_5D_SHAPES_TO_2D) - self.enable_2_5d_shapes.setToolTip("When checked, shapes with a z coordinate are displayed in 2.5D.") - self.enable_2_5d_shapes.toggled.connect(self._on_2_5d_shapes_toggled) - - self.z_range_label = QLabel("Z range:") - self.z_range_value_label = QLabel("") - z_range_header = QHBoxLayout() - z_range_header.addWidget(self.z_range_label) - z_range_header.addWidget(self.z_range_value_label) - z_range_header.addStretch() - self.z_range_header_widget = QWidget() - self.z_range_header_widget.setLayout(z_range_header) - - self.z_range_slider = QDoubleRangeSlider() - self.z_range_slider.setRange(0.0, 1.0) - self.z_range_slider.setValue((0.0, 1.0)) - self.z_range_slider.setToolTip("Filter visible points and shapes by z coordinate range.") - self.z_range_slider.valueChanged.connect(self._on_z_range_changed) - - self._z_slider_visible = False - self.z_range_header_widget.setVisible(False) - self.z_range_slider.setVisible(False) + self.discard_z_points = QCheckBox("Discard z for 3D points") + self.discard_z_points.setChecked(config.PROJECT_3D_POINTS_TO_2D) + self.discard_z_points.setToolTip( + "When checked, the z coordinate of new points layers is discarded so they are loaded in 2D. " + "Only applies to new layers; layers already displayed are not affected." + ) + self.discard_z_points.toggled.connect(self._on_discard_z_points_toggled) + + self.discard_z_shapes = QCheckBox("Discard z for 2.5D shapes") + self.discard_z_shapes.setChecked(config.PROJECT_2_5D_SHAPES_TO_2D) + self.discard_z_shapes.setToolTip( + "When checked, the z coordinate of new shapes layers is discarded so they are loaded in 2D. " + "Only applies to new layers; layers already displayed are not affected." + ) + self.discard_z_shapes.toggled.connect(self._on_discard_z_shapes_toggled) + + # The 3D toggles only matter when at least one element across the loaded + # SpatialData objects has a z axis. Otherwise we hide them to save screen + # real estate for users working with 2D-only data. + self._has_z_data = self._sdatas_have_z_axis(self._sdata) + self._three_d_settings_label = QLabel("3D Settings:") + self._three_d_settings_label.setVisible(self._has_z_data) + self.discard_z_points.setVisible(self._has_z_data) + self.discard_z_shapes.setVisible(self._has_z_data) self.layout().addWidget(self.slider) self.layout().addWidget(QLabel("Coordinate System:")) self.layout().addWidget(self.coordinate_system_widget) self.layout().addWidget(QLabel("Elements:")) self.layout().addWidget(self.elements_widget) - self.layout().addWidget(QLabel("3D Settings:")) - self.layout().addWidget(self.enable_3d_points) - self.layout().addWidget(self.enable_2_5d_shapes) - self.layout().addWidget(self.z_range_header_widget) - self.layout().addWidget(self.z_range_slider) + self.layout().addWidget(self._three_d_settings_label) + self.layout().addWidget(self.discard_z_points) + self.layout().addWidget(self.discard_z_shapes) self.elements_widget.itemDoubleClicked.connect(self._on_click_item) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.elements_widget._onItemChange(item.text()) @@ -241,14 +233,12 @@ def __init__(self, viewer: Viewer, sdata: EventedList): def _on_insert_layer(self, event: Event) -> None: layer = event.value layer.events.visible.connect(self._update_visible_in_coordinate_system) - self._update_z_slider() def _on_click_item(self, item: QListWidgetItem) -> None: self._onClick(item.text()) def _hide_slider(self) -> None: self.slider.setVisible(False) - self._update_z_slider() def _onClick(self, text: str) -> None: selected_cs = self.coordinate_system_widget._system @@ -303,37 +293,23 @@ def _update_layers_visibility(self) -> None: layer.metadata["_active_in_cs"].add(coordinate_system) layer.metadata["_current_cs"] = coordinate_system - def _on_3d_points_toggled(self, checked: bool) -> None: - config.PROJECT_3D_POINTS_TO_2D = not checked + def _on_discard_z_points_toggled(self, checked: bool) -> None: + config.PROJECT_3D_POINTS_TO_2D = checked - def _on_2_5d_shapes_toggled(self, checked: bool) -> None: - config.PROJECT_2_5D_SHAPES_TO_2D = not checked + def _on_discard_z_shapes_toggled(self, checked: bool) -> None: + config.PROJECT_2_5D_SHAPES_TO_2D = checked - def _update_z_slider(self) -> None: - """Show the z-range slider when layers with z data are present and update its range.""" - z_range = self.viewer_model.get_z_range() - if z_range is None: - self.z_range_header_widget.setVisible(False) - self.z_range_slider.setVisible(False) - self._z_slider_visible = False - return + @staticmethod + def _sdatas_have_z_axis(sdatas: EventedList) -> bool: + """Return ``True`` if any element across the given ``SpatialData`` objects has a z axis. - z_min, z_max = z_range - if z_min == z_max: - z_max = z_min + 1.0 - - self.z_range_slider.setRange(z_min, z_max) - if not self._z_slider_visible: - self.z_range_slider.setValue((z_min, z_max)) - self._z_slider_visible = True - self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]") - self.z_range_header_widget.setVisible(True) - self.z_range_slider.setVisible(True) - - def _on_z_range_changed(self, value: tuple[float, float]) -> None: - z_min, z_max = value - self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]") - self.viewer_model.filter_layers_by_z_range(z_min, z_max) + Used to decide whether to expose the 3D / 2.5D projection toggles in the widget. + """ + for sdata in sdatas: + for _, _, element in sdata._gen_elements(): + if SpatialDataViewer._has_z_axis(element): + return True + return False def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: original_name = key[: key.rfind("_")] if multi else key diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 70de58ef..cca44a92 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -569,7 +569,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult else: kwargs |= {"border_color": "white"} # useful code to have readily available to debug the correct radius of circles when represented as points - ellipses = _get_ellipses_from_circles(yx=yx, radii=radii) + ellipses = _get_ellipses_from_circles(coords=yx, radii=radii) layer = Shapes( ellipses, shape_type="ellipse", @@ -819,6 +819,21 @@ def _affine_transform_layers(self, coordinate_system: str) -> None: if layer._type_string == "points": self._adjust_radii_of_points_layer(layer, affine) + @staticmethod + def _has_z_axis(element: Any) -> bool: + """Return ``True`` if ``element`` exposes a ``z`` axis. + + For raster elements (images / labels) the ``z`` axis is reported by + :func:`spatialdata.models.get_axes_names`. For vector elements (points + as :class:`~dask.dataframe.DataFrame`, shapes as + :class:`~geopandas.GeoDataFrame`) the same helper is used. + """ + from xarray import DataArray, DataTree + + if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): + return False + return "z" in get_axes_names(element) + @staticmethod def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool: """Determine whether to include the z axis for a given spatial element. @@ -837,88 +852,3 @@ def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool: if isinstance(element, DaskDataFrame): return not config.PROJECT_3D_POINTS_TO_2D return not config.PROJECT_2_5D_SHAPES_TO_2D - - def get_z_range(self) -> tuple[float, float] | None: - """Return the global (min, max) z range across all visible layers, or ``None`` if no z data exists.""" - z_min, z_max = float("inf"), float("-inf") - found = False - for layer in self.viewer.layers: - metadata = layer.metadata - if not metadata.get("sdata"): - continue - sdata = metadata["sdata"] - element_name = metadata["name"] - element_data = sdata[element_name] - axes = get_axes_names(element_data) - if "z" not in axes: - continue - if isinstance(element_data, DaskDataFrame): - z_vals = element_data["z"].compute().values - elif isinstance(element_data, GeoDataFrame): - if "z" not in element_data.columns: - continue - z_vals = element_data["z"].values - else: - continue - if len(z_vals) == 0: - continue - found = True - z_min = min(z_min, float(z_vals.min())) - z_max = max(z_max, float(z_vals.max())) - if not found: - return None - return z_min, z_max - - def filter_layers_by_z_range(self, z_min: float, z_max: float) -> None: - """Hide points/shapes outside the given z range. - - For :class:`~napari.layers.Points` layers the ``shown`` property is - used. For :class:`~napari.layers.Shapes` layers the face and edge - color alpha channels are set to 0 for shapes outside the range while - preserving the original alpha for visible shapes. - """ - for layer in self.viewer.layers: - metadata = layer.metadata - if not metadata.get("sdata"): - continue - sdata = metadata["sdata"] - element_name = metadata["name"] - element_data = sdata[element_name] - axes = get_axes_names(element_data) - if "z" not in axes: - continue - - if isinstance(layer, Points): - if layer.data.shape[1] == 3: - z_vals = layer.data[:, 0] - else: - continue - mask = (z_vals >= z_min) & (z_vals <= z_max) - layer.shown = mask - - elif isinstance(layer, Shapes): - n_shapes = len(layer.data) - if n_shapes == 0: - continue - - if layer.data[0].shape[1] == 3: - z_vals = np.array([float(s[0, 0]) for s in layer.data]) - elif isinstance(element_data, GeoDataFrame) and "z" in element_data.columns: - z_raw = element_data["z"].values - if len(z_raw) != n_shapes: - continue - z_vals = z_raw - else: - continue - - if "_original_face_color" not in metadata: - metadata["_original_face_color"] = layer.face_color.copy() - metadata["_original_edge_color"] = layer.edge_color.copy() - - mask = (z_vals >= z_min) & (z_vals <= z_max) - face_colors = metadata["_original_face_color"].copy() - edge_colors = metadata["_original_edge_color"].copy() - face_colors[~mask, 3] = 0.0 - edge_colors[~mask, 3] = 0.0 - layer.face_color = face_colors - layer.edge_color = edge_colors diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index e44328b4..f46d8cef 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -184,6 +184,32 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: def _get_transform( element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None ) -> None | ArrayLike: + """Return the affine matrix for ``element`` in the given coordinate system. + + The z axis is included in the returned affine when **both**: + + * ``include_z`` is truthy, **and** + * the element (and therefore its underlying transformation) has a ``z`` axis, + as reported by :func:`spatialdata.models.get_axes_names`. + + If ``include_z`` is requested but the element / transformation does not expose a + ``z`` axis, the flag is silently ignored and a 2D ``(y, x)`` affine is returned. + + Parameters + ---------- + element + The :class:`spatialdata.models.SpatialElement` for which to compute the affine. + coordinate_system_name + Coordinate system to use. If ``None``, the first available is selected. + include_z + Whether to include the z axis in the affine. The z is only included when the + element / transformation also has a z axis; otherwise this flag is ignored. + + Returns + ------- + The affine matrix as an ``ArrayLike`` (``(3, 3)`` for 2D and ``(4, 4)`` for 2.5D/3D), + or ``None`` if no transformation is defined for the requested coordinate system. + """ if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): raise RuntimeError("Cannot get transform for {type(element)}") @@ -459,16 +485,17 @@ def generate_random_color_hex() -> str: return f"#{randint(0, 255):02x}{randint(0, 255):02x}{randint(0, 255):02x}ff" -def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: +def _get_ellipses_from_circles(coords: ArrayLike, radii: ArrayLike) -> ArrayLike: """Convert circles to ellipses. - Supports both 2D (y, x) and 2.5D (z, y, x) centroids. For 2.5D input the radius is - applied only to y and x while z is kept constant across the four corner vertices. + Supports both 2D and 2.5D centroids. For 2.5D input the radius is applied only to + y and x while z is kept constant across the four corner vertices. Parameters ---------- - yx - Centroids of the circles with shape ``(N, 2)`` or ``(N, 3)``. + coords + Centroids of the circles with shape ``(N, 2)`` in ``(y, x)`` order or ``(N, 3)`` + in ``(z, y, x)`` order. radii Radii of the circles. @@ -477,14 +504,15 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: ArrayLike Ellipses. """ - ndim = yx.shape[1] - assert ndim in (2, 3) + ndim = coords.shape[1] + if ndim not in (2, 3): + raise ValueError(f"Expected centroids with 2 or 3 columns (yx or zyx), got shape {coords.shape}.") if ndim == 3: - z = yx[:, :1] - yx_2d = yx[:, 1:] + z = coords[:, :1] + yx_2d = coords[:, 1:] else: - yx_2d = yx + yx_2d = coords r = np.stack([radii, radii], axis=1) lower_left = yx_2d - r diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index 49adfed2..504f08dd 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -20,32 +20,22 @@ class Test3DPointsVisualization: """Test 3D points visualization in napari.""" - def test_3d_points_projected_to_2d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): - """Test that 3D points are projected to 2D when config flag is True.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - viewer.dims.ndisplay = 3 - - assert len(viewer.layers) == 1 - assert isinstance(viewer.layers[0], Points) - # 2D projection: points should have 2 coordinates - assert viewer.layers[0].data.shape[1] == 2 - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): - """Test that 3D points are visualized in 3D when config flag is False.""" + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_3d_points_visualization( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Points dimensionality follows the ``PROJECT_3D_POINTS_TO_2D`` config flag.""" original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_3D_POINTS_TO_2D = False + config.PROJECT_3D_POINTS_TO_2D = project_to_2d viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_3d_points])) @@ -57,8 +47,7 @@ def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: Spati assert len(viewer.layers) == 1 assert isinstance(viewer.layers[0], Points) - # Full 3D: points should have 3 coordinates (z, y, x) - assert viewer.layers[0].data.shape[1] == 3 + assert viewer.layers[0].data.shape[1] == expected_ndim finally: config.PROJECT_3D_POINTS_TO_2D = original_value @@ -66,50 +55,34 @@ def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: Spati class Test2_5DShapesVisualization: """Test 2.5D shapes visualization in napari.""" - def test_2_5d_shapes_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): - """Test that 2.5D shapes are projected to 2D when config flag is True.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D - try: - config.PROJECT_2_5D_SHAPES_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - - # Add 2.5D shapes - widget._onClick("shapes_2.5d") - - assert len(viewer.layers) == 1 - assert isinstance(viewer.layers[0], Shapes) - # 2D projection: shape coordinates should have 2 values per vertex (y, x) - for shape_data in viewer.layers[0].data: - assert shape_data.shape[1] == 2 - - finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value - - def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): - """Test that 2.5D shapes are visualized in 3D when config flag is False.""" + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_2_5d_shapes_visualization( + self, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Shape vertex dimensionality follows the ``PROJECT_2_5D_SHAPES_TO_2D`` config flag.""" original_value = config.PROJECT_2_5D_SHAPES_TO_2D try: - config.PROJECT_2_5D_SHAPES_TO_2D = False + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) widget.coordinate_system_widget._select_coord_sys("global") widget.elements_widget._onItemChange("global") - - # Add 2.5D shapes widget._onClick("shapes_2.5d") assert len(viewer.layers) == 1 assert isinstance(viewer.layers[0], Shapes) - # Full 3D: shape coordinates should have 3 values per vertex (z, y, x) for shape_data in viewer.layers[0].data: - assert shape_data.shape[1] == 3 + assert shape_data.shape[1] == expected_ndim finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value @@ -117,30 +90,22 @@ def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: S class Test2_5DCirclesVisualization: """Test 2.5D circles visualization in napari.""" - def test_2_5d_circles_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_circles: SpatialData): - """Test that 2.5D circles are projected to 2D when config flag is True.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D - try: - config.PROJECT_2_5D_SHAPES_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("circles_2.5d") - - assert len(viewer.layers) == 1 - # 2D projection: coordinates should have 2 values (y, x) - assert viewer.layers[0].data.shape[1] == 2 - finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value - - def test_2_5d_circles_full_3d(self, make_napari_viewer: Any, sdata_2_5d_circles: SpatialData): - """Test that 2.5D circles are visualized in 3D when config flag is False.""" + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_2_5d_circles_visualization( + self, + make_napari_viewer: Any, + sdata_2_5d_circles: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Circles dimensionality follows the ``PROJECT_2_5D_SHAPES_TO_2D`` config flag.""" original_value = config.PROJECT_2_5D_SHAPES_TO_2D try: - config.PROJECT_2_5D_SHAPES_TO_2D = False + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) @@ -150,24 +115,31 @@ def test_2_5d_circles_full_3d(self, make_napari_viewer: Any, sdata_2_5d_circles: widget._onClick("circles_2.5d") assert len(viewer.layers) == 1 - # Full 3D: coordinates should have 3 values (z, y, x) - assert viewer.layers[0].data.shape[1] == 3 + assert viewer.layers[0].data.shape[1] == expected_ndim finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value class TestAffineTransformLayers: - """Test that _affine_transform_layers propagates include_z correctly.""" + """Test that ``_affine_transform_layers`` propagates ``include_z`` correctly.""" - def test_affine_transform_preserves_3d_for_points( + @pytest.mark.parametrize( + ("project_to_2d", "expected_data_ndim", "expected_affine_shape"), + [(False, 3, (4, 4)), (True, 2, (3, 3))], + ids=["full_3d", "projected_to_2d"], + ) + def test_affine_transform_preserves_dimensionality( self, make_napari_viewer: Any, sdata_3d_points_two_cs: SpatialData, + project_to_2d: bool, + expected_data_ndim: int, + expected_affine_shape: tuple[int, int], ): - """Switching coordinate system must produce a 4x4 affine for 3D points.""" + """Switching coordinate system preserves the affine matrix dimensionality.""" original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_3D_POINTS_TO_2D = False + config.PROJECT_3D_POINTS_TO_2D = project_to_2d viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) @@ -179,65 +151,43 @@ def test_affine_transform_preserves_3d_for_points( assert len(viewer.layers) == 1 layer = viewer.layers[0] assert isinstance(layer, Points) - assert layer.data.shape[1] == 3 + assert layer.data.shape[1] == expected_data_ndim - # Identity in "global" -> affine should be 4x4 identity - np.testing.assert_array_almost_equal(layer.affine.affine_matrix, np.eye(4)) + # Identity in "global" -> affine should be the identity of the expected shape + np.testing.assert_array_almost_equal(layer.affine.affine_matrix, np.eye(expected_affine_shape[0])) - # Switch to the "scaled" coordinate system widget.coordinate_system_widget._select_coord_sys("scaled") widget.viewer_model._affine_transform_layers("scaled") - # After switching, the affine must still be 4x4 (not 3x3) - assert layer.affine.affine_matrix.shape == (4, 4) - assert not np.allclose(layer.affine.affine_matrix, np.eye(4)) - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_affine_transform_projects_to_2d_when_configured( - self, - make_napari_viewer: Any, - sdata_3d_points_two_cs: SpatialData, - ): - """When projection is enabled the affine must be 3x3 (2D).""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - - assert len(viewer.layers) == 1 - layer = viewer.layers[0] - assert isinstance(layer, Points) - assert layer.data.shape[1] == 2 - - widget.coordinate_system_widget._select_coord_sys("scaled") - widget.viewer_model._affine_transform_layers("scaled") - - # Projected to 2D -> affine stays 3x3 - assert layer.affine.affine_matrix.shape == (3, 3) + # After switching the affine must keep its dimensionality + assert layer.affine.affine_matrix.shape == expected_affine_shape + if not project_to_2d: + assert not np.allclose(layer.affine.affine_matrix, np.eye(expected_affine_shape[0])) finally: config.PROJECT_3D_POINTS_TO_2D = original_value class TestSavePointsPreservesZ: - """Test that saving 3D points preserves the z coordinate.""" + """Test that saving points correctly handles the z coordinate.""" - def test_save_3d_points_preserves_z( + @pytest.mark.parametrize( + ("project_to_2d", "expected_data_ndim", "z_in_axes"), + [(False, 3, True), (True, 2, False)], + ids=["preserve_z", "drop_z"], + ) + def test_save_points_z_handling( self, tmp_path: Path, make_napari_viewer: Any, sdata_3d_points: SpatialData, + project_to_2d: bool, + expected_data_ndim: int, + z_in_axes: bool, ): - """Saving a 3D points layer must retain the z column in the stored element.""" + """Saving a 3D points layer must retain or drop the z column based on the config flag.""" original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_3D_POINTS_TO_2D = False + config.PROJECT_3D_POINTS_TO_2D = project_to_2d tmpdir = tmp_path / "sdata.zarr" sdata_3d_points.write(tmpdir) @@ -251,66 +201,42 @@ def test_save_3d_points_preserves_z( layer = viewer.layers[0] assert isinstance(layer, Points) - assert layer.data.shape[1] == 3 - - original_z = sdata_3d_points.points["points_3d"].compute()["z"].values.copy() + assert layer.data.shape[1] == expected_data_ndim parsed, _ = widget.viewer_model._save_points_to_sdata(layer, "points_3d", overwrite=True) saved_axes = get_axes_names(parsed) - assert "z" in saved_axes, "z axis must be preserved after save" + assert ("z" in saved_axes) is z_in_axes - saved_z = parsed.compute()["z"].values - np.testing.assert_array_almost_equal(saved_z, original_z) - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_save_2d_points_no_z( - self, - tmp_path: Path, - make_napari_viewer: Any, - sdata_3d_points: SpatialData, - ): - """When projected to 2D, saved points must not contain a z column.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = True - - tmpdir = tmp_path / "sdata.zarr" - sdata_3d_points.write(tmpdir) - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - - layer = viewer.layers[0] - assert isinstance(layer, Points) - assert layer.data.shape[1] == 2 - - parsed, _ = widget.viewer_model._save_points_to_sdata(layer, "points_3d", overwrite=True) - - saved_axes = get_axes_names(parsed) - assert "z" not in saved_axes + if z_in_axes: + original_z = sdata_3d_points.points["points_3d"].compute()["z"].values + saved_z = parsed.compute()["z"].values + np.testing.assert_array_almost_equal(saved_z, original_z) finally: config.PROJECT_3D_POINTS_TO_2D = original_value class TestSaveShapesPreservesZ: - """Test that saving 2.5D shapes preserves the z coordinate.""" + """Test that saving shapes correctly handles the z coordinate.""" - def test_save_2_5d_shapes_preserves_z( + @pytest.mark.parametrize( + ("project_to_2d", "expected_vertex_ndim", "z_in_axes"), + [(False, 3, True), (True, 2, False)], + ids=["preserve_z", "drop_z"], + ) + def test_save_shapes_z_handling( self, tmp_path: Path, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData, + project_to_2d: bool, + expected_vertex_ndim: int, + z_in_axes: bool, ): - """Saving a 2.5D shapes layer must retain the z column in the stored element.""" + """Saving a 2.5D shapes layer must retain or drop the z column based on the config flag.""" original_value = config.PROJECT_2_5D_SHAPES_TO_2D try: - config.PROJECT_2_5D_SHAPES_TO_2D = False + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d tmpdir = tmp_path / "sdata.zarr" sdata_2_5d_shapes.write(tmpdir) @@ -325,49 +251,17 @@ def test_save_2_5d_shapes_preserves_z( layer = viewer.layers[0] assert isinstance(layer, Shapes) for shape_data in layer.data: - assert shape_data.shape[1] == 3 + assert shape_data.shape[1] == expected_vertex_ndim parsed, _ = widget.viewer_model._save_shapes_to_sdata(layer, "shapes_2.5d", overwrite=True) saved_axes = get_axes_names(parsed) - assert "z" in saved_axes, "z axis must be preserved after save" - - saved_z = parsed["z"].values - original_unique_z = np.unique(sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values) - np.testing.assert_array_almost_equal(np.unique(saved_z), original_unique_z) - finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value + assert ("z" in saved_axes) is z_in_axes - def test_save_2d_shapes_no_z( - self, - tmp_path: Path, - make_napari_viewer: Any, - sdata_2_5d_shapes: SpatialData, - ): - """When projected to 2D, saved shapes must not contain a z column.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D - try: - config.PROJECT_2_5D_SHAPES_TO_2D = True - - tmpdir = tmp_path / "sdata.zarr" - sdata_2_5d_shapes.write(tmpdir) - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("shapes_2.5d") - - layer = viewer.layers[0] - assert isinstance(layer, Shapes) - for shape_data in layer.data: - assert shape_data.shape[1] == 2 - - parsed, _ = widget.viewer_model._save_shapes_to_sdata(layer, "shapes_2.5d", overwrite=True) - - saved_axes = get_axes_names(parsed) - assert "z" not in saved_axes + if z_in_axes: + saved_z = parsed["z"].values + original_unique_z = np.unique(sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values) + np.testing.assert_array_almost_equal(np.unique(saved_z), original_unique_z) finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value @@ -375,52 +269,16 @@ def test_save_2d_shapes_no_z( class TestUIToggle: """Test the 3D settings checkboxes in SdataWidget.""" - def test_toggle_3d_points_checkbox(self, make_napari_viewer: Any): - """Toggling the 3D points checkbox must update the config flag.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([])) - - assert not widget.enable_3d_points.isChecked() - assert config.PROJECT_3D_POINTS_TO_2D is True - - widget.enable_3d_points.setChecked(True) - assert config.PROJECT_3D_POINTS_TO_2D is False - - widget.enable_3d_points.setChecked(False) - assert config.PROJECT_3D_POINTS_TO_2D is True - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_toggle_2_5d_shapes_checkbox(self, make_napari_viewer: Any): - """Toggling the 2.5D shapes checkbox must update the config flag.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D - try: - config.PROJECT_2_5D_SHAPES_TO_2D = True - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([])) - - assert not widget.enable_2_5d_shapes.isChecked() - assert config.PROJECT_2_5D_SHAPES_TO_2D is True - - widget.enable_2_5d_shapes.setChecked(True) - assert config.PROJECT_2_5D_SHAPES_TO_2D is False - - widget.enable_2_5d_shapes.setChecked(False) - assert config.PROJECT_2_5D_SHAPES_TO_2D is True - finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value - def test_toggle_affects_loaded_points( self, make_napari_viewer: Any, sdata_3d_points: SpatialData, ): - """Loading points after toggling 3D on must produce 3D coordinates.""" + """Toggling the checkbox affects the dimensionality of newly loaded layers. + + This implicitly also tests that the checkbox state and the underlying + ``config.PROJECT_3D_POINTS_TO_2D`` flag stay in sync. + """ original_value = config.PROJECT_3D_POINTS_TO_2D try: config.PROJECT_3D_POINTS_TO_2D = True @@ -436,142 +294,13 @@ def test_toggle_affects_loaded_points( viewer.layers.clear() - widget.enable_3d_points.setChecked(True) + widget.discard_z_points.setChecked(False) widget._onClick("points_3d") assert viewer.layers[0].data.shape[1] == 3 finally: config.PROJECT_3D_POINTS_TO_2D = original_value -class TestZBinning: - """Test z-range slider filtering for points and shapes.""" - - def test_z_slider_hidden_without_z_data(self, make_napari_viewer: Any): - """The z-range slider must be hidden when no layer has z data.""" - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([])) - assert not widget._z_slider_visible - - def test_z_slider_appears_with_3d_points( - self, - make_napari_viewer: Any, - sdata_3d_points: SpatialData, - ): - """Adding a 3D points layer must activate the z-range slider.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = False - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) - - assert not widget._z_slider_visible - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - - widget._update_z_slider() - assert widget._z_slider_visible - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_filter_points_by_z_range( - self, - make_napari_viewer: Any, - sdata_3d_points: SpatialData, - ): - """Narrowing the z range must hide points outside the range via the shown mask.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = False - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - - layer = viewer.layers[0] - assert isinstance(layer, Points) - assert layer.data.shape[1] == 3 - assert np.all(layer.shown) - - z_vals = layer.data[:, 0] - z_mid = (z_vals.min() + z_vals.max()) / 2.0 - widget.viewer_model.filter_layers_by_z_range(z_mid, z_vals.max()) - - expected_mask = (z_vals >= z_mid) & (z_vals <= z_vals.max()) - np.testing.assert_array_equal(layer.shown, expected_mask) - assert not np.all(layer.shown) - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - def test_filter_shapes_by_z_range( - self, - make_napari_viewer: Any, - sdata_2_5d_shapes: SpatialData, - ): - """Narrowing the z range must set alpha=0 for shapes outside the range.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D - try: - config.PROJECT_2_5D_SHAPES_TO_2D = False - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("shapes_2.5d") - - layer = viewer.layers[0] - assert isinstance(layer, Shapes) - - z_vals = np.array([float(s[0, 0]) for s in layer.data]) - assert len(z_vals) == len(layer.data) - - widget.viewer_model.filter_layers_by_z_range(15.0, 25.0) - - expected_visible = (z_vals >= 15.0) & (z_vals <= 25.0) - for i, visible in enumerate(expected_visible): - if visible: - assert layer.edge_color[i, 3] > 0.0 - else: - assert layer.edge_color[i, 3] == 0.0 - finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value - - def test_get_z_range( - self, - make_napari_viewer: Any, - sdata_3d_points: SpatialData, - ): - """get_z_range must return the global min/max z across all layers.""" - original_value = config.PROJECT_3D_POINTS_TO_2D - try: - config.PROJECT_3D_POINTS_TO_2D = False - - viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) - - assert widget.viewer_model.get_z_range() is None - - widget.coordinate_system_widget._select_coord_sys("global") - widget.elements_widget._onItemChange("global") - widget._onClick("points_3d") - - z_range = widget.viewer_model.get_z_range() - assert z_range is not None - z_min, z_max = z_range - - actual_z = sdata_3d_points.points["points_3d"].compute()["z"].values - np.testing.assert_almost_equal(z_min, actual_z.min()) - np.testing.assert_almost_equal(z_max, actual_z.max()) - finally: - config.PROJECT_3D_POINTS_TO_2D = original_value - - class TestMixed2D3DVisualization: """Test mixed 2D and 3D visualization scenarios.""" @@ -598,7 +327,6 @@ def test_mixed_dimension_visualization( config.PROJECT_3D_POINTS_TO_2D = points_dim == 2 config.PROJECT_2_5D_SHAPES_TO_2D = shapes_dim == 2 - # Create a combined SpatialData combined_sdata = SpatialData( points={"points_3d": sdata_3d_points["points_3d"]}, shapes={"shapes_2.5d": sdata_2_5d_shapes["shapes_2.5d"]},