Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ optional-dependencies.mixmil = [
]
optional-dependencies.ml = [
"pytorch-lightning",
"cellink[annbatch]",

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this for convinience only. we can remove it if you want

]
optional-dependencies.annbatch = [
"annbatch[torch,zarrs]>=0.1.2; python_version >= '3.12'",
]
optional-dependencies.rvat = [
# installation wit pip doesn't work, install via conda install -c conda-forge chiscore
Expand Down
51 changes: 51 additions & 0 deletions src/cellink/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,55 @@
import importlib
from typing import Any

from ._export import to_plink, write_variants_to_vcf
from ._readwrite import read_dd, read_h5_dd, read_zarr_dd
from ._sgkit import from_sgkit_dataset, read_bgen, read_plink, read_sgkit_zarr
from ._pgen import stream_pgen_to_zarr, read_pgen_zarr

# Lazy re-exports for the optional `annbatch` extra. We don't want importing
# `cellink.io` to fail when annbatch isn't installed -- only the relevant
# attribute access should error, with a clear hint pointing at the extra.
_annbatch_exports = {
"write_annbatch_collection": "_annbatch",
"open_annbatch_loader": "_annbatch",
}
Comment on lines +9 to +15

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use https://docs.python.org/3/library/importlib.html#importlib.abc.MetaPathFinder.find_spec to see if annbatch is installed, and then choose to import from _annbatch.py or not?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, honestly I don't know why the LLM's kept suggesting me this in all cases and I went to a rabbithole for this. Since it was already in the codebase I went with it. Even though I know that it was likely an LLM suggestion. But I also see this in other places as well for example in Tim's PR and in spatialdata. Which are probably also LLM suggestions lol. I didn't see any big packages using __getattr__ for lazy imports I think. For example it isn't in xarray.

I am not really sure about having two different styles in the codebase itself tbh. Also my hunch would be to add for the two functions to match xarray I guess? But maybe that's an overkill given that annbatch is lightweight? idk python import times are anoyying sometimes

    try:
        import zarr
        from annbatch import DatasetCollection
    except ImportError as e:
        raise ImportError(_INSTALL_HINT) from e

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get wanting to stay consistent, but I also think it should be changed in Tim's PR FWIW. Not saying this has to happen here or now, but I prefer find_spec


__all__ = [
"to_plink",
"write_variants_to_vcf",
"read_dd",
"read_h5_dd",
"read_zarr_dd",
"from_sgkit_dataset",
"read_bgen",
"read_plink",
"read_sgkit_zarr",
"stream_pgen_to_zarr",
"read_pgen_zarr",
*_annbatch_exports,
]


def __getattr__(name: str) -> Any:
"""Lazy import for optional-extra symbols in `cellink.io`.

Currently used for the `annbatch` extra: we only attempt the import when
`cellink.io.write_annbatch_collection` (or sibling) is actually accessed,
and surface a clear ImportError pointing at the extra otherwise.
"""
if name in _annbatch_exports:
module_name = _annbatch_exports[name]
try:
module = importlib.import_module(f"{__name__}.{module_name}")
except ImportError as e:
raise ImportError(
f"Cannot import `{name}` from `cellink.io.{module_name}`: "
"this feature requires the `annbatch` extra. Install with:\n\n"
" pip install cellink[annbatch]"
) from e
return getattr(module, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
return sorted(set(__all__))
278 changes: 278 additions & 0 deletions src/cellink/io/_annbatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
"""IO helpers for streaming cell-level data via the `annbatch` package.

This module exposes two thin wrappers around `annbatch` that make it easy to
materialize the cell-level AnnData inside a `DonorData` (i.e. ``dd.C``) as a
sharded zarr collection, and to open such a collection as a configured
``annbatch.Loader``.

The donor-side AnnData (``dd.G``) is intentionally not handled here -- its

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an on-disk format for celllink @LArnoldt ? I don't see anything in https://cellink-docs.readthedocs.io/en/latest/api/io.html or is it just in-memory?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

``obs`` axis is small (donors) and streaming is the wrong tool.

The optional dependency stack is installable via:

pip install cellink[annbatch]
"""

from __future__ import annotations

import shutil
import uuid
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np

# The `annbatch` extra is optional. Importing this module without it should
# raise a clear, actionable error rather than a confusing ModuleNotFoundError
# from deep inside a dependency.
try:
import anndata
import zarr
from annbatch import DatasetCollection, Loader
except ImportError as e: # pragma: no cover - exercised only without the extra
raise ImportError(
"Cannot import `cellink.io._annbatch`: this feature requires the "
"`annbatch` extra. Install with:\n\n pip install cellink[annbatch]"
) from e

from anndata import AnnData

from cellink._core import DonorData

if TYPE_CHECKING:
from os import PathLike


__all__ = ["write_annbatch_collection", "open_annbatch_loader"]


def _configure_zarrs_codec_pipeline() -> None:
# Recommended by annbatch docs for performance on zarr v3 stores. Setting
# it inside our wrappers means callers don't have to remember.
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})


def _materialize_to_h5ad(adata: AnnData, tmp_dir: Path) -> Path:
tmp_dir.mkdir(parents=True, exist_ok=True)
h5_path = tmp_dir / f"{uuid.uuid4().hex}.h5ad"
adata.write_h5ad(h5_path)
return h5_path


def _resolve_source_paths(
source: AnnData | DonorData | str | Path | Sequence[str | Path],
tmp_dir: Path,
*,
layer: str | None,
) -> tuple[list[Path], bool]:
"""Resolve ``source`` to a list of on-disk h5ad/zarr paths.

Returns ``(paths, materialized)`` where ``materialized`` is True when we
wrote a temp h5ad that the caller is responsible for cleaning up.
"""
if isinstance(source, AnnData):
if layer is not None and layer not in source.layers:
raise KeyError(f"Layer {layer!r} not present in source AnnData")
return [_materialize_to_h5ad(source, tmp_dir)], True
if isinstance(source, DonorData):
adata = source.C
if not isinstance(adata, AnnData):
raise TypeError(
"DonorData.C must be an AnnData for annbatch IO; "
f"got {type(adata).__name__}"
)
if layer is not None and layer not in adata.layers:
raise KeyError(f"Layer {layer!r} not present in DonorData.C.layers")
return [_materialize_to_h5ad(adata, tmp_dir)], True
if isinstance(source, str | Path):
return [Path(source)], False
# Assume an iterable of paths.
paths = [Path(p) for p in source]
if not paths:
raise ValueError("`source` is an empty sequence of paths")
return paths, False


def write_annbatch_collection(
source: AnnData | DonorData | str | Path | Sequence[str | Path],
path: str | PathLike[str],
*,
obs_keys: Sequence[str],
layer: str | None = None,
shuffle: bool = True,
rng: np.random.Generator | None = None,
**add_adatas_kwargs: Any,
) -> Path:
"""Write a sharded zarr `annbatch` collection from cell-level data.

Parameters
----------
source
Cell-level data to stream into the collection. May be:

- an :class:`anndata.AnnData` (will first be written to a temp h5ad),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do in-memory anndata files need to be written to disk first?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the overlook. It's AI slop. Let me be more clear with @LArnoldt on the pipeline to make sure if we even need this function.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LArnoldt , how would imagine the pipeline? For annbatch preshuffling you need to create a collection but would you want this on your module level?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But not every .dd.h5 file would be in-memory no? I guess, there could be both options from written and from memory?

- a :class:`cellink.DonorData` (uses ``source.C``; same temp-write
fallback),
- a path or list of paths to existing h5ad/zarr files.
path
Directory path for the output zarr collection.
obs_keys
``obs`` columns the future loader will need to expose. Validated to be
non-empty so callers don't accidentally drop the donor id column at
read time. Filtering itself happens at read time in
:func:`open_annbatch_loader` via the ``load_adata`` closure -- nothing
is dropped at write time.
layer
Name of a layer to expose to the loader as ``X``. ``None`` means the
loader will read ``X`` directly. When ``source`` is in-memory data,
the layer is validated for presence; when ``source`` is a path, no
validation is performed.
shuffle
Forwarded to :meth:`annbatch.DatasetCollection.add_adatas`. Default
``True`` produces a globally-shuffled collection, which is usually
what you want for IID per-cell mini-batching.
rng
Random number generator used for shuffling at write time. Pass a
seeded :class:`numpy.random.Generator` for reproducible collections.
**add_adatas_kwargs
Forwarded verbatim to :meth:`annbatch.DatasetCollection.add_adatas`
(e.g. ``shard_size``, ``dataset_size``, ``shuffle_chunk_size``,
``var_subset``, ``zarr_compressor``).

Returns
-------
Path
Path to the created collection directory.
"""
if not obs_keys:
raise ValueError("`obs_keys` must be a non-empty sequence")

_configure_zarrs_codec_pipeline()

out_path = Path(path)
out_path.parent.mkdir(parents=True, exist_ok=True)

# Place temp inputs alongside the output so they share a filesystem (cheap
# rename / hard-linkable) and any stale residue is easy to spot.
tmp_dir = out_path.parent / f".{out_path.name}.tmp_inputs"
paths, materialized = _resolve_source_paths(source, tmp_dir, layer=layer)

try:
collection = DatasetCollection(str(out_path))
collection.add_adatas(
adata_paths=[str(p) for p in paths],
shuffle=shuffle,
rng=rng,
**add_adatas_kwargs,
)
finally:
if materialized and tmp_dir.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)

return out_path


def _make_load_adata(
obs_keys: Sequence[str],
layer: str | None,
):
"""Build a ``load_adata`` closure for ``Loader.use_collection``.

The closure restricts ``obs`` to ``obs_keys`` and selects ``X`` or
``layers[layer]`` as the feature matrix. Restricting ``obs`` is important
for performance -- the default loader pulls every column.
"""
obs_keys = list(obs_keys)

def _load_adata(g: Any) -> AnnData:
if layer is None:
x_node = g["X"]
else:
x_node = g["layers"][layer]
# Sparse vs dense is auto-detected: zarr.Array == dense; anything else
# is a group encoding a sparse matrix and needs `sparse_dataset`.
if isinstance(x_node, zarr.Array):
X = x_node
else:
X = anndata.io.sparse_dataset(x_node)
obs = anndata.io.read_elem(g["obs"])
missing = [k for k in obs_keys if k not in obs.columns]
if missing:
raise KeyError(
f"obs_keys {missing!r} not found in on-disk obs (have: "
f"{list(obs.columns)})"
)
return AnnData(X=X, obs=obs[obs_keys])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no var?


return _load_adata


def open_annbatch_loader(
path: str | PathLike[str],
*,
obs_keys: Sequence[str],
layer: str | None = None,
batch_size: int = 4096,
chunk_size: int = 32,
preload_nchunks: int = 256,
to_torch: bool = True,
rng: np.random.Generator | None = None,
**loader_kwargs: Any,
) -> Loader:
"""Open an annbatch collection at ``path`` and return a configured loader.

The returned :class:`annbatch.Loader` is bound to the collection via a
``load_adata`` closure that:

- exposes ``X`` (if ``layer is None``) or ``layers[layer]`` as the feature
matrix, and
- subsets ``obs`` to ``obs_keys``.

Parameters
----------
path
Path to a collection produced by :func:`write_annbatch_collection`.
obs_keys
``obs`` columns to surface in each yielded batch.
layer
Layer name to read; ``None`` => ``X``.
batch_size, chunk_size, preload_nchunks, to_torch
Forwarded to :class:`annbatch.Loader`. The defaults match the values
suggested in the annbatch docs for cell-level IID training.
rng
Random number generator used for the loader's shuffling. Pass a
seeded :class:`numpy.random.Generator` for reproducible iteration
order. Mutually exclusive with a custom ``batch_sampler`` (which
would carry its own rng).
**loader_kwargs
Forwarded verbatim to :class:`annbatch.Loader` (e.g. ``shuffle``,
``drop_last``, ``return_index``, ``preload_to_gpu``,
``concat_strategy``, ``batch_sampler``).

Returns
-------
annbatch.Loader
A loader ready to iterate. Callers can wrap it in a
:class:`torch.utils.data.DataLoader` if multi-worker IO is desired.
"""
if not obs_keys:
raise ValueError("`obs_keys` must be a non-empty sequence")

_configure_zarrs_codec_pipeline()

coll_path = Path(path)
if not coll_path.exists():
raise FileNotFoundError(f"annbatch collection not found at {coll_path}")

collection = DatasetCollection(str(coll_path), mode="r")

loader = Loader(
batch_size=batch_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
to_torch=to_torch,
rng=rng,
**loader_kwargs,
)
return loader.use_collection(collection, load_adata=_make_load_adata(obs_keys, layer))
Loading
Loading