Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion src/napari_spatialdata/_sdata_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@
from packaging.version import parse as parse_version
from qtpy.QtCore import QThread, Signal
from qtpy.QtGui import QIcon
from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget
from qtpy.QtWidgets import (
QCheckBox,
QLabel,
QListWidget,
QListWidgetItem,
QProgressBar,
QVBoxLayout,
QWidget,
)
from spatialdata import SpatialData
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM

from napari_spatialdata._viewer import SpatialDataViewer
from napari_spatialdata.constants import config
from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD
from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping

Expand Down Expand Up @@ -174,11 +183,39 @@ def __init__(self, viewer: Viewer, sdata: EventedList):
self.slider.setRange(0, 0)
self.slider.setVisible(False)

self.discard_z_points = QCheckBox("Discard z for 3D points")
self.discard_z_points.setChecked(config.PROJECT_3D_POINTS_TO_2D)
self.discard_z_points.setToolTip(
"When checked, the z coordinate of new points layers is discarded so they are loaded in 2D. "
"Only applies to new layers; layers already displayed are not affected."
)
self.discard_z_points.toggled.connect(self._on_discard_z_points_toggled)

self.discard_z_shapes = QCheckBox("Discard z for 2.5D shapes")
self.discard_z_shapes.setChecked(config.PROJECT_2_5D_SHAPES_TO_2D)
self.discard_z_shapes.setToolTip(
"When checked, the z coordinate of new shapes layers is discarded so they are loaded in 2D. "
"Only applies to new layers; layers already displayed are not affected."
)
self.discard_z_shapes.toggled.connect(self._on_discard_z_shapes_toggled)

# The 3D toggles only matter when at least one element across the loaded
# SpatialData objects has a z axis. Otherwise we hide them to save screen
# real estate for users working with 2D-only data.
self._has_z_data = self._sdatas_have_z_axis(self._sdata)
self._three_d_settings_label = QLabel("3D Settings:")
self._three_d_settings_label.setVisible(self._has_z_data)
self.discard_z_points.setVisible(self._has_z_data)
self.discard_z_shapes.setVisible(self._has_z_data)

self.layout().addWidget(self.slider)
self.layout().addWidget(QLabel("Coordinate System:"))
self.layout().addWidget(self.coordinate_system_widget)
self.layout().addWidget(QLabel("Elements:"))
self.layout().addWidget(self.elements_widget)
self.layout().addWidget(self._three_d_settings_label)
self.layout().addWidget(self.discard_z_points)
self.layout().addWidget(self.discard_z_shapes)
self.elements_widget.itemDoubleClicked.connect(self._on_click_item)
self.coordinate_system_widget.currentItemChanged.connect(
lambda item: self.elements_widget._onItemChange(item.text())
Expand Down Expand Up @@ -256,6 +293,24 @@ def _update_layers_visibility(self) -> None:
layer.metadata["_active_in_cs"].add(coordinate_system)
layer.metadata["_current_cs"] = coordinate_system

def _on_discard_z_points_toggled(self, checked: bool) -> None:
config.PROJECT_3D_POINTS_TO_2D = checked

def _on_discard_z_shapes_toggled(self, checked: bool) -> None:
config.PROJECT_2_5D_SHAPES_TO_2D = checked

@staticmethod
def _sdatas_have_z_axis(sdatas: EventedList) -> bool:
"""Return ``True`` if any element across the given ``SpatialData`` objects has a z axis.

Used to decide whether to expose the 3D / 2.5D projection toggles in the widget.
"""
for sdata in sdatas:
for _, _, element in sdata._gen_elements():
if SpatialDataViewer._has_z_axis(element):
return True
return False

def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points:
original_name = key[: key.rfind("_")] if multi else key

Expand Down
71 changes: 57 additions & 14 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def _save_points_to_sdata(
raise ValueError("Cannot export a points element with no points")
transformed_data = np.array([layer_to_save.data_to_world(xy) for xy in layer_to_save.data])
swap_data = np.fliplr(transformed_data)
# ignore z axis if present
if swap_data.shape[1] == 3:
swap_data = swap_data[:, :2]
parsed = PointsModel.parse(swap_data, transformations=transformation)

# saving to disk of points temporarily disabled until the interface update that will unify the view widget,
Expand Down Expand Up @@ -261,14 +258,21 @@ def _save_shapes_to_sdata(
for shape in layer_to_save._data_view.shapes
]

def _fix_coords(coords: ArrayLike) -> ArrayLike:
remove_z = coords.shape[1] == 3
first_index = 1 if remove_z else 0
coords = coords[:, first_index::]
return np.fliplr(coords)
has_z = coords[0].shape[1] == 3

polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords]
gdf = GeoDataFrame({"geometry": polygons})
def _fix_coords(coords: ArrayLike) -> tuple[ArrayLike, float | None]:
if coords.shape[1] == 3:
z_val = float(coords[0, 0])
yx = coords[:, 1:]
return np.fliplr(yx), z_val
return np.fliplr(coords), None

fixed = [_fix_coords(p) for p in coords]
polygons: list[Polygon] = [Polygon(xy) for xy, _ in fixed]
gdf_dict: dict[str, Any] = {"geometry": polygons}
if has_z:
gdf_dict["z"] = [z_val for _, z_val in fixed]
gdf = GeoDataFrame(gdf_dict)

force_2d(gdf)
parsed = ShapesModel.parse(gdf, transformations=transformation)
Expand Down Expand Up @@ -514,11 +518,15 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
original_name = original_name[: original_name.rfind("_")]

df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)
axes = get_axes_names(df)
include_z = "z" in axes and not config.PROJECT_2_5D_SHAPES_TO_2D
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)

# 2.5D circles not supported yet
xy = np.array([df.geometry.x, df.geometry.y]).T
yx = np.fliplr(xy)
if include_z:
z_vals = df["z"].to_numpy()
yx = np.column_stack([z_vals, yx])
radii = df.radius.to_numpy()

adata, table_name, table_names = self._get_table_data(sdata, original_name)
Expand Down Expand Up @@ -561,7 +569,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
else:
kwargs |= {"border_color": "white"}
# useful code to have readily available to debug the correct radius of circles when represented as points
ellipses = _get_ellipses_from_circles(yx=yx, radii=radii)
ellipses = _get_ellipses_from_circles(coords=yx, radii=radii)
layer = Shapes(
ellipses,
shape_type="ellipse",
Expand Down Expand Up @@ -804,8 +812,43 @@ def _affine_transform_layers(self, coordinate_system: str) -> None:
sdata = metadata["sdata"]
element_name = metadata["name"]
element_data = sdata[element_name]
affine = _get_transform(element_data, coordinate_system)
include_z = self._should_include_z(element_data)
affine = _get_transform(element_data, coordinate_system, include_z=include_z)
Comment thread
LucaMarconato marked this conversation as resolved.
if affine is not None:
layer.affine = affine
if layer._type_string == "points":
self._adjust_radii_of_points_layer(layer, affine)

@staticmethod
def _has_z_axis(element: Any) -> bool:
"""Return ``True`` if ``element`` exposes a ``z`` axis.

For raster elements (images / labels) the ``z`` axis is reported by
:func:`spatialdata.models.get_axes_names`. For vector elements (points
as :class:`~dask.dataframe.DataFrame`, shapes as
:class:`~geopandas.GeoDataFrame`) the same helper is used.
"""
from xarray import DataArray, DataTree

if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
return False
return "z" in get_axes_names(element)

@staticmethod
def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool:
Comment thread
LucaMarconato marked this conversation as resolved.
"""Determine whether to include the z axis for a given spatial element.

For raster data (images, labels) z is always included when present.
For vector data (points, shapes) z inclusion depends on the user-facing
projection config flags.
"""
from xarray import DataArray, DataTree

if isinstance(element, DataArray | DataTree):
return True
axes = get_axes_names(element)
if "z" not in axes:
return False
if isinstance(element, DaskDataFrame):
return not config.PROJECT_3D_POINTS_TO_2D
return not config.PROJECT_2_5D_SHAPES_TO_2D
65 changes: 55 additions & 10 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,32 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
def _get_transform(
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
) -> None | ArrayLike:
"""Return the affine matrix for ``element`` in the given coordinate system.

The z axis is included in the returned affine when **both**:

* ``include_z`` is truthy, **and**
* the element (and therefore its underlying transformation) has a ``z`` axis,
as reported by :func:`spatialdata.models.get_axes_names`.

If ``include_z`` is requested but the element / transformation does not expose a
``z`` axis, the flag is silently ignored and a 2D ``(y, x)`` affine is returned.

Parameters
----------
element
The :class:`spatialdata.models.SpatialElement` for which to compute the affine.
coordinate_system_name
Coordinate system to use. If ``None``, the first available is selected.
include_z
Whether to include the z axis in the affine. The z is only included when the
element / transformation also has a z axis; otherwise this flag is ignored.

Returns
-------
The affine matrix as an ``ArrayLike`` (``(3, 3)`` for 2D and ``(4, 4)`` for 2.5D/3D),
or ``None`` if no transformation is defined for the requested coordinate system.
"""
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
raise RuntimeError("Cannot get transform for {type(element)}")

Expand Down Expand Up @@ -459,13 +485,17 @@ def generate_random_color_hex() -> str:
return f"#{randint(0, 255):02x}{randint(0, 255):02x}{randint(0, 255):02x}ff"


def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
def _get_ellipses_from_circles(coords: ArrayLike, radii: ArrayLike) -> ArrayLike:
"""Convert circles to ellipses.

Supports both 2D and 2.5D centroids. For 2.5D input the radius is applied only to
y and x while z is kept constant across the four corner vertices.

Parameters
----------
yx
Centroids of the circles.
coords
Centroids of the circles with shape ``(N, 2)`` in ``(y, x)`` order or ``(N, 3)``
in ``(z, y, x)`` order.
radii
Radii of the circles.

Expand All @@ -474,14 +504,29 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
ArrayLike
Ellipses.
"""
ndim = yx.shape[1]
assert ndim == 2
r = np.stack([radii] * ndim, axis=1)
lower_left = yx - r
upper_right = yx + r
ndim = coords.shape[1]
if ndim not in (2, 3):
raise ValueError(f"Expected centroids with 2 or 3 columns (yx or zyx), got shape {coords.shape}.")

if ndim == 3:
z = coords[:, :1]
yx_2d = coords[:, 1:]
else:
yx_2d = coords

r = np.stack([radii, radii], axis=1)
lower_left = yx_2d - r
upper_right = yx_2d + r
r[:, 0] = -r[:, 0]
lower_right = yx - r
upper_left = yx + r
lower_right = yx_2d - r
upper_left = yx_2d + r

if ndim == 3:
lower_left = np.column_stack([z, lower_left])
lower_right = np.column_stack([z, lower_right])
upper_right = np.column_stack([z, upper_right])
upper_left = np.column_stack([z, upper_left])

ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1)
assert isinstance(ellipses, np.ndarray)
return ellipses
Expand Down
58 changes: 57 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _safe_get_max_texture_sizes(): # type: ignore[no-untyped-def]
from spatialdata._types import ArrayLike
from spatialdata.datasets import blobs
from spatialdata.models import PointsModel, ShapesModel, TableModel
from spatialdata.transformations import Identity, set_transformation
from spatialdata.transformations import Affine, Identity, set_transformation

from napari_spatialdata.utils._test_utils import export_figure, save_image

Expand Down Expand Up @@ -415,3 +415,59 @@ def sdata_2_5d_shapes() -> SpatialData:
shapes["shapes_2.5d"] = shape_element

return SpatialData(shapes=shapes)


@pytest.fixture
def sdata_2_5d_circles() -> SpatialData:
"""Create a SpatialData object with 2.5D circles (circles at different z levels)."""
n_circles = 10
rng = np.random.default_rng(SEED)
gdf = gpd.GeoDataFrame(
{
"geometry": gpd.points_from_xy(
rng.uniform(0, 100, n_circles),
rng.uniform(0, 100, n_circles),
),
"radius": rng.uniform(5, 15, n_circles),
"z": rng.uniform(0, 50, n_circles),
}
)
circles = ShapesModel.parse(gdf)
set_transformation(circles, {"global": Identity()}, set_all=True)

return SpatialData(shapes={"circles_2.5d": circles})


@pytest.fixture
def sdata_3d_points_two_cs() -> SpatialData:
"""Create a SpatialData with 3D points registered to two coordinate systems.

The element lives in ``global`` (identity) and in ``scaled`` (2x scale
with a 10-unit z-translation). This is useful for testing that
``_affine_transform_layers`` produces a correctly-sized affine matrix
when switching between coordinate systems.
"""
n_points = 5
rng = np.random.default_rng(SEED)
df = pd.DataFrame(
{
"x": rng.uniform(0, 100, n_points),
"y": rng.uniform(0, 100, n_points),
"z": rng.uniform(0, 50, n_points),
}
)
dask_df = from_pandas(df, npartitions=1)
points = PointsModel.parse(dask_df)

affine_matrix = np.array(
[
[2.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 2.0, 10.0],
[0.0, 0.0, 0.0, 1.0],
]
)
scaled_affine = Affine(affine_matrix, input_axes=("x", "y", "z"), output_axes=("x", "y", "z"))
set_transformation(points, {"global": Identity(), "scaled": scaled_affine}, set_all=True)

return SpatialData(points={"points_3d": points})
Loading
Loading