From 8a2b941d5562f25bc649535ca2d334171906d9f3 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 19 Jun 2026 20:24:03 +0200 Subject: [PATCH 1/7] feat(centroids): add return_area + SpatialData persist to get_centroids Surface the per-instance area that the labels bincount already computes (and geometric area for shapes), and add a SpatialData dispatch that writes centroids/area into the element's annotating table, squidpy-style. - `return_area` on all element overloads: labels return the pixel/voxel count (free, already computed); shapes return `pi*r**2` for circles and `geometry.area` for polygons; points raise. Carried as a feature column on the returned Points element. - Harmonise the three element overload signatures (they previously disagreed on `return_background`). - New `get_centroids(sdata, element_name, ..., persist_as="adata")` dispatch: resolves the element's annotating table and writes `obsm["spatial"]` (x, y[, z]) and `obs["area"]` at the element's rows, in place. Raises (pointing to `persist_as="Points"`) when no table annotates the element. No standalone AnnData is ever created. - Element-level `persist_as="adata"` raises, directing to the SpatialData form (which can resolve the table). - Fast path: the table-writing path applies the coordinate transform to the tiny per-instance centroid array in-memory and short-circuits when it is an identity, avoiding the dask `transform()` round-trip. `coordinate_system=None` returns intrinsic coordinates. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/spatialdata/_core/centroids.py | 319 ++++++++++++++++++++++++----- tests/core/test_centroids.py | 100 +++++++++ 2 files changed, 372 insertions(+), 47 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index cde583e15..b8e6a5a02 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -1,24 +1,29 @@ from __future__ import annotations from functools import singledispatch +from typing import Literal import numpy as np import pandas as pd import xarray as xr +from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon from xarray import DataArray, DataTree from spatialdata._core.operations.transform import transform -from spatialdata.models import get_axes_names +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import get_axes_names, get_table_keys from spatialdata.models._utils import SpatialElement -from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model +from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model from spatialdata.transformations.operations import get_transformation -from spatialdata.transformations.transformations import BaseTransformation +from spatialdata.transformations.transformations import BaseTransformation, Identity BoundingBoxDescription = dict[str, tuple[float, float]] +PersistAs = Literal["Points", "adata"] + def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: d = get_transformation(e, get_all=True) @@ -29,37 +34,97 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No ) +def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, allow_adata: bool) -> None: + if persist_as not in ("Points", "adata"): + raise ValueError(f"`persist_as` must be 'Points' or 'adata', got {persist_as!r}.") + if persist_as == "adata" and not allow_adata: + raise ValueError( + "persist_as='adata' writes centroids into the element's annotating table, which needs the " + "`SpatialData` object: call `get_centroids(sdata, element_name, ..., persist_as='adata')`. " + "To get the centroids as a standalone element instead, use persist_as='Points'." + ) + # ``coordinate_system=None`` means "intrinsic coordinates, do not transform". An intrinsic Points + # element is ill-defined (Points always carry a coordinate system), so intrinsic coords are only + # meaningful when writing into a table (persist_as='adata'). + if coordinate_system is None and persist_as != "adata": + raise ValueError("`coordinate_system=None` (intrinsic coordinates) is only supported with persist_as='adata'.") + + +def _transform_centroid_coords( + xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None +) -> np.ndarray: + """Apply the element's transformation to centroid coordinates in-memory. + + Centroids are tiny (one row per instance), so the affine matrix is applied directly here instead + of routing through the dask :func:`~spatialdata.transform` machinery (which builds a dask graph, + computes twice and re-validates regardless of the transformation). ``coordinate_system=None`` and + an :class:`~spatialdata.transformations.Identity` transformation short-circuit to the intrinsic + coordinates unchanged. ``axes`` is the order of the columns of ``xy`` (e.g. ``["x", "y"]``). + """ + if coordinate_system is None: + return xy + t = get_transformation(e, coordinate_system) + if isinstance(t, Identity): + return xy + matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes)) + n = len(axes) + return xy @ matrix[:n, :n].T + matrix[:n, n] + + @singledispatch def get_centroids( - e: SpatialElement, - coordinate_system: str = "global", + e: SpatialElement | SpatialData, + coordinate_system: str | None = "global", return_background: bool = False, -) -> DaskDataFrame: + return_area: bool = False, + persist_as: PersistAs = "Points", +) -> DaskDataFrame | SpatialData: """ - Get the centroids of the geometries contained in a SpatialElement, as a new Points element. + Get the centroids of the geometries contained in a SpatialElement. Parameters ---------- e - The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. + The SpatialElement (points, shapes — circles, polygons and multipolygons — or labels), or a + :class:`~spatialdata.SpatialData` object. When a ``SpatialData`` is passed, the second + positional argument is the name of the element to measure (see the ``SpatialData`` overload). coordinate_system - The coordinate system in which the centroids are computed. + The coordinate system in which the centroids are computed. ``None`` returns the intrinsic + coordinates without applying any transformation (only supported with ``persist_as="adata"``). return_background - If True, the centroid of the background label (0) is included in the output. + If True, the centroid of the background label (0) is included in the output (labels only). + return_area + If True, also return the per-instance area: the pixel/voxel count for labels and the geometric + area for shapes (``pi * r**2`` for circles). Not supported for points (raises). With + ``persist_as="Points"`` the area is added as a feature column of the returned Points element. + persist_as + ``"Points"`` (default) returns the centroids as a new Points element, transformed into + ``coordinate_system``. ``"adata"`` writes the centroids (and area) into the element's + annotating table and is only available through the :class:`~spatialdata.SpatialData` overload, + which can resolve that table. + + Returns + ------- + A Points element (``persist_as="Points"``), or the mutated ``SpatialData`` when writing centroids + into an annotating table (``SpatialData`` overload with ``persist_as="adata"``). Notes ----- For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute - each :class:`~shapely.Multipolygon`. + each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution + ``scale0`` level. Because centroids are tiny (one row per instance), the table-writing path applies the coordinate + transformation in-memory and short-circuits when it is an identity, avoiding the cost of the dask transform. """ raise ValueError(f"The object type {type(e)} is not supported.") -def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: +def _get_centroids_for_labels(xdata: xr.DataArray, return_area: bool = False) -> pd.DataFrame: """ Compute centroids for all labels in a DataArray in a single O(n_voxels) pass. - Works for any number of spatial dimensions (2D and 3D labels). + Works for any number of spatial dimensions (2D and 3D labels). When ``return_area`` is True, an + ``area`` column (the per-label pixel/voxel count) is added; it is already computed for the + centroids, so this is free. """ arr = xdata.data.compute() axes = list(xdata.dims) @@ -77,66 +142,226 @@ def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float)) data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this) - return pd.DataFrame(data, index=label_ids) + df = pd.DataFrame(data, index=label_ids) + if return_area: + df["area"] = counts.astype(float) + return df + + +def _get_centroids_for_shapes(e: GeoDataFrame, return_area: bool) -> tuple[pd.DataFrame, np.ndarray | None]: + """Intrinsic per-shape centroids (``x, y`` columns indexed by the element's index) and optional area.""" + first_geometry = e["geometry"].iloc[0] + if isinstance(first_geometry, Point): + xy = e.geometry.get_coordinates().values + # shapely .area is 0 for circles (Point geometry); the radius column carries the size. + area = np.pi * np.asarray(e["radius"], dtype=float) ** 2 if return_area else None + else: + assert isinstance(first_geometry, Polygon | MultiPolygon), ( + f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" + f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." + ) + xy = e.centroid.get_coordinates().values + area = e.geometry.area.to_numpy() if return_area else None + xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy()) + return xy_df, area + + +def _intrinsic_centroid_frame( + element: SpatialElement, return_background: bool, return_area: bool +) -> tuple[pd.DataFrame, np.ndarray | None, SpatialElement]: + """Per-instance intrinsic centroids (coordinate columns, indexed by instance id), optional area. + + Also returns the element the centroids live on (for labels, the ``scale0`` level of a multiscale + raster), which carries the transformation to apply downstream. + """ + model = get_model(element) + if model in (Labels2DModel, Labels3DModel): + raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element + df = _get_centroids_for_labels(raster, return_area=return_area) + if not return_background and 0 in df.index: + df = df.drop(index=0) # drop the background label (its area, if any, goes with it) + area = df.pop("area").to_numpy() if return_area else None + return df, area, raster + if model is ShapesModel: + xy_df, area = _get_centroids_for_shapes(element, return_area) + return xy_df, area, element + if model is PointsModel: + if return_area: + raise ValueError("`return_area` is not supported for points elements (points have no area).") + axes = get_axes_names(element) + assert axes in [("x", "y"), ("x", "y", "z")] + return element[list(axes)].compute(), None, element + raise ValueError(f"Centroids are not supported for the {model.__name__} element {element!r}.") + + +def _points_from_centroids( + df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str +) -> DaskDataFrame: + """Build a Points element from intrinsic centroids, transformed into ``coordinate_system``.""" + out = df.copy() + if area is not None: + out["area"] = np.asarray(area, dtype=float) + t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) + points = PointsModel.parse(out, transformations={coordinate_system: t}) + return transform(points, to_coordinate_system=coordinate_system) @get_centroids.register(DataArray) @get_centroids.register(DataTree) def _( e: DataArray | DataTree, - coordinate_system: str = "global", + coordinate_system: str | None = "global", return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", ) -> DaskDataFrame: """Get the centroids of a Labels element (2D or 3D).""" model = get_model(e) if model not in [Labels2DModel, Labels3DModel]: raise ValueError("Expected a `Labels` element. Found an `Image` instead.") + _validate_persist_args(persist_as, coordinate_system, allow_adata=False) _validate_coordinate_system(e, coordinate_system) - - if isinstance(e, DataTree): - assert len(e["scale0"]) == 1 - e = next(iter(e["scale0"].values())) - - df = _get_centroids_for_labels(e) - if not return_background and 0 in df.index: - df = df.drop(index=0) # drop the background label - t = get_transformation(e, coordinate_system) - centroids = PointsModel.parse(df, transformations={coordinate_system: t}) - return transform(centroids, to_coordinate_system=coordinate_system) + df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) + return _points_from_centroids(df, area, raster, coordinate_system) @get_centroids.register(GeoDataFrame) -def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: +def _( + e: GeoDataFrame, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", +) -> DaskDataFrame: """Get the centroids of a Shapes element (circles or polygons/multipolygons).""" + _validate_persist_args(persist_as, coordinate_system, allow_adata=False) _validate_coordinate_system(e, coordinate_system) - t = get_transformation(e, coordinate_system) - assert isinstance(t, BaseTransformation) - # separate points from (multi-)polygons - first_geometry = e["geometry"].iloc[0] - if isinstance(first_geometry, Point): - xy = e.geometry.get_coordinates().values - else: - assert isinstance(first_geometry, Polygon | MultiPolygon), ( - f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" - f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." - ) - xy = e.centroid.get_coordinates().values - xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy()) - points = PointsModel.parse(xy_df, transformations={coordinate_system: t}) - return transform(points, to_coordinate_system=coordinate_system) + xy_df, area = _get_centroids_for_shapes(e, return_area) + return _points_from_centroids(xy_df, area, e, coordinate_system) @get_centroids.register(DaskDataFrame) -def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: +def _( + e: DaskDataFrame, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", +) -> DaskDataFrame: """Get the centroids of a Points element.""" + if return_area: + raise ValueError("`return_area` is not supported for points elements (points have no area).") + _validate_persist_args(persist_as, coordinate_system, allow_adata=False) _validate_coordinate_system(e, coordinate_system) axes = get_axes_names(e) assert axes in [("x", "y"), ("x", "y", "z")] coords = e[list(axes)].compute() - t = get_transformation(e, coordinate_system) - assert isinstance(t, BaseTransformation) - centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) - return transform(centroids, to_coordinate_system=coordinate_system) + return _points_from_centroids(coords, None, e, coordinate_system) + + +def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str: + """Resolve the single table that annotates ``element_name`` (where centroids are written).""" + from spatialdata._core.query.relational_query import get_element_annotators + + if table_name is not None: + if table_name not in sdata.tables: + raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.") + return table_name + annotators = sorted(get_element_annotators(sdata, element_name)) + if not annotators: + raise ValueError( + f"Element {element_name!r} has no annotating table to write centroids into. Use " + f"persist_as='Points' to get the centroids as a Points element instead, or annotate the " + f"element with a table first." + ) + if len(annotators) > 1: + raise ValueError( + f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); " + f"pass `table_name=` to choose one." + ) + return str(annotators[0]) + + +def _write_centroids_into_table( + table: AnnData, + element_name: str, + centroids: pd.DataFrame, + area: np.ndarray | None, +) -> None: + """Write centroids into ``obsm["spatial"]`` and area into ``obs["area"]`` at the element's rows. + + Only the table rows annotating ``element_name`` are touched (a table may annotate several + elements); instances annotated but absent from the element are written as NaN. + """ + _, region_key, instance_key = get_table_keys(table) + mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy() + if not mask.any(): + raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).") + + keys = table.obs[instance_key].to_numpy()[mask] + coord_cols = list(centroids.columns) + ndim = len(coord_cols) + coords = centroids.reindex(keys)[coord_cols].to_numpy(dtype=float) + + existing = np.asarray(table.obsm["spatial"]) if "spatial" in table.obsm else None + if existing is not None and existing.shape == (table.n_obs, ndim): + spatial = existing.astype(float, copy=True) + else: + spatial = np.full((table.n_obs, ndim), np.nan) + spatial[mask] = coords + table.obsm["spatial"] = spatial + + if area is not None: + area_for_keys = pd.Series(np.asarray(area, dtype=float), index=centroids.index).reindex(keys).to_numpy() + col = ( + table.obs["area"].to_numpy(dtype=float).copy() + if "area" in table.obs + else np.full(table.n_obs, np.nan) + ) + col[mask] = area_for_keys + table.obs["area"] = col + + +@get_centroids.register(SpatialData) +def _( + e: SpatialData, + element_name: str, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: PersistAs = "Points", + table_name: str | None = None, +) -> DaskDataFrame | SpatialData: + """Get (or persist) the centroids of a named element of a :class:`~spatialdata.SpatialData`. + + With ``persist_as="Points"`` (default) this returns the centroids of ``element_name`` as a Points + element, identical to calling :func:`get_centroids` on the element directly. With + ``persist_as="adata"`` the centroids are written, squidpy-style, into the element's annotating + table — ``obsm["spatial"]`` (columns ordered ``x, y[, z]``) and, if ``return_area``, + ``obs["area"]`` — and the mutated ``SpatialData`` is returned. The annotating table is resolved + automatically (pass ``table_name=`` to disambiguate); if the element has no annotating table this + raises, and you should use ``persist_as="Points"`` instead. + """ + _validate_persist_args(persist_as, coordinate_system, allow_adata=True) + element = e[element_name] + + if persist_as == "Points": + return get_centroids( + element, + coordinate_system=coordinate_system, + return_background=return_background, + return_area=return_area, + ) + + # persist_as == "adata": resolve the annotating table and write centroids into it (in place). + table_name = _resolve_annotating_table(e, element_name, table_name) + df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area) + coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order) + coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system) + centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index) + _write_centroids_into_table(e.tables[table_name], element_name, centroids, area) + return e ## diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index f8c8be1da..4b0c9352a 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -172,6 +172,106 @@ def test_get_centroids_labels( assert np.allclose(centroids.compute().values, centroids_transformed) +def test_get_centroids_labels_area(labels): + # area for labels is the per-label pixel count; it rides along as a feature column of the Points. + element = labels["labels2d"] + centroids = get_centroids(element, return_area=True) + assert "area" in centroids.columns + ids, counts = np.unique(np.asarray(element.data), return_counts=True) + expected = dict(zip(ids, counts, strict=True)) + got = centroids[["area"]].compute()["area"] + assert not (got.index == 0).any() # background dropped + for label_id, area in got.items(): + assert area == expected[label_id] + + +def test_get_centroids_shapes_area_circles(shapes): + element = shapes["circles"] + centroids = get_centroids(element, return_area=True) + expected = np.pi * np.asarray(element["radius"], dtype=float) ** 2 + assert np.allclose(centroids["area"].compute().to_numpy(), expected) + + +@pytest.mark.parametrize("shapes_name", ["poly", "multipoly"]) +def test_get_centroids_shapes_area_polygons(shapes, shapes_name: str): + element = shapes[shapes_name] + centroids = get_centroids(element, return_area=True) + assert np.allclose(centroids["area"].compute().to_numpy(), element.geometry.area.to_numpy()) + + +def test_get_centroids_points_area_raises(points): + with pytest.raises(ValueError, match="not supported for points"): + get_centroids(points["points_0"], return_area=True) + + +def test_get_centroids_element_persist_adata_raises(labels): + # an element on its own has no annotating table; persist_as='adata' needs the SpatialData. + with pytest.raises(ValueError, match="persist_as='adata'"): + get_centroids(labels["labels2d"], persist_as="adata") + + +def test_get_centroids_sdata_persist_into_table(full_sdata): + # `table` annotates `labels2d` (instance_id 0..99 == label values); background (0) -> NaN. + table = full_sdata["table"] + assert "spatial" not in table.obsm + out = get_centroids(full_sdata, "labels2d", coordinate_system="global", return_area=True, persist_as="adata") + assert out is full_sdata # written in place + + table = full_sdata["table"] + assert table.obsm["spatial"].shape == (table.n_obs, 2) + assert "area" in table.obs + + inst = table.obs["instance_id"].to_numpy() + finite = np.isfinite(table.obsm["spatial"]).all(axis=1) + assert finite[inst != 0].all() # every non-background label got a centroid + assert not finite[inst == 0].any() # background row stays NaN + + # coordinates must match the element-level Points (global transform is the identity here) + pts = get_centroids(full_sdata["labels2d"], coordinate_system="global").compute() + written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) + common = pts.index.intersection(written.index[inst != 0]) + assert len(common) > 0 + assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + + # area must equal the pixel counts of the corresponding labels + ids, counts = np.unique(np.asarray(full_sdata["labels2d"].data), return_counts=True) + count_of = dict(zip(ids, counts, strict=True)) + area = table.obs["area"].to_numpy() + for row, label_id in enumerate(inst): + if label_id != 0: + assert area[row] == count_of[label_id] + + +def test_get_centroids_sdata_persist_fastpath_matches_transform(full_sdata): + # the in-memory affine fast path (adata) must equal the dask transform() path (element Points). + set_transformation(full_sdata["labels2d"], affine, "aligned") + get_centroids(full_sdata, "labels2d", coordinate_system="aligned", persist_as="adata") + + table = full_sdata["table"] + inst = table.obs["instance_id"].to_numpy() + written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) + pts = get_centroids(full_sdata["labels2d"], coordinate_system="aligned").compute() + common = pts.index.intersection(written.index[inst != 0]) + assert len(common) > 0 + assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + + +def test_get_centroids_sdata_persist_intrinsic_matches_identity(full_sdata): + # coordinate_system=None (intrinsic) equals a coordinate system whose transform is the identity. + get_centroids(full_sdata, "labels2d", coordinate_system="global", persist_as="adata") + global_spatial = full_sdata["table"].obsm["spatial"].copy() + get_centroids(full_sdata, "labels2d", coordinate_system=None, persist_as="adata") + intrinsic_spatial = full_sdata["table"].obsm["spatial"] + finite = np.isfinite(global_spatial).all(axis=1) + assert np.allclose(global_spatial[finite], intrinsic_spatial[finite]) + + +def test_get_centroids_sdata_no_table_raises(full_sdata): + # points_0 is not annotated by any table -> the error points the user to persist_as='Points'. + with pytest.raises(ValueError, match="persist_as='Points'"): + get_centroids(full_sdata, "points_0", persist_as="adata") + + def test_get_centroids_invalid_element(images): # cannot compute centroids for images with pytest.raises(ValueError, match="Expected a `Labels` element. Found an `Image` instead."): From a78494163455b84253ff0271dbb52aa98c418c96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jun 2026 18:26:16 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/centroids.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index b8e6a5a02..19be1efbd 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -314,11 +314,7 @@ def _write_centroids_into_table( if area is not None: area_for_keys = pd.Series(np.asarray(area, dtype=float), index=centroids.index).reindex(keys).to_numpy() - col = ( - table.obs["area"].to_numpy(dtype=float).copy() - if "area" in table.obs - else np.full(table.n_obs, np.nan) - ) + col = table.obs["area"].to_numpy(dtype=float).copy() if "area" in table.obs else np.full(table.n_obs, np.nan) col[mask] = area_for_keys table.obs["area"] = col From 40d3ed704dc33e711d32e3bba76dfcfbeed86810 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 21 Jun 2026 14:40:14 +0200 Subject: [PATCH 3/7] fix(centroids): satisfy mypy for str|None coord system and widened return type - narrow transformation to BaseTransformation before to_affine_matrix - assert coordinate_system is not None after _validate_persist_args in the element overloads (None is rejected there for persist_as != 'adata') - cast element-dispatch get_centroids results to DaskDataFrame at the two callers (vectorize, datasets); only the SpatialData overload returns SpatialData --- src/spatialdata/_core/centroids.py | 4 ++++ src/spatialdata/_core/operations/vectorize.py | 5 +++-- src/spatialdata/dataloader/datasets.py | 5 +++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index 19be1efbd..f2dd45ab1 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -64,6 +64,7 @@ def _transform_centroid_coords( if coordinate_system is None: return xy t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) if isinstance(t, Identity): return xy matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes)) @@ -221,6 +222,7 @@ def _( if model not in [Labels2DModel, Labels3DModel]: raise ValueError("Expected a `Labels` element. Found an `Image` instead.") _validate_persist_args(persist_as, coordinate_system, allow_adata=False) + assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) _validate_coordinate_system(e, coordinate_system) df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) return _points_from_centroids(df, area, raster, coordinate_system) @@ -236,6 +238,7 @@ def _( ) -> DaskDataFrame: """Get the centroids of a Shapes element (circles or polygons/multipolygons).""" _validate_persist_args(persist_as, coordinate_system, allow_adata=False) + assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) _validate_coordinate_system(e, coordinate_system) xy_df, area = _get_centroids_for_shapes(e, return_area) return _points_from_centroids(xy_df, area, e, coordinate_system) @@ -253,6 +256,7 @@ def _( if return_area: raise ValueError("`return_area` is not supported for points elements (points have no area).") _validate_persist_args(persist_as, coordinate_system, allow_adata=False) + assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) _validate_coordinate_system(e, coordinate_system) axes = get_axes_names(e) assert axes in [("x", "y"), ("x", "y", "z")] diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py index 414584589..d12a3f1aa 100644 --- a/src/spatialdata/_core/operations/vectorize.py +++ b/src/spatialdata/_core/operations/vectorize.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import dask import numpy as np @@ -137,7 +137,8 @@ def _get_centroids(element: SpatialElement) -> pd.DataFrame: if INTRINSIC_COORDINATE_SYSTEM in d: raise RuntimeError(f"The name {INTRINSIC_COORDINATE_SYSTEM} is reserved.") d[INTRINSIC_COORDINATE_SYSTEM] = Identity() - centroids = get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM).compute() + # get_centroids returns DaskDataFrame for elements (only the SpatialData overload returns SpatialData). + centroids = cast(DaskDataFrame, get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM)).compute() del d[INTRINSIC_COORDINATE_SYSTEM] return centroids diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 03879abc8..07abacf60 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -5,12 +5,13 @@ from functools import partial from itertools import chain from types import MappingProxyType -from typing import Any +from typing import Any, cast import anndata as ad import numpy as np import pandas as pd from anndata import AnnData +from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from pandas import CategoricalDtype from scipy.sparse import issparse @@ -494,7 +495,7 @@ def _get_tile_coords( # extent, aka the tile size extent = (circles.radius * 2).values.reshape(-1, 1) - centroids_points = get_centroids(circles, coordinate_system=cs) + centroids_points = cast(DaskDataFrame, get_centroids(circles, coordinate_system=cs)) axes = get_axes_names(centroids_points) centroids_numpy = centroids_points.compute().values From 2843c71ae04fbb84ba5f852c1b002c4bb5a67456 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 21 Jun 2026 15:16:39 +0200 Subject: [PATCH 4/7] refactor(centroids): harden table-write, collapse element overloads, trim docs Correctness: - _write_centroids_into_table: align instance ids via a single get_indexer; raise on a total miss (instance_key dtype != element index) instead of silently writing an all-NaN obsm['spatial']; partial misses (background, filtered instances) still NaN-fill as documented. - refuse to overwrite obsm['spatial'] when an existing array has a different width (would wipe coordinates of other regions sharing the table). - validate coordinate_system on the persist_as='adata' path too. Cleanup (no behavior change): - collapse the three near-identical element overloads into one stacked singledispatch registration delegating to _intrinsic_centroid_frame. - trim verbose docstrings/comments. Tests: instance_key dtype mismatch raises; dimension mismatch refuses to overwrite; invalid-element message updated. --- src/spatialdata/_core/centroids.py | 99 +++++++++++------------------- tests/core/test_centroids.py | 20 +++++- 2 files changed, 55 insertions(+), 64 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index f2dd45ab1..c63885bf6 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -53,13 +53,9 @@ def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, al def _transform_centroid_coords( xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None ) -> np.ndarray: - """Apply the element's transformation to centroid coordinates in-memory. + """Apply the element's affine to centroid coords in-memory; ``None``/``Identity`` pass through. - Centroids are tiny (one row per instance), so the affine matrix is applied directly here instead - of routing through the dask :func:`~spatialdata.transform` machinery (which builds a dask graph, - computes twice and re-validates regardless of the transformation). ``coordinate_system=None`` and - an :class:`~spatialdata.transformations.Identity` transformation short-circuit to the intrinsic - coordinates unchanged. ``axes`` is the order of the columns of ``xy`` (e.g. ``["x", "y"]``). + ``axes`` is the column order of ``xy`` (e.g. ``["x", "y"]``). """ if coordinate_system is None: return xy @@ -113,8 +109,7 @@ def get_centroids( ----- For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution - ``scale0`` level. Because centroids are tiny (one row per instance), the table-writing path applies the coordinate - transformation in-memory and short-circuits when it is an identity, avoiding the cost of the dask transform. + ``scale0`` level. """ raise ValueError(f"The object type {type(e)} is not supported.") @@ -192,7 +187,7 @@ def _intrinsic_centroid_frame( axes = get_axes_names(element) assert axes in [("x", "y"), ("x", "y", "z")] return element[list(axes)].compute(), None, element - raise ValueError(f"Centroids are not supported for the {model.__name__} element {element!r}.") + raise ValueError(f"Centroids are not supported for {model.__name__}; expected a Labels, Shapes or Points element.") def _points_from_centroids( @@ -210,58 +205,21 @@ def _points_from_centroids( @get_centroids.register(DataArray) @get_centroids.register(DataTree) -def _( - e: DataArray | DataTree, - coordinate_system: str | None = "global", - return_background: bool = False, - return_area: bool = False, - persist_as: PersistAs = "Points", -) -> DaskDataFrame: - """Get the centroids of a Labels element (2D or 3D).""" - model = get_model(e) - if model not in [Labels2DModel, Labels3DModel]: - raise ValueError("Expected a `Labels` element. Found an `Image` instead.") - _validate_persist_args(persist_as, coordinate_system, allow_adata=False) - assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) - _validate_coordinate_system(e, coordinate_system) - df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) - return _points_from_centroids(df, area, raster, coordinate_system) - - @get_centroids.register(GeoDataFrame) -def _( - e: GeoDataFrame, - coordinate_system: str | None = "global", - return_background: bool = False, - return_area: bool = False, - persist_as: PersistAs = "Points", -) -> DaskDataFrame: - """Get the centroids of a Shapes element (circles or polygons/multipolygons).""" - _validate_persist_args(persist_as, coordinate_system, allow_adata=False) - assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) - _validate_coordinate_system(e, coordinate_system) - xy_df, area = _get_centroids_for_shapes(e, return_area) - return _points_from_centroids(xy_df, area, e, coordinate_system) - - @get_centroids.register(DaskDataFrame) def _( - e: DaskDataFrame, + e: SpatialElement, coordinate_system: str | None = "global", return_background: bool = False, return_area: bool = False, persist_as: PersistAs = "Points", ) -> DaskDataFrame: - """Get the centroids of a Points element.""" - if return_area: - raise ValueError("`return_area` is not supported for points elements (points have no area).") + """Get the centroids of a Labels, Shapes or Points element.""" _validate_persist_args(persist_as, coordinate_system, allow_adata=False) assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) _validate_coordinate_system(e, coordinate_system) - axes = get_axes_names(e) - assert axes in [("x", "y"), ("x", "y", "z")] - coords = e[list(axes)].compute() - return _points_from_centroids(coords, None, e, coordinate_system) + df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) + return _points_from_centroids(df, area, raster, coordinate_system) def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str: @@ -303,23 +261,40 @@ def _write_centroids_into_table( if not mask.any(): raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).") + # Map each annotated instance id to its centroid row (-1 where absent, e.g. background or filtered + # instances -> NaN). A *total* miss means the instance_key and the element index never align (e.g. + # string vs integer ids); fail loudly rather than silently writing an all-NaN obsm["spatial"]. keys = table.obs[instance_key].to_numpy()[mask] + idx = centroids.index.get_indexer(keys) + if (idx == -1).all(): + raise ValueError( + f"No instance id annotating {element_name!r} matches a centroid; check that the table's " + f"`{instance_key}` dtype matches the element's instance ids." + ) + hit = idx != -1 coord_cols = list(centroids.columns) ndim = len(coord_cols) - coords = centroids.reindex(keys)[coord_cols].to_numpy(dtype=float) existing = np.asarray(table.obsm["spatial"]) if "spatial" in table.obsm else None + if existing is not None and existing.shape[0] == table.n_obs and existing.shape[1] != ndim: + raise ValueError( + f"Existing obsm['spatial'] has {existing.shape[1]} columns but {element_name!r} centroids have {ndim}; " + f"refusing to overwrite the coordinates of other regions. Persist these centroids with persist_as='Points'." + ) if existing is not None and existing.shape == (table.n_obs, ndim): spatial = existing.astype(float, copy=True) else: spatial = np.full((table.n_obs, ndim), np.nan) - spatial[mask] = coords + written = np.full((len(idx), ndim), np.nan) + written[hit] = centroids[coord_cols].to_numpy(dtype=float)[idx[hit]] + spatial[mask] = written table.obsm["spatial"] = spatial if area is not None: - area_for_keys = pd.Series(np.asarray(area, dtype=float), index=centroids.index).reindex(keys).to_numpy() col = table.obs["area"].to_numpy(dtype=float).copy() if "area" in table.obs else np.full(table.n_obs, np.nan) - col[mask] = area_for_keys + written_area = np.full(len(idx), np.nan) + written_area[hit] = np.asarray(area, dtype=float)[idx[hit]] + col[mask] = written_area table.obs["area"] = col @@ -333,15 +308,11 @@ def _( persist_as: PersistAs = "Points", table_name: str | None = None, ) -> DaskDataFrame | SpatialData: - """Get (or persist) the centroids of a named element of a :class:`~spatialdata.SpatialData`. - - With ``persist_as="Points"`` (default) this returns the centroids of ``element_name`` as a Points - element, identical to calling :func:`get_centroids` on the element directly. With - ``persist_as="adata"`` the centroids are written, squidpy-style, into the element's annotating - table — ``obsm["spatial"]`` (columns ordered ``x, y[, z]``) and, if ``return_area``, - ``obs["area"]`` — and the mutated ``SpatialData`` is returned. The annotating table is resolved - automatically (pass ``table_name=`` to disambiguate); if the element has no annotating table this - raises, and you should use ``persist_as="Points"`` instead. + """Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table. + + With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) + of the resolved annotating table (``table_name=`` disambiguates), and the mutated ``SpatialData`` is + returned. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element directly. """ _validate_persist_args(persist_as, coordinate_system, allow_adata=True) element = e[element_name] @@ -355,6 +326,8 @@ def _( ) # persist_as == "adata": resolve the annotating table and write centroids into it (in place). + if coordinate_system is not None: + _validate_coordinate_system(element, coordinate_system) table_name = _resolve_annotating_table(e, element_name, table_name) df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area) coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order) diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index 4b0c9352a..1e43e6f78 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -266,6 +266,24 @@ def test_get_centroids_sdata_persist_intrinsic_matches_identity(full_sdata): assert np.allclose(global_spatial[finite], intrinsic_spatial[finite]) +def test_get_centroids_sdata_persist_instance_key_mismatch_raises(full_sdata): + # instance ids stored with a dtype that doesn't match the element's integer labels must fail + # loudly instead of silently writing NaN coordinates into obsm["spatial"]. + table = full_sdata["table"] + table.obs["instance_id"] = table.obs["instance_id"].astype(str) + with pytest.raises(ValueError, match="No instance id annotating"): + get_centroids(full_sdata, "labels2d", persist_as="adata") + + +def test_get_centroids_sdata_persist_refuses_dim_mismatch(full_sdata): + # an existing obsm["spatial"] of a different width must not be silently overwritten (that would + # wipe the coordinates of other regions sharing the table). + table = full_sdata["table"] + table.obsm["spatial"] = np.zeros((table.n_obs, 3)) + with pytest.raises(ValueError, match="refusing to overwrite"): + get_centroids(full_sdata, "labels2d", persist_as="adata") + + def test_get_centroids_sdata_no_table_raises(full_sdata): # points_0 is not annotated by any table -> the error points the user to persist_as='Points'. with pytest.raises(ValueError, match="persist_as='Points'"): @@ -274,7 +292,7 @@ def test_get_centroids_sdata_no_table_raises(full_sdata): def test_get_centroids_invalid_element(images): # cannot compute centroids for images - with pytest.raises(ValueError, match="Expected a `Labels` element. Found an `Image` instead."): + with pytest.raises(ValueError, match="Centroids are not supported for Image2DModel"): get_centroids(images["image2d"]) # cannot compute centroids for tables From cdfcf71736db18e2ec433872d656eab3f4b74dd3 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 21 Jun 2026 15:39:33 +0200 Subject: [PATCH 5/7] feat(centroids): add inplace/copy semantics + SpatialData.get_centroids method - persist_as='adata' gains inplace: bool = True (mirrors sanitize_table): inplace=True mutates the resolved table and returns None; inplace=False copies only that AnnData, writes into the copy, returns it (sdata untouched). No whole-SpatialData copy. - add SpatialData.get_centroids(element_name, ...) method delegating to the module function (mirrors sdata.aggregate), so both sd.get_centroids(sdata, ...) and sdata.get_centroids(...) work. Calls the named _get_centroids_sdata overload directly so the full signature type-checks. - widen return type to DaskDataFrame | AnnData | None; update tests. --- src/spatialdata/_core/centroids.py | 26 ++++++++++++++---------- src/spatialdata/_core/spatialdata.py | 30 ++++++++++++++++++++++++++++ tests/core/test_centroids.py | 21 ++++++++++++++++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index c63885bf6..063e89412 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -75,7 +75,7 @@ def get_centroids( return_background: bool = False, return_area: bool = False, persist_as: PersistAs = "Points", -) -> DaskDataFrame | SpatialData: +) -> DaskDataFrame | AnnData | None: """ Get the centroids of the geometries contained in a SpatialElement. @@ -102,8 +102,8 @@ def get_centroids( Returns ------- - A Points element (``persist_as="Points"``), or the mutated ``SpatialData`` when writing centroids - into an annotating table (``SpatialData`` overload with ``persist_as="adata"``). + A Points element (``persist_as="Points"``). With ``persist_as="adata"`` (``SpatialData`` overload), + ``None`` when written in place, or the new ``AnnData`` table when ``inplace=False``. Notes ----- @@ -299,7 +299,7 @@ def _write_centroids_into_table( @get_centroids.register(SpatialData) -def _( +def _get_centroids_sdata( e: SpatialData, element_name: str, coordinate_system: str | None = "global", @@ -307,12 +307,14 @@ def _( return_area: bool = False, persist_as: PersistAs = "Points", table_name: str | None = None, -) -> DaskDataFrame | SpatialData: + inplace: bool = True, +) -> DaskDataFrame | AnnData | None: """Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table. - With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) - of the resolved annotating table (``table_name=`` disambiguates), and the mutated ``SpatialData`` is - returned. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element directly. + With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) of the + resolved annotating table (``table_name=`` disambiguates). ``inplace=True`` (default) mutates that table and + returns ``None``; ``inplace=False`` writes into a copy of *only that table* and returns the new ``AnnData``, + leaving ``e`` untouched. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element. """ _validate_persist_args(persist_as, coordinate_system, allow_adata=True) element = e[element_name] @@ -325,7 +327,7 @@ def _( return_area=return_area, ) - # persist_as == "adata": resolve the annotating table and write centroids into it (in place). + # persist_as == "adata": resolve the annotating table and write the centroids into it. if coordinate_system is not None: _validate_coordinate_system(element, coordinate_system) table_name = _resolve_annotating_table(e, element_name, table_name) @@ -333,8 +335,10 @@ def _( coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order) coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system) centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index) - _write_centroids_into_table(e.tables[table_name], element_name, centroids, area) - return e + + table = e.tables[table_name] if inplace else e.tables[table_name].copy() + _write_centroids_into_table(table, element_name, centroids, area) + return None if inplace else table ## diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fb55ab086..97bb2fa29 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -543,6 +543,36 @@ def aggregate( **kwargs, ) + def get_centroids( + self, + element_name: str, + coordinate_system: str | None = "global", + return_background: bool = False, + return_area: bool = False, + persist_as: Literal["Points", "adata"] = "Points", + table_name: str | None = None, + inplace: bool = True, + ) -> DaskDataFrame | AnnData | None: + """Get the centroids of ``element_name``, or persist them into its annotating table. + + Convenience method for :func:`spatialdata.get_centroids` called on ``self``; see that function for the + complete docstring. With ``persist_as="adata"`` the centroids are written into ``obsm["spatial"]`` (and area + into ``obs["area"]``) of the resolved annotating table; ``inplace=True`` mutates it and returns ``None``, + ``inplace=False`` returns a modified copy of that table. + """ + from spatialdata._core.centroids import _get_centroids_sdata + + return _get_centroids_sdata( + self, + element_name, + coordinate_system=coordinate_system, + return_background=return_background, + return_area=return_area, + persist_as=persist_as, + table_name=table_name, + inplace=inplace, + ) + def is_backed(self) -> bool: """Check if the data is backed by a Zarr storage or if it is in-memory.""" return self.path is not None diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index 1e43e6f78..4b6ba6bbe 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -215,7 +215,7 @@ def test_get_centroids_sdata_persist_into_table(full_sdata): table = full_sdata["table"] assert "spatial" not in table.obsm out = get_centroids(full_sdata, "labels2d", coordinate_system="global", return_area=True, persist_as="adata") - assert out is full_sdata # written in place + assert out is None # inplace=True (default) mutates the table and returns nothing table = full_sdata["table"] assert table.obsm["spatial"].shape == (table.n_obs, 2) @@ -266,6 +266,25 @@ def test_get_centroids_sdata_persist_intrinsic_matches_identity(full_sdata): assert np.allclose(global_spatial[finite], intrinsic_spatial[finite]) +def test_get_centroids_sdata_persist_inplace_false_returns_copy(full_sdata): + # inplace=False copies only the target table, writes into the copy, and leaves the sdata untouched. + out = get_centroids(full_sdata, "labels2d", return_area=True, persist_as="adata", inplace=False) + assert isinstance(out, AnnData) + assert out is not full_sdata["table"] + assert "spatial" in out.obsm and "area" in out.obs + assert "spatial" not in full_sdata["table"].obsm # original table not modified + + +def test_spatialdata_get_centroids_method(full_sdata): + # the method mirrors the module-level function for both persistence modes. + pts = full_sdata.get_centroids("labels2d", coordinate_system="global") + expected = get_centroids(full_sdata["labels2d"], coordinate_system="global") + assert np.allclose(pts.compute().to_numpy(), expected.compute().to_numpy()) + + assert full_sdata.get_centroids("labels2d", persist_as="adata") is None + assert "spatial" in full_sdata["table"].obsm + + def test_get_centroids_sdata_persist_instance_key_mismatch_raises(full_sdata): # instance ids stored with a dtype that doesn't match the element's integer labels must fail # loudly instead of silently writing NaN coordinates into obsm["spatial"]. From efbb469e8aa3f7b8b32dc532a2d474e610c8d158 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 21 Jun 2026 16:04:41 +0200 Subject: [PATCH 6/7] refactor(centroids): harden table-write edges, reuse affine helper, dedup scatter Correctness hardening (review round 2): - raise a clear error when the element index is non-unique (was a cryptic pandas InvalidIndexError from get_indexer on e.g. duplicate-index shapes). - obsm width-mismatch guard is now 1-D-safe (compare full .shape instead of indexing .shape[1], which raised IndexError on a 1-D obsm['spatial']). - reword the total-miss error (it's 'no shared instances', not only a dtype bug). Cleanup: - reuse the existing _affine_matrix_multiplication helper in _transform_centroid_coords (drops the inline matmul; gives the dead helper a caller). - collapse the duplicated coords/area NaN-fill into a local _scatter(). - _points_from_centroids: df.assign instead of copy-then-mutate. - name the squidpy storage keys (_SPATIAL_KEY/_AREA_KEY) instead of magic strings. --- src/spatialdata/_core/centroids.py | 68 ++++++++++++++++-------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index 063e89412..2c0954c21 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -14,6 +14,7 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import get_axes_names, get_table_keys from spatialdata.models._utils import SpatialElement from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model @@ -23,6 +24,9 @@ BoundingBoxDescription = dict[str, tuple[float, float]] PersistAs = Literal["Points", "adata"] +# squidpy-style storage keys for persist_as="adata". +_SPATIAL_KEY = "spatial" +_AREA_KEY = "area" def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: @@ -64,8 +68,7 @@ def _transform_centroid_coords( if isinstance(t, Identity): return xy matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes)) - n = len(axes) - return xy @ matrix[:n, :n].T + matrix[:n, n] + return _affine_matrix_multiplication(matrix, xy) @singledispatch @@ -194,9 +197,7 @@ def _points_from_centroids( df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str ) -> DaskDataFrame: """Build a Points element from intrinsic centroids, transformed into ``coordinate_system``.""" - out = df.copy() - if area is not None: - out["area"] = np.asarray(area, dtype=float) + out = df.assign(area=np.asarray(area, dtype=float)) if area is not None else df t = get_transformation(e, coordinate_system) assert isinstance(t, BaseTransformation) points = PointsModel.parse(out, transformations={coordinate_system: t}) @@ -256,46 +257,51 @@ def _write_centroids_into_table( Only the table rows annotating ``element_name`` are touched (a table may annotate several elements); instances annotated but absent from the element are written as NaN. """ + if not centroids.index.is_unique: + raise ValueError(f"Cannot persist centroids for {element_name!r}: its instance index has duplicate values.") _, region_key, instance_key = get_table_keys(table) mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy() if not mask.any(): raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).") - # Map each annotated instance id to its centroid row (-1 where absent, e.g. background or filtered - # instances -> NaN). A *total* miss means the instance_key and the element index never align (e.g. - # string vs integer ids); fail loudly rather than silently writing an all-NaN obsm["spatial"]. + # Map each annotated instance to its centroid row (-1 where absent -> NaN). A *total* miss means the + # instance ids never align with the element index (mismatched dtype, or no shared instances). keys = table.obs[instance_key].to_numpy()[mask] idx = centroids.index.get_indexer(keys) if (idx == -1).all(): raise ValueError( - f"No instance id annotating {element_name!r} matches a centroid; check that the table's " - f"`{instance_key}` dtype matches the element's instance ids." + f"No instance id annotating {element_name!r} is present in the element; check the table's " + f"`{instance_key}` values and dtype." ) hit = idx != -1 - coord_cols = list(centroids.columns) - ndim = len(coord_cols) - existing = np.asarray(table.obsm["spatial"]) if "spatial" in table.obsm else None - if existing is not None and existing.shape[0] == table.n_obs and existing.shape[1] != ndim: - raise ValueError( - f"Existing obsm['spatial'] has {existing.shape[1]} columns but {element_name!r} centroids have {ndim}; " - f"refusing to overwrite the coordinates of other regions. Persist these centroids with persist_as='Points'." - ) - if existing is not None and existing.shape == (table.n_obs, ndim): - spatial = existing.astype(float, copy=True) - else: - spatial = np.full((table.n_obs, ndim), np.nan) - written = np.full((len(idx), ndim), np.nan) - written[hit] = centroids[coord_cols].to_numpy(dtype=float)[idx[hit]] - spatial[mask] = written - table.obsm["spatial"] = spatial + def _scatter(values: np.ndarray) -> np.ndarray: + """Gather ``values`` (ordered like ``centroids``) onto the masked rows, NaN where absent.""" + out = np.full((len(idx), *values.shape[1:]), np.nan) + out[hit] = values[idx[hit]] + return out + + ndim = centroids.shape[1] + spatial = np.full((table.n_obs, ndim), np.nan) + existing = table.obsm.get(_SPATIAL_KEY) + if existing is not None: + existing = np.asarray(existing) + if existing.shape == (table.n_obs, ndim): + spatial = existing.astype(float, copy=True) # preserve other regions' coordinates + elif existing.shape[0] == table.n_obs: + raise ValueError( + f"Existing obsm['{_SPATIAL_KEY}'] {existing.shape} is incompatible with {ndim}-D centroids for " + f"{element_name!r}; refusing to overwrite other regions. Persist with persist_as='Points' instead." + ) + spatial[mask] = _scatter(centroids.to_numpy(dtype=float)) + table.obsm[_SPATIAL_KEY] = spatial if area is not None: - col = table.obs["area"].to_numpy(dtype=float).copy() if "area" in table.obs else np.full(table.n_obs, np.nan) - written_area = np.full(len(idx), np.nan) - written_area[hit] = np.asarray(area, dtype=float)[idx[hit]] - col[mask] = written_area - table.obs["area"] = col + col = np.full(table.n_obs, np.nan) + if _AREA_KEY in table.obs: + col = table.obs[_AREA_KEY].to_numpy(dtype=float).copy() + col[mask] = _scatter(np.asarray(area, dtype=float)) + table.obs[_AREA_KEY] = col @get_centroids.register(SpatialData) From 582f16306ba89e71f57aeecaca7d2bbc08545fe9 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 21 Jun 2026 16:23:12 +0200 Subject: [PATCH 7/7] refactor(centroids): drop redundant str() cast, dedup obsm-vs-points test assertion - _resolve_annotating_table returns annotators[0] directly (get_element_annotators already yields set[str]; the str() wrap was noise). - extract _assert_obsm_matches_points test helper to remove two near-identical written-obsm vs element-Points comparison blocks. --- src/spatialdata/_core/centroids.py | 2 +- tests/core/test_centroids.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index 2c0954c21..be8c5cf6f 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -243,7 +243,7 @@ def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); " f"pass `table_name=` to choose one." ) - return str(annotators[0]) + return annotators[0] def _write_centroids_into_table( diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index 4b6ba6bbe..6485ac4ef 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -16,6 +16,15 @@ RNG = default_rng(42) +def _assert_obsm_matches_points(table: AnnData, pts: pd.DataFrame) -> None: + # written obsm["spatial"] must match the element-level Points centroids on shared (non-background) ids. + inst = table.obs["instance_id"].to_numpy() + written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) + common = pts.index.intersection(written.index[inst != 0]) + assert len(common) > 0 + assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + + def _get_affine() -> Affine: theta: float = math.pi / 18 k = 10.0 @@ -228,10 +237,7 @@ def test_get_centroids_sdata_persist_into_table(full_sdata): # coordinates must match the element-level Points (global transform is the identity here) pts = get_centroids(full_sdata["labels2d"], coordinate_system="global").compute() - written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) - common = pts.index.intersection(written.index[inst != 0]) - assert len(common) > 0 - assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + _assert_obsm_matches_points(table, pts) # area must equal the pixel counts of the corresponding labels ids, counts = np.unique(np.asarray(full_sdata["labels2d"].data), return_counts=True) @@ -247,13 +253,8 @@ def test_get_centroids_sdata_persist_fastpath_matches_transform(full_sdata): set_transformation(full_sdata["labels2d"], affine, "aligned") get_centroids(full_sdata, "labels2d", coordinate_system="aligned", persist_as="adata") - table = full_sdata["table"] - inst = table.obs["instance_id"].to_numpy() - written = pd.DataFrame(table.obsm["spatial"], index=inst, columns=["x", "y"]) pts = get_centroids(full_sdata["labels2d"], coordinate_system="aligned").compute() - common = pts.index.intersection(written.index[inst != 0]) - assert len(common) > 0 - assert np.allclose(written.loc[common].to_numpy(), pts.loc[common][["x", "y"]].to_numpy()) + _assert_obsm_matches_points(full_sdata["table"], pts) def test_get_centroids_sdata_persist_intrinsic_matches_identity(full_sdata):