Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ jobs:
fail-fast: false
matrix:
include:
- {os: windows-latest, python: "3.11", dask-version: "2026.3.0", name: "min dask"}
- {os: windows-latest, python: "3.12", dask-version: "2026.3.0", name: "min dask"}
- {os: windows-latest, python: "3.14", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.11", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.12", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.13", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.14", dask-version: "latest"}
- {os: macos-latest, python: "3.11", dask-version: "latest"}
- {os: macos-latest, python: "3.12", dask-version: "latest"}
- {os: macos-latest, python: "3.14", prerelease: "allow", name: "prerelease"}
env:
OS: ${{ matrix.os }}
Expand All @@ -41,7 +42,6 @@ jobs:
- name: Install dependencies
run: |
if [[ "${PRERELEASE}" == "allow" ]]; then
sed -i '' 's/requires-python.*//' pyproject.toml # otherwise uv complains that anndata requires python>=3.12 and we only do >=3.11 😱
uv add git+https://github.com/scverse/anndata.git
uv add pandas>=3.dev0
fi
Expand Down
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.11
python_version = 3.12

ignore_errors = False
warn_redundant_casts = True
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ maintainers = [
urls.Documentation = "https://spatialdata.scverse.org/en/latest"
urls.Source = "https://github.com/scverse/spatialdata.git"
urls.Home-page = "https://github.com/scverse/spatialdata.git"
requires-python = ">=3.11"
requires-python = ">=3.12"
dynamic= [
"version" # allow version to be set by git tags
]
Expand Down Expand Up @@ -145,7 +145,7 @@ exclude = [

]
line-length = 120
target-version = "py311"
target-version = "py312"

[tool.ruff.lint]
ignore = [
Expand Down
12 changes: 4 additions & 8 deletions src/spatialdata/_core/operations/rasterize_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,10 @@ def rasterize_bins(
table = sdata.tables[table_name]
if not isinstance(element, GeoDataFrame | DaskDataFrame | DataArray):
raise ValueError("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.")
if isinstance(element, DataArray):
if "c" in element.dims:
raise ValueError(
"If bins is a DataArray, it should hold labels; found a image element instead, with"
f" 'c': {element.dims}."
)
if not np.issubdtype(element.dtype, np.integer):
raise ValueError(f"If bins is a DataArray, it should hold integers. Found dtype {element.dtype}.")
if isinstance(element, DataArray) and "c" in element.dims:
raise ValueError(
f"If bins is a DataArray, it should hold labels; found a image element instead, with 'c': {element.dims}."
)

_, region_key, instance_key = get_table_keys(table)
if not table.obs[region_key].dtype == "category":
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _bounding_box_mask_points(
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
) -> list[ArrayLike]:
) -> list[np.ndarray]:
"""Compute a mask that is true for the points inside axis-aligned bounding boxes.

Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) ->
channel_names = element["scale0"]["image"].coords["c"].data.tolist()

channel_metadata = [{"label": name} for name in channel_names]
# This is required here as we do not use the load node API of ome-zarr
omero_meta = group.attrs.get("omero", None) or group.attrs.get("ome", {}).get("omero")
# We don't use the ome-zarr load node API, and ome-zarr-py >= 0.18 emits no `omero` block, so default to empty.
omero_meta = group.attrs.get("omero") or group.attrs.get("ome", {}).get("omero") or {}
omero_meta["channels"] = channel_metadata
if ome_meta := group.attrs.get("ome", None):
ome_meta["omero"] = omero_meta
Expand Down
12 changes: 4 additions & 8 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from spatialdata._io._utils import (
_get_transformations_from_ngff_dict,
overwrite_channel_names,
overwrite_coordinate_transformations_raster,
)
from spatialdata._io.format import (
Expand All @@ -28,7 +29,6 @@
get_ome_zarr_format,
)
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channel_names
from spatialdata.models.models import ATTRS_KEY
from spatialdata.models.pyramids_utils import dask_arrays_to_datatree
from spatialdata.transformations._utils import (
Expand Down Expand Up @@ -301,13 +301,6 @@ def _write_raster(
metadata["name"] = name
metadata["label_metadata"] = label_metadata

# convert channel names to channel metadata in omero
if raster_type == "image":
metadata["metadata"] = {"omero": {"channels": []}}
channels = get_channel_names(raster_data)
for c in channels:
metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload]

if isinstance(raster_data, DataArray):
_write_raster_dataarray(
raster_type,
Expand All @@ -334,6 +327,9 @@ def _write_raster(
raise ValueError("Not a valid labels object")

group = group["labels"][name] if raster_type == "labels" else group
if raster_type == "image":
# ome-zarr-py >= 0.18 no longer writes the omero channel metadata, so we write it ourselves.
overwrite_channel_names(group, raster_data)
if ATTRS_KEY not in group.attrs:
group.attrs[ATTRS_KEY] = {}
attrs = group.attrs[ATTRS_KEY]
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, TypeAlias
from typing import Any

import numpy as np
from xarray import DataArray, DataTree
Expand All @@ -12,5 +12,5 @@
ArrayLike = NDArray[np.floating[Any]]
IntArrayLike = NDArray[np.integer[Any]]

Raster_T: TypeAlias = DataArray | DataTree
type Raster_T = DataArray | DataTree
ColorLike = tuple[float, ...] | str
8 changes: 3 additions & 5 deletions src/spatialdata/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import spatialdata
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from spatialdata.dataloader.datasets import ImageTilesDataset
Expand All @@ -12,10 +10,10 @@
]


def __getattr__(attr_name: str) -> ImageTilesDataset | Any:
def __getattr__(attr_name: str) -> type[ImageTilesDataset]:
if attr_name == "ImageTilesDataset":
from spatialdata.dataloader.datasets import ImageTilesDataset

return ImageTilesDataset

return getattr(spatialdata.dataloader, attr_name)
raise AttributeError(f"module {__name__!r} has no attribute {attr_name!r}")
4 changes: 2 additions & 2 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from functools import singledispatch
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any

import dask.dataframe as dd
import geopandas
Expand All @@ -17,7 +17,7 @@
from spatialdata._utils import _check_match_length_channels_c_dim
from spatialdata.transformations.transformations import BaseTransformation

SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame
type SpatialElement = DataArray | DataTree | GeoDataFrame | DaskDataFrame
TRANSFORM_KEY = "transform"
DEFAULT_COORDINATE_SYSTEM = "global"
ValidAxis_t = str
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/models/chunks_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from typing import Any

Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]
type Chunks_t = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]


def normalize_chunks(
Expand Down
21 changes: 19 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Mapping, Sequence
from functools import singledispatchmethod
from pathlib import Path
from typing import Any, Literal, TypeAlias
from typing import Any, Literal

import dask.dataframe as dd
import numpy as np
Expand Down Expand Up @@ -398,6 +398,13 @@ def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None:
for d in data:
cls._check_chunk_size_not_too_large(data[d][name])

def _validate_labels_dtype(data: DataArray | DataTree) -> None:
dtype = data.dtype if isinstance(data, DataArray) else data["scale0"]["image"].dtype
if not (np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_)):
raise ValueError(
f"Labels must have an integer dtype, found {dtype}. Cast the data, e.g. `.astype(np.uint16)`."
)


class Labels2DModel(RasterSchema):
dims = (Y, X)
Expand All @@ -412,6 +419,11 @@ def parse( # noqa: D102
raise ValueError("`c_coords` is not supported for labels")
return super().parse(*args, **kwargs)

@classmethod
def validate(cls, data: Any) -> None:
super().validate(data)
cls._validate_labels_dtype(data)


class Labels3DModel(RasterSchema):
dims = (Z, Y, X)
Expand All @@ -422,6 +434,11 @@ def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D10
raise ValueError("`c_coords` is not supported for labels")
return super().parse(*args, **kwargs)

@classmethod
def validate(cls, data: Any) -> None:
super().validate(data)
cls._validate_labels_dtype(data)


class Image2DModel(RasterSchema):
dims = (C, Y, X)
Expand Down Expand Up @@ -1252,7 +1269,7 @@ def parse(
return adata


Schema_t: TypeAlias = (
type Schema_t = (
type[Image2DModel]
| type[Image3DModel]
| type[Labels2DModel]
Expand Down
18 changes: 0 additions & 18 deletions tests/core/operations/test_rasterize_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from spatialdata._types import ArrayLike
from spatialdata.models.models import (
Image2DModel,
Labels2DModel,
PointsModel,
ShapesModel,
TableModel,
Expand Down Expand Up @@ -276,23 +275,6 @@ def _get_sdata(n: int):
value_key="instance_id",
)

# if bins is a DataArray, it should hold integers
image = Labels2DModel.parse(RNG.normal(size=(3, 3)), dims=("y", "x"))
del sdata["points"]
sdata["points"] = image
with pytest.raises(
ValueError,
match=f"If bins is a DataArray, it should hold integers. Found dtype {image.dtype}.",
):
_ = rasterize_bins(
sdata=sdata,
bins="points",
table_name="table",
col_key="col_index",
row_key="row_index",
value_key="instance_id",
)


def test_relabel_labels(caplog):
obs = DataFrame(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_query_raster(
shape = (10,) + shape
shape = (n_channels,) + shape if not is_labels else (1,) + shape

image = np.zeros(shape)
image = np.zeros(shape, dtype=int if is_labels else float)
axes = ["y", "x"]
if is_3d:
image[:, 2:7, 5::, 0:5] = 1
Expand Down
16 changes: 8 additions & 8 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def test_raster_schema(
converter = partial(converter, dims=dims)
elif converter is to_spatial_image:
converter = partial(converter, dims=model.dims)
if n_dims == 2:
image: ArrayLike = RNG.uniform(size=(10, 10))
elif n_dims == 3:
image: ArrayLike = RNG.uniform(size=(3, 10, 10))
elif n_dims == 4:
image: ArrayLike = RNG.uniform(size=(2, 3, 10, 10))
# labels must be integer-valued, images can be float
shape = {2: (10, 10), 3: (3, 10, 10), 4: (2, 3, 10, 10)}[n_dims]
if model in [Labels2DModel, Labels3DModel]:
image: ArrayLike = RNG.integers(0, 100, size=shape)
else:
image = RNG.uniform(size=shape)
image = converter(image)
self._parse_transformation_from_multiple_places(model, image)
spatial_image = model.parse(image)
Expand Down Expand Up @@ -891,8 +891,8 @@ def test_label_no_c_coords(model: Labels2DModel | Labels3DModel):


def test_warning_on_large_chunks():
data_small = DataArray(dask.array.zeros((100, 100), chunks=(50, 50)), dims=["x", "y"])
data_large = DataArray(dask.array.zeros((50000, 50000), chunks=(50000, 50000)), dims=["x", "y"])
data_small = DataArray(dask.array.zeros((100, 100), chunks=(50, 50), dtype=np.int64), dims=["x", "y"])
data_large = DataArray(dask.array.zeros((50000, 50000), chunks=(50000, 50000), dtype=np.int64), dims=["x", "y"])
assert np.array(data_large.shape).prod().item() > LARGE_CHUNK_THRESHOLD_BYTES

# single and multiscale, small chunk size
Expand Down
Loading