Skip to content
320 changes: 264 additions & 56 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
from __future__ import annotations

from functools import singledispatch
from typing import Literal

import numpy as np
import pandas as pd
import xarray as xr
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Point, Polygon
from xarray import DataArray, DataTree

from spatialdata._core.operations.transform import transform
from spatialdata.models import get_axes_names
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import _affine_matrix_multiplication
from spatialdata.models import get_axes_names, get_table_keys
from spatialdata.models._utils import SpatialElement
from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model
from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import BaseTransformation
from spatialdata.transformations.transformations import BaseTransformation, Identity

BoundingBoxDescription = dict[str, tuple[float, float]]

PersistAs = Literal["Points", "adata"]
# squidpy-style storage keys for persist_as="adata".
_SPATIAL_KEY = "spatial"
_AREA_KEY = "area"


def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None:
d = get_transformation(e, get_all=True)
Expand All @@ -29,37 +38,92 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No
)


def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, allow_adata: bool) -> None:
if persist_as not in ("Points", "adata"):
raise ValueError(f"`persist_as` must be 'Points' or 'adata', got {persist_as!r}.")
if persist_as == "adata" and not allow_adata:
raise ValueError(
"persist_as='adata' writes centroids into the element's annotating table, which needs the "
"`SpatialData` object: call `get_centroids(sdata, element_name, ..., persist_as='adata')`. "
"To get the centroids as a standalone element instead, use persist_as='Points'."
)
# ``coordinate_system=None`` means "intrinsic coordinates, do not transform". An intrinsic Points
# element is ill-defined (Points always carry a coordinate system), so intrinsic coords are only
# meaningful when writing into a table (persist_as='adata').
if coordinate_system is None and persist_as != "adata":
raise ValueError("`coordinate_system=None` (intrinsic coordinates) is only supported with persist_as='adata'.")


def _transform_centroid_coords(
xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None
) -> np.ndarray:
"""Apply the element's affine to centroid coords in-memory; ``None``/``Identity`` pass through.

``axes`` is the column order of ``xy`` (e.g. ``["x", "y"]``).
"""
if coordinate_system is None:
return xy
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
if isinstance(t, Identity):
return xy
matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes))
return _affine_matrix_multiplication(matrix, xy)


@singledispatch
def get_centroids(
e: SpatialElement,
coordinate_system: str = "global",
e: SpatialElement | SpatialData,
coordinate_system: str | None = "global",
return_background: bool = False,
) -> DaskDataFrame:
return_area: bool = False,
persist_as: PersistAs = "Points",
) -> DaskDataFrame | AnnData | None:
"""
Get the centroids of the geometries contained in a SpatialElement, as a new Points element.
Get the centroids of the geometries contained in a SpatialElement.

Parameters
----------
e
The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported.
The SpatialElement (points, shapes — circles, polygons and multipolygons — or labels), or a
:class:`~spatialdata.SpatialData` object. When a ``SpatialData`` is passed, the second
positional argument is the name of the element to measure (see the ``SpatialData`` overload).
coordinate_system
The coordinate system in which the centroids are computed.
The coordinate system in which the centroids are computed. ``None`` returns the intrinsic
coordinates without applying any transformation (only supported with ``persist_as="adata"``).
return_background
If True, the centroid of the background label (0) is included in the output.
If True, the centroid of the background label (0) is included in the output (labels only).
return_area
If True, also return the per-instance area: the pixel/voxel count for labels and the geometric
area for shapes (``pi * r**2`` for circles). Not supported for points (raises). With
``persist_as="Points"`` the area is added as a feature column of the returned Points element.
persist_as
``"Points"`` (default) returns the centroids as a new Points element, transformed into
``coordinate_system``. ``"adata"`` writes the centroids (and area) into the element's
annotating table and is only available through the :class:`~spatialdata.SpatialData` overload,
which can resolve that table.

Returns
-------
A Points element (``persist_as="Points"``). With ``persist_as="adata"`` (``SpatialData`` overload),
``None`` when written in place, or the new ``AnnData`` table when ``inplace=False``.

Notes
-----
For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute
each :class:`~shapely.Multipolygon`.
each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution
``scale0`` level.
"""
raise ValueError(f"The object type {type(e)} is not supported.")


def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame:
def _get_centroids_for_labels(xdata: xr.DataArray, return_area: bool = False) -> pd.DataFrame:
"""
Compute centroids for all labels in a DataArray in a single O(n_voxels) pass.

Works for any number of spatial dimensions (2D and 3D labels).
Works for any number of spatial dimensions (2D and 3D labels). When ``return_area`` is True, an
``area`` column (the per-label pixel/voxel count) is added; it is already computed for the
centroids, so this is free.
"""
arr = xdata.data.compute()
axes = list(xdata.dims)
Expand All @@ -77,66 +141,210 @@ def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame:
coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float))
data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this)

return pd.DataFrame(data, index=label_ids)
df = pd.DataFrame(data, index=label_ids)
if return_area:
df["area"] = counts.astype(float)
return df


@get_centroids.register(DataArray)
@get_centroids.register(DataTree)
def _(
e: DataArray | DataTree,
coordinate_system: str = "global",
return_background: bool = False,
) -> DaskDataFrame:
"""Get the centroids of a Labels element (2D or 3D)."""
model = get_model(e)
if model not in [Labels2DModel, Labels3DModel]:
raise ValueError("Expected a `Labels` element. Found an `Image` instead.")
_validate_coordinate_system(e, coordinate_system)

if isinstance(e, DataTree):
assert len(e["scale0"]) == 1
e = next(iter(e["scale0"].values()))

df = _get_centroids_for_labels(e)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label
t = get_transformation(e, coordinate_system)
centroids = PointsModel.parse(df, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)


@get_centroids.register(GeoDataFrame)
def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Shapes element (circles or polygons/multipolygons)."""
_validate_coordinate_system(e, coordinate_system)
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
# separate points from (multi-)polygons
def _get_centroids_for_shapes(e: GeoDataFrame, return_area: bool) -> tuple[pd.DataFrame, np.ndarray | None]:
"""Intrinsic per-shape centroids (``x, y`` columns indexed by the element's index) and optional area."""
first_geometry = e["geometry"].iloc[0]
if isinstance(first_geometry, Point):
xy = e.geometry.get_coordinates().values
# shapely .area is 0 for circles (Point geometry); the radius column carries the size.
area = np.pi * np.asarray(e["radius"], dtype=float) ** 2 if return_area else None
else:
assert isinstance(first_geometry, Polygon | MultiPolygon), (
f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or"
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
)
xy = e.centroid.get_coordinates().values
area = e.geometry.area.to_numpy() if return_area else None
xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy())
points = PointsModel.parse(xy_df, transformations={coordinate_system: t})
return xy_df, area


def _intrinsic_centroid_frame(
element: SpatialElement, return_background: bool, return_area: bool
) -> tuple[pd.DataFrame, np.ndarray | None, SpatialElement]:
"""Per-instance intrinsic centroids (coordinate columns, indexed by instance id), optional area.

Also returns the element the centroids live on (for labels, the ``scale0`` level of a multiscale
raster), which carries the transformation to apply downstream.
"""
model = get_model(element)
if model in (Labels2DModel, Labels3DModel):
raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element
df = _get_centroids_for_labels(raster, return_area=return_area)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label (its area, if any, goes with it)
area = df.pop("area").to_numpy() if return_area else None
return df, area, raster
if model is ShapesModel:
xy_df, area = _get_centroids_for_shapes(element, return_area)
return xy_df, area, element
if model is PointsModel:
if return_area:
raise ValueError("`return_area` is not supported for points elements (points have no area).")
axes = get_axes_names(element)
assert axes in [("x", "y"), ("x", "y", "z")]
return element[list(axes)].compute(), None, element
raise ValueError(f"Centroids are not supported for {model.__name__}; expected a Labels, Shapes or Points element.")


def _points_from_centroids(
df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str
) -> DaskDataFrame:
"""Build a Points element from intrinsic centroids, transformed into ``coordinate_system``."""
out = df.assign(area=np.asarray(area, dtype=float)) if area is not None else df
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
points = PointsModel.parse(out, transformations={coordinate_system: t})
return transform(points, to_coordinate_system=coordinate_system)


@get_centroids.register(DataArray)
@get_centroids.register(DataTree)
@get_centroids.register(GeoDataFrame)
@get_centroids.register(DaskDataFrame)
def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Points element."""
def _(
e: SpatialElement,
coordinate_system: str | None = "global",
return_background: bool = False,
return_area: bool = False,
persist_as: PersistAs = "Points",
) -> DaskDataFrame:
"""Get the centroids of a Labels, Shapes or Points element."""
_validate_persist_args(persist_as, coordinate_system, allow_adata=False)
assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False)
_validate_coordinate_system(e, coordinate_system)
axes = get_axes_names(e)
assert axes in [("x", "y"), ("x", "y", "z")]
coords = e[list(axes)].compute()
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
centroids = PointsModel.parse(coords, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)
df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area)
return _points_from_centroids(df, area, raster, coordinate_system)


def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str:
"""Resolve the single table that annotates ``element_name`` (where centroids are written)."""
from spatialdata._core.query.relational_query import get_element_annotators

if table_name is not None:
if table_name not in sdata.tables:
raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.")
return table_name
annotators = sorted(get_element_annotators(sdata, element_name))
if not annotators:
raise ValueError(
f"Element {element_name!r} has no annotating table to write centroids into. Use "
f"persist_as='Points' to get the centroids as a Points element instead, or annotate the "
f"element with a table first."
)
if len(annotators) > 1:
raise ValueError(
f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); "
f"pass `table_name=` to choose one."
)
return annotators[0]


def _write_centroids_into_table(
table: AnnData,
element_name: str,
centroids: pd.DataFrame,
area: np.ndarray | None,
) -> None:
"""Write centroids into ``obsm["spatial"]`` and area into ``obs["area"]`` at the element's rows.

Only the table rows annotating ``element_name`` are touched (a table may annotate several
elements); instances annotated but absent from the element are written as NaN.
"""
if not centroids.index.is_unique:
raise ValueError(f"Cannot persist centroids for {element_name!r}: its instance index has duplicate values.")
_, region_key, instance_key = get_table_keys(table)
mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy()
if not mask.any():
raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).")

# Map each annotated instance to its centroid row (-1 where absent -> NaN). A *total* miss means the
# instance ids never align with the element index (mismatched dtype, or no shared instances).
keys = table.obs[instance_key].to_numpy()[mask]
idx = centroids.index.get_indexer(keys)
if (idx == -1).all():
raise ValueError(
f"No instance id annotating {element_name!r} is present in the element; check the table's "
f"`{instance_key}` values and dtype."
)
hit = idx != -1

def _scatter(values: np.ndarray) -> np.ndarray:
"""Gather ``values`` (ordered like ``centroids``) onto the masked rows, NaN where absent."""
out = np.full((len(idx), *values.shape[1:]), np.nan)
out[hit] = values[idx[hit]]
return out

ndim = centroids.shape[1]
spatial = np.full((table.n_obs, ndim), np.nan)
existing = table.obsm.get(_SPATIAL_KEY)
if existing is not None:
existing = np.asarray(existing)
if existing.shape == (table.n_obs, ndim):
spatial = existing.astype(float, copy=True) # preserve other regions' coordinates
elif existing.shape[0] == table.n_obs:
raise ValueError(
f"Existing obsm['{_SPATIAL_KEY}'] {existing.shape} is incompatible with {ndim}-D centroids for "
f"{element_name!r}; refusing to overwrite other regions. Persist with persist_as='Points' instead."
)
spatial[mask] = _scatter(centroids.to_numpy(dtype=float))
table.obsm[_SPATIAL_KEY] = spatial

if area is not None:
col = np.full(table.n_obs, np.nan)
if _AREA_KEY in table.obs:
col = table.obs[_AREA_KEY].to_numpy(dtype=float).copy()
col[mask] = _scatter(np.asarray(area, dtype=float))
table.obs[_AREA_KEY] = col


@get_centroids.register(SpatialData)
def _get_centroids_sdata(
e: SpatialData,
element_name: str,
coordinate_system: str | None = "global",
return_background: bool = False,
return_area: bool = False,
persist_as: PersistAs = "Points",
table_name: str | None = None,
inplace: bool = True,
) -> DaskDataFrame | AnnData | None:
"""Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table.

With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) of the
resolved annotating table (``table_name=`` disambiguates). ``inplace=True`` (default) mutates that table and
returns ``None``; ``inplace=False`` writes into a copy of *only that table* and returns the new ``AnnData``,
leaving ``e`` untouched. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element.
"""
_validate_persist_args(persist_as, coordinate_system, allow_adata=True)
element = e[element_name]

if persist_as == "Points":
return get_centroids(
element,
coordinate_system=coordinate_system,
return_background=return_background,
return_area=return_area,
)

# persist_as == "adata": resolve the annotating table and write the centroids into it.
if coordinate_system is not None:
_validate_coordinate_system(element, coordinate_system)
table_name = _resolve_annotating_table(e, element_name, table_name)
df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area)
coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order)
coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system)
centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index)

table = e.tables[table_name] if inplace else e.tables[table_name].copy()
_write_centroids_into_table(table, element_name, centroids, area)
return None if inplace else table


##
Loading