From 88d588173e7680fad288d2b5fe7f313b856bbbd9 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Thu, 2 Apr 2026 17:44:52 -0400 Subject: [PATCH] feat: aggregate_sample_features saves .pt tensor alongside H5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add save_pt=True (default) and output_pt_suffix='features.pt' parameters to aggregate_sample_features and AggregateSampleFeaturesConfig. Motivation: downstream pipelines (mussel-nf MERGE_SAMPLE_FEATURES) need a .pt tensor for slide-level model compatibility. Previously callers had to manually convert H5 → PT; this makes it native. Changes: - mussel/utils/feature_extract.py: add save_pt, output_pt_suffix params; save torch.from_numpy(features) via save_torch_tensor when save_pt=True - mussel/cli/aggregate_sample_features.py: expose save_pt, output_pt_suffix in AggregateSampleFeaturesConfig and pass through to the function - tests: add PT output assertions to existing tests; add test_aggregate_sample_features_save_pt_false for opt-out path Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/aggregate_sample_features.py | 40 +++++++++++++++---- .../cli/test_aggregate_sample_features.py | 20 ++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/mussel/cli/aggregate_sample_features.py b/mussel/cli/aggregate_sample_features.py index 8ff7b951..47f91523 100644 --- a/mussel/cli/aggregate_sample_features.py +++ b/mussel/cli/aggregate_sample_features.py @@ -1,16 +1,18 @@ import logging import os from dataclasses import dataclass, field +from pathlib import Path from typing import List, Optional import h5py import hydra import numpy as np +import torch from hydra.conf import HelpConf, HydraConf from hydra.core.config_store import ConfigStore from mussel.utils.feature_extract import aggregate_sample_features -from mussel.utils.file import save_hdf5 +from mussel.utils.file import save_hdf5, save_torch_tensor logger = logging.getLogger(__name__) @@ -25,9 +27,13 @@ class AggregateSampleFeaturesConfig: (produced by extract_features). Must have the same length as sample_ids. sample_ids (List[str]): Sample identifier for each slide. Slides sharing the same sample_id are concatenated into a single output file. - output_dir (str): Directory where output H5 files are written. - output_h5_suffix (str): Filename suffix for output files (default "features.h5"). + output_dir (str): Directory where output files are written. + output_h5_suffix (str): Filename suffix for H5 output files (default "features.h5"). Each sample writes to "{output_dir}/{sample_id}.{output_h5_suffix}". + output_pt_suffix (str): Filename suffix for PT output files (default "features.pt"). + Only used when save_pt=True. + save_pt (bool): Whether to also save a PyTorch .pt tensor alongside the H5 + (default True). max_tiles (Optional[int]): Maximum number of tiles per sample after concatenation. When the total tile count exceeds this, tiles are subsampled. None (default) keeps all tiles. @@ -42,6 +48,8 @@ class AggregateSampleFeaturesConfig: sample_ids: List[str] = field(default_factory=list) output_dir: str = "" output_h5_suffix: str = "features.h5" + output_pt_suffix: str = "features.pt" + save_pt: bool = True max_tiles: Optional[int] = None subsampling_strategy: str = "random" seed: int = 42 @@ -53,7 +61,8 @@ class AggregateSampleFeaturesConfig: This tool reads HDF5 feature files produced by extract_features (one per slide), groups them by sample_id, concatenates all tiles on the tile axis, optionally -subsamples to a max_tiles budget, and writes one output H5 per unique sample. +subsamples to a max_tiles budget, and writes one output H5 and one output PT +tensor per unique sample. To write only the H5 (no PT file), pass save_pt=false. Subsampling strategies (when max_tiles is set): - random: uniformly sample max_tiles from the full tile pool (default) @@ -67,6 +76,13 @@ class AggregateSampleFeaturesConfig: output_dir=/results/samples \\ max_tiles=10000 \\ subsampling_strategy=proportional + + # H5-only output (no .pt file): + aggregate_sample_features \\ + 'patch_features_h5_paths=[slide1.h5,slide2.h5]' \\ + 'sample_ids=[P001,P001]' \\ + output_dir=/results/samples \\ + save_pt=false """ parameter_doc = f"""== Available Parameters == @@ -128,12 +144,22 @@ def main(cfg: AggregateSampleFeaturesConfig): os.makedirs(cfg.output_dir, exist_ok=True) for sample_id, (features, coords) in results.items(): - out_path = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}") - save_hdf5(out_path, {"features": features, "coords": coords}, mode="w") + out_h5 = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}") + save_hdf5(out_h5, {"features": features, "coords": coords}, mode="w") logger.info( - "Wrote %s (%d tiles, dim=%d)", out_path, len(features), features.shape[1] + "Wrote %s (%d tiles, dim=%d)", out_h5, len(features), features.shape[1] ) + out_pt = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_pt_suffix}") + if cfg.save_pt: + save_torch_tensor(out_pt, torch.from_numpy(features)) + logger.info("Wrote %s", out_pt) + else: + out_pt_path = Path(out_pt) + if out_pt_path.exists(): + out_pt_path.unlink() + logger.info("Removed stale PT output %s because save_pt=False", out_pt) + logger.info("Done.") diff --git a/tests/mussel/cli/test_aggregate_sample_features.py b/tests/mussel/cli/test_aggregate_sample_features.py index 4d0f9a1e..cd31b4a8 100644 --- a/tests/mussel/cli/test_aggregate_sample_features.py +++ b/tests/mussel/cli/test_aggregate_sample_features.py @@ -1,6 +1,7 @@ import h5py import numpy as np import pytest +import torch from mussel.utils.feature_extract import subsample_tiles @@ -160,6 +161,23 @@ def test_aggregate_sample_features_multi_slide(tmp_path): assert results["sampleX"][1].shape == (35, 2) +def test_aggregate_sample_features_save_pt_false(tmp_path): + """save_pt=False — only H5 is written, no PT file.""" + h5_a = tmp_path / "slide_a.h5" + _write_fake_h5(h5_a, n_tiles=10, seed=0) + + cfg = AggregateSampleFeaturesConfig( + patch_features_h5_paths=[str(h5_a)], + sample_ids=["s1"], + output_dir=str(tmp_path / "out"), + save_pt=False, + ) + mussel.cli.aggregate_sample_features.main(OmegaConf.structured(cfg)) + + assert (tmp_path / "out" / "s1.features.h5").exists() + assert not (tmp_path / "out" / "s1.features.pt").exists() + + def test_aggregate_sample_features_two_samples(tmp_path): """Three slides, two samples — two entries in result.""" rng = np.random.default_rng(0) @@ -234,6 +252,8 @@ def test_aggregate_sample_features_cli(tmp_path): assert f["features"].shape[0] == 30 # 25+20=45 → subsampled to 30 with h5py.File(tmp_path / "samples" / "P2.features.h5") as f: assert f["features"].shape[0] == 30 # 30 ≤ 30, no subsampling + assert (tmp_path / "samples" / "P1.features.pt").exists() + assert (tmp_path / "samples" / "P2.features.pt").exists() def test_cli_mismatched_lengths_raises(tmp_path):