Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions mussel/cli/aggregate_sample_features.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import h5py
import hydra
import numpy as np
import torch
from hydra.conf import HelpConf, HydraConf
from hydra.core.config_store import ConfigStore

from mussel.utils.feature_extract import aggregate_sample_features
from mussel.utils.file import save_hdf5
from mussel.utils.file import save_hdf5, save_torch_tensor

logger = logging.getLogger(__name__)

Expand All @@ -25,9 +27,13 @@ class AggregateSampleFeaturesConfig:
(produced by extract_features). Must have the same length as sample_ids.
sample_ids (List[str]): Sample identifier for each slide. Slides sharing
the same sample_id are concatenated into a single output file.
output_dir (str): Directory where output H5 files are written.
output_h5_suffix (str): Filename suffix for output files (default "features.h5").
output_dir (str): Directory where output files are written.
output_h5_suffix (str): Filename suffix for H5 output files (default "features.h5").
Each sample writes to "{output_dir}/{sample_id}.{output_h5_suffix}".
output_pt_suffix (str): Filename suffix for PT output files (default "features.pt").
Only used when save_pt=True.
save_pt (bool): Whether to also save a PyTorch .pt tensor alongside the H5
(default True).
Comment on lines +30 to +36
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CLI help text in desc_doc below still says it “writes one output H5 per unique sample”, but the default behavior now also writes a .pt tensor unless save_pt=false. Please update the CLI description/help output to mention the additional PT file and how to opt out, so the CLI’s usage docs match the actual outputs.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a761faa. The desc_doc now says "writes one output H5 and one output PT tensor per unique sample" and includes an H5-only example using save_pt=false.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already addressed during the rebase onto #115desc_doc now reads "writes one output H5 and one output PT tensor per unique sample. To write only the H5 (no PT file), pass save_pt=false" and includes an H5-only example.

max_tiles (Optional[int]): Maximum number of tiles per sample after concatenation.
When the total tile count exceeds this, tiles are subsampled.
None (default) keeps all tiles.
Expand All @@ -42,6 +48,8 @@ class AggregateSampleFeaturesConfig:
sample_ids: List[str] = field(default_factory=list)
output_dir: str = ""
output_h5_suffix: str = "features.h5"
output_pt_suffix: str = "features.pt"
save_pt: bool = True
max_tiles: Optional[int] = None
subsampling_strategy: str = "random"
seed: int = 42
Expand All @@ -53,7 +61,8 @@ class AggregateSampleFeaturesConfig:

This tool reads HDF5 feature files produced by extract_features (one per slide),
groups them by sample_id, concatenates all tiles on the tile axis, optionally
subsamples to a max_tiles budget, and writes one output H5 per unique sample.
subsamples to a max_tiles budget, and writes one output H5 and one output PT
tensor per unique sample. To write only the H5 (no PT file), pass save_pt=false.

Subsampling strategies (when max_tiles is set):
- random: uniformly sample max_tiles from the full tile pool (default)
Expand All @@ -67,6 +76,13 @@ class AggregateSampleFeaturesConfig:
output_dir=/results/samples \\
max_tiles=10000 \\
subsampling_strategy=proportional

# H5-only output (no .pt file):
aggregate_sample_features \\
'patch_features_h5_paths=[slide1.h5,slide2.h5]' \\
'sample_ids=[P001,P001]' \\
output_dir=/results/samples \\
save_pt=false
"""

parameter_doc = f"""== Available Parameters ==
Expand Down Expand Up @@ -128,12 +144,22 @@ def main(cfg: AggregateSampleFeaturesConfig):

os.makedirs(cfg.output_dir, exist_ok=True)
for sample_id, (features, coords) in results.items():
out_path = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}")
save_hdf5(out_path, {"features": features, "coords": coords}, mode="w")
out_h5 = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}")
save_hdf5(out_h5, {"features": features, "coords": coords}, mode="w")
logger.info(
"Wrote %s (%d tiles, dim=%d)", out_path, len(features), features.shape[1]
"Wrote %s (%d tiles, dim=%d)", out_h5, len(features), features.shape[1]
)

out_pt = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_pt_suffix}")
if cfg.save_pt:
save_torch_tensor(out_pt, torch.from_numpy(features))
logger.info("Wrote %s", out_pt)
else:
out_pt_path = Path(out_pt)
if out_pt_path.exists():
out_pt_path.unlink()
logger.info("Removed stale PT output %s because save_pt=False", out_pt)

logger.info("Done.")


Expand Down
20 changes: 20 additions & 0 deletions tests/mussel/cli/test_aggregate_sample_features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import h5py
import numpy as np
import pytest
import torch

from mussel.utils.feature_extract import subsample_tiles

Expand Down Expand Up @@ -160,6 +161,23 @@ def test_aggregate_sample_features_multi_slide(tmp_path):
assert results["sampleX"][1].shape == (35, 2)


def test_aggregate_sample_features_save_pt_false(tmp_path):
"""save_pt=False — only H5 is written, no PT file."""
h5_a = tmp_path / "slide_a.h5"
_write_fake_h5(h5_a, n_tiles=10, seed=0)

cfg = AggregateSampleFeaturesConfig(
patch_features_h5_paths=[str(h5_a)],
sample_ids=["s1"],
output_dir=str(tmp_path / "out"),
save_pt=False,
)
mussel.cli.aggregate_sample_features.main(OmegaConf.structured(cfg))

assert (tmp_path / "out" / "s1.features.h5").exists()
assert not (tmp_path / "out" / "s1.features.pt").exists()


def test_aggregate_sample_features_two_samples(tmp_path):
"""Three slides, two samples — two entries in result."""
rng = np.random.default_rng(0)
Expand Down Expand Up @@ -234,6 +252,8 @@ def test_aggregate_sample_features_cli(tmp_path):
assert f["features"].shape[0] == 30 # 25+20=45 → subsampled to 30
with h5py.File(tmp_path / "samples" / "P2.features.h5") as f:
assert f["features"].shape[0] == 30 # 30 ≤ 30, no subsampling
assert (tmp_path / "samples" / "P1.features.pt").exists()
assert (tmp_path / "samples" / "P2.features.pt").exists()


def test_cli_mismatched_lengths_raises(tmp_path):
Expand Down
Loading