diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index cde583e1..be8c5cf6 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -1,24 +1,33 @@ from __future__ import annotations from functools import singledispatch +from typing import Literal import numpy as np import pandas as pd import xarray as xr +from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon from xarray import DataArray, DataTree from spatialdata._core.operations.transform import transform -from spatialdata.models import get_axes_names +from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import _affine_matrix_multiplication +from spatialdata.models import get_axes_names, get_table_keys from spatialdata.models._utils import SpatialElement -from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model +from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model from spatialdata.transformations.operations import get_transformation -from spatialdata.transformations.transformations import BaseTransformation +from spatialdata.transformations.transformations import BaseTransformation, Identity BoundingBoxDescription = dict[str, tuple[float, float]] +PersistAs = Literal["Points", "adata"] +# squidpy-style storage keys for persist_as="adata". +_SPATIAL_KEY = "spatial" +_AREA_KEY = "area" + def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: d = get_transformation(e, get_all=True) @@ -29,37 +38,92 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No ) +def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, allow_adata: bool) -> None: + if persist_as not in ("Points", "adata"): + raise ValueError(f"`persist_as` must be 'Points' or 'adata', got {persist_as!r}.") + if persist_as == "adata" and not allow_adata: + raise ValueError( + "persist_as='adata' writes centroids into the element's annotating table, which needs the " + "`SpatialData` object: call `get_centroids(sdata, element_name, ..., persist_as='adata')`. " + "To get the centroids as a standalone element instead, use persist_as='Points'." + ) + # ``coordinate_system=None`` means "intrinsic coordinates, do not transform". An intrinsic Points + # element is ill-defined (Points always carry a coordinate system), so intrinsic coords are only + # meaningful when writing into a table (persist_as='adata'). + if coordinate_system is None and persist_as != "adata": + raise ValueError("`coordinate_system=None` (intrinsic coordinates) is only supported with persist_as='adata'.") + + +def _transform_centroid_coords( + xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None +) -> np.ndarray: + """Apply the element's affine to centroid coords in-memory; ``None``/``Identity`` pass through. + + ``axes`` is the column order of ``xy`` (e.g. ``["x", "y"]``). + """ + if coordinate_system is None: + return xy + t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) + if isinstance(t, Identity): + return xy + matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes)) + return _affine_matrix_multiplication(matrix, xy) + + @singledispatch def get_centroids( - e: SpatialElement, - coordinate_system: str = "global", + e: SpatialElement | SpatialData, + coordinate_system: str | None = "global", return_background: bool = False, -) -> DaskDataFrame: + return_area: bool = False, + persist_as: PersistAs = "Points", +) -> DaskDataFrame | AnnData | None: """ - Get the centroids of the geometries contained in a SpatialElement, as a new Points element. + Get the centroids of the geometries contained in a SpatialElement. Parameters ---------- e - The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. + The SpatialElement (points, shapes — circles, polygons and multipolygons — or labels), or a + :class:`~spatialdata.SpatialData` object. When a ``SpatialData`` is passed, the second + positional argument is the name of the element to measure (see the ``SpatialData`` overload). coordinate_system - The coordinate system in which the centroids are computed. + The coordinate system in which the centroids are computed. ``None`` returns the intrinsic + coordinates without applying any transformation (only supported with ``persist_as="adata"``). return_background - If True, the centroid of the background label (0) is included in the output. + If True, the centroid of the background label (0) is included in the output (labels only). + return_area + If True, also return the per-instance area: the pixel/voxel count for labels and the geometric + area for shapes (``pi * r**2`` for circles). Not supported for points (raises). With + ``persist_as="Points"`` the area is added as a feature column of the returned Points element. + persist_as + ``"Points"`` (default) returns the centroids as a new Points element, transformed into + ``coordinate_system``. ``"adata"`` writes the centroids (and area) into the element's + annotating table and is only available through the :class:`~spatialdata.SpatialData` overload, + which can resolve that table. + + Returns + ------- + A Points element (``persist_as="Points"``). With ``persist_as="adata"`` (``SpatialData`` overload), + ``None`` when written in place, or the new ``AnnData`` table when ``inplace=False``. Notes ----- For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute - each :class:`~shapely.Multipolygon`. + each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution + ``scale0`` level. """ raise ValueError(f"The object type {type(e)} is not supported.") -def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: +def _get_centroids_for_labels(xdata: xr.DataArray, return_area: bool = False) -> pd.DataFrame: """ Compute centroids for all labels in a DataArray in a single O(n_voxels) pass. - Works for any number of spatial dimensions (2D and 3D labels). + Works for any number of spatial dimensions (2D and 3D labels). When ``return_area`` is True, an + ``area`` column (the per-label pixel/voxel count) is added; it is already computed for the + centroids, so this is free. """ arr = xdata.data.compute() axes = list(xdata.dims) @@ -77,66 +141,210 @@ def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float)) data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this) - return pd.DataFrame(data, index=label_ids) + df = pd.DataFrame(data, index=label_ids) + if return_area: + df["area"] = counts.astype(float) + return df -@get_centroids.register(DataArray) -@get_centroids.register(DataTree) -def _( - e: DataArray | DataTree, - coordinate_system: str = "global", - return_background: bool = False, -) -> DaskDataFrame: - """Get the centroids of a Labels element (2D or 3D).""" - model = get_model(e) - if model not in [Labels2DModel, Labels3DModel]: - raise ValueError("Expected a `Labels` element. Found an `Image` instead.") - _validate_coordinate_system(e, coordinate_system) - - if isinstance(e, DataTree): - assert len(e["scale0"]) == 1 - e = next(iter(e["scale0"].values())) - - df = _get_centroids_for_labels(e) - if not return_background and 0 in df.index: - df = df.drop(index=0) # drop the background label - t = get_transformation(e, coordinate_system) - centroids = PointsModel.parse(df, transformations={coordinate_system: t}) - return transform(centroids, to_coordinate_system=coordinate_system) - - -@get_centroids.register(GeoDataFrame) -def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: - """Get the centroids of a Shapes element (circles or polygons/multipolygons).""" - _validate_coordinate_system(e, coordinate_system) - t = get_transformation(e, coordinate_system) - assert isinstance(t, BaseTransformation) - # separate points from (multi-)polygons +def _get_centroids_for_shapes(e: GeoDataFrame, return_area: bool) -> tuple[pd.DataFrame, np.ndarray | None]: + """Intrinsic per-shape centroids (``x, y`` columns indexed by the element's index) and optional area.""" first_geometry = e["geometry"].iloc[0] if isinstance(first_geometry, Point): xy = e.geometry.get_coordinates().values + # shapely .area is 0 for circles (Point geometry); the radius column carries the size. + area = np.pi * np.asarray(e["radius"], dtype=float) ** 2 if return_area else None else: assert isinstance(first_geometry, Polygon | MultiPolygon), ( f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." ) xy = e.centroid.get_coordinates().values + area = e.geometry.area.to_numpy() if return_area else None xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy()) - points = PointsModel.parse(xy_df, transformations={coordinate_system: t}) + return xy_df, area + + +def _intrinsic_centroid_frame( + element: SpatialElement, return_background: bool, return_area: bool +) -> tuple[pd.DataFrame, np.ndarray | None, SpatialElement]: + """Per-instance intrinsic centroids (coordinate columns, indexed by instance id), optional area. + + Also returns the element the centroids live on (for labels, the ``scale0`` level of a multiscale + raster), which carries the transformation to apply downstream. + """ + model = get_model(element) + if model in (Labels2DModel, Labels3DModel): + raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element + df = _get_centroids_for_labels(raster, return_area=return_area) + if not return_background and 0 in df.index: + df = df.drop(index=0) # drop the background label (its area, if any, goes with it) + area = df.pop("area").to_numpy() if return_area else None + return df, area, raster + if model is ShapesModel: + xy_df, area = _get_centroids_for_shapes(element, return_area) + return xy_df, area, element + if model is PointsModel: + if return_area: + raise ValueError("`return_area` is not supported for points elements (points have no area).") + axes = get_axes_names(element) + assert axes in [("x", "y"), ("x", "y", "z")] + return element[list(axes)].compute(), None, element + raise ValueError(f"Centroids are not supported for {model.__name__}; expected a Labels, Shapes or Points element.") + + +def _points_from_centroids( + df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str +) -> DaskDataFrame: + """Build a Points element from intrinsic centroids, transformed into ``coordinate_system``.""" + out = df.assign(area=np.asarray(area, dtype=float)) if area is not None else df + t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) + points = PointsModel.parse(out, transformations={coordinate_system: t}) return transform(points, to_coordinate_system=coordinate_system) +@get_centroids.register(DataArray) +@get_centroids.register(DataTree) +@get_centroids.register(GeoDataFrame) @get_centroids.register(DaskDataFrame) -def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: - """Get the centroids of a Points element.""" +def _( + e: SpatialElement, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", +) -> DaskDataFrame: + """Get the centroids of a Labels, Shapes or Points element.""" + _validate_persist_args(persist_as, coordinate_system, allow_adata=False) + assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) _validate_coordinate_system(e, coordinate_system) - axes = get_axes_names(e) - assert axes in [("x", "y"), ("x", "y", "z")] - coords = e[list(axes)].compute() - t = get_transformation(e, coordinate_system) - assert isinstance(t, BaseTransformation) - centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) - return transform(centroids, to_coordinate_system=coordinate_system) + df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) + return _points_from_centroids(df, area, raster, coordinate_system) + + +def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str: + """Resolve the single table that annotates ``element_name`` (where centroids are written).""" + from spatialdata._core.query.relational_query import get_element_annotators + + if table_name is not None: + if table_name not in sdata.tables: + raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.") + return table_name + annotators = sorted(get_element_annotators(sdata, element_name)) + if not annotators: + raise ValueError( + f"Element {element_name!r} has no annotating table to write centroids into. Use " + f"persist_as='Points' to get the centroids as a Points element instead, or annotate the " + f"element with a table first." + ) + if len(annotators) > 1: + raise ValueError( + f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); " + f"pass `table_name=` to choose one." + ) + return annotators[0] + + +def _write_centroids_into_table( + table: AnnData, + element_name: str, + centroids: pd.DataFrame, + area: np.ndarray | None, +) -> None: + """Write centroids into ``obsm["spatial"]`` and area into ``obs["area"]`` at the element's rows. + + Only the table rows annotating ``element_name`` are touched (a table may annotate several + elements); instances annotated but absent from the element are written as NaN. + """ + if not centroids.index.is_unique: + raise ValueError(f"Cannot persist centroids for {element_name!r}: its instance index has duplicate values.") + _, region_key, instance_key = get_table_keys(table) + mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy() + if not mask.any(): + raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).") + + # Map each annotated instance to its centroid row (-1 where absent -> NaN). A *total* miss means the + # instance ids never align with the element index (mismatched dtype, or no shared instances). + keys = table.obs[instance_key].to_numpy()[mask] + idx = centroids.index.get_indexer(keys) + if (idx == -1).all(): + raise ValueError( + f"No instance id annotating {element_name!r} is present in the element; check the table's " + f"`{instance_key}` values and dtype." + ) + hit = idx != -1 + + def _scatter(values: np.ndarray) -> np.ndarray: + """Gather ``values`` (ordered like ``centroids``) onto the masked rows, NaN where absent.""" + out = np.full((len(idx), *values.shape[1:]), np.nan) + out[hit] = values[idx[hit]] + return out + + ndim = centroids.shape[1] + spatial = np.full((table.n_obs, ndim), np.nan) + existing = table.obsm.get(_SPATIAL_KEY) + if existing is not None: + existing = np.asarray(existing) + if existing.shape == (table.n_obs, ndim): + spatial = existing.astype(float, copy=True) # preserve other regions' coordinates + elif existing.shape[0] == table.n_obs: + raise ValueError( + f"Existing obsm['{_SPATIAL_KEY}'] {existing.shape} is incompatible with {ndim}-D centroids for " + f"{element_name!r}; refusing to overwrite other regions. Persist with persist_as='Points' instead." + ) + spatial[mask] = _scatter(centroids.to_numpy(dtype=float)) + table.obsm[_SPATIAL_KEY] = spatial + + if area is not None: + col = np.full(table.n_obs, np.nan) + if _AREA_KEY in table.obs: + col = table.obs[_AREA_KEY].to_numpy(dtype=float).copy() + col[mask] = _scatter(np.asarray(area, dtype=float)) + table.obs[_AREA_KEY] = col + + +@get_centroids.register(SpatialData) +def _get_centroids_sdata( + e: SpatialData, + element_name: str, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", + table_name: str | None = None, + inplace: bool = True, +) -> DaskDataFrame | AnnData | None: + """Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table. + + With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) of the + resolved annotating table (``table_name=`` disambiguates). ``inplace=True`` (default) mutates that table and + returns ``None``; ``inplace=False`` writes into a copy of *only that table* and returns the new ``AnnData``, + leaving ``e`` untouched. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element. + """ + _validate_persist_args(persist_as, coordinate_system, allow_adata=True) + element = e[element_name] + + if persist_as == "Points": + return get_centroids( + element, + coordinate_system=coordinate_system, + return_background=return_background, + return_area=return_area, + ) + + # persist_as == "adata": resolve the annotating table and write the centroids into it. + if coordinate_system is not None: + _validate_coordinate_system(element, coordinate_system) + table_name = _resolve_annotating_table(e, element_name, table_name) + df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area) + coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order) + coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system) + centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index) + + table = e.tables[table_name] if inplace else e.tables[table_name].copy() + _write_centroids_into_table(table, element_name, centroids, area) + return None if inplace else table ## diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py index 41458458..d12a3f1a 100644 --- a/src/spatialdata/_core/operations/vectorize.py +++ b/src/spatialdata/_core/operations/vectorize.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import dask import numpy as np @@ -137,7 +137,8 @@ def _get_centroids(element: SpatialElement) -> pd.DataFrame: if INTRINSIC_COORDINATE_SYSTEM in d: raise RuntimeError(f"The name {INTRINSIC_COORDINATE_SYSTEM} is reserved.") d[INTRINSIC_COORDINATE_SYSTEM] = Identity() - centroids = get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM).compute() + # get_centroids returns DaskDataFrame for elements (only the SpatialData overload returns SpatialData). + centroids = cast(DaskDataFrame, get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM)).compute() del d[INTRINSIC_COORDINATE_SYSTEM] return centroids diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fb55ab08..97bb2fa2 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -543,6 +543,36 @@ def aggregate( **kwargs, ) + def get_centroids( + self, + element_name: str, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: Literal["Points", "adata"] = "Points", + table_name: str | None = None, + inplace: bool = True, + ) -> DaskDataFrame | AnnData | None: + """Get the centroids of ``element_name``, or persist them into its annotating table. + + Convenience method for :func:`spatialdata.get_centroids` called on ``self``; see that function for the + complete docstring. With ``persist_as="adata"`` the centroids are written into ``obsm["spatial"]`` (and area + into ``obs["area"]``) of the resolved annotating table; ``inplace=True`` mutates it and returns ``None``, + ``inplace=False`` returns a modified copy of that table. + """ + from spatialdata._core.centroids import _get_centroids_sdata + + return _get_centroids_sdata( + self, + element_name, + coordinate_system=coordinate_system, + return_background=return_background, + return_area=return_area, + persist_as=persist_as, + table_name=table_name, + inplace=inplace, + ) + def is_backed(self) -> bool: """Check if the data is backed by a Zarr storage or if it is in-memory.""" return self.path is not None diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 03879abc..07abacf6 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -5,12 +5,13 @@ from functools import partial from itertools import chain from types import MappingProxyType -from typing import Any +from typing import Any, cast import anndata as ad import numpy as np import pandas as pd from anndata import AnnData +from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from pandas import CategoricalDtype from scipy.sparse import issparse @@ -494,7 +495,7 @@ def _get_tile_coords( # extent, aka the tile size extent = (circles.radius * 2).values.reshape(-1, 1) - centroids_points = get_centroids(circles, coordinate_system=cs) + centroids_points = cast(DaskDataFrame, get_centroids(circles, coordinate_system=cs)) axes = get_axes_names(centroids_points) centroids_numpy = centroids_points.compute().values diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index f8c8be1d..6485ac4e 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -16,6 +16,15 @@ RNG = default_rng(42) +def _assert_obsm_matches_points(table: AnnData, pts: pd.DataFrame) -> None: + # written obsm["spatial"] must match the element-level Points centroids on shared (non-background) ids. + inst = table.obs["instance_id"].to_numpy() + written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) + common = pts.index.intersection(written.index[inst != 0]) + assert len(common) > 0 + assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + + def _get_affine() -> Affine: theta: float = math.pi / 18 k = 10.0 @@ -172,9 +181,138 @@ def test_get_centroids_labels( assert np.allclose(centroids.compute().values, centroids_transformed) +def test_get_centroids_labels_area(labels): + # area for labels is the per-label pixel count; it rides along as a feature column of the Points. + element = labels["labels2d"] + centroids = get_centroids(element, return_area=True) + assert "area" in centroids.columns + ids, counts = np.unique(np.asarray(element.data), return_counts=True) + expected = dict(zip(ids, counts, strict=True)) + got = centroids[["area"]].compute()["area"] + assert not (got.index == 0).any() # background dropped + for label_id, area in got.items(): + assert area == expected[label_id] + + +def test_get_centroids_shapes_area_circles(shapes): + element = shapes["circles"] + centroids = get_centroids(element, return_area=True) + expected = np.pi * np.asarray(element["radius"], dtype=float) ** 2 + assert np.allclose(centroids["area"].compute().to_numpy(), expected) + + +@pytest.mark.parametrize("shapes_name", ["poly", "multipoly"]) +def test_get_centroids_shapes_area_polygons(shapes, shapes_name: str): + element = shapes[shapes_name] + centroids = get_centroids(element, return_area=True) + assert np.allclose(centroids["area"].compute().to_numpy(), element.geometry.area.to_numpy()) + + +def test_get_centroids_points_area_raises(points): + with pytest.raises(ValueError, match="not supported for points"): + get_centroids(points["points_0"], return_area=True) + + +def test_get_centroids_element_persist_adata_raises(labels): + # an element on its own has no annotating table; persist_as='adata' needs the SpatialData. + with pytest.raises(ValueError, match="persist_as='adata'"): + get_centroids(labels["labels2d"], persist_as="adata") + + +def test_get_centroids_sdata_persist_into_table(full_sdata): + # `table` annotates `labels2d` (instance_id 0..99 == label values); background (0) -> NaN. + table = full_sdata["table"] + assert "spatial" not in table.obsm + out = get_centroids(full_sdata, "labels2d", coordinate_system="global", return_area=True, persist_as="adata") + assert out is None # inplace=True (default) mutates the table and returns nothing + + table = full_sdata["table"] + assert table.obsm["spatial"].shape == (table.n_obs, 2) + assert "area" in table.obs + + inst = table.obs["instance_id"].to_numpy() + finite = np.isfinite(table.obsm["spatial"]).all(axis=1) + assert finite[inst != 0].all() # every non-background label got a centroid + assert not finite[inst == 0].any() # background row stays NaN + + # coordinates must match the element-level Points (global transform is the identity here) + pts = get_centroids(full_sdata["labels2d"], coordinate_system="global").compute() + _assert_obsm_matches_points(table, pts) + + # area must equal the pixel counts of the corresponding labels + ids, counts = np.unique(np.asarray(full_sdata["labels2d"].data), return_counts=True) + count_of = dict(zip(ids, counts, strict=True)) + area = table.obs["area"].to_numpy() + for row, label_id in enumerate(inst): + if label_id != 0: + assert area[row] == count_of[label_id] + + +def test_get_centroids_sdata_persist_fastpath_matches_transform(full_sdata): + # the in-memory affine fast path (adata) must equal the dask transform() path (element Points). + set_transformation(full_sdata["labels2d"], affine, "aligned") + get_centroids(full_sdata, "labels2d", coordinate_system="aligned", persist_as="adata") + + pts = get_centroids(full_sdata["labels2d"], coordinate_system="aligned").compute() + _assert_obsm_matches_points(full_sdata["table"], pts) + + +def test_get_centroids_sdata_persist_intrinsic_matches_identity(full_sdata): + # coordinate_system=None (intrinsic) equals a coordinate system whose transform is the identity. + get_centroids(full_sdata, "labels2d", coordinate_system="global", persist_as="adata") + global_spatial = full_sdata["table"].obsm["spatial"].copy() + get_centroids(full_sdata, "labels2d", coordinate_system=None, persist_as="adata") + intrinsic_spatial = full_sdata["table"].obsm["spatial"] + finite = np.isfinite(global_spatial).all(axis=1) + assert np.allclose(global_spatial[finite], intrinsic_spatial[finite]) + + +def test_get_centroids_sdata_persist_inplace_false_returns_copy(full_sdata): + # inplace=False copies only the target table, writes into the copy, and leaves the sdata untouched. + out = get_centroids(full_sdata, "labels2d", return_area=True, persist_as="adata", inplace=False) + assert isinstance(out, AnnData) + assert out is not full_sdata["table"] + assert "spatial" in out.obsm and "area" in out.obs + assert "spatial" not in full_sdata["table"].obsm # original table not modified + + +def test_spatialdata_get_centroids_method(full_sdata): + # the method mirrors the module-level function for both persistence modes. + pts = full_sdata.get_centroids("labels2d", coordinate_system="global") + expected = get_centroids(full_sdata["labels2d"], coordinate_system="global") + assert np.allclose(pts.compute().to_numpy(), expected.compute().to_numpy()) + + assert full_sdata.get_centroids("labels2d", persist_as="adata") is None + assert "spatial" in full_sdata["table"].obsm + + +def test_get_centroids_sdata_persist_instance_key_mismatch_raises(full_sdata): + # instance ids stored with a dtype that doesn't match the element's integer labels must fail + # loudly instead of silently writing NaN coordinates into obsm["spatial"]. + table = full_sdata["table"] + table.obs["instance_id"] = table.obs["instance_id"].astype(str) + with pytest.raises(ValueError, match="No instance id annotating"): + get_centroids(full_sdata, "labels2d", persist_as="adata") + + +def test_get_centroids_sdata_persist_refuses_dim_mismatch(full_sdata): + # an existing obsm["spatial"] of a different width must not be silently overwritten (that would + # wipe the coordinates of other regions sharing the table). + table = full_sdata["table"] + table.obsm["spatial"] = np.zeros((table.n_obs, 3)) + with pytest.raises(ValueError, match="refusing to overwrite"): + get_centroids(full_sdata, "labels2d", persist_as="adata") + + +def test_get_centroids_sdata_no_table_raises(full_sdata): + # points_0 is not annotated by any table -> the error points the user to persist_as='Points'. + with pytest.raises(ValueError, match="persist_as='Points'"): + get_centroids(full_sdata, "points_0", persist_as="adata") + + def test_get_centroids_invalid_element(images): # cannot compute centroids for images - with pytest.raises(ValueError, match="Expected a `Labels` element. Found an `Image` instead."): + with pytest.raises(ValueError, match="Centroids are not supported for Image2DModel"): get_centroids(images["image2d"]) # cannot compute centroids for tables