Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions mussel/cli/aggregate_sample_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,32 @@ def main(cfg: AggregateSampleFeaturesConfig):
f"sample_ids ({len(sample_ids)}) must have the same length."
)

features_list = []
coords_list = []
for h5_path in patch_features_h5_paths:
with h5py.File(h5_path, "r") as h5:
features_list.append(np.array(h5["features"]))
coords_list.append(h5["coords"][:])

results = aggregate_sample_features(
features_list=features_list,
coords_list=coords_list,
sample_ids=sample_ids,
max_tiles=cfg.max_tiles,
subsampling_strategy=cfg.subsampling_strategy,
seed=cfg.seed,
)
# Group indices by sample_id first so we only hold one sample's slides in
# memory at a time, keeping peak memory proportional to the largest sample
# rather than the entire input set.
groups: dict[str, list[int]] = {}
for idx, sid in enumerate(sample_ids):
groups.setdefault(sid, []).append(idx)

os.makedirs(cfg.output_dir, exist_ok=True)
for sample_id, (features, coords) in results.items():
for sample_id, indices in groups.items():
features_list = []
coords_list = []
for i in indices:
with h5py.File(patch_features_h5_paths[i], "r") as h5:
features_list.append(np.array(h5["features"]))
coords_list.append(h5["coords"][:])

result = aggregate_sample_features(
features_list=features_list,
coords_list=coords_list,
sample_ids=[sample_id] * len(indices),
max_tiles=cfg.max_tiles,
subsampling_strategy=cfg.subsampling_strategy,
seed=cfg.seed,
)

features, coords = result[sample_id]
out_h5 = os.path.join(cfg.output_dir, f"{sample_id}.{cfg.output_h5_suffix}")
save_hdf5(out_h5, {"features": features, "coords": coords}, mode="w")
logger.info(
Expand Down
24 changes: 21 additions & 3 deletions mussel/utils/feature_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ def aggregate_sample_features(
max_tiles: Optional[int] = None,
subsampling_strategy: str = "random",
seed: int = 42,
) -> dict:
) -> dict[str, tuple[np.ndarray, np.ndarray]]:
"""Concatenate per-slide patch features into one array per sample.

Groups slides by ``sample_id``, concatenates their features and coordinates
Expand Down Expand Up @@ -1968,11 +1968,29 @@ def aggregate_sample_features(
f"and sample_ids ({len(sample_ids)}) must all have the same length."
)

groups: dict = collections.OrderedDict()
for i, (feats, coords) in enumerate(zip(features_list, coords_list)):
sid = sample_ids[i]
if feats.ndim != 2:
raise ValueError(
f"features_list[{i}] (sample_id={sid!r}) must be 2-D, "
f"got shape {feats.shape}."
)
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError(
f"coords_list[{i}] (sample_id={sid!r}) must have shape (N, 2), "
f"got shape {coords.shape}."
)
if len(feats) != len(coords):
raise ValueError(
f"features_list[{i}] and coords_list[{i}] (sample_id={sid!r}) "
f"have different lengths: {len(feats)} vs {len(coords)}."
)

groups: dict[str, list[int]] = {}
for idx, sid in enumerate(sample_ids):
groups.setdefault(sid, []).append(idx)

results: dict = collections.OrderedDict()
results: dict[str, tuple[np.ndarray, np.ndarray]] = {}

for sample_id, indices in groups.items():
logger.info("Aggregating sample %s from %d slide(s)", sample_id, len(indices))
Expand Down
81 changes: 66 additions & 15 deletions tests/mussel/cli/test_aggregate_sample_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,35 @@ def test_subsample_tiles_invalid_strategy():
aggregate_sample_features as _aggregate_sample_features


def test_aggregate_sample_features_invalid_shapes():
"""Per-slide validation raises informative ValueError on bad input."""
feats, coords = _make_data(10)

# 1-D features array
with pytest.raises(ValueError, match="2-D"):
_aggregate_sample_features(
features_list=[feats.ravel()],
coords_list=[coords],
sample_ids=["s"],
)

# coords wrong second dim
with pytest.raises(ValueError, match=r"\(N, 2\)"):
_aggregate_sample_features(
features_list=[feats],
coords_list=[coords[:, :1]],
sample_ids=["s"],
)

# mismatched lengths
with pytest.raises(ValueError, match="different lengths"):
_aggregate_sample_features(
features_list=[feats],
coords_list=[coords[:5]],
sample_ids=["s"],
)


def test_aggregate_sample_features_single_slide(tmp_path):
"""One slide per sample — output equals input."""
feats_a, coords_a = _make_data(30)
Expand Down Expand Up @@ -159,6 +188,12 @@ def test_aggregate_sample_features_multi_slide(tmp_path):

assert results["sampleX"][0].shape == (35, 4)
assert results["sampleX"][1].shape == (35, 2)
np.testing.assert_array_equal(
results["sampleX"][0], np.concatenate([feats_a, feats_b], axis=0)
)
np.testing.assert_array_equal(
results["sampleX"][1], np.concatenate([coords_a, coords_b], axis=0)
)


def test_aggregate_sample_features_save_pt_false(tmp_path):
Expand Down Expand Up @@ -197,23 +232,39 @@ def test_aggregate_sample_features_two_samples(tmp_path):


def test_aggregate_sample_features_with_subsampling(tmp_path):
"""Subsampling reduces output to max_tiles."""
rng = np.random.default_rng(0)
feats_a = rng.random((80, 4)).astype(np.float32)
coords_a = rng.integers(0, 1000, (80, 2))
feats_b = rng.random((60, 4)).astype(np.float32)
coords_b = rng.integers(0, 1000, (60, 2))
"""Subsampling reduces output to max_tiles, is reproducible, and keeps features/coords aligned."""
# Use identifiable rows: feature row i has value i in all dims, coord row i
# is (i, i). After subsampling, each selected feature row must equal its
# corresponding coord row, proving the two arrays stay in sync.
n_a, n_b = 80, 60
feats_a = np.tile(np.arange(n_a, dtype=np.float32)[:, None], (1, 4))
coords_a = np.tile(np.arange(n_a)[:, None], (1, 2))
feats_b = np.tile(np.arange(n_a, n_a + n_b, dtype=np.float32)[:, None], (1, 4))
coords_b = np.tile(np.arange(n_a, n_a + n_b)[:, None], (1, 2))

def run():
return _aggregate_sample_features(
features_list=[feats_a, feats_b],
coords_list=[coords_a, coords_b],
sample_ids=["big", "big"],
max_tiles=50,
subsampling_strategy="random",
seed=99,
)

results = _aggregate_sample_features(
features_list=[feats_a, feats_b],
coords_list=[coords_a, coords_b],
sample_ids=["big", "big"],
max_tiles=50,
subsampling_strategy="random",
seed=99,
)
r1 = run()
r2 = run()
f_out, c_out = r1["big"]

assert f_out.shape[0] == 50
assert c_out.shape[0] == 50

# Reproducible with same seed
np.testing.assert_array_equal(r1["big"][0], r2["big"][0])
np.testing.assert_array_equal(r1["big"][1], r2["big"][1])

assert results["big"][0].shape[0] == 50
# features and coords remain aligned: feature value == coord value for each row
np.testing.assert_array_equal(f_out[:, 0].astype(np.int64), c_out[:, 0])


# =============================================================================
Expand Down
Loading