diff --git a/docs/raster_index/creating.md b/docs/raster_index/creating.md index ce0b5fe..8146e9f 100644 --- a/docs/raster_index/creating.md +++ b/docs/raster_index/creating.md @@ -63,6 +63,25 @@ ds = rasterix.assign_index(ds) ds ``` +### Zarr spatial convention + +For data following the [Zarr Spatial Convention](https://zarr-specs.readthedocs.io/en/latest/v3/conventions/spatial/v1.0.html), `spatial:` attributes are detected on the variable (array-level) or the Dataset (group-level), with array-level attributes taking precedence. The convention must be registered in the `zarr_conventions` attribute. When present, `spatial:dimensions` is used to auto-detect the spatial dimension names, so non-standard names work without passing `x_dim`/`y_dim`; the CRS is detected from the companion [proj: convention](https://github.com/zarr-conventions/geo-proj) if present: + +```{code-cell} +ds = xr.Dataset( + {"temperature": (("Y", "X"), np.random.rand(100, 100))}, + attrs={ + "zarr_conventions": [{"name": "spatial:"}, {"name": "proj:"}], + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [10.0, 0.0, 400000.0, 0.0, -10.0, 5000000.0], + "proj:code": "EPSG:32610", + }, +) + +ds = rasterix.assign_index(ds) +ds +``` + ## Direct construction with class methods For more control, you can create a {py:class}`RasterIndex` directly using class methods. This is useful when you have the transform parameters available directly, or when working with data that doesn't have embedded metadata. diff --git a/docs/raster_index/heuristics.md b/docs/raster_index/heuristics.md index 6fbe197..07a9f56 100644 --- a/docs/raster_index/heuristics.md +++ b/docs/raster_index/heuristics.md @@ -7,7 +7,7 @@ | 1 | `GeoTransform` | CF grid mapping variable (e.g. `spatial_ref`) | [GDAL GeoTransform](https://gdal.org/en/stable/tutorials/geotransforms_tut.html) | | 2 | `proj:transform` | DataArray `.attrs` | [STAC Projection Extension](https://github.com/stac-extensions/projection) | | 3 | `model_tiepoint` + `model_pixel_scale` | DataArray `.attrs` | [GeoTIFF spec](https://docs.ogc.org/is/19-008r4/19-008r4.html) | -| 4 | `spatial:transform` | DataArray `.attrs` | [Zarr Spatial Convention](https://zarr-specs.readthedocs.io/en/latest/v3/conventions/spatial/v1.0.html) | +| 4 | `spatial:transform` | DataArray, data variable, or Dataset `.attrs` | [Zarr Spatial Convention](https://zarr-specs.readthedocs.io/en/latest/v3/conventions/spatial/v1.0.html) | | 5 | 1D coordinate arrays | Coordinate variables for x/y dims | Common in NetCDF | ## Grid mapping variable lookup @@ -18,6 +18,12 @@ For priority 1, the grid mapping variable is found following CF conventions: - **Dataset**: uses the first `grid_mapping` attribute found across data variables - **Fallback**: a coordinate variable named `spatial_ref` +## Zarr spatial convention lookup + +For priority 4, the convention must be registered in the `zarr_conventions` attribute (by name `spatial:` or by UUID); bare `spatial:` attributes are ignored otherwise. Attributes are looked up on the DataArray for DataArrays, and on each data variable then the Dataset itself for Datasets — array-level attributes take precedence over group-level ones, per the convention. + +If a `spatial:dimensions` attribute is present, it is also used to auto-detect the x/y dimension names when `x_dim`/`y_dim` are not passed to {py:func}`~rasterix.assign_index`. The listed names are interpreted as `[y, x]`, following the convention's examples; `spatial:transform` uses standard [Affine](https://affine.readthedocs.io/en/latest/) coefficient ordering `[a, b, c, d, e, f]`, mapping `(column, row) -> (x, y)`. + ## Coordinate array fallback For priority 5, coordinate variables must be 1D with at least 2 values. Pixel spacing is computed as `x[1] - x[0]` and `y[1] - y[0]`, and coordinates are assumed to be pixel-centered. diff --git a/src/rasterix/lib.py b/src/rasterix/lib.py index a3b0e34..bbeebc5 100644 --- a/src/rasterix/lib.py +++ b/src/rasterix/lib.py @@ -1,6 +1,7 @@ """Shared library utilities for rasterix.""" import logging +from collections.abc import Mapping from typing import NotRequired, TypedDict from affine import Affine @@ -8,6 +9,9 @@ # https://github.com/zarr-conventions/spatial _ZARR_SPATIAL_CONVENTION_UUID = "689b58e2-cf7b-45e0-9fff-9cfc0883d6b4" +# https://github.com/zarr-conventions/geo-proj +_ZARR_GEO_PROJ_CONVENTION_UUID = "f17cb550-5864-4468-aeb7-f3180cfb622f" + # Define TRACE level (lower than DEBUG) TRACE = 5 @@ -117,6 +121,7 @@ def affine_from_stac_proj_metadata(metadata: dict) -> Affine | None: "_ZarrSpatialMetadata", { "zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]], + "spatial:dimensions": NotRequired[list[str]], "spatial:transform": NotRequired[list[float]], "spatial:transform_type": NotRequired[str], "spatial:registration": NotRequired[str], @@ -124,18 +129,36 @@ def affine_from_stac_proj_metadata(metadata: dict) -> Affine | None: ) -def _has_spatial_zarr_convention(metadata: _ZarrSpatialMetadata) -> bool: +_ZarrProjMetadata = TypedDict( + "_ZarrProjMetadata", + { + "zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]], + "proj:code": NotRequired[str], + "proj:wkt2": NotRequired[str], + "proj:projjson": NotRequired[object], + }, +) + + +def _has_zarr_convention(metadata: Mapping, *, uuid: str, name: str) -> bool: + """Check whether a Zarr convention is registered in the ``zarr_conventions`` attribute.""" zarr_conventions = metadata.get("zarr_conventions") if not zarr_conventions: return False for entry in zarr_conventions: - if isinstance(entry, dict) and ( - entry.get("uuid") == _ZARR_SPATIAL_CONVENTION_UUID or entry.get("name") == "spatial:" - ): + if isinstance(entry, dict) and (entry.get("uuid") == uuid or entry.get("name") == name): return True return False +def _has_spatial_zarr_convention(metadata: _ZarrSpatialMetadata) -> bool: + return _has_zarr_convention(metadata, uuid=_ZARR_SPATIAL_CONVENTION_UUID, name="spatial:") + + +def _has_proj_zarr_convention(metadata: _ZarrProjMetadata) -> bool: + return _has_zarr_convention(metadata, uuid=_ZARR_GEO_PROJ_CONVENTION_UUID, name="proj:") + + def affine_from_spatial_zarr_convention(metadata: dict) -> Affine | None: """Extract Affine transform from Zarr spatial convention metadata. @@ -178,3 +201,32 @@ def affine_from_spatial_zarr_convention(metadata: dict) -> Affine | None: return Affine(*map(float, transform[:6])) return None + + +def spatial_dims_from_zarr_convention(metadata: dict) -> tuple[str, str] | None: + """Extract spatial dimension names from Zarr spatial convention metadata. + + See https://github.com/zarr-conventions/spatial for the full specification. + + Parameters + ---------- + metadata : dict + Dictionary containing Zarr spatial convention metadata. + + Returns + ------- + (x_dim, y_dim) or None + Dimension names from ``spatial:dimensions``, interpreted as ``[y, x]`` + following the convention's examples. None if the convention is not + registered or ``spatial:dimensions`` is absent. + """ + possibly_spatial_metadata: _ZarrSpatialMetadata = metadata # type: ignore[assignment] + + if _has_spatial_zarr_convention(possibly_spatial_metadata): + if dims := possibly_spatial_metadata.get("spatial:dimensions"): + if len(dims) != 2: + raise ValueError(f"spatial:dimensions must have exactly 2 elements, got {len(dims)}") + y_dim, x_dim = map(str, dims) + return x_dim, y_dim + + return None diff --git a/src/rasterix/raster_index.py b/src/rasterix/raster_index.py index 2140d3b..fbd3792 100644 --- a/src/rasterix/raster_index.py +++ b/src/rasterix/raster_index.py @@ -19,10 +19,15 @@ from xarray.core.types import JoinOptions from xproj.typing import CRSAwareIndex +from rasterix.lib import spatial_dims_from_zarr_convention from rasterix.odc_compat import BoundingBox, bbox_intersection, bbox_union, maybe_int, snap_grid from rasterix.options import get_options as get_rasterix_options from rasterix.rioxarray_compat import guess_dims -from rasterix.utils import get_affine, get_crs_from_proj_zarr_convention +from rasterix.utils import ( + _iter_spatial_zarr_metadata, + get_affine, + get_crs_from_proj_zarr_convention, +) T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") @@ -48,9 +53,13 @@ def assign_index( obj : xarray.DataArray or xarray.Dataset The object to assign the index to. x_dim : str, optional - Name of the x dimension. If None, will be automatically detected. + Name of the x dimension. If None, will be automatically detected from + Zarr ``spatial:dimensions`` convention metadata if present, else from + common dimension names and CF attributes. y_dim : str, optional - Name of the y dimension. If None, will be automatically detected. + Name of the y dimension. If None, will be automatically detected from + Zarr ``spatial:dimensions`` convention metadata if present, else from + common dimension names and CF attributes. crs: bool, optional Auto-detect CRS using xproj? @@ -82,7 +91,15 @@ def assign_index( >>> indexed_da = assign_index(da) """ if x_dim is None or y_dim is None: - guessed_x, guessed_y = guess_dims(obj) + guessed = None + for metadata, _ in _iter_spatial_zarr_metadata(obj): + if guessed := spatial_dims_from_zarr_convention(metadata): + if missing := set(guessed) - set(obj.dims): + raise ValueError(f"spatial:dimensions names {missing!r} are not dimensions of obj.") + break + if guessed is None: + guessed = guess_dims(obj) + guessed_x, guessed_y = guessed x_dim = x_dim or guessed_x y_dim = y_dim or guessed_y diff --git a/src/rasterix/utils.py b/src/rasterix/utils.py index 87acbc4..8a080b3 100644 --- a/src/rasterix/utils.py +++ b/src/rasterix/utils.py @@ -1,19 +1,19 @@ -from typing import NotRequired, TypedDict +from collections.abc import Iterator +from typing import Any import xarray as xr from affine import Affine from pyproj import CRS from rasterix.lib import ( + _has_proj_zarr_convention, + _ZarrProjMetadata, affine_from_spatial_zarr_convention, affine_from_stac_proj_metadata, affine_from_tiepoint_and_scale, logger, ) -# https://github.com/zarr-conventions/geo-proj -_ZARR_GEO_PROJ_CONVENTION_UUID = "f17cb550-5864-4468-aeb7-f3180cfb622f" - def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None: grid_mapping_var = None @@ -42,6 +42,25 @@ def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None: return None +def _iter_spatial_zarr_metadata( + obj: xr.Dataset | xr.DataArray, +) -> Iterator[tuple[dict[str, Any], tuple[dict[str, Any], ...]]]: + """Yield candidate metadata dicts for the Zarr ``spatial:`` convention. + + Yields ``(metadata, sources)`` tuples, where ``metadata`` is the effective + metadata to interpret and ``sources`` are the underlying attrs dicts to + mutate when clearing consumed attributes. Per the convention, group-level + (Dataset) properties apply to child arrays that don't define their own, + so array-level attrs take precedence over group-level attrs. + """ + if isinstance(obj, xr.DataArray): + yield obj.attrs, (obj.attrs,) + else: + for var in obj.data_vars.values(): + yield {**obj.attrs, **var.attrs}, (var.attrs, obj.attrs) + yield obj.attrs, (obj.attrs,) + + def get_affine( obj: xr.Dataset | xr.DataArray, *, x_dim="x", y_dim="y", clear_transform: bool = False ) -> Affine: @@ -74,7 +93,7 @@ def get_affine( del grid_mapping_var.attrs["GeoTransform"] return Affine.from_gdal(*map(float, transform.split(" "))) - # Check for STAC, GeoTIFF, or spatial zarr convention metadata in DataArray attrs + # Check for STAC or GeoTIFF metadata in DataArray attrs attrs = obj.attrs if isinstance(obj, xr.DataArray) else {} # Try to extract affine from STAC proj:transform @@ -96,12 +115,14 @@ def get_affine( return affine - # Try to extract from spatial zarr convention attributes - if affine := affine_from_spatial_zarr_convention(attrs): - logger.trace("Creating affine from spatial zarr convention attributes") - if clear_transform: - del attrs["spatial:transform"] - return affine + # Try to extract from spatial zarr convention attributes (array-level first, then group-level) + for metadata, sources in _iter_spatial_zarr_metadata(obj): + if affine := affine_from_spatial_zarr_convention(metadata): + logger.trace("Creating affine from spatial zarr convention attributes") + if clear_transform: + for source in sources: + source.pop("spatial:transform", None) + return affine # Fall back to computing from coordinate arrays logger.trace(f"Creating affine from coordinate arrays {x_dim=!r} and {y_dim=!r}") @@ -131,31 +152,6 @@ def get_affine( ) * Affine.scale(dx, dy) -_ZarrConventionRegistration = TypedDict("_ZarrConventionRegistration", {"proj:": str}) - -_ZarrProjMetadata = TypedDict( - "_ZarrProjMetadata", - { - "zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]], - "proj:code": NotRequired[str], - "proj:wkt2": NotRequired[str], - "proj:projjson": NotRequired[object], - }, -) - - -def _has_proj_zarr_convention(metadata: _ZarrProjMetadata) -> bool: - zarr_conventions = metadata.get("zarr_conventions") - if not zarr_conventions: - return False - for entry in zarr_conventions: - if isinstance(entry, dict) and ( - entry.get("uuid") == _ZARR_GEO_PROJ_CONVENTION_UUID or entry.get("name") == "proj:" - ): - return True - return False - - def get_crs_from_proj_zarr_convention(obj: xr.Dataset | xr.DataArray) -> CRS | None: """Extract CRS from Zarr proj: convention metadata if present. diff --git a/tests/test_raster_index.py b/tests/test_raster_index.py index 9443faf..4973653 100644 --- a/tests/test_raster_index.py +++ b/tests/test_raster_index.py @@ -676,6 +676,121 @@ def test_assign_index_with_spatial_zarr_convention_registration_not_implemented( assign_index(da) +def test_assign_index_with_spatial_zarr_convention_dimensions(): + # spatial:dimensions is interpreted as [y, x], following the convention's examples + da = xr.DataArray( + np.ones((10, 12)), + dims=("Y", "X"), + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + + result = assign_index(da) + + assert isinstance(result.xindexes["X"], RasterIndex) + assert isinstance(result.xindexes["Y"], RasterIndex) + assert result.xindexes["X"].transform() == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + assert result.sizes == {"Y": 10, "X": 12} + + +def test_assign_index_with_spatial_zarr_convention_dataset_group_attrs(): + ds = xr.Dataset( + {"data": (("Y", "X"), np.ones((10, 12)))}, + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + + result = assign_index(ds) + + assert isinstance(result.xindexes["X"], RasterIndex) + assert isinstance(result.xindexes["Y"], RasterIndex) + assert result.xindexes["X"].transform() == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + assert "spatial:transform" not in result.attrs + + +def test_assign_index_with_spatial_zarr_convention_variable_attrs(): + ds = xr.Dataset( + { + "data": ( + ("Y", "X"), + np.ones((10, 12)), + { + "zarr_conventions": [{"name": "spatial:"}], + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + } + ) + + result = assign_index(ds) + + assert isinstance(result.xindexes["X"], RasterIndex) + assert isinstance(result.xindexes["Y"], RasterIndex) + assert result.xindexes["X"].transform() == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + assert "spatial:transform" not in result["data"].attrs + + +def test_assign_index_with_spatial_zarr_convention_variable_attrs_group_registration(): + # group-level registration applies to child arrays that don't define their own + ds = xr.Dataset( + { + "data": ( + ("Y", "X"), + np.ones((10, 12)), + { + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + }, + attrs={"zarr_conventions": [{"name": "spatial:"}]}, + ) + + result = assign_index(ds) + + assert isinstance(result.xindexes["X"], RasterIndex) + assert isinstance(result.xindexes["Y"], RasterIndex) + assert result.xindexes["X"].transform() == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + assert "spatial:transform" not in result["data"].attrs + + +def test_assign_index_with_spatial_zarr_convention_unregistered_ignored(): + # without the zarr_conventions registration, spatial: attrs are ignored + ds = xr.Dataset( + {"data": (("Y", "X"), np.ones((10, 12)))}, + attrs={ + "spatial:dimensions": ["Y", "X"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + + with pytest.raises(ValueError, match="do not have explicit coordinate values"): + # spatial:transform is ignored, so this falls through to the coordinate fallback + assign_index(ds, x_dim="X", y_dim="Y") + + +def test_assign_index_with_spatial_zarr_convention_dimensions_mismatch(): + da = xr.DataArray( + np.ones((10, 12)), + dims=("Y", "X"), + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:dimensions": ["lat", "lon"], + "spatial:transform": [1.0, 0.0, 0.0, 0.0, -1.0, 10.0], + }, + ) + + with pytest.raises(ValueError, match="spatial:dimensions"): + assign_index(da) + + def test_assign_index_no_coords_no_metadata(): """Test that assign_index raises error when coords are missing and no transform metadata.""" da = xr.DataArray(np.ones((10, 10)), dims=("y", "x"))