Skip to content
28 changes: 27 additions & 1 deletion braggtrack/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Shared CLI helpers for volume loading, CSV I/O, and synthetic fallback."""
"""Shared CLI helpers for volume loading, CSV I/O, notebooks, and synthetic fallback."""

from __future__ import annotations

import csv
import hashlib
import json
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -59,3 +60,28 @@ def write_csv(path: Path, rows: list[dict[str, Any]]) -> None:
writer = csv.DictWriter(fh, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)


def write_qc_notebook(path: Path, *, title: str, code_source: list[str]) -> None:
"""Write a minimal QC notebook with a markdown header and one code cell."""
nb = {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [f"# {title}\n"],
},
{
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": code_source,
},
],
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
"nbformat": 4,
"nbformat_minor": 5,
}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(nb, indent=2))
10 changes: 6 additions & 4 deletions braggtrack/cli/embed_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Compute Week 4 multi-view MIP embeddings for segmented spots."""
"""Compute multi-view MIP embeddings for segmented spots."""

from __future__ import annotations

Expand Down Expand Up @@ -26,8 +26,10 @@ def build_parser() -> argparse.ArgumentParser:
default=None,
help="Dataset root with scan folders (default: data/sample_operando if present, else .)",
)
p.add_argument("--segdir", default="artifacts/week2", help="Segmentation output with features.csv + labels.npz")
p.add_argument("--outdir", default="artifacts/week4", help="Embedding output root")
p.add_argument(
"--segdir", default="artifacts/segmentation", help="Segmentation output with features.csv + labels.npz"
)
p.add_argument("--outdir", default="artifacts/embedding", help="Embedding output root")
p.add_argument("--margin", type=int, default=2, help="Voxel padding around each spot bbox")
p.add_argument(
"--backend",
Expand Down Expand Up @@ -122,7 +124,7 @@ def main() -> int:
"dim": dim,
"backend": args.backend,
"model": args.model if args.backend == "torch" else "mock-hash",
"schema_version": "week4.v1",
"schema_version": "embedding.v1",
}
(scan_out / "embedding_manifest.json").write_text(json.dumps(manifest, indent=2))
summaries.append({**manifest, "embeddings": str(scan_out / "embeddings.npz")})
Expand Down
162 changes: 101 additions & 61 deletions braggtrack/cli/segment_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Run classical segmentation over discovered scan files and write artifacts."""
"""Run segmentation over discovered scan files and write artifacts."""

from __future__ import annotations

Expand All @@ -8,7 +8,7 @@

import numpy as np

from braggtrack.cli._utils import synth_volume_from_file, write_csv
from braggtrack.cli._utils import synth_volume_from_file, write_csv, write_qc_notebook
from braggtrack.io import (
MissingH5DependencyError,
discover_operando_scans,
Expand All @@ -18,10 +18,12 @@
from braggtrack.segmentation import (
extract_instance_table,
fill_holes_binary,
merge_nearby_labels,
otsu_threshold,
relabel_sequential,
remove_small_objects,
segment_classical,
segment_dino,
)


Expand All @@ -33,9 +35,15 @@ def build_parser() -> argparse.ArgumentParser:
default=None,
help="Dataset root with scan folders (default: data/sample_operando if present, else .)",
)
parser.add_argument("--outdir", default="artifacts/week2", help="Output artifact directory")
parser.add_argument("--outdir", default="artifacts/segmentation", help="Output artifact directory")
parser.add_argument(
"--method",
choices=["classical", "dino"],
default="classical",
help="Segmentation method: classical (LoG + watershed) or dino (DINOv3 features + HDBSCAN)",
)
parser.add_argument("--blur-passes", type=int, default=1)
parser.add_argument("--seed-separation", type=int, default=1)
parser.add_argument("--seed-separation", type=int, default=2)
parser.add_argument("--h-value", type=float, default=0.1)
parser.add_argument("--min-size", type=int, default=8)
parser.add_argument(
Expand All @@ -50,52 +58,61 @@ def build_parser() -> argparse.ArgumentParser:
default=99.95,
help="Seed must also exceed this percentile of the LoG response inside the foreground",
)
parser.add_argument(
"--threshold-fraction",
type=float,
default=1.0,
help="Multiply Otsu threshold by this fraction to capture diffuse spots (1.0 = original Otsu)",
)
parser.add_argument(
"--merge-distance",
type=float,
default=15.0,
help="Merge adjacent labels whose centroids are within this many voxels (0 disables)",
)
# DINO-specific arguments
parser.add_argument(
"--dino-backend",
choices=["auto", "mock", "torch"],
default=None,
help="DINO backend (default: auto-detect torch, fall back to mock)",
)
parser.add_argument(
"--dino-model",
default="facebook/dinov3-vitb16-pretrain-lvd1689m",
help="HuggingFace model ID for the DINO torch backend",
)
parser.add_argument("--dino-pca-components", type=int, default=16, help="PCA components for DINO feature reduction")
parser.add_argument("--dino-min-cluster-size", type=int, default=3, help="HDBSCAN min_cluster_size")
parser.add_argument("--dino-min-samples", type=int, default=2, help="HDBSCAN min_samples")
parser.add_argument(
"--dino-min-overlap",
type=float,
default=0.3,
help="Min overlap fraction for 3D slice stitching",
)
return parser


def _write_notebook(path: Path) -> None:
nb = {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Week 2 Visual QC\n",
"Loads per-scan feature tables and overlays up to 20 representative objects.\n",
],
},
{
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": [
"import csv, json\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"root = Path('artifacts/week2')\n",
"for scan_dir in sorted(root.glob('scan*')):\n",
" table = scan_dir / 'features.csv'\n",
" if not table.exists():\n",
" continue\n",
" rows = list(csv.DictReader(table.open()))\n",
" rows = sorted(rows, key=lambda r: float(r['integrated_intensity']), reverse=True)[:20]\n",
" print(scan_dir.name, 'objects:', len(rows))\n",
" fig, ax = plt.subplots(figsize=(8, 4))\n",
" ax.set_title(f'{scan_dir.name} top-20 object intensities')\n",
" ax.bar(range(len(rows)), [float(r['integrated_intensity']) for r in rows])\n",
" ax.set_xlabel('Object rank')\n",
" ax.set_ylabel('Integrated intensity')\n",
" plt.show()\n",
],
},
],
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
"nbformat": 4,
"nbformat_minor": 5,
}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(nb, indent=2))
_SEGMENTATION_QC_CODE = [
"import csv, json\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"root = Path('artifacts/segmentation')\n",
"for scan_dir in sorted(root.glob('scan*')):\n",
" table = scan_dir / 'features.csv'\n",
" if not table.exists():\n",
" continue\n",
" rows = list(csv.DictReader(table.open()))\n",
" rows = sorted(rows, key=lambda r: float(r['integrated_intensity']), reverse=True)[:20]\n",
" print(scan_dir.name, 'objects:', len(rows))\n",
" fig, ax = plt.subplots(figsize=(8, 4))\n",
" ax.set_title(f'{scan_dir.name} top-20 object intensities')\n",
" ax.bar(range(len(rows)), [float(r['integrated_intensity']) for r in rows])\n",
" ax.set_xlabel('Object rank')\n",
" ax.set_ylabel('Integrated intensity')\n",
" plt.show()\n",
]


def main() -> int:
Expand All @@ -119,21 +136,36 @@ def main() -> int:
volume = synth_volume_from_file(scan.path)
source = "synthetic_fallback"

threshold = otsu_threshold(volume.ravel())
result = segment_classical(
volume,
threshold=threshold,
blur_passes=max(1, args.blur_passes),
h_value=float(args.h_value),
min_seed_separation=max(1, args.seed_separation),
seed_peak_fraction=float(args.seed_peak_fraction),
seed_response_percentile=float(args.seed_response_percentile),
)
if args.method == "dino":
result = segment_dino(
volume,
backend=args.dino_backend,
model_name=args.dino_model,
n_components_pca=args.dino_pca_components,
min_cluster_size=args.dino_min_cluster_size,
min_samples=args.dino_min_samples,
threshold_fraction=float(args.threshold_fraction),
min_overlap_fraction=args.dino_min_overlap,
)
else:
raw_threshold = otsu_threshold(volume.ravel())
threshold = raw_threshold * float(args.threshold_fraction)
result = segment_classical(
volume,
threshold=threshold,
blur_passes=max(1, args.blur_passes),
h_value=float(args.h_value),
min_seed_separation=max(1, args.seed_separation),
seed_peak_fraction=float(args.seed_peak_fraction),
seed_response_percentile=float(args.seed_response_percentile),
)

labels = remove_small_objects(result.labeled_volume, min_size=max(1, args.min_size))
binary = labels > 0
binary = fill_holes_binary(binary)
labels = np.where(binary, labels, 0)
if args.merge_distance > 0:
labels = merge_nearby_labels(labels, volume, max_centroid_distance=args.merge_distance)
labels = relabel_sequential(labels)

table = extract_instance_table(labels, volume)
Expand All @@ -145,10 +177,14 @@ def main() -> int:
"scan": scan.scan_name,
"file": str(scan.path),
"source": source,
"threshold": threshold,
"method": args.method,
"threshold": result.threshold,
"threshold_fraction": args.threshold_fraction,
"effective_threshold": result.threshold,
"merge_distance": args.merge_distance,
"seed_count": result.seed_count,
"component_count": len(table),
"schema_version": "week2.v1",
"schema_version": "segmentation.v1",
"labels_archive": str(scan_out / "labels.npz"),
},
indent=2,
Expand All @@ -165,13 +201,17 @@ def main() -> int:
"summary": str(scan_out / "summary.json"),
"features": str(scan_out / "features.csv"),
"labels_archive": str(scan_out / "labels.npz"),
"schema_version": "week2.v1",
"schema_version": "segmentation.v1",
}
)

(outdir / "segmentation_summary.json").write_text(json.dumps(summaries, indent=2))
write_csv(outdir / "segmentation_summary.csv", summaries)
_write_notebook(outdir / "qc" / "week2_visual_qc.ipynb")
write_qc_notebook(
outdir / "qc" / "segmentation_visual_qc.ipynb",
title="Segmentation Visual QC",
code_source=_SEGMENTATION_QC_CODE,
)

print(json.dumps(summaries, indent=2))
return 0 if summaries else 1
Expand Down
2 changes: 1 addition & 1 deletion braggtrack/cli/segment_synthetic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Run a Week 2 segmentation smoke test on a synthetic 3D volume."""
"""Run a segmentation smoke test on a synthetic 3D volume."""

from __future__ import annotations

Expand Down
Loading
Loading