diff --git a/mussel/cli/aggregate_sample_features.py b/mussel/cli/aggregate_sample_features.py index d335efcd..8ff7b951 100644 --- a/mussel/cli/aggregate_sample_features.py +++ b/mussel/cli/aggregate_sample_features.py @@ -1,12 +1,16 @@ import logging +import os from dataclasses import dataclass, field from typing import List, Optional +import h5py import hydra +import numpy as np from hydra.conf import HelpConf, HydraConf from hydra.core.config_store import ConfigStore from mussel.utils.feature_extract import aggregate_sample_features +from mussel.utils.file import save_hdf5 logger = logging.getLogger(__name__) @@ -97,16 +101,39 @@ def main(cfg: AggregateSampleFeaturesConfig): cfg.seed, ) - aggregate_sample_features( - patch_features_h5_paths=list(cfg.patch_features_h5_paths), - sample_ids=list(cfg.sample_ids), - output_dir=cfg.output_dir, - output_h5_suffix=cfg.output_h5_suffix, + patch_features_h5_paths = list(cfg.patch_features_h5_paths) + sample_ids = list(cfg.sample_ids) + + if len(patch_features_h5_paths) != len(sample_ids): + raise ValueError( + f"patch_features_h5_paths ({len(patch_features_h5_paths)}) and " + f"sample_ids ({len(sample_ids)}) must have the same length." + ) + + features_list = [] + coords_list = [] + for h5_path in patch_features_h5_paths: + with h5py.File(h5_path, "r") as h5: + features_list.append(np.array(h5["features"])) + coords_list.append(h5["coords"][:]) + + results = aggregate_sample_features( + features_list=features_list, + coords_list=coords_list, + sample_ids=sample_ids, max_tiles=cfg.max_tiles, subsampling_strategy=cfg.subsampling_strategy, seed=cfg.seed, ) + os.makedirs(cfg.output_dir, exist_ok=True) + for sample_id, (features, coords) in results.items(): + out_path = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}") + save_hdf5(out_path, {"features": features, "coords": coords}, mode="w") + logger.info( + "Wrote %s (%d tiles, dim=%d)", out_path, len(features), features.shape[1] + ) + logger.info("Done.") diff --git a/mussel/utils/feature_extract.py b/mussel/utils/feature_extract.py index 6b897356..36144f18 100644 --- a/mussel/utils/feature_extract.py +++ b/mussel/utils/feature_extract.py @@ -1904,50 +1904,49 @@ def subsample_tiles( def aggregate_sample_features( - patch_features_h5_paths: List[str], + features_list: List[np.ndarray], + coords_list: List[np.ndarray], sample_ids: List[str], - output_dir: str, - output_h5_suffix: str = "features.h5", max_tiles: Optional[int] = None, subsampling_strategy: str = "random", seed: int = 42, -) -> None: - """Concatenate per-slide patch features into one H5 per sample. +) -> dict: + """Concatenate per-slide patch features into one array per sample. - Reads per-slide feature H5 files (each with ``features`` (N_i, D) and - ``coords`` (N_i, 2) datasets), groups them by ``sample_id``, concatenates - on the tile axis, optionally subsamples to ``max_tiles``, and writes one - output H5 per unique sample. + Groups slides by ``sample_id``, concatenates their features and coordinates + on the tile axis, and optionally subsamples to ``max_tiles``. Args: - patch_features_h5_paths: Paths to per-slide feature H5 files - (produced by ``extract_features``). + features_list: Per-slide feature arrays, each of shape ``(N_i, D)``. + coords_list: Per-slide coordinate arrays, each of shape ``(N_i, 2)``. + Must have the same length as ``features_list``. sample_ids: Sample identifier for each slide (same length as - ``patch_features_h5_paths``). Slides with the same ``sample_id`` - are concatenated together. - output_dir: Directory where one ``{sample_id}.{output_h5_suffix}`` file - is written per unique sample. - output_h5_suffix: Filename suffix for output files (default - ``"features.h5"``). + ``features_list``). Slides with the same ``sample_id`` are + concatenated together. max_tiles: If set, subsample each sample to at most this many tiles after concatenation. ``None`` keeps all tiles. subsampling_strategy: Strategy when subsampling — ``"random"``, ``"proportional"``, or ``"equal"``. Ignored when ``max_tiles`` is ``None`` or total tiles ≤ ``max_tiles``. seed: Random seed for subsampling reproducibility (default ``42``). + + Returns: + Ordered dict mapping each unique ``sample_id`` to a + ``(features, coords)`` tuple of concatenated (and optionally + subsampled) arrays. """ - if len(patch_features_h5_paths) != len(sample_ids): + if not (len(features_list) == len(coords_list) == len(sample_ids)): raise ValueError( - f"patch_features_h5_paths ({len(patch_features_h5_paths)}) and " - f"sample_ids ({len(sample_ids)}) must have the same length." + f"features_list ({len(features_list)}), coords_list ({len(coords_list)}), " + f"and sample_ids ({len(sample_ids)}) must all have the same length." ) - os.makedirs(output_dir, exist_ok=True) - groups: dict = collections.OrderedDict() for idx, sid in enumerate(sample_ids): groups.setdefault(sid, []).append(idx) + results: dict = collections.OrderedDict() + for sample_id, indices in groups.items(): logger.info("Aggregating sample %s from %d slide(s)", sample_id, len(indices)) @@ -1956,14 +1955,12 @@ def aggregate_sample_features( slide_sizes = [] for i in indices: - h5_path = patch_features_h5_paths[i] - with h5py.File(h5_path, "r") as h5: - feats = np.array(h5["features"]) - coords = h5["coords"][:] + feats = features_list[i] + coords = coords_list[i] all_features.append(feats) all_coords.append(coords) slide_sizes.append(len(feats)) - logger.debug(" slide %d: %d tiles from %s", i, len(feats), h5_path) + logger.debug(" slide %d: %d tiles", i, len(feats)) features = np.concatenate(all_features, axis=0) coords = np.concatenate(all_coords, axis=0) @@ -1984,11 +1981,12 @@ def aggregate_sample_features( max_tiles, ) - out_path = os.path.join(output_dir, f"{sample_id}.{output_h5_suffix}") - save_hdf5(out_path, {"features": features, "coords": coords}, mode="w") logger.info( - "Wrote %s (%d tiles, dim=%d)", out_path, len(features), features.shape[1] + "Sample %s: %d tiles, dim=%d", sample_id, len(features), features.shape[1] ) + results[sample_id] = (features, coords) + + return results @timed diff --git a/tests/mussel/cli/test_aggregate_sample_features.py b/tests/mussel/cli/test_aggregate_sample_features.py index 4f6d31e6..4d0f9a1e 100644 --- a/tests/mussel/cli/test_aggregate_sample_features.py +++ b/tests/mussel/cli/test_aggregate_sample_features.py @@ -125,88 +125,77 @@ def test_subsample_tiles_invalid_strategy(): def test_aggregate_sample_features_single_slide(tmp_path): """One slide per sample — output equals input.""" - h5_a = tmp_path / "slide_a.h5" - feats_a, coords_a = _write_fake_h5(h5_a, n_tiles=30) + feats_a, coords_a = _make_data(30) - _aggregate_sample_features( - patch_features_h5_paths=[str(h5_a)], + results = _aggregate_sample_features( + features_list=[feats_a], + coords_list=[coords_a], sample_ids=["sample1"], - output_dir=str(tmp_path / "out"), - output_h5_suffix="features.h5", max_tiles=None, subsampling_strategy="random", seed=42, ) - out_h5 = tmp_path / "out" / "sample1.features.h5" - assert out_h5.exists() - with h5py.File(out_h5) as f: - np.testing.assert_array_equal(f["features"][:], feats_a) - np.testing.assert_array_equal(f["coords"][:], coords_a) + assert "sample1" in results + np.testing.assert_array_equal(results["sample1"][0], feats_a) + np.testing.assert_array_equal(results["sample1"][1], coords_a) def test_aggregate_sample_features_multi_slide(tmp_path): """Two slides per sample — features are concatenated.""" - h5_a = tmp_path / "slide_a.h5" - h5_b = tmp_path / "slide_b.h5" - _write_fake_h5(h5_a, n_tiles=20, seed=1) - _write_fake_h5(h5_b, n_tiles=15, seed=2) - - _aggregate_sample_features( - patch_features_h5_paths=[str(h5_a), str(h5_b)], + rng = np.random.default_rng(0) + feats_a = rng.random((20, 4)).astype(np.float32) + coords_a = rng.integers(0, 1000, (20, 2)) + feats_b = rng.random((15, 4)).astype(np.float32) + coords_b = rng.integers(0, 1000, (15, 2)) + + results = _aggregate_sample_features( + features_list=[feats_a, feats_b], + coords_list=[coords_a, coords_b], sample_ids=["sampleX", "sampleX"], - output_dir=str(tmp_path / "out"), max_tiles=None, ) - out_h5 = tmp_path / "out" / "sampleX.features.h5" - assert out_h5.exists() - with h5py.File(out_h5) as f: - assert f["features"].shape == (35, 4) - assert f["coords"].shape == (35, 2) + assert results["sampleX"][0].shape == (35, 4) + assert results["sampleX"][1].shape == (35, 2) def test_aggregate_sample_features_two_samples(tmp_path): - """Three slides, two samples — two output files.""" - paths = [tmp_path / f"s{i}.h5" for i in range(3)] - for i, p in enumerate(paths): - _write_fake_h5(p, n_tiles=10, seed=i) + """Three slides, two samples — two entries in result.""" + rng = np.random.default_rng(0) + slides = [(rng.random((10, 4)).astype(np.float32), rng.integers(0, 1000, (10, 2))) for _ in range(3)] + features_list = [f for f, _ in slides] + coords_list = [c for _, c in slides] - _aggregate_sample_features( - patch_features_h5_paths=[str(p) for p in paths], + results = _aggregate_sample_features( + features_list=features_list, + coords_list=coords_list, sample_ids=["sA", "sA", "sB"], - output_dir=str(tmp_path / "out"), max_tiles=None, ) - out_a = tmp_path / "out" / "sA.features.h5" - out_b = tmp_path / "out" / "sB.features.h5" - assert out_a.exists() and out_b.exists() - with h5py.File(out_a) as f: - assert f["features"].shape[0] == 20 - with h5py.File(out_b) as f: - assert f["features"].shape[0] == 10 + assert results["sA"][0].shape[0] == 20 + assert results["sB"][0].shape[0] == 10 def test_aggregate_sample_features_with_subsampling(tmp_path): """Subsampling reduces output to max_tiles.""" - h5_a = tmp_path / "s0.h5" - h5_b = tmp_path / "s1.h5" - _write_fake_h5(h5_a, n_tiles=80, seed=0) - _write_fake_h5(h5_b, n_tiles=60, seed=1) - - _aggregate_sample_features( - patch_features_h5_paths=[str(h5_a), str(h5_b)], + rng = np.random.default_rng(0) + feats_a = rng.random((80, 4)).astype(np.float32) + coords_a = rng.integers(0, 1000, (80, 2)) + feats_b = rng.random((60, 4)).astype(np.float32) + coords_b = rng.integers(0, 1000, (60, 2)) + + results = _aggregate_sample_features( + features_list=[feats_a, feats_b], + coords_list=[coords_a, coords_b], sample_ids=["big", "big"], - output_dir=str(tmp_path / "out"), max_tiles=50, subsampling_strategy="random", seed=99, ) - out_h5 = tmp_path / "out" / "big.features.h5" - with h5py.File(out_h5) as f: - assert f["features"].shape[0] == 50 + assert results["big"][0].shape[0] == 50 # =============================================================================