diff --git a/pyproject.toml b/pyproject.toml index 793ddac..260af01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,10 @@ optional-dependencies.mixmil = [ ] optional-dependencies.ml = [ "pytorch-lightning", + "cellink[annbatch]", +] +optional-dependencies.annbatch = [ + "annbatch[torch,zarrs]>=0.1.2; python_version >= '3.12'", ] optional-dependencies.rvat = [ # installation wit pip doesn't work, install via conda install -c conda-forge chiscore diff --git a/src/cellink/io/__init__.py b/src/cellink/io/__init__.py index 4502c3d..8a8b79c 100644 --- a/src/cellink/io/__init__.py +++ b/src/cellink/io/__init__.py @@ -1,4 +1,55 @@ +import importlib +from typing import Any + from ._export import to_plink, write_variants_to_vcf from ._readwrite import read_dd, read_h5_dd, read_zarr_dd from ._sgkit import from_sgkit_dataset, read_bgen, read_plink, read_sgkit_zarr from ._pgen import stream_pgen_to_zarr, read_pgen_zarr + +# Lazy re-exports for the optional `annbatch` extra. We don't want importing +# `cellink.io` to fail when annbatch isn't installed -- only the relevant +# attribute access should error, with a clear hint pointing at the extra. +_annbatch_exports = { + "write_annbatch_collection": "_annbatch", + "open_annbatch_loader": "_annbatch", +} + +__all__ = [ + "to_plink", + "write_variants_to_vcf", + "read_dd", + "read_h5_dd", + "read_zarr_dd", + "from_sgkit_dataset", + "read_bgen", + "read_plink", + "read_sgkit_zarr", + "stream_pgen_to_zarr", + "read_pgen_zarr", + *_annbatch_exports, +] + + +def __getattr__(name: str) -> Any: + """Lazy import for optional-extra symbols in `cellink.io`. + + Currently used for the `annbatch` extra: we only attempt the import when + `cellink.io.write_annbatch_collection` (or sibling) is actually accessed, + and surface a clear ImportError pointing at the extra otherwise. + """ + if name in _annbatch_exports: + module_name = _annbatch_exports[name] + try: + module = importlib.import_module(f"{__name__}.{module_name}") + except ImportError as e: + raise ImportError( + f"Cannot import `{name}` from `cellink.io.{module_name}`: " + "this feature requires the `annbatch` extra. Install with:\n\n" + " pip install cellink[annbatch]" + ) from e + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted(set(__all__)) diff --git a/src/cellink/io/_annbatch.py b/src/cellink/io/_annbatch.py new file mode 100644 index 0000000..7e171bc --- /dev/null +++ b/src/cellink/io/_annbatch.py @@ -0,0 +1,278 @@ +"""IO helpers for streaming cell-level data via the `annbatch` package. + +This module exposes two thin wrappers around `annbatch` that make it easy to +materialize the cell-level AnnData inside a `DonorData` (i.e. ``dd.C``) as a +sharded zarr collection, and to open such a collection as a configured +``annbatch.Loader``. + +The donor-side AnnData (``dd.G``) is intentionally not handled here -- its +``obs`` axis is small (donors) and streaming is the wrong tool. + +The optional dependency stack is installable via: + + pip install cellink[annbatch] +""" + +from __future__ import annotations + +import shutil +import uuid +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +# The `annbatch` extra is optional. Importing this module without it should +# raise a clear, actionable error rather than a confusing ModuleNotFoundError +# from deep inside a dependency. +try: + import anndata + import zarr + from annbatch import DatasetCollection, Loader +except ImportError as e: # pragma: no cover - exercised only without the extra + raise ImportError( + "Cannot import `cellink.io._annbatch`: this feature requires the " + "`annbatch` extra. Install with:\n\n pip install cellink[annbatch]" + ) from e + +from anndata import AnnData + +from cellink._core import DonorData + +if TYPE_CHECKING: + from os import PathLike + + +__all__ = ["write_annbatch_collection", "open_annbatch_loader"] + + +def _configure_zarrs_codec_pipeline() -> None: + # Recommended by annbatch docs for performance on zarr v3 stores. Setting + # it inside our wrappers means callers don't have to remember. + zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) + + +def _materialize_to_h5ad(adata: AnnData, tmp_dir: Path) -> Path: + tmp_dir.mkdir(parents=True, exist_ok=True) + h5_path = tmp_dir / f"{uuid.uuid4().hex}.h5ad" + adata.write_h5ad(h5_path) + return h5_path + + +def _resolve_source_paths( + source: AnnData | DonorData | str | Path | Sequence[str | Path], + tmp_dir: Path, + *, + layer: str | None, +) -> tuple[list[Path], bool]: + """Resolve ``source`` to a list of on-disk h5ad/zarr paths. + + Returns ``(paths, materialized)`` where ``materialized`` is True when we + wrote a temp h5ad that the caller is responsible for cleaning up. + """ + if isinstance(source, AnnData): + if layer is not None and layer not in source.layers: + raise KeyError(f"Layer {layer!r} not present in source AnnData") + return [_materialize_to_h5ad(source, tmp_dir)], True + if isinstance(source, DonorData): + adata = source.C + if not isinstance(adata, AnnData): + raise TypeError( + "DonorData.C must be an AnnData for annbatch IO; " + f"got {type(adata).__name__}" + ) + if layer is not None and layer not in adata.layers: + raise KeyError(f"Layer {layer!r} not present in DonorData.C.layers") + return [_materialize_to_h5ad(adata, tmp_dir)], True + if isinstance(source, str | Path): + return [Path(source)], False + # Assume an iterable of paths. + paths = [Path(p) for p in source] + if not paths: + raise ValueError("`source` is an empty sequence of paths") + return paths, False + + +def write_annbatch_collection( + source: AnnData | DonorData | str | Path | Sequence[str | Path], + path: str | PathLike[str], + *, + obs_keys: Sequence[str], + layer: str | None = None, + shuffle: bool = True, + rng: np.random.Generator | None = None, + **add_adatas_kwargs: Any, +) -> Path: + """Write a sharded zarr `annbatch` collection from cell-level data. + + Parameters + ---------- + source + Cell-level data to stream into the collection. May be: + + - an :class:`anndata.AnnData` (will first be written to a temp h5ad), + - a :class:`cellink.DonorData` (uses ``source.C``; same temp-write + fallback), + - a path or list of paths to existing h5ad/zarr files. + path + Directory path for the output zarr collection. + obs_keys + ``obs`` columns the future loader will need to expose. Validated to be + non-empty so callers don't accidentally drop the donor id column at + read time. Filtering itself happens at read time in + :func:`open_annbatch_loader` via the ``load_adata`` closure -- nothing + is dropped at write time. + layer + Name of a layer to expose to the loader as ``X``. ``None`` means the + loader will read ``X`` directly. When ``source`` is in-memory data, + the layer is validated for presence; when ``source`` is a path, no + validation is performed. + shuffle + Forwarded to :meth:`annbatch.DatasetCollection.add_adatas`. Default + ``True`` produces a globally-shuffled collection, which is usually + what you want for IID per-cell mini-batching. + rng + Random number generator used for shuffling at write time. Pass a + seeded :class:`numpy.random.Generator` for reproducible collections. + **add_adatas_kwargs + Forwarded verbatim to :meth:`annbatch.DatasetCollection.add_adatas` + (e.g. ``shard_size``, ``dataset_size``, ``shuffle_chunk_size``, + ``var_subset``, ``zarr_compressor``). + + Returns + ------- + Path + Path to the created collection directory. + """ + if not obs_keys: + raise ValueError("`obs_keys` must be a non-empty sequence") + + _configure_zarrs_codec_pipeline() + + out_path = Path(path) + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Place temp inputs alongside the output so they share a filesystem (cheap + # rename / hard-linkable) and any stale residue is easy to spot. + tmp_dir = out_path.parent / f".{out_path.name}.tmp_inputs" + paths, materialized = _resolve_source_paths(source, tmp_dir, layer=layer) + + try: + collection = DatasetCollection(str(out_path)) + collection.add_adatas( + adata_paths=[str(p) for p in paths], + shuffle=shuffle, + rng=rng, + **add_adatas_kwargs, + ) + finally: + if materialized and tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + return out_path + + +def _make_load_adata( + obs_keys: Sequence[str], + layer: str | None, +): + """Build a ``load_adata`` closure for ``Loader.use_collection``. + + The closure restricts ``obs`` to ``obs_keys`` and selects ``X`` or + ``layers[layer]`` as the feature matrix. Restricting ``obs`` is important + for performance -- the default loader pulls every column. + """ + obs_keys = list(obs_keys) + + def _load_adata(g: Any) -> AnnData: + if layer is None: + x_node = g["X"] + else: + x_node = g["layers"][layer] + # Sparse vs dense is auto-detected: zarr.Array == dense; anything else + # is a group encoding a sparse matrix and needs `sparse_dataset`. + if isinstance(x_node, zarr.Array): + X = x_node + else: + X = anndata.io.sparse_dataset(x_node) + obs = anndata.io.read_elem(g["obs"]) + missing = [k for k in obs_keys if k not in obs.columns] + if missing: + raise KeyError( + f"obs_keys {missing!r} not found in on-disk obs (have: " + f"{list(obs.columns)})" + ) + return AnnData(X=X, obs=obs[obs_keys]) + + return _load_adata + + +def open_annbatch_loader( + path: str | PathLike[str], + *, + obs_keys: Sequence[str], + layer: str | None = None, + batch_size: int = 4096, + chunk_size: int = 32, + preload_nchunks: int = 256, + to_torch: bool = True, + rng: np.random.Generator | None = None, + **loader_kwargs: Any, +) -> Loader: + """Open an annbatch collection at ``path`` and return a configured loader. + + The returned :class:`annbatch.Loader` is bound to the collection via a + ``load_adata`` closure that: + + - exposes ``X`` (if ``layer is None``) or ``layers[layer]`` as the feature + matrix, and + - subsets ``obs`` to ``obs_keys``. + + Parameters + ---------- + path + Path to a collection produced by :func:`write_annbatch_collection`. + obs_keys + ``obs`` columns to surface in each yielded batch. + layer + Layer name to read; ``None`` => ``X``. + batch_size, chunk_size, preload_nchunks, to_torch + Forwarded to :class:`annbatch.Loader`. The defaults match the values + suggested in the annbatch docs for cell-level IID training. + rng + Random number generator used for the loader's shuffling. Pass a + seeded :class:`numpy.random.Generator` for reproducible iteration + order. Mutually exclusive with a custom ``batch_sampler`` (which + would carry its own rng). + **loader_kwargs + Forwarded verbatim to :class:`annbatch.Loader` (e.g. ``shuffle``, + ``drop_last``, ``return_index``, ``preload_to_gpu``, + ``concat_strategy``, ``batch_sampler``). + + Returns + ------- + annbatch.Loader + A loader ready to iterate. Callers can wrap it in a + :class:`torch.utils.data.DataLoader` if multi-worker IO is desired. + """ + if not obs_keys: + raise ValueError("`obs_keys` must be a non-empty sequence") + + _configure_zarrs_codec_pipeline() + + coll_path = Path(path) + if not coll_path.exists(): + raise FileNotFoundError(f"annbatch collection not found at {coll_path}") + + collection = DatasetCollection(str(coll_path), mode="r") + + loader = Loader( + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + to_torch=to_torch, + rng=rng, + **loader_kwargs, + ) + return loader.use_collection(collection, load_adata=_make_load_adata(obs_keys, layer)) diff --git a/tests/test_annbatch_io.py b/tests/test_annbatch_io.py new file mode 100644 index 0000000..08bc833 --- /dev/null +++ b/tests/test_annbatch_io.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import scipy.sparse as sp +from anndata import AnnData + +# Skip the entire module if the optional extra isn't installed. +pytest.importorskip("annbatch", reason="needs cellink[annbatch] extra") + +from cellink import DonorData +from cellink._core.dummy_data import sim_adata, sim_gdata +from cellink.io import open_annbatch_loader, write_annbatch_collection + + +def _make_cells(n_obs: int = 200, n_vars: int = 50, *, with_layer: bool = False, seed: int = 0) -> AnnData: + X = sp.random(n_obs, n_vars, density=0.2, format="csr", random_state=seed).astype(np.float32) + obs = pd.DataFrame( + { + "donor_id": [f"D{i % 5}" for i in range(n_obs)], + "cell_label": [f"L{i % 3}" for i in range(n_obs)], + }, + index=[f"C{i}" for i in range(n_obs)], + ) + var = pd.DataFrame(index=[f"G{i}" for i in range(n_vars)]) + adata = AnnData(X=X, obs=obs, var=var) + if with_layer: + # Use distinctive values so we can tell layer from X by content. + counts = sp.random(n_obs, n_vars, density=0.3, format="csr", random_state=seed + 1).astype(np.float32) + counts.data += 100.0 # make values clearly larger than X + adata.layers["counts"] = counts + return adata + + +def _unpack_batch(batch): + """Return (X, obs) from an annbatch LoaderOutput (a TypedDict).""" + return batch["X"], batch["obs"] + + +def _iter_loader(loader): + """Iterate the loader and collect (obs frames, X-row counts, n_cols).""" + obs_frames: list[pd.DataFrame] = [] + x_rows: list[int] = [] + x_cols: int | None = None + for batch in loader: + x, obs = _unpack_batch(batch) + n_rows = int(x.shape[0]) + x_rows.append(n_rows) + if x_cols is None: + x_cols = int(x.shape[1]) + else: + assert int(x.shape[1]) == x_cols + obs_frames.append(obs if isinstance(obs, pd.DataFrame) else pd.DataFrame(obs)) + return obs_frames, x_rows, x_cols + + +def test_roundtrip_anndata(tmp_path: Path) -> None: + adata = _make_cells(n_obs=200, n_vars=50) + out = tmp_path / "collection.zarr" + written = write_annbatch_collection( + adata, + out, + obs_keys=["donor_id", "cell_label"], + ) + assert Path(written).exists() + assert Path(written) == out + + loader = open_annbatch_loader( + out, + obs_keys=["donor_id", "cell_label"], + batch_size=64, + chunk_size=16, + preload_nchunks=8, + to_torch=False, + shuffle=False, + ) + obs_frames, x_rows, x_cols = _iter_loader(loader) + + assert sum(x_rows) == adata.n_obs + assert all(r <= 64 for r in x_rows) + assert x_cols == adata.n_vars + + obs_concat = pd.concat(obs_frames, ignore_index=True) + assert list(obs_concat.columns) == ["donor_id", "cell_label"] + assert len(obs_concat) == adata.n_obs + + +def test_roundtrip_donordata(tmp_path: Path) -> None: + adata = sim_adata() + gdata = sim_gdata(adata=adata) + dd = DonorData(G=gdata, C=adata) + + out = tmp_path / "collection.zarr" + write_annbatch_collection(dd, out, obs_keys=["donor_id", "celltype"]) + + loader = open_annbatch_loader( + out, + obs_keys=["donor_id", "celltype"], + batch_size=32, + chunk_size=8, + preload_nchunks=4, + to_torch=False, + shuffle=False, + ) + obs_frames, x_rows, _ = _iter_loader(loader) + + assert sum(x_rows) == dd.C.n_obs + + obs_concat = pd.concat(obs_frames, ignore_index=True) + seen_donors = set(obs_concat["donor_id"].astype(str).unique()) + known_donors = set(dd.G.obs_names.astype(str)) + assert seen_donors.issubset(known_donors) + + +def test_layer_selection(tmp_path: Path) -> None: + adata = _make_cells(n_obs=200, n_vars=50, with_layer=True) + out = tmp_path / "collection.zarr" + write_annbatch_collection( + adata, + out, + obs_keys=["donor_id", "cell_label"], + layer="counts", + ) + + loader = open_annbatch_loader( + out, + obs_keys=["donor_id", "cell_label"], + layer="counts", + batch_size=64, + chunk_size=16, + preload_nchunks=8, + to_torch=False, + shuffle=False, + ) + _, _, _ = _iter_loader(loader) + + # The stored layer values had +100.0 added; X values do not. So checking + # the maximum across all batches discriminates between X and layers["counts"]. + loader2 = open_annbatch_loader( + out, + obs_keys=["donor_id", "cell_label"], + layer="counts", + batch_size=64, + chunk_size=16, + preload_nchunks=8, + to_torch=False, + shuffle=False, + ) + max_layer_val = 0.0 + for batch in loader2: + x, _ = _unpack_batch(batch) + arr = x.toarray() if hasattr(x, "toarray") else np.asarray(x) + if arr.size and arr.max() > max_layer_val: + max_layer_val = float(arr.max()) + assert max_layer_val >= 100.0, "Expected layer values to exceed 100.0; got X instead?" + + +def test_missing_extra_error_message() -> None: + """If `annbatch` is missing, accessing the symbol gives a clear hint. + + The whole module is gated by ``importorskip("annbatch")`` at the top, so + when the extra is installed we can only sanity-check that the error path + in ``cellink.io.__init__`` references the install hint. + """ + import cellink.io as cio + + src = Path(cio.__file__).read_text() + assert "pip install cellink[annbatch]" in src