From 794e07a903b12d89cb820d7c6c60f3c7413129b4 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 14:30:37 +0000 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20improve=20segmentation=20=E2=80=94?= =?UTF-8?q?=20merge=20nearby=20labels,=20tune=20defaults?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add merge_nearby_labels() post-processing step that greedily merges adjacent watershed fragments whose intensity-weighted centroids are within a configurable distance. This addresses over-segmentation where a single Bragg spot gets split into multiple watershed basins. Parameter changes (data-driven from sweep on bundled scans): - min_seed_separation: 1 → 2 (halves over-splitting, spread=1) - CLI adds --threshold-fraction and --merge-distance (default 15) - Notebook segment() uses merge_nearby_labels with distance=15 Before: 35/36/33 spots across 3 scans (spread=3, over-segmented) After: 18/17/16 spots across 3 scans (spread=2, better consistency) https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/cli/segment_dataset.py | 25 +++++++- braggtrack/segmentation/__init__.py | 3 +- braggtrack/segmentation/classical.py | 2 +- braggtrack/segmentation/postprocess.py | 60 ++++++++++++++++++ notebooks/braggtrack_demo.ipynb | 88 +++++++++++++------------- 5 files changed, 129 insertions(+), 49 deletions(-) diff --git a/braggtrack/cli/segment_dataset.py b/braggtrack/cli/segment_dataset.py index 4d7443b..a90d08f 100644 --- a/braggtrack/cli/segment_dataset.py +++ b/braggtrack/cli/segment_dataset.py @@ -18,6 +18,7 @@ from braggtrack.segmentation import ( extract_instance_table, fill_holes_binary, + merge_nearby_labels, otsu_threshold, relabel_sequential, remove_small_objects, @@ -35,7 +36,7 @@ def build_parser() -> argparse.ArgumentParser: ) parser.add_argument("--outdir", default="artifacts/week2", help="Output artifact directory") parser.add_argument("--blur-passes", type=int, default=1) - parser.add_argument("--seed-separation", type=int, default=1) + parser.add_argument("--seed-separation", type=int, default=2) parser.add_argument("--h-value", type=float, default=0.1) parser.add_argument("--min-size", type=int, default=8) parser.add_argument( @@ -50,6 +51,18 @@ def build_parser() -> argparse.ArgumentParser: default=99.95, help="Seed must also exceed this percentile of the LoG response inside the foreground", ) + parser.add_argument( + "--threshold-fraction", + type=float, + default=1.0, + help="Multiply Otsu threshold by this fraction to capture diffuse spots (1.0 = original Otsu)", + ) + parser.add_argument( + "--merge-distance", + type=float, + default=15.0, + help="Merge adjacent labels whose centroids are within this many voxels (0 disables)", + ) return parser @@ -119,7 +132,8 @@ def main() -> int: volume = synth_volume_from_file(scan.path) source = "synthetic_fallback" - threshold = otsu_threshold(volume.ravel()) + raw_threshold = otsu_threshold(volume.ravel()) + threshold = raw_threshold * float(args.threshold_fraction) result = segment_classical( volume, threshold=threshold, @@ -134,6 +148,8 @@ def main() -> int: binary = labels > 0 binary = fill_holes_binary(binary) labels = np.where(binary, labels, 0) + if args.merge_distance > 0: + labels = merge_nearby_labels(labels, volume, max_centroid_distance=args.merge_distance) labels = relabel_sequential(labels) table = extract_instance_table(labels, volume) @@ -145,7 +161,10 @@ def main() -> int: "scan": scan.scan_name, "file": str(scan.path), "source": source, - "threshold": threshold, + "threshold": raw_threshold, + "threshold_fraction": args.threshold_fraction, + "effective_threshold": threshold, + "merge_distance": args.merge_distance, "seed_count": result.seed_count, "component_count": len(table), "schema_version": "week2.v1", diff --git a/braggtrack/segmentation/__init__.py b/braggtrack/segmentation/__init__.py index 7acd5b1..a62c731 100644 --- a/braggtrack/segmentation/__init__.py +++ b/braggtrack/segmentation/__init__.py @@ -12,7 +12,7 @@ from .features import extract_instance_table from .otsu import flag_outlier_frames, otsu_threshold, smooth_thresholds from .pipeline import SegmentationResult, connected_components_3d, segment_volume -from .postprocess import fill_holes_binary, relabel_sequential, remove_small_objects +from .postprocess import fill_holes_binary, merge_nearby_labels, relabel_sequential, remove_small_objects from .projection import label_projection_by_intensity, otsu_floor_from_mip __all__ = [ @@ -26,6 +26,7 @@ "label_projection_by_intensity", "local_maxima_seeds", "log_enhance_3d", + "merge_nearby_labels", "otsu_floor_from_mip", "flag_outlier_frames", "otsu_threshold", diff --git a/braggtrack/segmentation/classical.py b/braggtrack/segmentation/classical.py index 5cf17f0..b0821ba 100644 --- a/braggtrack/segmentation/classical.py +++ b/braggtrack/segmentation/classical.py @@ -159,7 +159,7 @@ def segment_classical( blur_passes: int = 1, sigma: float = 1.0, h_value: float = 0.1, - min_seed_separation: int = 1, + min_seed_separation: int = 2, seed_peak_fraction: float = 0.2, seed_response_percentile: float = 99.95, ) -> ClassicalSegmentationResult: diff --git a/braggtrack/segmentation/postprocess.py b/braggtrack/segmentation/postprocess.py index 30b6bbc..573f850 100644 --- a/braggtrack/segmentation/postprocess.py +++ b/braggtrack/segmentation/postprocess.py @@ -32,6 +32,66 @@ def fill_holes_binary(mask: np.ndarray) -> np.ndarray: return filled[1:-1, 1:-1, 1:-1] +def merge_nearby_labels( + labels: np.ndarray, + volume: np.ndarray, + max_centroid_distance: float, +) -> np.ndarray: + """Merge adjacent labeled regions whose intensity-weighted centroids are close. + + Two labels are candidates for merging when they are spatially adjacent + (share a face via 6-connectivity dilation) **and** the Euclidean distance + between their intensity-weighted centroids is below *max_centroid_distance*. + Merging is greedy — closest pairs first — and iterates until no more + merges are possible. + """ + from scipy.ndimage import binary_dilation, generate_binary_structure + + labels = np.asarray(labels, dtype=np.int32).copy() + volume = np.asarray(volume, dtype=np.float64) + struct = generate_binary_structure(3, 1) # 6-connectivity + + changed = True + while changed: + changed = False + unique_ids = [i for i in np.unique(labels) if i > 0] + if len(unique_ids) < 2: + break + + centroids: dict[int, np.ndarray] = {} + for lid in unique_ids: + mask = labels == lid + coords = np.argwhere(mask) + weights = volume[mask] + total = weights.sum() + if total > 0: + centroids[lid] = (coords * weights[:, None]).sum(axis=0) / total + else: + centroids[lid] = coords.mean(axis=0) + + merge_pairs: list[tuple[float, int, int]] = [] + for lid in unique_ids: + dilated = binary_dilation(labels == lid, structure=struct) + neighbor_ids = set(np.unique(labels[dilated])) - {0, lid} + for nid in neighbor_ids: + if nid < lid: + continue + dist = float(np.linalg.norm(centroids[lid] - centroids[nid])) + if dist < max_centroid_distance: + merge_pairs.append((dist, lid, nid)) + + merge_pairs.sort() + merged_this_round: set[int] = set() + for _, a, b in merge_pairs: + if a in merged_this_round or b in merged_this_round: + continue + labels[labels == b] = a + merged_this_round.add(b) + changed = True + + return labels + + def relabel_sequential(labels: np.ndarray) -> np.ndarray: """Remap positive labels to consecutive integers starting at 1.""" labels = np.asarray(labels) diff --git a/notebooks/braggtrack_demo.ipynb b/notebooks/braggtrack_demo.ipynb index b526705..02b0e49 100644 --- a/notebooks/braggtrack_demo.ipynb +++ b/notebooks/braggtrack_demo.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "id": "e65e834e", "metadata": {}, - "source": "# BraggTrack end-to-end demo\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BASE-Laboratory/BraggTrack/blob/main/notebooks/braggtrack_demo.ipynb)\n\nRuns the full pipeline on the bundled `data/sample_operando/` scans:\n\n1. **Discover** \u2014 find the per-scan H5 files.\n2. **Segment (Week 2)** \u2014 LoG \u2192 h-maxima \u2192 seeded watershed \u2192 instance features.\n3. **Track physics-only (Week 3)** \u2014 Hungarian over a geometry cost with per-axis gating; build a lifecycle DAG.\n4. **Semantic descriptors (Week 4)** \u2014 orthogonal MIPs + frozen-encoder embeddings.\n5. **Geometry + semantic tracking (Week 4)** \u2014 compose `\u03b1 \u00b7 geometry + \u03b2 \u00b7 (1 \u2212 cos)`.\n6. **\u03b1/\u03b2 ablation** \u2014 how the semantic weight shifts tracking metrics.\n7. **Synthetic crossing** \u2014 a case where geometry alone fails and semantics recover identity.\n\nFinal section shows the one-line CLI equivalents for each stage.\n\nThis notebook uses the **mock** DINO backend by default, so no PyTorch / HuggingFace weights are required. Set `BRAGGTRACK_DINO_BACKEND=torch` if you have them installed and want real embeddings." + "source": "# BraggTrack end-to-end demo\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BASE-Laboratory/BraggTrack/blob/main/notebooks/braggtrack_demo.ipynb)\n\nRuns the full pipeline on the bundled `data/sample_operando/` scans:\n\n1. **Discover** — find the per-scan H5 files.\n2. **Segment (Week 2)** — LoG → h-maxima → seeded watershed → instance features.\n3. **Track physics-only (Week 3)** — Hungarian over a geometry cost with per-axis gating; build a lifecycle DAG.\n4. **Semantic descriptors (Week 4)** — orthogonal MIPs + frozen-encoder embeddings.\n5. **Geometry + semantic tracking (Week 4)** — compose `α · geometry + β · (1 − cos)`.\n6. **α/β ablation** — how the semantic weight shifts tracking metrics.\n7. **Synthetic crossing** — a case where geometry alone fails and semantics recover identity.\n\nFinal section shows the one-line CLI equivalents for each stage.\n\nThis notebook uses the **mock** DINO backend by default, so no PyTorch / HuggingFace weights are required. Set `BRAGGTRACK_DINO_BACKEND=torch` if you have them installed and want real embeddings." }, { "cell_type": "markdown", @@ -23,7 +23,7 @@ { "cell_type": "code", "id": "6d0140c6", - "source": "import os, subprocess, sys\n\n_ON_COLAB = \"google.colab\" in sys.modules or os.environ.get(\"COLAB_RELEASE_TAG\")\n\nif _ON_COLAB:\n print(\"Colab detected \u2014 installing BraggTrack + sample data...\")\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"braggtrack[notebook] @ git+https://github.com/BASE-Laboratory/BraggTrack.git\",\n ])\n if not os.path.isdir(\"data/sample_operando\"):\n subprocess.check_call([\n \"git\", \"clone\", \"--depth=1\", \"--filter=blob:none\", \"--sparse\",\n \"https://github.com/BASE-Laboratory/BraggTrack.git\", \"_braggtrack_repo\",\n ])\n subprocess.check_call(\n [\"git\", \"sparse-checkout\", \"set\", \"data/sample_operando\"],\n cwd=\"_braggtrack_repo\",\n )\n os.makedirs(\"data\", exist_ok=True)\n os.rename(\"_braggtrack_repo/data/sample_operando\", \"data/sample_operando\")\n subprocess.check_call([\"rm\", \"-rf\", \"_braggtrack_repo\"])\n os.environ.setdefault(\"BRAGGTRACK_DATA_ROOT\", os.path.abspath(\"data/sample_operando\"))\n print(\"Done.\")\nelse:\n print(\"Local environment \u2014 skipping Colab setup.\")", + "source": "import os, subprocess, sys\n\n_ON_COLAB = \"google.colab\" in sys.modules or os.environ.get(\"COLAB_RELEASE_TAG\")\n\nif _ON_COLAB:\n print(\"Colab detected — installing BraggTrack + sample data...\")\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"braggtrack[notebook] @ git+https://github.com/BASE-Laboratory/BraggTrack.git\",\n ])\n if not os.path.isdir(\"data/sample_operando\"):\n subprocess.check_call([\n \"git\", \"clone\", \"--depth=1\", \"--filter=blob:none\", \"--sparse\",\n \"https://github.com/BASE-Laboratory/BraggTrack.git\", \"_braggtrack_repo\",\n ])\n subprocess.check_call(\n [\"git\", \"sparse-checkout\", \"set\", \"data/sample_operando\"],\n cwd=\"_braggtrack_repo\",\n )\n os.makedirs(\"data\", exist_ok=True)\n os.rename(\"_braggtrack_repo/data/sample_operando\", \"data/sample_operando\")\n subprocess.check_call([\"rm\", \"-rf\", \"_braggtrack_repo\"])\n os.environ.setdefault(\"BRAGGTRACK_DATA_ROOT\", os.path.abspath(\"data/sample_operando\"))\n print(\"Done.\")\nelse:\n print(\"Local environment — skipping Colab setup.\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -41,14 +41,14 @@ } }, "outputs": [], - "source": "%matplotlib inline\nimport copy\nfrom pathlib import Path\n\nimport h5py\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom matplotlib.colors import ListedColormap\n\nfrom braggtrack.io import discover_operando_scans, sample_operando_root\nfrom braggtrack.segmentation import (\n extract_instance_table,\n fill_holes_binary,\n flag_outlier_frames,\n label_projection_by_intensity,\n otsu_floor_from_mip,\n otsu_threshold,\n relabel_sequential,\n remove_small_objects,\n segment_classical,\n smooth_thresholds,\n)\nfrom braggtrack.semantic import crop_spot_cube, make_multiview_encoder, orthogonal_mips\nfrom braggtrack.tracking import (\n GeometrySemanticCost,\n PositionShapeCost,\n TrackEvent,\n build_tracks,\n compute_tracking_metrics,\n tracks_to_table,\n)" + "source": "%matplotlib inline\nimport copy\nfrom pathlib import Path\n\nimport h5py\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom matplotlib.colors import ListedColormap\n\nfrom braggtrack.io import discover_operando_scans, sample_operando_root\nfrom braggtrack.segmentation import (\n extract_instance_table,\n fill_holes_binary,\n flag_outlier_frames,\n label_projection_by_intensity,\n merge_nearby_labels,\n otsu_floor_from_mip,\n otsu_threshold,\n relabel_sequential,\n remove_small_objects,\n segment_classical,\n smooth_thresholds,\n)\nfrom braggtrack.semantic import crop_spot_cube, make_multiview_encoder, orthogonal_mips\nfrom braggtrack.tracking import (\n GeometrySemanticCost,\n PositionShapeCost,\n TrackEvent,\n build_tracks,\n compute_tracking_metrics,\n tracks_to_table,\n)" }, { "cell_type": "markdown", "id": "1a48caf1", "metadata": {}, "source": [ - "## 1 \u2014 Discover bundled scans\n", + "## 1 — Discover bundled scans\n", "\n", "`sample_operando_root()` points at `data/sample_operando/` (shipped with the repo); `discover_operando_scans` walks any directory of `scanNNNN/` folders and returns `ScanFile(scan_name, path)` entries sorted by scan name." ] @@ -69,7 +69,7 @@ "source": [ "scans = discover_operando_scans(sample_operando_root())\n", "for s in scans:\n", - " print(f\" {s.scan_name} \u2192 {s.path.relative_to(Path.cwd()) if s.path.is_relative_to(Path.cwd()) else s.path}\")" + " print(f\" {s.scan_name} → {s.path.relative_to(Path.cwd()) if s.path.is_relative_to(Path.cwd()) else s.path}\")" ] }, { @@ -79,7 +79,7 @@ "source": [ "### A note on the H5 layout\n", "\n", - "`braggtrack.io.load_primary_volume` currently tries a fixed shortlist of NeXus paths (`/entry/data/data`, `/entry1/data/data`, \u2026). The bundled ESRF-ID03 files store their 3-D detector stack at `entry_0000/ESRF-ID03/pco_nf/data`, which isn't in that list. The CLI falls back to `_synth_volume_from_file` (deterministic Gaussian blobs) when this happens, which is why the CI acceptance runs on synthetic data.\n", + "`braggtrack.io.load_primary_volume` currently tries a fixed shortlist of NeXus paths (`/entry/data/data`, `/entry1/data/data`, …). The bundled ESRF-ID03 files store their 3-D detector stack at `entry_0000/ESRF-ID03/pco_nf/data`, which isn't in that list. The CLI falls back to `_synth_volume_from_file` (deterministic Gaussian blobs) when this happens, which is why the CI acceptance runs on synthetic data.\n", "\n", "For the demo we reach into the H5 directly so we see the **real** data. A tiny follow-up would be to add an auto-discovery step (pick the largest 3-D numeric dataset) to `load_primary_volume`." ] @@ -103,7 +103,7 @@ "cell_type": "markdown", "id": "d74a25a8", "metadata": {}, - "source": "## 2 \u2014 Week 2: classical segmentation\n\n`segment_classical` runs 3-D Gaussian blur \u2192 Laplacian \u2192 h-maxima seeds \u2192 seeded watershed.\n\n### Threshold stabilisation across scans\n\nEach scan produces its own Otsu threshold on the raw intensity histogram.\nIn theory these should be nearly identical for back-to-back operando acquisitions,\nbut minor intensity fluctuations (beam drift, detector warm-up, etc.) cause\nper-frame Otsu to jitter \u2014 and because everything downstream (foreground mask \u2192 seed\nfloor \u2192 watershed) is threshold-sensitive, small jitter produces wildly different\nspot counts.\n\n**Fix:** compute per-frame Otsu thresholds, then pass them through a\nrolling-median smoother (`smooth_thresholds`). The median suppresses isolated\noutliers (beam drops, detector flashes) while still tracking genuine long-term\ndrift. For 500+ frame sequences this runs in O(N\u00b7W) on scalar thresholds \u2014\nno need to pool raw volumes in memory.\n\nTwo further knobs that matter for real data:\n\n* `threshold` \u2014 **intensity-domain** foreground, now smoothed across scans. Controls the watershed mask.\n* `seed_peak_fraction` / `seed_response_percentile` \u2014 **LoG-response-domain** admissibility floor inside the foreground." + "source": "## 2 — Week 2: classical segmentation\n\n`segment_classical` runs 3-D Gaussian blur → Laplacian → h-maxima seeds → seeded watershed.\n\n### Threshold stabilisation across scans\n\nEach scan produces its own Otsu threshold on the raw intensity histogram.\nIn theory these should be nearly identical for back-to-back operando acquisitions,\nbut minor intensity fluctuations (beam drift, detector warm-up, etc.) cause\nper-frame Otsu to jitter — and because everything downstream (foreground mask → seed\nfloor → watershed) is threshold-sensitive, small jitter produces wildly different\nspot counts.\n\n**Fix:** compute per-frame Otsu thresholds, then pass them through a\nrolling-median smoother (`smooth_thresholds`). The median suppresses isolated\noutliers (beam drops, detector flashes) while still tracking genuine long-term\ndrift. For 500+ frame sequences this runs in O(N·W) on scalar thresholds —\nno need to pool raw volumes in memory.\n\nTwo further knobs that matter for real data:\n\n* `threshold` — **intensity-domain** foreground, now smoothed across scans. Controls the watershed mask.\n* `seed_peak_fraction` / `seed_response_percentile` — **LoG-response-domain** admissibility floor inside the foreground." }, { "cell_type": "code", @@ -118,7 +118,7 @@ } }, "outputs": [], - "source": "def segment(volume: np.ndarray, threshold: float) -> np.ndarray:\n \"\"\"Segment a single volume using a pre-computed (smoothed) threshold.\"\"\"\n res = segment_classical(\n volume,\n threshold=threshold,\n blur_passes=1,\n h_value=0.1,\n min_seed_separation=2,\n seed_peak_fraction=0.2,\n seed_response_percentile=99.95,\n )\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = relabel_sequential(np.where(binary, labels, 0))\n return labels\n\n# --- Load all volumes and compute smoothed thresholds ---\nall_volumes: list[np.ndarray] = []\nraw_thresholds: list[float] = []\nfor s in scans:\n v = load_3d_volume(s.path)\n all_volumes.append(v)\n raw_thresholds.append(otsu_threshold(v.ravel()))\n\nsmoothed = smooth_thresholds(raw_thresholds, window=5)\noutliers = flag_outlier_frames(raw_thresholds, window=5)\n\nprint(\"Per-frame Otsu thresholds vs smoothed:\")\nfor s, raw_t, sm_t, is_out in zip(scans, raw_thresholds, smoothed, outliers):\n tag = \" ** OUTLIER **\" if is_out else \"\"\n print(f\" {s.scan_name}: raw={raw_t:.1f} smoothed={sm_t:.1f}{tag}\")\n\n# Segment the first volume for the single-scan QC plots below.\nvol0 = all_volumes[0]\nlabels0 = segment(vol0, threshold=float(smoothed[0]))\nprint(f\"\\nscan0001: smoothed threshold={smoothed[0]:.1f}, {int(labels0.max())} instances\")" + "source": "MERGE_DISTANCE = 15 # merge adjacent labels whose centroids are within this many voxels\n\ndef segment(volume: np.ndarray, threshold: float) -> np.ndarray:\n \"\"\"Segment a single volume using a pre-computed (smoothed) threshold.\"\"\"\n res = segment_classical(\n volume,\n threshold=threshold,\n blur_passes=1,\n h_value=0.1,\n min_seed_separation=2,\n seed_peak_fraction=0.2,\n seed_response_percentile=99.95,\n )\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = np.where(binary, labels, 0)\n labels = merge_nearby_labels(labels, volume, max_centroid_distance=MERGE_DISTANCE)\n labels = relabel_sequential(labels)\n return labels\n\n# --- Load all volumes and compute smoothed thresholds ---\nall_volumes: list[np.ndarray] = []\nraw_thresholds: list[float] = []\nfor s in scans:\n v = load_3d_volume(s.path)\n all_volumes.append(v)\n raw_thresholds.append(otsu_threshold(v.ravel()))\n\nsmoothed = smooth_thresholds(raw_thresholds, window=5)\noutliers = flag_outlier_frames(raw_thresholds, window=5)\n\nprint(\"Per-frame Otsu thresholds vs smoothed:\")\nfor s, raw_t, sm_t, is_out in zip(scans, raw_thresholds, smoothed, outliers):\n tag = \" ** OUTLIER **\" if is_out else \"\"\n print(f\" {s.scan_name}: raw={raw_t:.1f} smoothed={sm_t:.1f}{tag}\")\n\n# Segment the first volume for the single-scan QC plots below.\nvol0 = all_volumes[0]\nlabels0 = segment(vol0, threshold=float(smoothed[0]))\nprint(f\"\\nscan0001: smoothed threshold={smoothed[0]:.1f}, {int(labels0.max())} instances\")" }, { "cell_type": "code", @@ -134,7 +134,7 @@ }, "outputs": [], "source": [ - "# Proper label projection: pick the label of the brightest voxel along \u03bc,\n", + "# Proper label projection: pick the label of the brightest voxel along μ,\n", "# hide rays whose max-IP intensity is below a 2-D Otsu floor so diffuse\n", "# background doesn't smear watershed basins across the frame.\n", "proj = vol0.max(axis=0)\n", @@ -144,7 +144,7 @@ "\n", "fig, axes = plt.subplots(1, 3, figsize=(13, 4))\n", "axes[0].imshow(proj, cmap=\"gray\", vmin=vmin, vmax=vmax)\n", - "axes[0].set_title(\"Volume \u2014 max-IP along \u03bc\")\n", + "axes[0].set_title(\"Volume — max-IP along μ\")\n", "axes[1].imshow(proj_labels, cmap=\"tab20\", interpolation=\"nearest\")\n", "axes[1].set_title(f\"Label map (argmax-by-intensity, MIP floor={mip_floor:.0f})\")\n", "\n", @@ -166,7 +166,7 @@ "source": [ "### Instance feature table\n", "\n", - "`extract_instance_table` returns one row per labelled component with intensity-weighted centroid, bbox, voxel count, covariance, and eigenvalues. The centroid is named after the reciprocal-space axis convention (**\u03bc \u2192 z**, **d \u2192 y**, **\u03c7 \u2192 x**)." + "`extract_instance_table` returns one row per labelled component with intensity-weighted centroid, bbox, voxel count, covariance, and eigenvalues. The centroid is named after the reciprocal-space axis convention (**μ → z**, **d → y**, **χ → x**)." ] }, { @@ -214,13 +214,13 @@ { "cell_type": "markdown", "id": "8456b5d6", - "source": "### Segmented masks \u2014 tri-axis projection\n\nEach Bragg spot is a 3-D blob in reciprocal space. To visualise the full mask we project the label volume along all three physical axes (\u03bc, \u03c7, d), always picking the label of the **brightest** voxel on each ray. The three views give complementary information about spot shape and overlap.", + "source": "### Segmented masks — tri-axis projection\n\nEach Bragg spot is a 3-D blob in reciprocal space. To visualise the full mask we project the label volume along all three physical axes (μ, χ, d), always picking the label of the **brightest** voxel on each ray. The three views give complementary information about spot shape and overlap.", "metadata": {} }, { "cell_type": "code", "id": "c849a191", - "source": "# Build a perceptually distinct colormap for label overlays.\nn_labels = max(int(l.max()) for l in all_labels) + 1\nrng_cm = np.random.RandomState(42)\nlabel_colors = np.zeros((n_labels, 4))\nlabel_colors[0] = [0, 0, 0, 0] # background transparent\nfor i in range(1, n_labels):\n label_colors[i] = [*rng_cm.uniform(0.2, 0.95, 3), 0.65]\nlabel_cmap = ListedColormap(label_colors)\n\naxis_info = [\n (0, \"along mu (d x chi)\", \"chi\", \"d\"),\n (1, \"along chi (mu x d)\", \"d\", \"mu\"),\n (2, \"along d (mu x chi)\", \"chi\", \"mu\"),\n]\n\nfig, axes = plt.subplots(3, 3, figsize=(14, 13))\nfor col, (s, v, l) in enumerate(zip(scans, all_volumes, all_labels)):\n for row, (ax_id, title, xlabel, ylabel) in enumerate(axis_info):\n ax = axes[row, col]\n mip = v.max(axis=ax_id)\n floor = otsu_floor_from_mip(v, axis=ax_id)\n proj_l = label_projection_by_intensity(v, l, axis=ax_id, mip_floor=floor)\n\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n mask_overlay = np.ma.masked_where(proj_l == 0, proj_l)\n ax.imshow(mask_overlay, cmap=label_cmap, interpolation=\"nearest\",\n vmin=0, vmax=n_labels - 1)\n if row == 0:\n n_spots = len(extract_instance_table(l, v))\n ax.set_title(f\"{s.scan_name} \u2014 {n_spots} spots\", fontsize=11)\n if col == 0:\n ax.set_ylabel(f\"MIP {title}\\n{ylabel}\", fontsize=9)\n ax.tick_params(labelsize=7)\n\nplt.suptitle(\"Segmented masks \u2014 tri-axis label projection (intensity-argmax)\", y=1.01, fontsize=13)\nplt.tight_layout()\nplt.show()", + "source": "# Build a perceptually distinct colormap for label overlays.\nn_labels = max(int(l.max()) for l in all_labels) + 1\nrng_cm = np.random.RandomState(42)\nlabel_colors = np.zeros((n_labels, 4))\nlabel_colors[0] = [0, 0, 0, 0] # background transparent\nfor i in range(1, n_labels):\n label_colors[i] = [*rng_cm.uniform(0.2, 0.95, 3), 0.65]\nlabel_cmap = ListedColormap(label_colors)\n\naxis_info = [\n (0, \"along mu (d x chi)\", \"chi\", \"d\"),\n (1, \"along chi (mu x d)\", \"d\", \"mu\"),\n (2, \"along d (mu x chi)\", \"chi\", \"mu\"),\n]\n\nfig, axes = plt.subplots(3, 3, figsize=(14, 13))\nfor col, (s, v, l) in enumerate(zip(scans, all_volumes, all_labels)):\n for row, (ax_id, title, xlabel, ylabel) in enumerate(axis_info):\n ax = axes[row, col]\n mip = v.max(axis=ax_id)\n floor = otsu_floor_from_mip(v, axis=ax_id)\n proj_l = label_projection_by_intensity(v, l, axis=ax_id, mip_floor=floor)\n\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n mask_overlay = np.ma.masked_where(proj_l == 0, proj_l)\n ax.imshow(mask_overlay, cmap=label_cmap, interpolation=\"nearest\",\n vmin=0, vmax=n_labels - 1)\n if row == 0:\n n_spots = len(extract_instance_table(l, v))\n ax.set_title(f\"{s.scan_name} — {n_spots} spots\", fontsize=11)\n if col == 0:\n ax.set_ylabel(f\"MIP {title}\\n{ylabel}\", fontsize=9)\n ax.tick_params(labelsize=7)\n\nplt.suptitle(\"Segmented masks — tri-axis label projection (intensity-argmax)\", y=1.01, fontsize=13)\nplt.tight_layout()\nplt.show()", "metadata": {}, "execution_count": null, "outputs": [] @@ -230,7 +230,7 @@ "id": "6fc3abf4", "metadata": {}, "source": [ - "## 3 \u2014 Week 3: physics-only tracking\n", + "## 3 — Week 3: physics-only tracking\n", "\n", "`PositionShapeCost` combines squared centroid distance with squared eigenvalue distance; `build_tracks` runs pairwise Hungarian assignments and stitches them into a NetworkX `DiGraph` with `TrackEvent` annotations." ] @@ -286,7 +286,7 @@ " axes[1].plot(scan_ids, [r[\"centroid_chi\"] for r in obs], \"o-\")\n", " axes[2].plot(scan_ids, [r[\"centroid_d\"] for r in obs], \"o-\")\n", "\n", - "for ax, lbl in zip(axes, [\"centroid_mu (\u03bc)\", \"centroid_chi (\u03c7)\", \"centroid_d (d)\"]):\n", + "for ax, lbl in zip(axes, [\"centroid_mu (μ)\", \"centroid_chi (χ)\", \"centroid_d (d)\"]):\n", " ax.set_xlabel(\"scan index\")\n", " ax.set_ylabel(lbl)\n", " ax.set_xticks(range(len(scans)))\n", @@ -326,7 +326,7 @@ "cell_type": "markdown", "id": "f88c035a", "metadata": {}, - "source": "### Reading these plots honestly\n\n* **Spot counts are stabilised across scans** thanks to the rolling-median threshold smoother. Without it, independent per-scan Otsu thresholds produce wildly different counts (11 / 22 / 36) on scans taken seconds apart. The smoothed thresholds converge to the local median, giving consistent segmentation.\n* **Markers don't always sit on the brightest MIP pixel.** The trajectory points are the intensity-weighted 3-D centroid of an instance, not the argmax of the 2-D projection. For anisotropic spots the two can be a few voxels apart.\n* **Identity can drift.** Hungarian assignment minimises pairwise cost \u2014 it doesn't know whether `T3` in scan 2 *is* the same crystallite as `T3` in scan 1 unless the geometry cost says so. That's the motivation for the semantic term in Section 5." + "source": "### Reading these plots honestly\n\n* **Spot counts are stabilised across scans** thanks to the rolling-median threshold smoother. Without it, independent per-scan Otsu thresholds produce wildly different counts (11 / 22 / 36) on scans taken seconds apart. The smoothed thresholds converge to the local median, giving consistent segmentation.\n* **Markers don't always sit on the brightest MIP pixel.** The trajectory points are the intensity-weighted 3-D centroid of an instance, not the argmax of the 2-D projection. For anisotropic spots the two can be a few voxels apart.\n* **Identity can drift.** Hungarian assignment minimises pairwise cost — it doesn't know whether `T3` in scan 2 *is* the same crystallite as `T3` in scan 1 unless the geometry cost says so. That's the motivation for the semantic term in Section 5." }, { "cell_type": "code", @@ -342,7 +342,7 @@ }, "outputs": [], "source": [ - "# Per-scan gallery: \u03bc-max-IP + label overlay + track-observation markers,\n", + "# Per-scan gallery: μ-max-IP + label overlay + track-observation markers,\n", "# with *pooled* vmin/vmax across scans so intensity changes aren't a display artefact.\n", "mips = [v.max(axis=0) for v in all_volumes]\n", "vmin = float(np.percentile(np.stack(mips), 1))\n", @@ -383,9 +383,9 @@ "id": "94bdcacb", "metadata": {}, "source": [ - "## 4 \u2014 Week 4: multi-view MIPs\n", + "## 4 — Week 4: multi-view MIPs\n", "\n", - "For each spot, crop a padded sub-volume, zero out voxels that don't belong to the instance, and take three maximum-intensity projections \u2014 one along each physical axis." + "For each spot, crop a padded sub-volume, zero out voxels that don't belong to the instance, and take three maximum-intensity projections — one along each physical axis." ] }, { @@ -411,7 +411,7 @@ "for ax, im, title in zip(\n", " axes,\n", " [mip_mu, mip_chi, mip_d],\n", - " [\"MIP along \u03bc (d \u00d7 \u03c7)\", \"MIP along \u03c7 (\u03bc \u00d7 d)\", \"MIP along d (\u03bc \u00d7 \u03c7)\"],\n", + " [\"MIP along μ (d × χ)\", \"MIP along χ (μ × d)\", \"MIP along d (μ × χ)\"],\n", "):\n", " ax.imshow(im, cmap=\"magma\")\n", " ax.set_title(title)\n", @@ -428,7 +428,7 @@ "source": [ "### Mock multi-view embeddings\n", "\n", - "`make_multiview_encoder(\"mock\")` hashes each MIP's bytes into a deterministic 384-d unit vector. The real Dinov2 backend is drop-in (`\"torch\"` or `\"auto\"`) but needs `torch`+`transformers` weights. The mock is **deterministic per cell geometry** but not physically meaningful \u2014 it's useful for plumbing, CI, and seeing the cost-function interface react to real data without GPU dependencies." + "`make_multiview_encoder(\"mock\")` hashes each MIP's bytes into a deterministic 384-d unit vector. The real Dinov2 backend is drop-in (`\"torch\"` or `\"auto\"`) but needs `torch`+`transformers` weights. The mock is **deterministic per cell geometry** but not physically meaningful — it's useful for plumbing, CI, and seeing the cost-function interface react to real data without GPU dependencies." ] }, { @@ -459,7 +459,7 @@ " embed_spots(v, l, f) for v, l, f in zip(all_volumes, all_labels, all_features)\n", "]\n", "for s, emb in zip(scans, embeddings_per_scan):\n", - " print(f\"{s.scan_name}: {emb.shape[0]} spots \u00d7 {emb.shape[1]}-d embeddings\")" + " print(f\"{s.scan_name}: {emb.shape[0]} spots × {emb.shape[1]}-d embeddings\")" ] }, { @@ -500,7 +500,7 @@ { "cell_type": "code", "id": "58b1570e", - "source": "# MIP gallery: top-K brightest spots per scan, 3 MIP views each.\nTOP_K = 5\n\nfig, axes = plt.subplots(\n len(scans), TOP_K * 3, figsize=(TOP_K * 4.5, len(scans) * 2.2),\n gridspec_kw={\"wspace\": 0.05, \"hspace\": 0.35},\n)\n\nfor scan_idx, (s, v, l, feats) in enumerate(zip(scans, all_volumes, all_labels, all_features)):\n ranked = sorted(feats, key=lambda r: r[\"integrated_intensity\"], reverse=True)[:TOP_K]\n for spot_j, row in enumerate(ranked):\n masked, _ = crop_spot_cube(v, l, int(row[\"label\"]), row, margin=3)\n mips = orthogonal_mips(masked)\n for k, (im, view) in enumerate(zip(mips, [\"mu\", \"chi\", \"d\"])):\n ax = axes[scan_idx, spot_j * 3 + k]\n ax.imshow(im, cmap=\"magma\", aspect=\"auto\")\n ax.set_xticks([]); ax.set_yticks([])\n if scan_idx == 0:\n ax.set_title(f\"L{int(row['label'])}\\n{view}\", fontsize=8)\n else:\n ax.set_title(f\"L{int(row['label'])} {view}\", fontsize=8)\n # Label the row\n axes[scan_idx, 0].set_ylabel(s.scan_name, fontsize=10, rotation=0, labelpad=50, va=\"center\")\n\nplt.suptitle(f\"Semantic MIP fingerprints \u2014 top {TOP_K} spots per scan\", fontsize=13, y=1.02)\nplt.show()", + "source": "# MIP gallery: top-K brightest spots per scan, 3 MIP views each.\nTOP_K = 5\n\nfig, axes = plt.subplots(\n len(scans), TOP_K * 3, figsize=(TOP_K * 4.5, len(scans) * 2.2),\n gridspec_kw={\"wspace\": 0.05, \"hspace\": 0.35},\n)\n\nfor scan_idx, (s, v, l, feats) in enumerate(zip(scans, all_volumes, all_labels, all_features)):\n ranked = sorted(feats, key=lambda r: r[\"integrated_intensity\"], reverse=True)[:TOP_K]\n for spot_j, row in enumerate(ranked):\n masked, _ = crop_spot_cube(v, l, int(row[\"label\"]), row, margin=3)\n mips = orthogonal_mips(masked)\n for k, (im, view) in enumerate(zip(mips, [\"mu\", \"chi\", \"d\"])):\n ax = axes[scan_idx, spot_j * 3 + k]\n ax.imshow(im, cmap=\"magma\", aspect=\"auto\")\n ax.set_xticks([]); ax.set_yticks([])\n if scan_idx == 0:\n ax.set_title(f\"L{int(row['label'])}\\n{view}\", fontsize=8)\n else:\n ax.set_title(f\"L{int(row['label'])} {view}\", fontsize=8)\n # Label the row\n axes[scan_idx, 0].set_ylabel(s.scan_name, fontsize=10, rotation=0, labelpad=50, va=\"center\")\n\nplt.suptitle(f\"Semantic MIP fingerprints — top {TOP_K} spots per scan\", fontsize=13, y=1.02)\nplt.show()", "metadata": {}, "execution_count": null, "outputs": [] @@ -508,13 +508,13 @@ { "cell_type": "markdown", "id": "f77a0c5f", - "source": "### Embedding space \u2014 PCA projection\n\nThe 384-d embedding vectors projected onto their first two principal components, coloured by scan. In a well-trained encoder, matched spots across scans would cluster together; with the mock backend the points are scattered (hash-based, no semantic structure).", + "source": "### Embedding space — PCA projection\n\nThe 384-d embedding vectors projected onto their first two principal components, coloured by scan. In a well-trained encoder, matched spots across scans would cluster together; with the mock backend the points are scattered (hash-based, no semantic structure).", "metadata": {} }, { "cell_type": "code", "id": "49abcaba", - "source": "# PCA of all embeddings, coloured by scan and sized by integrated intensity.\nall_emb = np.concatenate(embeddings_per_scan, axis=0)\nscan_ids = np.concatenate([np.full(e.shape[0], i) for i, e in enumerate(embeddings_per_scan)])\nintensities = np.concatenate([\n np.array([r[\"integrated_intensity\"] for r in feats])\n for feats in all_features\n])\n\n# Centre and project onto top-2 PCs.\nmu = all_emb.mean(axis=0, keepdims=True)\ncentred = all_emb - mu\nU, S, Vt = np.linalg.svd(centred, full_matrices=False)\npc = centred @ Vt[:2].T\nvar_explained = S[:2] ** 2 / (S ** 2).sum() * 100\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n# Left: PCA coloured by scan\nscan_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\"]\nfor i, s in enumerate(scans):\n mask = scan_ids == i\n ax1.scatter(pc[mask, 0], pc[mask, 1], c=scan_colors[i], label=s.scan_name,\n s=40 + 120 * (intensities[mask] / intensities.max()), alpha=0.7,\n edgecolors=\"k\", linewidths=0.3)\nax1.set_xlabel(f\"PC1 ({var_explained[0]:.1f}% var)\")\nax1.set_ylabel(f\"PC2 ({var_explained[1]:.1f}% var)\")\nax1.legend(fontsize=9)\nax1.set_title(\"Embedding space (PCA) \u2014 by scan\")\n\n# Right: cross-scan similarity matrix (all scans)\nsim_all = all_emb @ all_emb.T\nim = ax2.imshow(sim_all, cmap=\"RdBu_r\", vmin=-1, vmax=1)\n# Draw scan boundaries\ncum = 0\nfor emb in embeddings_per_scan:\n cum += emb.shape[0]\n ax2.axhline(cum - 0.5, color=\"white\", lw=1)\n ax2.axvline(cum - 0.5, color=\"white\", lw=1)\nax2.set_title(\"Pairwise cosine similarity (all spots)\")\nax2.set_xlabel(\"spot index (all scans)\")\nax2.set_ylabel(\"spot index (all scans)\")\nplt.colorbar(im, ax=ax2, shrink=0.8)\n\nplt.tight_layout()\nplt.show()", + "source": "# PCA of all embeddings, coloured by scan and sized by integrated intensity.\nall_emb = np.concatenate(embeddings_per_scan, axis=0)\nscan_ids = np.concatenate([np.full(e.shape[0], i) for i, e in enumerate(embeddings_per_scan)])\nintensities = np.concatenate([\n np.array([r[\"integrated_intensity\"] for r in feats])\n for feats in all_features\n])\n\n# Centre and project onto top-2 PCs.\nmu = all_emb.mean(axis=0, keepdims=True)\ncentred = all_emb - mu\nU, S, Vt = np.linalg.svd(centred, full_matrices=False)\npc = centred @ Vt[:2].T\nvar_explained = S[:2] ** 2 / (S ** 2).sum() * 100\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))\n\n# Left: PCA coloured by scan\nscan_colors = [\"#e41a1c\", \"#377eb8\", \"#4daf4a\"]\nfor i, s in enumerate(scans):\n mask = scan_ids == i\n ax1.scatter(pc[mask, 0], pc[mask, 1], c=scan_colors[i], label=s.scan_name,\n s=40 + 120 * (intensities[mask] / intensities.max()), alpha=0.7,\n edgecolors=\"k\", linewidths=0.3)\nax1.set_xlabel(f\"PC1 ({var_explained[0]:.1f}% var)\")\nax1.set_ylabel(f\"PC2 ({var_explained[1]:.1f}% var)\")\nax1.legend(fontsize=9)\nax1.set_title(\"Embedding space (PCA) — by scan\")\n\n# Right: cross-scan similarity matrix (all scans)\nsim_all = all_emb @ all_emb.T\nim = ax2.imshow(sim_all, cmap=\"RdBu_r\", vmin=-1, vmax=1)\n# Draw scan boundaries\ncum = 0\nfor emb in embeddings_per_scan:\n cum += emb.shape[0]\n ax2.axhline(cum - 0.5, color=\"white\", lw=1)\n ax2.axvline(cum - 0.5, color=\"white\", lw=1)\nax2.set_title(\"Pairwise cosine similarity (all spots)\")\nax2.set_xlabel(\"spot index (all scans)\")\nax2.set_ylabel(\"spot index (all scans)\")\nplt.colorbar(im, ax=ax2, shrink=0.8)\n\nplt.tight_layout()\nplt.show()", "metadata": {}, "execution_count": null, "outputs": [] @@ -523,7 +523,7 @@ "cell_type": "markdown", "id": "35122be4", "metadata": {}, - "source": "## 5 \u2014 Geometry + semantic tracking\n\n`GeometrySemanticCost(geometry, cost_alpha, cost_beta)` wraps the geometry cost and adds `\u03b2 \u00b7 (1 \u2212 cos(f_i, f_j))`. With the **mock** encoder the embeddings are hash-based pseudo-vectors with no real semantic structure, so adding them to the cost doesn't improve (or noticeably worsen) tracking on this well-separated data. The metric tables below will show identical fragmentation for physics-only vs. geometry+semantic \u2014 that's expected with mock embeddings on easy data.\n\nThe payoff of the semantic term shows up in Section 7's synthetic crossing, where geometry alone makes the wrong assignment and the embedding corrects it. With real DINOv2 features on ambiguous real data, the same knob reduces ID swaps." + "source": "## 5 — Geometry + semantic tracking\n\n`GeometrySemanticCost(geometry, cost_alpha, cost_beta)` wraps the geometry cost and adds `β · (1 − cos(f_i, f_j))`. With the **mock** encoder the embeddings are hash-based pseudo-vectors with no real semantic structure, so adding them to the cost doesn't improve (or noticeably worsen) tracking on this well-separated data. The metric tables below will show identical fragmentation for physics-only vs. geometry+semantic — that's expected with mock embeddings on easy data.\n\nThe payoff of the semantic term shows up in Section 7's synthetic crossing, where geometry alone makes the wrong assignment and the embedding corrects it. With real DINOv2 features on ambiguous real data, the same knob reduces ID swaps." }, { "cell_type": "code", @@ -551,7 +551,7 @@ "\n", "comparison = pd.DataFrame({\n", " \"physics-only\": metrics_geo,\n", - " \"\u03b1=1, \u03b2=0.5 (mock)\": metrics_sem,\n", + " \"α=1, β=0.5 (mock)\": metrics_sem,\n", "})\n", "comparison" ] @@ -560,7 +560,7 @@ "cell_type": "markdown", "id": "b0aedc86", "metadata": {}, - "source": "## 6 \u2014 \u03b1/\u03b2 ablation\n\nSweep `cost_beta` with `cost_alpha=1` fixed. With the mock encoder on well-separated data, metrics stay flat (the hash-based embeddings add noise that doesn't shift the Hungarian assignment). With real DINOv2 embeddings on ambiguous data, increasing \u03b2 trades geometry confidence for identity preservation \u2014 you'd see fragmentation drop and ID-switch rate change." + "source": "## 6 — α/β ablation\n\nSweep `cost_beta` with `cost_alpha=1` fixed. With the mock encoder on well-separated data, metrics stay flat (the hash-based embeddings add noise that doesn't shift the Hungarian assignment). With real DINOv2 embeddings on ambiguous data, increasing β trades geometry confidence for identity preservation — you'd see fragmentation drop and ID-switch rate change." }, { "cell_type": "code", @@ -613,7 +613,7 @@ "ax2.set_ylabel(\"total tracks\", color=\"tab:red\")\n", "ax2.tick_params(axis=\"y\", labelcolor=\"tab:red\")\n", "\n", - "plt.title(\"\u03b1=1 fixed; \u03b2 sweep (mock embeddings)\")\n", + "plt.title(\"α=1 fixed; β sweep (mock embeddings)\")\n", "plt.tight_layout()\n", "plt.show()" ] @@ -623,7 +623,7 @@ "id": "89db0aea", "metadata": {}, "source": [ - "## 7 \u2014 Near-overlap: when semantics actually help\n", + "## 7 — Near-overlap: when semantics actually help\n", "\n", "The generic crossing scenarios never quite force a mistake because the solver sees the correct assignment is still the nearest. To make the failure concrete we construct a 2-spot, 2-scan case where the **wrong** pairing has lower geometric cost than the right one. Without embeddings the Hungarian solver happily swaps the two tracks. Giving each physical identity a distinct unit-vector embedding restores the correct assignment." ] @@ -648,9 +648,9 @@ "e_red = np.array([1.0, 0.0], dtype=np.float64)\n", "e_blue = np.array([0.0, 1.0], dtype=np.float64)\n", "\n", - "# The red spot moves from \u03bc=0.0 to \u03bc=0.3; the blue spot barely moves (0.2 \u2192 0.1).\n", - "# In scan-t+1, the \"blue_next\" at \u03bc=0.1 is *closer* to the red_prev at 0.0 than to\n", - "# the blue_prev at 0.2, so geometry-only matches {red\u2192blue, blue\u2192red}: swapped.\n", + "# The red spot moves from μ=0.0 to μ=0.3; the blue spot barely moves (0.2 → 0.1).\n", + "# In scan-t+1, the \"blue_next\" at μ=0.1 is *closer* to the red_prev at 0.0 than to\n", + "# the blue_prev at 0.2, so geometry-only matches {red→blue, blue→red}: swapped.\n", "def _row(mu: float, embedding: np.ndarray | None = None) -> dict:\n", " r = {\n", " \"label\": 1, \"voxel_count\": 10, \"integrated_intensity\": 100.0,\n", @@ -668,15 +668,15 @@ "geo_strict = PositionShapeCost(position_weight=1.0, shape_weight=0.0, gate_mu=20, gate_chi=20, gate_d=20)\n", "matches_geo, _, _ = associate_frames(scan_t, scan_t1, geo_strict)\n", "\n", - "# With embeddings: attach identity vectors and compose with \u03b2=1.\n", + "# With embeddings: attach identity vectors and compose with β=1.\n", "scan_t_sem = [_row(0.0, e_red), _row(0.2, e_blue)]\n", "scan_t1_sem = [_row(0.1, e_blue), _row(0.3, e_red)]\n", "cost_sem = GeometrySemanticCost(geo_strict, cost_alpha=1.0, cost_beta=1.0)\n", "matches_sem, _, _ = associate_frames(scan_t_sem, scan_t1_sem, cost_sem)\n", "\n", - "print(\"Truth: red (idx 0 @ t) \u2192 idx 1 (\u03bc=0.3) @ t+1\")\n", - "print(\" blue (idx 1 @ t) \u2192 idx 0 (\u03bc=0.1) @ t+1\")\n", - "print(f\"Geometry only: {dict(matches_geo)} (wrong \u2014 identities swapped)\")\n", + "print(\"Truth: red (idx 0 @ t) → idx 1 (μ=0.3) @ t+1\")\n", + "print(\" blue (idx 1 @ t) → idx 0 (μ=0.1) @ t+1\")\n", + "print(f\"Geometry only: {dict(matches_geo)} (wrong — identities swapped)\")\n", "print(f\"Geometry + semantic: {dict(matches_sem)} (correct)\")" ] }, @@ -711,8 +711,8 @@ " ax.set_title(title)\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)\n", - "_plot_pairing(axes[0], \"Geometry only \u2014 identities swapped\", matches_geo)\n", - "_plot_pairing(axes[1], \"Geometry + semantic \u2014 identities preserved\", matches_sem)\n", + "_plot_pairing(axes[0], \"Geometry only — identities swapped\", matches_geo)\n", + "_plot_pairing(axes[1], \"Geometry + semantic — identities preserved\", matches_sem)\n", "# Identity key\n", "axes[0].text(0.02, 0.28, \"red = identity 1\", color=\"tab:red\", fontsize=9)\n", "axes[0].text(0.02, 0.26, \"blue = identity 2\", color=\"tab:blue\", fontsize=9)\n", @@ -725,9 +725,9 @@ "id": "7e605de2", "metadata": {}, "source": [ - "## 8 \u2014 The same pipeline from the command line\n", + "## 8 — The same pipeline from the command line\n", "\n", - "Every library call above is exposed as a CLI \u2014 feed a dataset root and an output directory, get reproducible artifacts under `artifacts/`.\n", + "Every library call above is exposed as a CLI — feed a dataset root and an output directory, get reproducible artifacts under `artifacts/`.\n", "\n", "```bash\n", "# 1. Segment every scan under data/sample_operando/\n", @@ -736,13 +736,13 @@ "# 2. Compute mock multi-view embeddings\n", "python -m braggtrack.cli.embed_dataset --segdir artifacts/week2 --outdir artifacts/week4 --backend mock\n", "\n", - "# 3. Track with geometry + semantic cost (\u03b2=0.5)\n", + "# 3. Track with geometry + semantic cost (β=0.5)\n", "python -m braggtrack.cli.track_dataset artifacts/week2 \\\n", " --outdir artifacts/week3 \\\n", " --embedding-dir artifacts/week4 \\\n", " --cost-alpha 1.0 --cost-beta 0.5\n", "\n", - "# 4. Ablate \u03b1/\u03b2 and write a JSON report\n", + "# 4. Ablate α/β and write a JSON report\n", "python scripts/ablation_week4.py \\\n", " --indir artifacts/week2 \\\n", " --embedding-dir artifacts/week4 \\\n", @@ -776,4 +776,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From 756de34db77336f7442b9f858aae85159aaacf10 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 14:49:23 +0000 Subject: [PATCH 2/8] feat: add DINO-based segmentation backend (--method dino) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DINOv3 patch-level features → PCA → HDBSCAN clustering → 3D slice stitching via union-find, replacing hand-tuned intensity-domain parameters with learned feature-space representations that generalise across beamlines and detectors. - Add PatchFeatureEncoder protocol + MockPatchEncoder + TorchDinoPatchEncoder - Add segment_dino() pipeline with slice extraction, clustering, upsampling, 3D stitching, and Otsu foreground masking - Wire --method classical|dino flag and DINO-specific CLI args - Add scikit-learn>=1.3 to core dependencies (for HDBSCAN) - Add 16 tests covering mock encoder, clustering, stitching, and end-to-end https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/cli/segment_dataset.py | 62 +++++- braggtrack/segmentation/__init__.py | 3 + braggtrack/segmentation/dino_segment.py | 251 ++++++++++++++++++++++++ braggtrack/semantic/__init__.py | 10 +- braggtrack/semantic/dino.py | 109 +++++++++- pyproject.toml | 1 + tests/test_segmentation_dino.py | 153 +++++++++++++++ 7 files changed, 573 insertions(+), 16 deletions(-) create mode 100644 braggtrack/segmentation/dino_segment.py create mode 100644 tests/test_segmentation_dino.py diff --git a/braggtrack/cli/segment_dataset.py b/braggtrack/cli/segment_dataset.py index a90d08f..60a292f 100644 --- a/braggtrack/cli/segment_dataset.py +++ b/braggtrack/cli/segment_dataset.py @@ -1,4 +1,4 @@ -"""Run classical segmentation over discovered scan files and write artifacts.""" +"""Run segmentation over discovered scan files and write artifacts.""" from __future__ import annotations @@ -23,6 +23,7 @@ relabel_sequential, remove_small_objects, segment_classical, + segment_dino, ) @@ -35,6 +36,12 @@ def build_parser() -> argparse.ArgumentParser: help="Dataset root with scan folders (default: data/sample_operando if present, else .)", ) parser.add_argument("--outdir", default="artifacts/week2", help="Output artifact directory") + parser.add_argument( + "--method", + choices=["classical", "dino"], + default="classical", + help="Segmentation method: classical (LoG + watershed) or dino (DINOv3 features + HDBSCAN)", + ) parser.add_argument("--blur-passes", type=int, default=1) parser.add_argument("--seed-separation", type=int, default=2) parser.add_argument("--h-value", type=float, default=0.1) @@ -63,6 +70,27 @@ def build_parser() -> argparse.ArgumentParser: default=15.0, help="Merge adjacent labels whose centroids are within this many voxels (0 disables)", ) + # DINO-specific arguments + parser.add_argument( + "--dino-backend", + choices=["auto", "mock", "torch"], + default=None, + help="DINO backend (default: auto-detect torch, fall back to mock)", + ) + parser.add_argument( + "--dino-model", + default="facebook/dinov3-vitb16-pretrain-lvd1689m", + help="HuggingFace model ID for the DINO torch backend", + ) + parser.add_argument("--dino-pca-components", type=int, default=16, help="PCA components for DINO feature reduction") + parser.add_argument("--dino-min-cluster-size", type=int, default=3, help="HDBSCAN min_cluster_size") + parser.add_argument("--dino-min-samples", type=int, default=2, help="HDBSCAN min_samples") + parser.add_argument( + "--dino-min-overlap", + type=float, + default=0.3, + help="Min overlap fraction for 3D slice stitching", + ) return parser @@ -134,15 +162,28 @@ def main() -> int: raw_threshold = otsu_threshold(volume.ravel()) threshold = raw_threshold * float(args.threshold_fraction) - result = segment_classical( - volume, - threshold=threshold, - blur_passes=max(1, args.blur_passes), - h_value=float(args.h_value), - min_seed_separation=max(1, args.seed_separation), - seed_peak_fraction=float(args.seed_peak_fraction), - seed_response_percentile=float(args.seed_response_percentile), - ) + + if args.method == "dino": + result = segment_dino( + volume, + backend=args.dino_backend, + model_name=args.dino_model, + n_components_pca=args.dino_pca_components, + min_cluster_size=args.dino_min_cluster_size, + min_samples=args.dino_min_samples, + threshold_fraction=float(args.threshold_fraction), + min_overlap_fraction=args.dino_min_overlap, + ) + else: + result = segment_classical( + volume, + threshold=threshold, + blur_passes=max(1, args.blur_passes), + h_value=float(args.h_value), + min_seed_separation=max(1, args.seed_separation), + seed_peak_fraction=float(args.seed_peak_fraction), + seed_response_percentile=float(args.seed_response_percentile), + ) labels = remove_small_objects(result.labeled_volume, min_size=max(1, args.min_size)) binary = labels > 0 @@ -161,6 +202,7 @@ def main() -> int: "scan": scan.scan_name, "file": str(scan.path), "source": source, + "method": args.method, "threshold": raw_threshold, "threshold_fraction": args.threshold_fraction, "effective_threshold": threshold, diff --git a/braggtrack/segmentation/__init__.py b/braggtrack/segmentation/__init__.py index a62c731..7c54b98 100644 --- a/braggtrack/segmentation/__init__.py +++ b/braggtrack/segmentation/__init__.py @@ -9,6 +9,7 @@ segment_classical, watershed_from_seeds, ) +from .dino_segment import DinoSegmentationResult, segment_dino from .features import extract_instance_table from .otsu import flag_outlier_frames, otsu_threshold, smooth_thresholds from .pipeline import SegmentationResult, connected_components_3d, segment_volume @@ -17,6 +18,7 @@ __all__ = [ "ClassicalSegmentationResult", + "DinoSegmentationResult", "SegmentationResult", "connected_components_3d", "extract_instance_table", @@ -34,6 +36,7 @@ "relabel_sequential", "remove_small_objects", "segment_classical", + "segment_dino", "segment_volume", "watershed_from_seeds", ] diff --git a/braggtrack/segmentation/dino_segment.py b/braggtrack/segmentation/dino_segment.py new file mode 100644 index 0000000..1963151 --- /dev/null +++ b/braggtrack/segmentation/dino_segment.py @@ -0,0 +1,251 @@ +"""DINO-based 3D segmentation: patch feature clustering + 3D stitching. + +Replaces the hand-tuned LoG + watershed pipeline with frozen DINOv3 patch +features clustered via HDBSCAN, producing instance labels that generalise +across beamlines and detectors without per-instrument parameter tuning. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class DinoSegmentationResult: + """Output of :func:`segment_dino`, field-compatible with ClassicalSegmentationResult.""" + + threshold: float + seed_count: int + component_count: int + labeled_volume: np.ndarray + response: np.ndarray + + +def _extract_slice_features( + volume: np.ndarray, + encoder: object, + *, + axis: int = 0, +) -> tuple[np.ndarray, tuple[int, int]]: + """Extract patch features for each slice along *axis*. + + Returns ``(features, slice_hw)`` where *features* has shape + ``(n_slices, H_patches, W_patches, D)`` and *slice_hw* is the + pixel-level ``(H, W)`` of each slice. + """ + n_slices = volume.shape[axis] + feature_maps: list[np.ndarray] = [] + for i in range(n_slices): + slc = np.take(volume, i, axis=axis) + fmap = encoder.extract_patch_features(slc) # type: ignore[union-attr] + feature_maps.append(fmap) + + slice_hw = (volume.shape[1] if axis == 0 else volume.shape[0], volume.shape[2] if axis != 2 else volume.shape[1]) + return np.stack(feature_maps, axis=0), slice_hw + + +def _cluster_feature_map( + features: np.ndarray, + *, + n_components_pca: int = 16, + min_cluster_size: int = 3, + min_samples: int = 2, +) -> np.ndarray: + """Cluster a 2D feature map ``(H_p, W_p, D)`` into instance labels. + + Uses PCA dimensionality reduction followed by HDBSCAN. + Returns ``(H_p, W_p)`` int array, 0 = background/noise. + """ + from sklearn.cluster import HDBSCAN + from sklearn.decomposition import PCA + + h_p, w_p, d = features.shape + flat = features.reshape(-1, d) + + n_comp = min(n_components_pca, d, flat.shape[0]) + if n_comp < 2: + return np.zeros((h_p, w_p), dtype=np.int32) + + reduced = PCA(n_components=n_comp).fit_transform(flat) + + clusterer = HDBSCAN( + min_cluster_size=max(2, min_cluster_size), + min_samples=max(1, min_samples), + ) + raw_labels = clusterer.fit_predict(reduced) + # HDBSCAN labels: -1 = noise, 0..K = clusters. Shift to 1-based. + labels = np.where(raw_labels >= 0, raw_labels + 1, 0) + return labels.reshape(h_p, w_p).astype(np.int32) + + +def _upsample_labels( + patch_labels: np.ndarray, + target_shape: tuple[int, int], + patch_size: int, +) -> np.ndarray: + """Nearest-neighbor upsample patch-resolution labels to pixel resolution.""" + h_p, w_p = patch_labels.shape + out = np.zeros(target_shape, dtype=np.int32) + for py in range(h_p): + for px in range(w_p): + y0 = py * patch_size + x0 = px * patch_size + y1 = min(y0 + patch_size, target_shape[0]) + x1 = min(x0 + patch_size, target_shape[1]) + out[y0:y1, x0:x1] = patch_labels[py, px] + return out + + +def _stitch_slices_3d( + per_slice_labels: list[np.ndarray], + *, + min_overlap_fraction: float = 0.3, +) -> np.ndarray: + """Stitch 2D per-slice labels into a consistent 3D label volume. + + Two labels on adjacent slices are merged when their spatial overlap + exceeds *min_overlap_fraction* of the smaller region. Uses a + union-find for global relabelling. + """ + if not per_slice_labels: + return np.zeros((0, 0, 0), dtype=np.int32) + + h, w = per_slice_labels[0].shape + n_slices = len(per_slice_labels) + + # Assign globally unique label offsets per slice. + offset = 0 + global_slices: list[np.ndarray] = [] + for sl in per_slice_labels: + shifted = np.where(sl > 0, sl + offset, 0) + global_slices.append(shifted.astype(np.int32)) + mx = int(sl.max()) + offset += mx + + total_labels = offset + 1 + + # Union-find. + parent = list(range(total_labels)) + + def find(x: int) -> int: + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + # Merge labels across adjacent slices by overlap. + for i in range(n_slices - 1): + sl_a = global_slices[i] + sl_b = global_slices[i + 1] + pairs_a = sl_a.ravel() + pairs_b = sl_b.ravel() + # Only look at pixels where both slices have a label. + mask = (pairs_a > 0) & (pairs_b > 0) + if not mask.any(): + continue + + la = pairs_a[mask] + lb = pairs_b[mask] + unique_pairs, counts = np.unique(np.stack([la, lb], axis=1), axis=0, return_counts=True) + for (lid_a, lid_b), cnt in zip(unique_pairs, counts): + size_a = int(np.count_nonzero(sl_a == lid_a)) + size_b = int(np.count_nonzero(sl_b == lid_b)) + min_size = min(size_a, size_b) + if min_size > 0 and cnt / min_size >= min_overlap_fraction: + union(int(lid_a), int(lid_b)) + + # Flatten union-find and relabel sequentially. + volume_3d = np.stack(global_slices, axis=0) + flat = volume_3d.ravel() + root_map = np.zeros(total_labels, dtype=np.int32) + for lbl in range(total_labels): + root_map[lbl] = find(lbl) + + flat = root_map[flat] + # Relabel to sequential. + unique_roots = np.unique(flat[flat > 0]) + new_map = np.zeros(total_labels, dtype=np.int32) + for new_id, old_root in enumerate(unique_roots, start=1): + new_map[old_root] = new_id + + flat = new_map[flat] + return flat.reshape(n_slices, h, w).astype(np.int32) + + +def segment_dino( + volume: np.ndarray, + *, + backend: str | None = None, + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", + torch_device: str | None = None, + n_components_pca: int = 16, + min_cluster_size: int = 3, + min_samples: int = 2, + threshold_fraction: float = 1.0, + min_overlap_fraction: float = 0.3, + axis: int = 0, +) -> DinoSegmentationResult: + """Segment a 3D volume using DINOv3 patch-level features + HDBSCAN. + + Parameters + ---------- + volume + Raw 3-D intensity cube (z, y, x), typically float64. + backend + DINO backend: ``"auto"``, ``"mock"``, or ``"torch"``. + model_name + HuggingFace model ID for the torch backend. + n_components_pca + Number of PCA components for dimensionality reduction. + min_cluster_size, min_samples + HDBSCAN density parameters. + threshold_fraction + Multiply Otsu threshold by this for foreground masking. + min_overlap_fraction + Minimum overlap for stitching 2D labels across slices. + axis + Axis to slice along (0 = mu/z, typically the narrowest). + """ + from braggtrack.segmentation.otsu import otsu_threshold + from braggtrack.semantic.dino import make_patch_encoder + + volume = np.asarray(volume, dtype=np.float64) + encoder = make_patch_encoder(backend, model_name=model_name, torch_device=torch_device) + + raw_threshold = otsu_threshold(volume.ravel()) + threshold = raw_threshold * threshold_fraction + + features, slice_hw = _extract_slice_features(volume, encoder, axis=axis) + + per_slice_labels: list[np.ndarray] = [] + for i in range(features.shape[0]): + patch_labels = _cluster_feature_map( + features[i], + n_components_pca=n_components_pca, + min_cluster_size=min_cluster_size, + min_samples=min_samples, + ) + full_labels = _upsample_labels(patch_labels, slice_hw, encoder.patch_size) + per_slice_labels.append(full_labels) + + labels_3d = _stitch_slices_3d(per_slice_labels, min_overlap_fraction=min_overlap_fraction) + + # Apply foreground mask from raw intensity. + foreground = volume >= threshold + labels_3d = np.where(foreground, labels_3d, 0).astype(np.int32) + + component_count = len(np.unique(labels_3d[labels_3d > 0])) + return DinoSegmentationResult( + threshold=threshold, + seed_count=component_count, + component_count=component_count, + labeled_volume=labels_3d, + response=np.zeros_like(volume), + ) diff --git a/braggtrack/semantic/__init__.py b/braggtrack/semantic/__init__.py index c014a37..cce3e96 100644 --- a/braggtrack/semantic/__init__.py +++ b/braggtrack/semantic/__init__.py @@ -1,6 +1,12 @@ """Week 4 multi-view semantic features (orthogonal MIPs + frozen ViT embeddings).""" -from .dino import embed_multiview_mips, make_multiview_encoder +from .dino import embed_multiview_mips, make_multiview_encoder, make_patch_encoder from .mips import crop_spot_cube, orthogonal_mips -__all__ = ["crop_spot_cube", "embed_multiview_mips", "make_multiview_encoder", "orthogonal_mips"] +__all__ = [ + "crop_spot_cube", + "embed_multiview_mips", + "make_multiview_encoder", + "make_patch_encoder", + "orthogonal_mips", +] diff --git a/braggtrack/semantic/dino.py b/braggtrack/semantic/dino.py index 458fd66..0c8c27e 100644 --- a/braggtrack/semantic/dino.py +++ b/braggtrack/semantic/dino.py @@ -1,11 +1,10 @@ -"""Frozen DINO-style embeddings for three orthogonal MIPs (Week 4). +"""Frozen DINO-style embeddings and patch-level feature extraction. Backend is selected with env ``BRAGGTRACK_DINO_BACKEND``: -* ``mock`` — deterministic CPU-only vector from image bytes (default when +* ``mock`` — deterministic CPU-only vectors from image bytes (default when PyTorch is unavailable). -* ``torch`` — Hugging Face ``facebook/dinov2-small`` (one CLS embedding per view, - concatenated and L2-normalised). +* ``torch`` — Hugging Face DINOv3/v2 model (CLS or patch tokens). * ``auto`` — use ``torch`` if import succeeds, else ``mock``. """ @@ -131,6 +130,108 @@ def make_multiview_encoder( return TorchDinoMultiviewEncoder(model_name, torch_device) +# --------------------------------------------------------------------------- +# Patch-level feature extraction (for DINO-based segmentation) +# --------------------------------------------------------------------------- + + +class PatchFeatureEncoder(Protocol): + """Interface for extracting spatially-resolved patch features.""" + + def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray: + """Return (H_patches, W_patches, D_features) feature map from a 2D grayscale image.""" + ... + + @property + def patch_size(self) -> int: ... + + @property + def feature_dim(self) -> int: ... + + +class MockPatchEncoder: + """Deterministic hash-based patch features for CI (no GPU).""" + + @property + def patch_size(self) -> int: + return 14 + + @property + def feature_dim(self) -> int: + return 384 + + def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray: + h, w = image_2d.shape[:2] + h_p = max(1, h // self.patch_size) + w_p = max(1, w // self.patch_size) + seed = int.from_bytes( + hashlib.sha256(np.asarray(image_2d, dtype=np.float32).tobytes()).digest()[:8], + "little", + signed=False, + ) + rng = np.random.default_rng(seed) + features = rng.standard_normal((h_p, w_p, self.feature_dim)).astype(np.float32) + norms = np.linalg.norm(features, axis=-1, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) + return (features / norms).astype(np.float32) + + +class TorchDinoPatchEncoder: + """Extracts DINOv2/v3 patch tokens as a spatial feature map.""" + + def __init__(self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", device: str | None = None) -> None: + import torch + from transformers import AutoImageProcessor, AutoModel + + self._torch = torch + self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self._proc = AutoImageProcessor.from_pretrained(model_name) + self._model = AutoModel.from_pretrained(model_name) + self._model.eval() + self._model.to(self._device) + self._patch_size = getattr(self._model.config, "patch_size", 14) + self._feature_dim = int(self._model.config.hidden_size) + + @property + def patch_size(self) -> int: + return self._patch_size + + @property + def feature_dim(self) -> int: + return self._feature_dim + + def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray: + rgb = _mips_to_rgb_uint8(image_2d) + with self._torch.no_grad(): + inputs = self._proc(images=[rgb], return_tensors="pt") + inputs = {k: v.to(self._device) for k, v in inputs.items()} + out = self._model(**inputs) + patch_tokens = out.last_hidden_state[:, 1:, :].squeeze(0) + h_img = inputs["pixel_values"].shape[2] + w_img = inputs["pixel_values"].shape[3] + h_p = h_img // self._patch_size + w_p = w_img // self._patch_size + features = patch_tokens[: h_p * w_p].reshape(h_p, w_p, -1) + x = features.float().cpu().numpy() + norms = np.linalg.norm(x, axis=-1, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) + return (x / norms).astype(np.float32) + + +def make_patch_encoder( + backend: BackendName | None = None, + *, + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", + torch_device: str | None = None, +) -> PatchFeatureEncoder: + """Construct a reusable patch-level encoder.""" + req = _requested_backend(backend) + use = _resolve_backend(req) + if use == "mock": + return MockPatchEncoder() + return TorchDinoPatchEncoder(model_name, torch_device) + + def embed_multiview_mips( mip_mu: np.ndarray, mip_chi: np.ndarray, diff --git a/pyproject.toml b/pyproject.toml index 49b9c17..2fbd8b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "numpy>=1.26", "scipy>=1.12", "scikit-image>=0.22", + "scikit-learn>=1.3", "networkx>=3.2", "h5py>=3.10", ] diff --git a/tests/test_segmentation_dino.py b/tests/test_segmentation_dino.py new file mode 100644 index 0000000..f29b1aa --- /dev/null +++ b/tests/test_segmentation_dino.py @@ -0,0 +1,153 @@ +"""Tests for DINO-based segmentation (mock backend, no GPU).""" + +import os +import unittest + +import numpy as np + +os.environ.setdefault("BRAGGTRACK_DINO_BACKEND", "mock") + +from braggtrack.segmentation.dino_segment import ( + DinoSegmentationResult, + _cluster_feature_map, + _stitch_slices_3d, + _upsample_labels, + segment_dino, +) +from braggtrack.semantic.dino import MockPatchEncoder, make_patch_encoder + + +class TestMockPatchEncoder(unittest.TestCase): + def test_output_shape(self) -> None: + enc = MockPatchEncoder() + img = np.random.default_rng(42).standard_normal((56, 56)) + features = enc.extract_patch_features(img) + self.assertEqual(features.shape, (4, 4, 384)) + + def test_deterministic(self) -> None: + enc = MockPatchEncoder() + img = np.ones((28, 28)) + a = enc.extract_patch_features(img) + b = enc.extract_patch_features(img) + np.testing.assert_array_equal(a, b) + + def test_l2_normalized(self) -> None: + enc = MockPatchEncoder() + img = np.random.default_rng(7).standard_normal((42, 42)) + features = enc.extract_patch_features(img) + norms = np.linalg.norm(features, axis=-1) + np.testing.assert_allclose(norms, 1.0, atol=1e-5) + + +class TestMakePatchEncoder(unittest.TestCase): + def test_mock_backend(self) -> None: + enc = make_patch_encoder("mock") + self.assertIsInstance(enc, MockPatchEncoder) + self.assertEqual(enc.patch_size, 14) + self.assertEqual(enc.feature_dim, 384) + + +class TestClusterFeatureMap(unittest.TestCase): + def test_basic_clustering(self) -> None: + rng = np.random.default_rng(99) + features = np.zeros((6, 6, 32), dtype=np.float32) + features[:3, :, :] = rng.standard_normal((3, 6, 32)) + 5.0 + features[3:, :, :] = rng.standard_normal((3, 6, 32)) - 5.0 + labels = _cluster_feature_map(features, n_components_pca=8, min_cluster_size=3, min_samples=1) + self.assertEqual(labels.shape, (6, 6)) + self.assertGreaterEqual(len(np.unique(labels[labels > 0])), 1) + + def test_tiny_input_returns_zeros(self) -> None: + features = np.ones((1, 1, 4), dtype=np.float32) + labels = _cluster_feature_map(features, n_components_pca=2) + self.assertEqual(labels.shape, (1, 1)) + + +class TestUpsampleLabels(unittest.TestCase): + def test_basic_upsample(self) -> None: + patch_labels = np.array([[1, 2], [3, 0]], dtype=np.int32) + out = _upsample_labels(patch_labels, target_shape=(28, 28), patch_size=14) + self.assertEqual(out.shape, (28, 28)) + self.assertEqual(out[0, 0], 1) + self.assertEqual(out[0, 14], 2) + self.assertEqual(out[14, 0], 3) + self.assertEqual(out[14, 14], 0) + + def test_edge_handling(self) -> None: + patch_labels = np.array([[1]], dtype=np.int32) + out = _upsample_labels(patch_labels, target_shape=(10, 10), patch_size=14) + self.assertEqual(out.shape, (10, 10)) + self.assertTrue(np.all(out == 1)) + + +class TestStitchSlices3D(unittest.TestCase): + def test_empty_input(self) -> None: + result = _stitch_slices_3d([]) + self.assertEqual(result.shape, (0, 0, 0)) + + def test_single_slice(self) -> None: + sl = np.array([[1, 0], [0, 2]], dtype=np.int32) + result = _stitch_slices_3d([sl]) + self.assertEqual(result.shape, (1, 2, 2)) + self.assertEqual(len(np.unique(result[result > 0])), 2) + + def test_overlapping_slices_merge(self) -> None: + sl1 = np.array([[1, 1, 0], [1, 0, 0], [0, 0, 0]], dtype=np.int32) + sl2 = np.array([[2, 2, 0], [2, 0, 0], [0, 0, 0]], dtype=np.int32) + result = _stitch_slices_3d([sl1, sl2], min_overlap_fraction=0.3) + self.assertEqual(result.shape, (2, 3, 3)) + labels_s0 = result[0][result[0] > 0] + labels_s1 = result[1][result[1] > 0] + self.assertEqual(len(np.unique(labels_s0)), 1) + self.assertEqual(len(np.unique(labels_s1)), 1) + self.assertEqual(np.unique(labels_s0)[0], np.unique(labels_s1)[0]) + + def test_non_overlapping_stay_separate(self) -> None: + sl1 = np.array([[1, 0], [0, 0]], dtype=np.int32) + sl2 = np.array([[0, 0], [0, 2]], dtype=np.int32) + result = _stitch_slices_3d([sl1, sl2], min_overlap_fraction=0.3) + labels_s0 = result[0][result[0] > 0] + labels_s1 = result[1][result[1] > 0] + self.assertTrue(len(labels_s0) > 0) + self.assertTrue(len(labels_s1) > 0) + self.assertNotEqual(np.unique(labels_s0)[0], np.unique(labels_s1)[0]) + + +class TestSegmentDino(unittest.TestCase): + def test_mock_backend_runs(self) -> None: + rng = np.random.default_rng(42) + volume = rng.standard_normal((6, 28, 28)).astype(np.float64) + volume[2:4, 10:18, 10:18] += 10.0 + result = segment_dino(volume, backend="mock") + self.assertIsInstance(result, DinoSegmentationResult) + self.assertEqual(result.labeled_volume.shape, volume.shape) + self.assertEqual(result.response.shape, volume.shape) + self.assertGreater(result.threshold, 0) + + def test_result_field_compatible(self) -> None: + volume = np.random.default_rng(0).standard_normal((4, 28, 28)).astype(np.float64) + volume += 5.0 + result = segment_dino(volume, backend="mock") + self.assertIsInstance(result.threshold, float) + self.assertIsInstance(result.seed_count, int) + self.assertIsInstance(result.component_count, int) + self.assertIsInstance(result.labeled_volume, np.ndarray) + self.assertIsInstance(result.response, np.ndarray) + + def test_foreground_mask_applied(self) -> None: + rng = np.random.default_rng(123) + volume = rng.standard_normal((4, 28, 28)).astype(np.float64) + volume[:, :14, :14] += 20.0 + result = segment_dino(volume, backend="mock") + below_threshold = volume < result.threshold + self.assertTrue(np.all(result.labeled_volume[below_threshold] == 0)) + + def test_labels_are_nonnegative(self) -> None: + volume = np.random.default_rng(55).standard_normal((4, 28, 28)).astype(np.float64) + volume += 3.0 + result = segment_dino(volume, backend="mock") + self.assertTrue(np.all(result.labeled_volume >= 0)) + + +if __name__ == "__main__": + unittest.main() From 630fafaabf40553e16cfd6b1d3746bec96ae3058 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 14:56:42 +0000 Subject: [PATCH 3/8] test: add DINO segmentation acceptance pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runs `--method dino --dino-backend mock` on bundled sample data and validates artifacts (labels, features, summary) for all 3 scans. Also fixes HDBSCAN on tiny volumes where few patches exist — treats single-patch slices as one region instead of returning empty labels. - Add scripts/check_dino_acceptance.py (mirrors check_week2_acceptance) - Add tests/test_dino_acceptance.py - Wire DINO acceptance gate into scripts/ci_report.py - Fix _cluster_feature_map to handle small volumes gracefully https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/segmentation/dino_segment.py | 16 +++-- scripts/check_dino_acceptance.py | 90 +++++++++++++++++++++++++ scripts/ci_report.py | 6 +- tests/test_dino_acceptance.py | 26 +++++++ tests/test_segmentation_dino.py | 3 +- 5 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 scripts/check_dino_acceptance.py create mode 100644 tests/test_dino_acceptance.py diff --git a/braggtrack/segmentation/dino_segment.py b/braggtrack/segmentation/dino_segment.py index 1963151..6253829 100644 --- a/braggtrack/segmentation/dino_segment.py +++ b/braggtrack/segmentation/dino_segment.py @@ -62,21 +62,29 @@ def _cluster_feature_map( from sklearn.decomposition import PCA h_p, w_p, d = features.shape + n_patches = h_p * w_p flat = features.reshape(-1, d) - n_comp = min(n_components_pca, d, flat.shape[0]) + n_comp = min(n_components_pca, d, n_patches) if n_comp < 2: - return np.zeros((h_p, w_p), dtype=np.int32) + # Too few patches to cluster — assign all to a single region. + return np.ones((h_p, w_p), dtype=np.int32) reduced = PCA(n_components=n_comp).fit_transform(flat) + effective_min_cluster = max(2, min(min_cluster_size, n_patches // 2)) clusterer = HDBSCAN( - min_cluster_size=max(2, min_cluster_size), - min_samples=max(1, min_samples), + min_cluster_size=effective_min_cluster, + min_samples=max(1, min(min_samples, effective_min_cluster - 1)), ) raw_labels = clusterer.fit_predict(reduced) # HDBSCAN labels: -1 = noise, 0..K = clusters. Shift to 1-based. labels = np.where(raw_labels >= 0, raw_labels + 1, 0) + + # If HDBSCAN assigned everything to noise, treat all patches as one region. + if not np.any(labels > 0): + return np.ones((h_p, w_p), dtype=np.int32) + return labels.reshape(h_p, w_p).astype(np.int32) diff --git a/scripts/check_dino_acceptance.py b/scripts/check_dino_acceptance.py new file mode 100644 index 0000000..37bb721 --- /dev/null +++ b/scripts/check_dino_acceptance.py @@ -0,0 +1,90 @@ +"""DINO segmentation acceptance checks on bundled sample data.""" + +from __future__ import annotations + +import csv +import json +import subprocess +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from braggtrack.io import resolve_dataset_root + +OUTDIR = Path("artifacts/dino") +DATASET_ROOT = resolve_dataset_root(None) + + +def main() -> int: + proc = subprocess.run( + [ + sys.executable, + "-m", + "braggtrack.cli.segment_dataset", + str(DATASET_ROOT), + "--outdir", + str(OUTDIR), + "--method", + "dino", + "--dino-backend", + "mock", + ], + check=False, + capture_output=True, + text=True, + ) + payload = json.loads(proc.stdout) if proc.stdout.strip() else [] + + failures: list[str] = [] + + if proc.returncode != 0: + failures.append(f"segment_dataset exited with code {proc.returncode}: {proc.stderr.strip()}") + + if len(payload) != 3: + failures.append(f"Expected 3 scans in output, found {len(payload)}") + + for row in payload: + scan = row.get("scan", "?") + if row.get("component_count", 0) <= 0: + failures.append(f"{scan}: component_count must be > 0") + if row.get("schema_version") != "week2.v1": + failures.append(f"{scan}: schema_version mismatch") + + summary_csv = OUTDIR / "segmentation_summary.csv" + if not summary_csv.exists(): + failures.append("Missing segmentation_summary.csv") + else: + with summary_csv.open() as fh: + rows = list(csv.DictReader(fh)) + if len(rows) != 3: + failures.append(f"segmentation_summary.csv expected 3 rows, found {len(rows)}") + + for scan_dir in sorted(OUTDIR.glob("scan*")): + summary_path = scan_dir / "summary.json" + if not summary_path.exists(): + failures.append(f"{scan_dir.name}: missing summary.json") + continue + summary = json.loads(summary_path.read_text()) + if summary.get("method") != "dino": + failures.append(f"{scan_dir.name}: method should be 'dino', got {summary.get('method')}") + if not (scan_dir / "features.csv").exists(): + failures.append(f"{scan_dir.name}: missing features.csv") + if not (scan_dir / "labels.npz").exists(): + failures.append(f"{scan_dir.name}: missing labels.npz") + + report = { + "method": "dino", + "scan_count": len(payload), + "non_empty_components": sum(1 for r in payload if r.get("component_count", 0) > 0), + "schema_consistent": all(r.get("schema_version") == "week2.v1" for r in payload), + "failures": failures, + } + print(json.dumps(report, indent=2)) + return 0 if not failures else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ci_report.py b/scripts/ci_report.py index 2262c65..5292186 100644 --- a/scripts/ci_report.py +++ b/scripts/ci_report.py @@ -123,7 +123,10 @@ def main() -> int: ) wk4_ok = wk4_ok_cmd and isinstance(wk4_payload, dict) and wk4_payload.get("failures") == [] - all_ok = unit_ok and acc_ok and smoke_ok and wk2_ok and wk3_ok and wk4_ok + dino_ok_cmd, dino_payload, _ = run_cmd_json([sys.executable, "scripts/check_dino_acceptance.py"], "DINO Acceptance") + dino_ok = dino_ok_cmd and isinstance(dino_payload, dict) and dino_payload.get("failures") == [] + + all_ok = unit_ok and acc_ok and smoke_ok and wk2_ok and wk3_ok and wk4_ok and dino_ok print("\n=== Summary ===") print(f"unit_tests={'PASS' if unit_ok else 'FAIL'}") print(f"acceptance={'PASS' if acc_ok else 'FAIL'}") @@ -131,6 +134,7 @@ def main() -> int: print(f"week2_acceptance={'PASS' if wk2_ok else 'FAIL'}") print(f"week3_acceptance={'PASS' if wk3_ok else 'FAIL'}") print(f"week4_acceptance={'PASS' if wk4_ok else 'FAIL'}") + print(f"dino_acceptance={'PASS' if dino_ok else 'FAIL'}") print(f"overall={'PASS' if all_ok else 'FAIL'}") return 0 if all_ok else 1 diff --git a/tests/test_dino_acceptance.py b/tests/test_dino_acceptance.py new file mode 100644 index 0000000..1975615 --- /dev/null +++ b/tests/test_dino_acceptance.py @@ -0,0 +1,26 @@ +"""Acceptance test: DINO segmentation pipeline on bundled sample data.""" + +import json +import subprocess +import sys +import unittest + + +class DinoAcceptanceTests(unittest.TestCase): + def test_dino_acceptance_script(self) -> None: + proc = subprocess.run( + [sys.executable, "scripts/check_dino_acceptance.py"], + check=False, + capture_output=True, + text=True, + ) + self.assertEqual(proc.returncode, 0, msg=proc.stdout + proc.stderr) + payload = json.loads(proc.stdout) + self.assertEqual(payload["method"], "dino") + self.assertEqual(payload["scan_count"], 3) + self.assertEqual(payload["non_empty_components"], 3) + self.assertEqual(payload["failures"], []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_segmentation_dino.py b/tests/test_segmentation_dino.py index f29b1aa..1407822 100644 --- a/tests/test_segmentation_dino.py +++ b/tests/test_segmentation_dino.py @@ -57,10 +57,11 @@ def test_basic_clustering(self) -> None: self.assertEqual(labels.shape, (6, 6)) self.assertGreaterEqual(len(np.unique(labels[labels > 0])), 1) - def test_tiny_input_returns_zeros(self) -> None: + def test_tiny_input_single_region(self) -> None: features = np.ones((1, 1, 4), dtype=np.float32) labels = _cluster_feature_map(features, n_components_pca=2) self.assertEqual(labels.shape, (1, 1)) + self.assertEqual(labels[0, 0], 1) class TestUpsampleLabels(unittest.TestCase): From 8a3143e2ffb277421134257f32e22494a5128b00 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 15:07:45 +0000 Subject: [PATCH 4/8] feat: add DINO vs classical segmentation comparison notebook Side-by-side comparison on bundled sample data: spot counts, tri-axis label projections, feature distributions, Dice overlap, centroid scatter, and cross-scan consistency metrics. Uses mock backend by default. https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- notebooks/dino_segmentation_comparison.ipynb | 207 +++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 notebooks/dino_segmentation_comparison.ipynb diff --git a/notebooks/dino_segmentation_comparison.ipynb b/notebooks/dino_segmentation_comparison.ipynb new file mode 100644 index 0000000..a216609 --- /dev/null +++ b/notebooks/dino_segmentation_comparison.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ddb2fb4c", + "source": "# DINO vs Classical Segmentation Comparison\n\nThis notebook runs both segmentation backends on the bundled `data/sample_operando/` scans and compares their outputs side-by-side.\n\n| Method | How it works | Strengths |\n|--------|-------------|-----------|\n| **Classical** | Otsu threshold → LoG enhancement → h-maxima seeds → seeded watershed → merge nearby | Fast, interpretable, well-tuned for this beamline |\n| **DINO** | DINOv3 patch features → PCA → HDBSCAN clustering → 3D slice stitching → Otsu foreground mask | Learns in feature space — should generalise across beamlines/detectors without re-tuning |\n\nUses the **mock** DINO backend by default (no GPU required). Set `BRAGGTRACK_DINO_BACKEND=torch` for real DINOv3 features.", + "metadata": {} + }, + { + "cell_type": "markdown", + "id": "23d7b278", + "source": "## Setup", + "metadata": {} + }, + { + "cell_type": "code", + "id": "320c0082", + "source": "import os, subprocess, sys\n\n_ON_COLAB = \"google.colab\" in sys.modules or os.environ.get(\"COLAB_RELEASE_TAG\")\n\nif _ON_COLAB:\n print(\"Colab detected — installing BraggTrack + sample data...\")\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"braggtrack[notebook] @ git+https://github.com/BASE-Laboratory/BraggTrack.git\",\n ])\n if not os.path.isdir(\"data/sample_operando\"):\n subprocess.check_call([\n \"git\", \"clone\", \"--depth=1\", \"--filter=blob:none\", \"--sparse\",\n \"https://github.com/BASE-Laboratory/BraggTrack.git\", \"_braggtrack_repo\",\n ])\n subprocess.check_call(\n [\"git\", \"sparse-checkout\", \"set\", \"data/sample_operando\"],\n cwd=\"_braggtrack_repo\",\n )\n os.makedirs(\"data\", exist_ok=True)\n os.rename(\"_braggtrack_repo/data/sample_operando\", \"data/sample_operando\")\n subprocess.check_call([\"rm\", \"-rf\", \"_braggtrack_repo\"])\n os.environ.setdefault(\"BRAGGTRACK_DATA_ROOT\", os.path.abspath(\"data/sample_operando\"))\n print(\"Done.\")\nelse:\n print(\"Local environment — skipping Colab setup.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "9f1e4a10", + "source": "%matplotlib inline\nfrom pathlib import Path\n\nimport h5py\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom matplotlib.colors import ListedColormap\n\nfrom braggtrack.io import discover_operando_scans, sample_operando_root\nfrom braggtrack.segmentation import (\n extract_instance_table,\n fill_holes_binary,\n label_projection_by_intensity,\n merge_nearby_labels,\n otsu_floor_from_mip,\n otsu_threshold,\n relabel_sequential,\n remove_small_objects,\n segment_classical,\n segment_dino,\n smooth_thresholds,\n)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "12606936", + "source": "## 1 — Load real data\n\nRead the largest 3D numeric dataset from each H5 file (bypasses the fixed NeXus path shortlist).", + "metadata": {} + }, + { + "cell_type": "code", + "id": "379a76d7", + "source": "def load_3d_volume(h5_path: Path) -> np.ndarray:\n \"\"\"Pick the largest 3-D numeric dataset in an H5 file.\"\"\"\n candidates: list[tuple[str, tuple[int, ...]]] = []\n with h5py.File(h5_path, \"r\") as f:\n def _visit(name, obj):\n if isinstance(obj, h5py.Dataset) and obj.ndim == 3 and np.issubdtype(obj.dtype, np.number):\n candidates.append((name, obj.shape))\n f.visititems(_visit)\n if not candidates:\n raise KeyError(f\"No 3D numeric dataset in {h5_path}\")\n name = max(candidates, key=lambda t: int(np.prod(t[1])))[0]\n return np.asarray(f[name][...], dtype=np.float64)\n\nscans = discover_operando_scans(sample_operando_root())\nall_volumes = [load_3d_volume(s.path) for s in scans]\n\nfor s, v in zip(scans, all_volumes):\n print(f\"{s.scan_name}: shape={v.shape} intensity=[{v.min():.0f}, {v.max():.0f}]\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "61e5cbd4", + "source": "## 2 — Run both segmentation methods\n\n### Classical pipeline\nOtsu → LoG → h-maxima → seeded watershed → remove small → fill holes → merge nearby → relabel.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "1dbae8f3", + "source": "MERGE_DISTANCE = 15\n\ndef run_classical(volume: np.ndarray, threshold: float) -> np.ndarray:\n res = segment_classical(\n volume,\n threshold=threshold,\n blur_passes=1,\n h_value=0.1,\n min_seed_separation=2,\n seed_peak_fraction=0.2,\n seed_response_percentile=99.95,\n )\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = np.where(binary, labels, 0)\n labels = merge_nearby_labels(labels, volume, max_centroid_distance=MERGE_DISTANCE)\n return relabel_sequential(labels)\n\nraw_thresholds = [otsu_threshold(v.ravel()) for v in all_volumes]\nsmoothed = smooth_thresholds(raw_thresholds, window=5)\n\nclassical_labels = []\nfor s, v, thr in zip(scans, all_volumes, smoothed):\n lab = run_classical(v, float(thr))\n classical_labels.append(lab)\n print(f\"{s.scan_name} classical: threshold={thr:.1f}, {int(lab.max())} spots\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "a5cfd4e0", + "source": "### DINO pipeline\nDINOv3 patch features → PCA → HDBSCAN → upsample → 3D stitch → Otsu foreground mask → post-process.\n\nThe post-processing (remove small, fill holes, merge nearby, relabel) is identical to keep the comparison fair.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "2ab096d1", + "source": "def run_dino(volume: np.ndarray) -> np.ndarray:\n res = segment_dino(volume, backend=\"mock\")\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = np.where(binary, labels, 0)\n labels = merge_nearby_labels(labels, volume, max_centroid_distance=MERGE_DISTANCE)\n return relabel_sequential(labels)\n\ndino_labels = []\nfor s, v in zip(scans, all_volumes):\n lab = run_dino(v)\n dino_labels.append(lab)\n print(f\"{s.scan_name} DINO: {int(lab.max())} spots\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "1b93c867", + "source": "## 3 — Spot count comparison", + "metadata": {} + }, + { + "cell_type": "code", + "id": "e1799372", + "source": "scan_names = [s.scan_name for s in scans]\nclassical_counts = [int(l.max()) for l in classical_labels]\ndino_counts = [int(l.max()) for l in dino_labels]\n\ncomparison_df = pd.DataFrame({\n \"scan\": scan_names,\n \"classical_spots\": classical_counts,\n \"dino_spots\": dino_counts,\n})\ncomparison_df[\"difference\"] = comparison_df[\"dino_spots\"] - comparison_df[\"classical_spots\"]\nprint(comparison_df.to_string(index=False))\n\nprint(f\"\\nClassical spread (max-min): {max(classical_counts) - min(classical_counts)}\")\nprint(f\"DINO spread (max-min): {max(dino_counts) - min(dino_counts)}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "1de56eb0", + "source": "x = np.arange(len(scans))\nwidth = 0.35\n\nfig, ax = plt.subplots(figsize=(8, 4))\nbars1 = ax.bar(x - width / 2, classical_counts, width, label=\"Classical\", color=\"#1f77b4\")\nbars2 = ax.bar(x + width / 2, dino_counts, width, label=\"DINO (mock)\", color=\"#ff7f0e\")\nax.set_xlabel(\"Scan\")\nax.set_ylabel(\"Spot count\")\nax.set_title(\"Spot counts: Classical vs DINO\")\nax.set_xticks(x)\nax.set_xticklabels(scan_names)\nax.legend()\n\nfor bars in [bars1, bars2]:\n for bar in bars:\n ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3,\n str(int(bar.get_height())), ha=\"center\", va=\"bottom\", fontsize=10)\n\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "dfd3d99c", + "source": "## 4 — Visual comparison: tri-axis label projections\n\nSide-by-side label overlays for each scan, projected along all three physical axes (μ, χ, d). Each row is a scan; left column = classical, right column = DINO.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "569f7b23", + "source": "# Build a shared colormap large enough for both methods.\nmax_labels = max(\n max(int(l.max()) for l in classical_labels),\n max(int(l.max()) for l in dino_labels),\n) + 1\nrng_cm = np.random.RandomState(42)\nlabel_colors = np.zeros((max_labels, 4))\nlabel_colors[0] = [0, 0, 0, 0]\nfor i in range(1, max_labels):\n label_colors[i] = [*rng_cm.uniform(0.2, 0.95, 3), 0.65]\nlabel_cmap = ListedColormap(label_colors)\n\naxis_info = [\n (0, \"MIP along mu\", \"chi\", \"d\"),\n (1, \"MIP along chi\", \"d\", \"mu\"),\n (2, \"MIP along d\", \"chi\", \"mu\"),\n]\n\nfig, axes = plt.subplots(len(scans), 6, figsize=(22, len(scans) * 3.5))\n\nfor row, (s, v, c_lab, d_lab) in enumerate(zip(scans, all_volumes, classical_labels, dino_labels)):\n for col_offset, (method_name, labels) in enumerate([(\"Classical\", c_lab), (\"DINO\", d_lab)]):\n for ax_idx, (axis_id, title, xlabel, ylabel) in enumerate(axis_info):\n ax = axes[row, col_offset * 3 + ax_idx]\n mip = v.max(axis=axis_id)\n floor = otsu_floor_from_mip(v, axis=axis_id)\n proj_l = label_projection_by_intensity(v, labels, axis=axis_id, mip_floor=floor)\n\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n mask = np.ma.masked_where(proj_l == 0, proj_l)\n ax.imshow(mask, cmap=label_cmap, interpolation=\"nearest\", vmin=0, vmax=max_labels - 1)\n\n if row == 0:\n ax.set_title(f\"{method_name}\\n{title}\", fontsize=9)\n ax.tick_params(labelsize=6)\n if ax_idx == 0 and col_offset == 0:\n n_c = int(c_lab.max())\n n_d = int(d_lab.max())\n ax.set_ylabel(f\"{s.scan_name}\\nC={n_c} D={n_d}\", fontsize=9)\n\nplt.suptitle(\"Classical (left 3 cols) vs DINO (right 3 cols) — tri-axis label projection\", y=1.01, fontsize=13)\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "49b9b0e3", + "source": "## 5 — Instance feature comparison\n\nCompare the per-spot properties (voxel count, integrated intensity, centroid, eigenvalues) between the two methods.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "3af857cf", + "source": "classical_features = [extract_instance_table(l, v) for l, v in zip(classical_labels, all_volumes)]\ndino_features = [extract_instance_table(l, v) for l, v in zip(dino_labels, all_volumes)]\n\nfor s, cf, df in zip(scans, classical_features, dino_features):\n print(f\"\\n{s.scan_name}:\")\n c_df = pd.DataFrame(cf)\n d_df = pd.DataFrame(df)\n print(f\" Classical: {len(c_df)} spots, \"\n f\"mean voxels={c_df['voxel_count'].mean():.1f}, \"\n f\"total intensity={c_df['integrated_intensity'].sum():.0f}\")\n print(f\" DINO: {len(d_df)} spots, \"\n f\"mean voxels={d_df['voxel_count'].mean():.1f}, \"\n f\"total intensity={d_df['integrated_intensity'].sum():.0f}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "29a5bcea", + "source": "# Voxel count and intensity distributions\nfig, axes = plt.subplots(1, 2, figsize=(12, 4))\n\nfor scan_idx, (s, cf, df) in enumerate(zip(scans, classical_features, dino_features)):\n c_vox = [r[\"voxel_count\"] for r in cf]\n d_vox = [r[\"voxel_count\"] for r in df]\n axes[0].scatter([scan_idx - 0.1] * len(c_vox), c_vox, c=\"#1f77b4\", alpha=0.6, s=40,\n label=\"Classical\" if scan_idx == 0 else None)\n axes[0].scatter([scan_idx + 0.1] * len(d_vox), d_vox, c=\"#ff7f0e\", alpha=0.6, s=40,\n label=\"DINO\" if scan_idx == 0 else None)\n\n c_int = [r[\"integrated_intensity\"] for r in cf]\n d_int = [r[\"integrated_intensity\"] for r in df]\n axes[1].scatter([scan_idx - 0.1] * len(c_int), c_int, c=\"#1f77b4\", alpha=0.6, s=40)\n axes[1].scatter([scan_idx + 0.1] * len(d_int), d_int, c=\"#ff7f0e\", alpha=0.6, s=40)\n\naxes[0].set_ylabel(\"Voxel count\")\naxes[0].set_title(\"Voxel count per spot\")\naxes[0].legend()\naxes[1].set_ylabel(\"Integrated intensity\")\naxes[1].set_title(\"Integrated intensity per spot\")\nfor ax in axes:\n ax.set_xticks(range(len(scans)))\n ax.set_xticklabels([s.scan_name for s in scans])\n ax.set_xlabel(\"Scan\")\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "d3693a3f", + "source": "## 6 — Spatial overlap (Dice coefficient)\n\nFor each scan, compute the Dice coefficient between the binary foreground masks produced by the two methods. This measures how much the methods agree on *where* spots are, regardless of how they partition them into instances.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "5e9bdbee", + "source": "def dice(a: np.ndarray, b: np.ndarray) -> float:\n a_bool = a > 0\n b_bool = b > 0\n intersection = np.count_nonzero(a_bool & b_bool)\n total = np.count_nonzero(a_bool) + np.count_nonzero(b_bool)\n return 2.0 * intersection / total if total > 0 else 1.0\n\nfor s, c_lab, d_lab in zip(scans, classical_labels, dino_labels):\n d = dice(c_lab, d_lab)\n c_fg = np.count_nonzero(c_lab > 0)\n d_fg = np.count_nonzero(d_lab > 0)\n print(f\"{s.scan_name}: Dice={d:.3f} (classical fg={c_fg} voxels, DINO fg={d_fg} voxels)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "4b7a8bbc", + "source": "## 7 — Centroid scatter: classical vs DINO\n\nPlot the centroids from both methods on the same axes. Matching centroids (spots found by both methods) will overlap; method-unique detections will stand alone.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "7beadc17", + "source": "fig, axes = plt.subplots(1, len(scans), figsize=(5 * len(scans), 4.5))\nif len(scans) == 1:\n axes = [axes]\n\nfor ax, s, v, cf, df in zip(axes, scans, all_volumes, classical_features, dino_features):\n mip = v.max(axis=0)\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n\n for r in cf:\n ax.plot(r[\"centroid_chi\"], r[\"centroid_d\"], \"o\", mfc=\"none\",\n mec=\"#1f77b4\", mew=1.5, ms=12)\n for r in df:\n ax.plot(r[\"centroid_chi\"], r[\"centroid_d\"], \"x\",\n mec=\"#ff7f0e\", mew=1.5, ms=10)\n\n ax.set_title(f\"{s.scan_name}\\nblue O = classical, orange X = DINO\")\n ax.axis(\"off\")\n\nplt.tight_layout()\nplt.show()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "77d2e6e2", + "source": "## 8 — Consistency across scans\n\nA key motivation for DINO-based segmentation is consistency: the same physical spots should produce the same segmentation across consecutive scans. Compare the coefficient of variation (std/mean) of spot counts across scans for each method.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "82ec00b3", + "source": "c_arr = np.array(classical_counts, dtype=float)\nd_arr = np.array(dino_counts, dtype=float)\n\nc_cv = c_arr.std() / c_arr.mean() if c_arr.mean() > 0 else 0\nd_cv = d_arr.std() / d_arr.mean() if d_arr.mean() > 0 else 0\n\nconsistency = pd.DataFrame({\n \"Method\": [\"Classical\", \"DINO (mock)\"],\n \"Mean spots\": [c_arr.mean(), d_arr.mean()],\n \"Std spots\": [c_arr.std(), d_arr.std()],\n \"CV (std/mean)\": [c_cv, d_cv],\n \"Spread (max-min)\": [c_arr.max() - c_arr.min(), d_arr.max() - d_arr.min()],\n})\nprint(consistency.to_string(index=False))", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "ec6c4ef9", + "source": "## 9 — Per-scan feature tables\n\nFull feature tables for both methods on scan 1, for detailed inspection.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "49e37ae2", + "source": "cols = [\"label\", \"voxel_count\", \"integrated_intensity\",\n \"centroid_mu\", \"centroid_chi\", \"centroid_d\",\n \"eig_1\", \"eig_2\", \"eig_3\"]\n\nprint(\"=== Classical — scan0001 ===\")\ndisplay(pd.DataFrame(classical_features[0])[cols]) if classical_features[0] else print(\"(no spots)\")\n\nprint(\"\\n=== DINO — scan0001 ===\")\ndisplay(pd.DataFrame(dino_features[0])[cols]) if dino_features[0] else print(\"(no spots)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "a64fefce", + "source": "## 10 — Notes and next steps\n\n**What the mock backend shows:** The mock DINO backend produces hash-based random features, so the clustering is not semantically meaningful — it tests the *pipeline plumbing* (slice extraction → PCA → HDBSCAN → stitching → foreground mask) but not the quality of learned representations.\n\n**What changes with real DINOv3 weights:**\n- Set `BRAGGTRACK_DINO_BACKEND=torch` (requires `torch` + `transformers` + GPU)\n- The encoder extracts genuine patch-level features where similar textures cluster together\n- Expect better instance separation without hand-tuned LoG/watershed parameters\n- The same model should work across different beamlines and detectors\n\n**CLI equivalent:**\n```bash\n# Classical\nbraggtrack-segment-dataset --method classical --outdir artifacts/classical\n\n# DINO (mock)\nbraggtrack-segment-dataset --method dino --dino-backend mock --outdir artifacts/dino\n\n# DINO (real weights)\nbraggtrack-segment-dataset --method dino --dino-backend torch --outdir artifacts/dino_real\n```", + "metadata": {} + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 27455a4c51a8695cc87689d9ff9c777996eb97b1 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 15:22:54 +0000 Subject: [PATCH 5/8] fix: address code review findings for DINO segmentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Type encoder param as PatchFeatureEncoder instead of object - Type backend param as BackendName instead of str - Vectorize _upsample_labels with np.repeat (was Python double-loop) - Pre-compute label sizes in _stitch_slices_3d (avoids O(n_pairs*n_pixels)) - Eliminate double Otsu in CLI — compute threshold only in classical branch - Use result.threshold in summary JSON (works for both methods) - Make response field zero-size array (no LoG response in DINO path) - Clarify slice_hw with if/elif/else instead of nested ternary - Remove os.environ side effect from test file https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/cli/segment_dataset.py | 9 +++-- braggtrack/segmentation/dino_segment.py | 44 ++++++++++++++++--------- tests/test_segmentation_dino.py | 5 +-- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/braggtrack/cli/segment_dataset.py b/braggtrack/cli/segment_dataset.py index 60a292f..e21e36e 100644 --- a/braggtrack/cli/segment_dataset.py +++ b/braggtrack/cli/segment_dataset.py @@ -160,9 +160,6 @@ def main() -> int: volume = synth_volume_from_file(scan.path) source = "synthetic_fallback" - raw_threshold = otsu_threshold(volume.ravel()) - threshold = raw_threshold * float(args.threshold_fraction) - if args.method == "dino": result = segment_dino( volume, @@ -175,6 +172,8 @@ def main() -> int: min_overlap_fraction=args.dino_min_overlap, ) else: + raw_threshold = otsu_threshold(volume.ravel()) + threshold = raw_threshold * float(args.threshold_fraction) result = segment_classical( volume, threshold=threshold, @@ -203,9 +202,9 @@ def main() -> int: "file": str(scan.path), "source": source, "method": args.method, - "threshold": raw_threshold, + "threshold": result.threshold, "threshold_fraction": args.threshold_fraction, - "effective_threshold": threshold, + "effective_threshold": result.threshold, "merge_distance": args.merge_distance, "seed_count": result.seed_count, "component_count": len(table), diff --git a/braggtrack/segmentation/dino_segment.py b/braggtrack/segmentation/dino_segment.py index 6253829..885c402 100644 --- a/braggtrack/segmentation/dino_segment.py +++ b/braggtrack/segmentation/dino_segment.py @@ -8,9 +8,13 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from braggtrack.semantic.dino import BackendName, PatchFeatureEncoder + @dataclass(frozen=True) class DinoSegmentationResult: @@ -25,7 +29,7 @@ class DinoSegmentationResult: def _extract_slice_features( volume: np.ndarray, - encoder: object, + encoder: PatchFeatureEncoder, *, axis: int = 0, ) -> tuple[np.ndarray, tuple[int, int]]: @@ -39,10 +43,15 @@ def _extract_slice_features( feature_maps: list[np.ndarray] = [] for i in range(n_slices): slc = np.take(volume, i, axis=axis) - fmap = encoder.extract_patch_features(slc) # type: ignore[union-attr] + fmap = encoder.extract_patch_features(slc) feature_maps.append(fmap) - slice_hw = (volume.shape[1] if axis == 0 else volume.shape[0], volume.shape[2] if axis != 2 else volume.shape[1]) + if axis == 0: + slice_hw = (volume.shape[1], volume.shape[2]) + elif axis == 1: + slice_hw = (volume.shape[0], volume.shape[2]) + else: + slice_hw = (volume.shape[0], volume.shape[1]) return np.stack(feature_maps, axis=0), slice_hw @@ -94,15 +103,11 @@ def _upsample_labels( patch_size: int, ) -> np.ndarray: """Nearest-neighbor upsample patch-resolution labels to pixel resolution.""" - h_p, w_p = patch_labels.shape + expanded = np.repeat(np.repeat(patch_labels, patch_size, axis=0), patch_size, axis=1) out = np.zeros(target_shape, dtype=np.int32) - for py in range(h_p): - for px in range(w_p): - y0 = py * patch_size - x0 = px * patch_size - y1 = min(y0 + patch_size, target_shape[0]) - x1 = min(x0 + patch_size, target_shape[1]) - out[y0:y1, x0:x1] = patch_labels[py, px] + h = min(expanded.shape[0], target_shape[0]) + w = min(expanded.shape[1], target_shape[1]) + out[:h, :w] = expanded[:h, :w] return out @@ -148,6 +153,12 @@ def union(a: int, b: int) -> None: if ra != rb: parent[rb] = ra + # Pre-compute label sizes per slice for O(1) lookup during merging. + slice_label_sizes: list[dict[int, int]] = [] + for sl in global_slices: + unique_labels, label_counts = np.unique(sl[sl > 0], return_counts=True) + slice_label_sizes.append(dict(zip(unique_labels.tolist(), label_counts.tolist()))) + # Merge labels across adjacent slices by overlap. for i in range(n_slices - 1): sl_a = global_slices[i] @@ -162,10 +173,10 @@ def union(a: int, b: int) -> None: la = pairs_a[mask] lb = pairs_b[mask] unique_pairs, counts = np.unique(np.stack([la, lb], axis=1), axis=0, return_counts=True) + sizes_a = slice_label_sizes[i] + sizes_b = slice_label_sizes[i + 1] for (lid_a, lid_b), cnt in zip(unique_pairs, counts): - size_a = int(np.count_nonzero(sl_a == lid_a)) - size_b = int(np.count_nonzero(sl_b == lid_b)) - min_size = min(size_a, size_b) + min_size = min(sizes_a.get(int(lid_a), 0), sizes_b.get(int(lid_b), 0)) if min_size > 0 and cnt / min_size >= min_overlap_fraction: union(int(lid_a), int(lid_b)) @@ -190,7 +201,7 @@ def union(a: int, b: int) -> None: def segment_dino( volume: np.ndarray, *, - backend: str | None = None, + backend: BackendName | None = None, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", torch_device: str | None = None, n_components_pca: int = 16, @@ -250,10 +261,11 @@ def segment_dino( labels_3d = np.where(foreground, labels_3d, 0).astype(np.int32) component_count = len(np.unique(labels_3d[labels_3d > 0])) + # response is empty — DINO segmentation has no LoG-equivalent response surface. return DinoSegmentationResult( threshold=threshold, seed_count=component_count, component_count=component_count, labeled_volume=labels_3d, - response=np.zeros_like(volume), + response=np.empty(0, dtype=np.float64), ) diff --git a/tests/test_segmentation_dino.py b/tests/test_segmentation_dino.py index 1407822..cabc289 100644 --- a/tests/test_segmentation_dino.py +++ b/tests/test_segmentation_dino.py @@ -1,12 +1,9 @@ """Tests for DINO-based segmentation (mock backend, no GPU).""" -import os import unittest import numpy as np -os.environ.setdefault("BRAGGTRACK_DINO_BACKEND", "mock") - from braggtrack.segmentation.dino_segment import ( DinoSegmentationResult, _cluster_feature_map, @@ -122,7 +119,7 @@ def test_mock_backend_runs(self) -> None: result = segment_dino(volume, backend="mock") self.assertIsInstance(result, DinoSegmentationResult) self.assertEqual(result.labeled_volume.shape, volume.shape) - self.assertEqual(result.response.shape, volume.shape) + self.assertEqual(result.response.size, 0) self.assertGreater(result.threshold, 0) def test_result_field_compatible(self) -> None: From fbf7f7c4905a1a2482e6e1c7525a16e6f81af8a3 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 19:48:23 +0000 Subject: [PATCH 6/8] feat: add per-grain kinematic time-series tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compute physically meaningful evolution quantities for each tracked grain: strain (Δd/d₀), misorientation (angular drift in μ and χ), growth/dissolution (relative intensity and volume changes), and shape evolution (anisotropy, covariance trace). Includes summary statistics and flat-table export. https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/tracking/__init__.py | 12 ++ braggtrack/tracking/kinematics.py | 223 ++++++++++++++++++++++++++ tests/test_kinematics.py | 249 ++++++++++++++++++++++++++++++ 3 files changed, 484 insertions(+) create mode 100644 braggtrack/tracking/kinematics.py create mode 100644 tests/test_kinematics.py diff --git a/braggtrack/tracking/__init__.py b/braggtrack/tracking/__init__.py index 8f426d3..c4eada5 100644 --- a/braggtrack/tracking/__init__.py +++ b/braggtrack/tracking/__init__.py @@ -2,16 +2,28 @@ from .assignment import associate_frames from .cost import CostFunction, GeometrySemanticCost, PositionShapeCost +from .kinematics import ( + GrainKinematics, + KinematicsSummary, + compute_grain_kinematics, + kinematics_to_table, + summarize_kinematics, +) from .lifecycle import TrackEvent, build_tracks, tracks_to_table from .metrics import compute_tracking_metrics __all__ = [ "CostFunction", "GeometrySemanticCost", + "GrainKinematics", + "KinematicsSummary", "PositionShapeCost", "TrackEvent", "associate_frames", "build_tracks", + "compute_grain_kinematics", "compute_tracking_metrics", + "kinematics_to_table", + "summarize_kinematics", "tracks_to_table", ] diff --git a/braggtrack/tracking/kinematics.py b/braggtrack/tracking/kinematics.py new file mode 100644 index 0000000..eb277da --- /dev/null +++ b/braggtrack/tracking/kinematics.py @@ -0,0 +1,223 @@ +"""Per-grain kinematic time-series from tracked diffraction spots. + +Computes physically meaningful evolution quantities for each tracked grain: +strain (Δd/d₀), misorientation (angular drift in μ and χ), growth/dissolution +(relative intensity and volume changes), and shape evolution (anisotropy, +covariance trace). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + + +@dataclass +class GrainKinematics: + """Time-resolved physical evolution of a single tracked grain.""" + + track_id: int + scan_indices: list[int] + + # Centroids per observation + centroid_mu: np.ndarray + centroid_chi: np.ndarray + centroid_d: np.ndarray + + # Strain: (d - d₀) / d₀ + strain: np.ndarray + + # Misorientation: cumulative angular drift from first observation + misorientation_mu: np.ndarray + misorientation_chi: np.ndarray + misorientation_total: np.ndarray + + # Growth / dissolution + integrated_intensity: np.ndarray + voxel_count: np.ndarray + relative_intensity: np.ndarray + relative_volume: np.ndarray + + # Shape evolution + eigenvalues: np.ndarray # (n_obs, 3) + anisotropy: np.ndarray # eig_1 / eig_3 + covariance_trace: np.ndarray # eig_1 + eig_2 + eig_3 + + +@dataclass +class KinematicsSummary: + """Aggregate statistics across all tracked grains.""" + + n_tracks: int + n_full_tracks: int + n_scans: int + + # Per-grain summaries (one entry per track, same order as grain_kinematics) + track_ids: list[int] + max_strain: list[float] + total_misorientation: list[float] + intensity_change_frac: list[float] + volume_change_frac: list[float] + max_anisotropy: list[float] + + grain_kinematics: list[GrainKinematics] = field(repr=False) + + +def compute_grain_kinematics( + track_table: list[dict[str, Any]], +) -> list[GrainKinematics]: + """Compute kinematic time-series for each tracked grain. + + Parameters + ---------- + track_table + Output of :func:`~braggtrack.tracking.lifecycle.tracks_to_table`: + list of dicts with ``track_id``, ``scan_idx``, ``centroid_mu``, + ``centroid_chi``, ``centroid_d``, ``eig_1``, ``eig_2``, ``eig_3``, + ``integrated_intensity``, ``voxel_count``. + + Returns + ------- + list[GrainKinematics] + One entry per track, sorted by track_id. + """ + tracks: dict[int, list[dict]] = {} + for row in track_table: + tracks.setdefault(int(row["track_id"]), []).append(row) + + results: list[GrainKinematics] = [] + for tid in sorted(tracks): + obs = sorted(tracks[tid], key=lambda r: r["scan_idx"]) + n = len(obs) + + scan_indices = [int(r["scan_idx"]) for r in obs] + mu = np.array([float(r["centroid_mu"]) for r in obs]) + chi = np.array([float(r["centroid_chi"]) for r in obs]) + d = np.array([float(r["centroid_d"]) for r in obs]) + + # Strain: (d - d₀) / d₀ + d0 = d[0] + strain = (d - d0) / d0 if d0 != 0 else np.zeros(n) + + # Misorientation: drift from initial position + dmu = mu - mu[0] + dchi = chi - chi[0] + misorientation_total = np.sqrt(dmu**2 + dchi**2) + + intensity = np.array([float(r["integrated_intensity"]) for r in obs]) + voxels = np.array([float(r["voxel_count"]) for r in obs]) + + i0 = intensity[0] + v0 = voxels[0] + rel_intensity = (intensity - i0) / i0 if i0 != 0 else np.zeros(n) + rel_volume = (voxels - v0) / v0 if v0 != 0 else np.zeros(n) + + eigs = np.array([[float(r["eig_1"]), float(r["eig_2"]), float(r["eig_3"])] for r in obs]) + e3 = eigs[:, 2] + anisotropy = np.where(e3 > 0, eigs[:, 0] / e3, 1.0) + cov_trace = eigs.sum(axis=1) + + results.append( + GrainKinematics( + track_id=tid, + scan_indices=scan_indices, + centroid_mu=mu, + centroid_chi=chi, + centroid_d=d, + strain=strain, + misorientation_mu=dmu, + misorientation_chi=dchi, + misorientation_total=misorientation_total, + integrated_intensity=intensity, + voxel_count=voxels, + relative_intensity=rel_intensity, + relative_volume=rel_volume, + eigenvalues=eigs, + anisotropy=anisotropy, + covariance_trace=cov_trace, + ) + ) + return results + + +def summarize_kinematics( + grain_kinematics: list[GrainKinematics], + n_scans: int, +) -> KinematicsSummary: + """Aggregate per-grain kinematics into a summary table. + + Parameters + ---------- + grain_kinematics + Output of :func:`compute_grain_kinematics`. + n_scans + Total number of scans in the sequence. + """ + track_ids: list[int] = [] + max_strain: list[float] = [] + total_misorientation: list[float] = [] + intensity_change: list[float] = [] + volume_change: list[float] = [] + max_anisotropy: list[float] = [] + + for gk in grain_kinematics: + track_ids.append(gk.track_id) + max_strain.append(float(np.max(np.abs(gk.strain)))) + total_misorientation.append(float(gk.misorientation_total[-1]) if len(gk.misorientation_total) > 0 else 0.0) + intensity_change.append(float(gk.relative_intensity[-1]) if len(gk.relative_intensity) > 0 else 0.0) + volume_change.append(float(gk.relative_volume[-1]) if len(gk.relative_volume) > 0 else 0.0) + max_anisotropy.append(float(np.max(gk.anisotropy))) + + n_full = sum(1 for gk in grain_kinematics if len(gk.scan_indices) >= n_scans) + + return KinematicsSummary( + n_tracks=len(grain_kinematics), + n_full_tracks=n_full, + n_scans=n_scans, + track_ids=track_ids, + max_strain=max_strain, + total_misorientation=total_misorientation, + intensity_change_frac=intensity_change, + volume_change_frac=volume_change, + max_anisotropy=max_anisotropy, + grain_kinematics=grain_kinematics, + ) + + +def kinematics_to_table( + grain_kinematics: list[GrainKinematics], +) -> list[dict[str, Any]]: + """Flatten per-grain kinematics into a row-per-observation table. + + Each row contains the track_id, scan_idx, and all computed kinematic + quantities for that observation. Suitable for CSV export or DataFrame + construction. + """ + rows: list[dict[str, Any]] = [] + for gk in grain_kinematics: + for i, scan_idx in enumerate(gk.scan_indices): + rows.append( + { + "track_id": gk.track_id, + "scan_idx": scan_idx, + "centroid_mu": float(gk.centroid_mu[i]), + "centroid_chi": float(gk.centroid_chi[i]), + "centroid_d": float(gk.centroid_d[i]), + "strain": float(gk.strain[i]), + "misorientation_mu": float(gk.misorientation_mu[i]), + "misorientation_chi": float(gk.misorientation_chi[i]), + "misorientation_total": float(gk.misorientation_total[i]), + "integrated_intensity": float(gk.integrated_intensity[i]), + "voxel_count": float(gk.voxel_count[i]), + "relative_intensity": float(gk.relative_intensity[i]), + "relative_volume": float(gk.relative_volume[i]), + "eig_1": float(gk.eigenvalues[i, 0]), + "eig_2": float(gk.eigenvalues[i, 1]), + "eig_3": float(gk.eigenvalues[i, 2]), + "anisotropy": float(gk.anisotropy[i]), + "covariance_trace": float(gk.covariance_trace[i]), + } + ) + return rows diff --git a/tests/test_kinematics.py b/tests/test_kinematics.py new file mode 100644 index 0000000..22b9539 --- /dev/null +++ b/tests/test_kinematics.py @@ -0,0 +1,249 @@ +"""Tests for per-grain kinematic time-series.""" + +import unittest + +import numpy as np + +from braggtrack.tracking.cost import PositionShapeCost +from braggtrack.tracking.kinematics import ( + compute_grain_kinematics, + kinematics_to_table, + summarize_kinematics, +) +from braggtrack.tracking.lifecycle import build_tracks, tracks_to_table + + +def _spot( + mu: float, + chi: float, + d: float, + intensity: float = 100.0, + voxels: int = 10, + eig: tuple[float, float, float] = (0.5, 0.5, 0.5), +) -> dict: + return { + "label": 1, + "voxel_count": voxels, + "integrated_intensity": intensity, + "centroid_mu": mu, + "centroid_chi": chi, + "centroid_d": d, + "eig_1": eig[0], + "eig_2": eig[1], + "eig_3": eig[2], + } + + +def _build_table(scan_tables: list[list[dict]]) -> list[dict]: + cost_fn = PositionShapeCost() + G = build_tracks(scan_tables, cost_fn) + return tracks_to_table(G) + + +class TestComputeGrainKinematics(unittest.TestCase): + def test_single_track_three_scans(self) -> None: + table = _build_table( + [ + [_spot(1.0, 2.0, 10.0, intensity=100, voxels=20)], + [_spot(1.1, 2.2, 10.5, intensity=120, voxels=25)], + [_spot(1.3, 2.5, 11.0, intensity=150, voxels=30)], + ] + ) + results = compute_grain_kinematics(table) + self.assertEqual(len(results), 1) + gk = results[0] + self.assertEqual(len(gk.scan_indices), 3) + self.assertEqual(gk.scan_indices, [0, 1, 2]) + + def test_strain_from_d_spacing(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10.0)], + [_spot(0, 0, 10.5)], + [_spot(0, 0, 11.0)], + ] + ) + gk = compute_grain_kinematics(table)[0] + np.testing.assert_allclose(gk.strain, [0.0, 0.05, 0.10], atol=1e-10) + + def test_misorientation_from_mu_chi(self) -> None: + table = _build_table( + [ + [_spot(0.0, 0.0, 10.0)], + [_spot(3.0, 4.0, 10.0)], + ] + ) + gk = compute_grain_kinematics(table)[0] + np.testing.assert_allclose(gk.misorientation_mu, [0.0, 3.0]) + np.testing.assert_allclose(gk.misorientation_chi, [0.0, 4.0]) + np.testing.assert_allclose(gk.misorientation_total, [0.0, 5.0]) + + def test_relative_intensity_and_volume(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10, intensity=100, voxels=20)], + [_spot(0, 0, 10, intensity=200, voxels=40)], + ] + ) + gk = compute_grain_kinematics(table)[0] + np.testing.assert_allclose(gk.relative_intensity, [0.0, 1.0]) + np.testing.assert_allclose(gk.relative_volume, [0.0, 1.0]) + + def test_anisotropy(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10, eig=(3.0, 2.0, 1.0))], + [_spot(0, 0, 10, eig=(6.0, 2.0, 1.0))], + ] + ) + gk = compute_grain_kinematics(table)[0] + np.testing.assert_allclose(gk.anisotropy, [3.0, 6.0]) + + def test_covariance_trace(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10, eig=(3.0, 2.0, 1.0))], + ] + ) + gk = compute_grain_kinematics(table)[0] + np.testing.assert_allclose(gk.covariance_trace, [6.0]) + + def test_multiple_tracks(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10), _spot(50, 50, 50)], + [_spot(0.1, 0.1, 10.1), _spot(50.1, 50.1, 50.1)], + ] + ) + results = compute_grain_kinematics(table) + self.assertEqual(len(results), 2) + tids = [gk.track_id for gk in results] + self.assertEqual(len(set(tids)), 2) + + def test_birth_mid_sequence(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10)], + [_spot(0.1, 0.1, 10.1), _spot(50, 50, 50)], + [_spot(0.2, 0.2, 10.2), _spot(50.1, 50.1, 50.1)], + ] + ) + results = compute_grain_kinematics(table) + self.assertEqual(len(results), 2) + lengths = sorted(len(gk.scan_indices) for gk in results) + self.assertEqual(lengths, [2, 3]) + + def test_empty_table(self) -> None: + results = compute_grain_kinematics([]) + self.assertEqual(results, []) + + def test_zero_d_spacing_no_crash(self) -> None: + table = [ + { + "track_id": 1, + "scan_idx": 0, + "centroid_mu": 0, + "centroid_chi": 0, + "centroid_d": 0.0, + "eig_1": 1, + "eig_2": 1, + "eig_3": 1, + "integrated_intensity": 100, + "voxel_count": 10, + }, + ] + results = compute_grain_kinematics(table) + self.assertEqual(len(results), 1) + np.testing.assert_allclose(results[0].strain, [0.0]) + + +class TestSummarizeKinematics(unittest.TestCase): + def test_summary_fields(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10, intensity=100, voxels=20, eig=(3, 2, 1))], + [_spot(0.5, 0.3, 10.2, intensity=150, voxels=30, eig=(4, 2, 1))], + ] + ) + gk = compute_grain_kinematics(table) + summary = summarize_kinematics(gk, n_scans=2) + self.assertEqual(summary.n_tracks, 1) + self.assertEqual(summary.n_full_tracks, 1) + self.assertEqual(summary.n_scans, 2) + self.assertEqual(len(summary.max_strain), 1) + self.assertGreater(summary.max_strain[0], 0) + self.assertGreater(summary.total_misorientation[0], 0) + self.assertGreater(summary.intensity_change_frac[0], 0) + self.assertGreater(summary.volume_change_frac[0], 0) + + def test_full_track_count(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10)], + [_spot(0.1, 0.1, 10.1), _spot(50, 50, 50)], + [_spot(0.2, 0.2, 10.2), _spot(50.1, 50.1, 50.1)], + ] + ) + gk = compute_grain_kinematics(table) + summary = summarize_kinematics(gk, n_scans=3) + self.assertEqual(summary.n_full_tracks, 1) + self.assertEqual(summary.n_tracks, 2) + + +class TestKinematicsToTable(unittest.TestCase): + def test_row_count(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10), _spot(50, 50, 50)], + [_spot(0.1, 0.1, 10.1), _spot(50.1, 50.1, 50.1)], + ] + ) + gk = compute_grain_kinematics(table) + rows = kinematics_to_table(gk) + self.assertEqual(len(rows), 4) + + def test_required_columns(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10)], + [_spot(0.1, 0.1, 10.1)], + ] + ) + gk = compute_grain_kinematics(table) + rows = kinematics_to_table(gk) + expected = { + "track_id", + "scan_idx", + "centroid_mu", + "centroid_chi", + "centroid_d", + "strain", + "misorientation_mu", + "misorientation_chi", + "misorientation_total", + "integrated_intensity", + "voxel_count", + "relative_intensity", + "relative_volume", + "eig_1", + "eig_2", + "eig_3", + "anisotropy", + "covariance_trace", + } + self.assertEqual(set(rows[0].keys()), expected) + + def test_values_match_grain_kinematics(self) -> None: + table = _build_table( + [ + [_spot(0, 0, 10.0)], + [_spot(0, 0, 10.5)], + ] + ) + gk = compute_grain_kinematics(table) + rows = kinematics_to_table(gk) + self.assertAlmostEqual(rows[1]["strain"], 0.05) + + +if __name__ == "__main__": + unittest.main() From 68a2198f946f7da57ef7346f7c98972fee7aa7b4 Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 22:18:48 +0000 Subject: [PATCH 7/8] test: add physics ground-truth validation for kinematics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Deterministic 4-grain scenario with analytically known evolution: - Grain A: linear elastic loading (strain = 0.1%/step) - Grain B: pure rotation (0.5°/step μ, 0.3°/step χ, no strain) - Grain C: dissolution (linear intensity/volume decay) - Grain D: late nucleation (born scan 2, growing) Verifies exact recovery of strain, misorientation, growth/dissolution, shape metrics, and summary statistics through the full pipeline. https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- tests/test_kinematics_groundtruth.py | 257 +++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 tests/test_kinematics_groundtruth.py diff --git a/tests/test_kinematics_groundtruth.py b/tests/test_kinematics_groundtruth.py new file mode 100644 index 0000000..2c4424a --- /dev/null +++ b/tests/test_kinematics_groundtruth.py @@ -0,0 +1,257 @@ +"""Ground-truth validation of kinematics physics. + +Constructs a deterministic multi-grain scenario with analytically known +physical evolution, runs it through the full pipeline (tracking → kinematics), +and asserts exact recovery of prescribed strain, misorientation, growth, and +shape evolution. + +Scenario (5 scans, 4 grains well-separated so tracking is unambiguous): + + Grain A — Linear elastic loading + d increases 0.1% per step from d₀=10.0 + Expected strain: [0, 0.001, 0.002, 0.003, 0.004] + + Grain B — Pure rotation (no strain) + μ rotates +0.5°/step, χ rotates +0.3°/step, d constant + Expected misorientation_total: [0, √(0.5²+0.3²), 2×..., 3×..., 4×...] + + Grain C — Dissolution (shrinking grain) + Intensity drops linearly: 1000 → 800 → 600 → 400 → 200 + Volume drops linearly: 100 → 80 → 60 → 40 → 20 + Expected relative_intensity: [0, -0.2, -0.4, -0.6, -0.8] + + Grain D — Late nucleation (born at scan 2) + Appears at scan index 2, grows: intensity 50 → 100 → 150 + Expected: 3 observations, strain computed from its own d₀ +""" + +import unittest + +import numpy as np + +from braggtrack.tracking.cost import PositionShapeCost +from braggtrack.tracking.kinematics import compute_grain_kinematics, summarize_kinematics +from braggtrack.tracking.lifecycle import build_tracks, tracks_to_table + + +N_SCANS = 5 + + +def _spot( + mu: float, + chi: float, + d: float, + intensity: float = 100.0, + voxels: int = 10, + eig: tuple[float, float, float] = (1.0, 1.0, 1.0), +) -> dict: + return { + "label": 1, + "voxel_count": voxels, + "integrated_intensity": intensity, + "centroid_mu": mu, + "centroid_chi": chi, + "centroid_d": d, + "eig_1": eig[0], + "eig_2": eig[1], + "eig_3": eig[2], + } + + +def _build_scenario() -> list[list[dict]]: + """Build the 5-scan, 4-grain ground truth scenario.""" + scans: list[list[dict]] = [] + + for i in range(N_SCANS): + frame: list[dict] = [] + + # Grain A: linear elastic loading at (mu=10, chi=10) + d_a = 10.0 * (1.0 + 0.001 * i) + frame.append(_spot(10.0, 10.0, d_a, intensity=500, voxels=50)) + + # Grain B: pure rotation at (mu=50+0.5*i, chi=50+0.3*i, d=20) + frame.append(_spot(50.0 + 0.5 * i, 50.0 + 0.3 * i, 20.0, intensity=500, voxels=50)) + + # Grain C: dissolution at (mu=90, chi=90, d=15) + intensity_c = 1000.0 - 200.0 * i + voxels_c = 100 - 20 * i + frame.append(_spot(90.0, 90.0, 15.0, intensity=intensity_c, voxels=voxels_c)) + + # Grain D: nucleation at scan 2, at (mu=30, chi=70, d=25) + if i >= 2: + intensity_d = 50.0 + 50.0 * (i - 2) + d_d = 25.0 * (1.0 + 0.002 * (i - 2)) + frame.append( + _spot(30.0, 70.0, d_d, intensity=intensity_d, voxels=30, eig=(3.0, 2.0, 1.0)) + ) + + scans.append(frame) + + return scans + + +def _run_pipeline(scans: list[list[dict]]): + """Run full tracking + kinematics pipeline.""" + cost_fn = PositionShapeCost() + G = build_tracks(scans, cost_fn) + table = tracks_to_table(G) + return compute_grain_kinematics(table) + + +def _find_grain(results, mu_approx: float, chi_approx: float): + """Find grain by approximate initial centroid position.""" + for gk in results: + if abs(float(gk.centroid_mu[0]) - mu_approx) < 5.0 and abs(float(gk.centroid_chi[0]) - chi_approx) < 5.0: + return gk + raise ValueError(f"No grain near mu={mu_approx}, chi={chi_approx}") + + +class TestGroundTruthElasticLoading(unittest.TestCase): + """Grain A: verify strain recovery from prescribed d-spacing evolution.""" + + @classmethod + def setUpClass(cls) -> None: + cls.results = _run_pipeline(_build_scenario()) + cls.grain_a = _find_grain(cls.results, 10.0, 10.0) + + def test_track_length(self) -> None: + self.assertEqual(len(self.grain_a.scan_indices), N_SCANS) + + def test_strain_values(self) -> None: + expected = np.array([0.001 * i for i in range(N_SCANS)]) + np.testing.assert_allclose(self.grain_a.strain, expected, atol=1e-12) + + def test_no_misorientation(self) -> None: + np.testing.assert_allclose(self.grain_a.misorientation_total, 0.0, atol=1e-12) + + def test_constant_intensity(self) -> None: + np.testing.assert_allclose(self.grain_a.relative_intensity, 0.0, atol=1e-12) + + +class TestGroundTruthRotation(unittest.TestCase): + """Grain B: verify misorientation recovery from prescribed μ/χ drift.""" + + @classmethod + def setUpClass(cls) -> None: + cls.results = _run_pipeline(_build_scenario()) + cls.grain_b = _find_grain(cls.results, 50.0, 50.0) + + def test_track_length(self) -> None: + self.assertEqual(len(self.grain_b.scan_indices), N_SCANS) + + def test_misorientation_mu(self) -> None: + expected = np.array([0.5 * i for i in range(N_SCANS)]) + np.testing.assert_allclose(self.grain_b.misorientation_mu, expected, atol=1e-12) + + def test_misorientation_chi(self) -> None: + expected = np.array([0.3 * i for i in range(N_SCANS)]) + np.testing.assert_allclose(self.grain_b.misorientation_chi, expected, atol=1e-12) + + def test_misorientation_total(self) -> None: + step = np.sqrt(0.5**2 + 0.3**2) + expected = np.array([step * i for i in range(N_SCANS)]) + np.testing.assert_allclose(self.grain_b.misorientation_total, expected, atol=1e-12) + + def test_no_strain(self) -> None: + np.testing.assert_allclose(self.grain_b.strain, 0.0, atol=1e-12) + + +class TestGroundTruthDissolution(unittest.TestCase): + """Grain C: verify growth/dissolution from prescribed intensity/volume decay.""" + + @classmethod + def setUpClass(cls) -> None: + cls.results = _run_pipeline(_build_scenario()) + cls.grain_c = _find_grain(cls.results, 90.0, 90.0) + + def test_track_length(self) -> None: + self.assertEqual(len(self.grain_c.scan_indices), N_SCANS) + + def test_relative_intensity(self) -> None: + # I = [1000, 800, 600, 400, 200], I₀ = 1000 + expected = np.array([0.0, -0.2, -0.4, -0.6, -0.8]) + np.testing.assert_allclose(self.grain_c.relative_intensity, expected, atol=1e-12) + + def test_relative_volume(self) -> None: + # V = [100, 80, 60, 40, 20], V₀ = 100 + expected = np.array([0.0, -0.2, -0.4, -0.6, -0.8]) + np.testing.assert_allclose(self.grain_c.relative_volume, expected, atol=1e-12) + + def test_no_strain(self) -> None: + np.testing.assert_allclose(self.grain_c.strain, 0.0, atol=1e-12) + + def test_no_misorientation(self) -> None: + np.testing.assert_allclose(self.grain_c.misorientation_total, 0.0, atol=1e-12) + + +class TestGroundTruthNucleation(unittest.TestCase): + """Grain D: verify late-born grain with correct reference frame.""" + + @classmethod + def setUpClass(cls) -> None: + cls.results = _run_pipeline(_build_scenario()) + cls.grain_d = _find_grain(cls.results, 30.0, 70.0) + + def test_track_length(self) -> None: + self.assertEqual(len(self.grain_d.scan_indices), 3) + + def test_birth_scan(self) -> None: + self.assertEqual(self.grain_d.scan_indices[0], 2) + + def test_strain_from_own_d0(self) -> None: + # d₀ = 25.0, d = [25.0, 25.05, 25.10] + # strain = [0, 0.002, 0.004] + expected = np.array([0.0, 0.002, 0.004]) + np.testing.assert_allclose(self.grain_d.strain, expected, atol=1e-12) + + def test_growth(self) -> None: + # intensity = [50, 100, 150], I₀ = 50 + expected = np.array([0.0, 1.0, 2.0]) + np.testing.assert_allclose(self.grain_d.relative_intensity, expected, atol=1e-12) + + def test_anisotropy(self) -> None: + # eig = (3, 2, 1), anisotropy = 3/1 = 3.0 for all observations + np.testing.assert_allclose(self.grain_d.anisotropy, 3.0, atol=1e-12) + + def test_covariance_trace(self) -> None: + # eig = (3, 2, 1), trace = 6.0 + np.testing.assert_allclose(self.grain_d.covariance_trace, 6.0, atol=1e-12) + + +class TestGroundTruthSummary(unittest.TestCase): + """Verify summary statistics against known scenario.""" + + @classmethod + def setUpClass(cls) -> None: + cls.results = _run_pipeline(_build_scenario()) + cls.summary = summarize_kinematics(cls.results, n_scans=N_SCANS) + + def test_track_count(self) -> None: + self.assertEqual(self.summary.n_tracks, 4) + + def test_full_track_count(self) -> None: + # Grains A, B, C span all 5 scans; Grain D only 3 + self.assertEqual(self.summary.n_full_tracks, 3) + + def test_max_strain_grain_a(self) -> None: + # Grain A max strain = 0.004 + grain_a = _find_grain(self.results, 10.0, 10.0) + idx = self.summary.track_ids.index(grain_a.track_id) + self.assertAlmostEqual(self.summary.max_strain[idx], 0.004) + + def test_total_misorientation_grain_b(self) -> None: + # Grain B final misorientation = 4 * √(0.25 + 0.09) + grain_b = _find_grain(self.results, 50.0, 50.0) + idx = self.summary.track_ids.index(grain_b.track_id) + expected = 4.0 * np.sqrt(0.5**2 + 0.3**2) + self.assertAlmostEqual(self.summary.total_misorientation[idx], expected) + + def test_intensity_change_grain_c(self) -> None: + # Grain C final relative intensity = -0.8 + grain_c = _find_grain(self.results, 90.0, 90.0) + idx = self.summary.track_ids.index(grain_c.track_id) + self.assertAlmostEqual(self.summary.intensity_change_frac[idx], -0.8) + + +if __name__ == "__main__": + unittest.main() From a065213c9fa5903a2c53ff23548c5e98eb83e0eb Mon Sep 17 00:00:00 2001 From: James Le Houx Date: Sun, 17 May 2026 22:33:43 +0000 Subject: [PATCH 8/8] refactor: remove week naming, fix structural issues across codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Remove all "Week N" references from docstrings, artifact paths, schema versions, script names, and test names. Replace with descriptive names: segmentation, tracking, embedding, io. 2. Fix broken import in ablation script (was importing non-existent _load_feature_csv from track_dataset; now imports load_feature_csv from _utils). 3. Remove no-op assignment in otsu.py smooth_thresholds (line that assigned smoothed[outlier] = smoothed[outlier]). 4. Add tests/conftest.py with shared make_spot() fixture factory, deduplicate _spot() helpers across test files. 5. Standardize CLI arg naming: track_dataset now uses "root" instead of "indir" to match all other CLI modules. 6. Extract _write_notebook() to shared write_qc_notebook() in cli/_utils.py, eliminating duplication between segment_dataset and track_dataset. Artifact paths: week2 → segmentation, week3 → tracking, week4 → embedding Schema versions: week2.v1 → segmentation.v1, week3.v1 → tracking.v1, week4.v1 → embedding.v1/tracking_semantic.v1 Script renames: check_week2_* → check_segmentation_*, etc. Test renames: test_week1_* → test_io_*, test_week2_* → test_segmentation_*, etc. https://claude.ai/code/session_015Y9zQk4A8uKJAorKuvBoCk --- braggtrack/cli/_utils.py | 28 ++++- braggtrack/cli/embed_dataset.py | 10 +- braggtrack/cli/segment_dataset.py | 76 +++++------- braggtrack/cli/segment_synthetic.py | 2 +- braggtrack/cli/track_dataset.py | 112 +++++++----------- braggtrack/io/validation.py | 2 +- braggtrack/segmentation/__init__.py | 2 +- braggtrack/segmentation/classical.py | 2 +- braggtrack/segmentation/otsu.py | 5 - braggtrack/segmentation/pipeline.py | 2 +- braggtrack/semantic/__init__.py | 2 +- braggtrack/tracking/__init__.py | 2 +- braggtrack/tracking/cost.py | 4 +- braggtrack/tracking/synthetic.py | 2 +- ...ablation_week4.py => ablation_semantic.py} | 23 ++-- scripts/check_dino_acceptance.py | 4 +- ...tance.py => check_embedding_acceptance.py} | 12 +- ...k_acceptance.py => check_io_acceptance.py} | 2 +- ...ce.py => check_segmentation_acceptance.py} | 8 +- ...ptance.py => check_tracking_acceptance.py} | 12 +- scripts/ci_report.py | 32 ++--- tests/conftest.py | 25 ++++ tests/test_acceptance_script.py | 2 +- ...k1_validation.py => test_io_validation.py} | 2 +- tests/test_kinematics.py | 22 +--- tests/test_kinematics_groundtruth.py | 18 +-- tests/test_segment_dataset_cli.py | 4 +- ...nce.py => test_segmentation_acceptance.py} | 6 +- ...est_semantic_week4.py => test_semantic.py} | 2 +- tests/test_tracking.py | 17 +-- ...eptance.py => test_tracking_acceptance.py} | 6 +- 31 files changed, 204 insertions(+), 244 deletions(-) rename scripts/{ablation_week4.py => ablation_semantic.py} (88%) rename scripts/{check_week4_acceptance.py => check_embedding_acceptance.py} (88%) rename scripts/{check_acceptance.py => check_io_acceptance.py} (96%) rename scripts/{check_week2_acceptance.py => check_segmentation_acceptance.py} (87%) rename scripts/{check_week3_acceptance.py => check_tracking_acceptance.py} (88%) create mode 100644 tests/conftest.py rename tests/{test_week1_validation.py => test_io_validation.py} (95%) rename tests/{test_week2_acceptance.py => test_segmentation_acceptance.py} (73%) rename tests/{test_semantic_week4.py => test_semantic.py} (99%) rename tests/{test_week3_acceptance.py => test_tracking_acceptance.py} (76%) diff --git a/braggtrack/cli/_utils.py b/braggtrack/cli/_utils.py index fe70467..6e60042 100644 --- a/braggtrack/cli/_utils.py +++ b/braggtrack/cli/_utils.py @@ -1,9 +1,10 @@ -"""Shared CLI helpers for volume loading, CSV I/O, and synthetic fallback.""" +"""Shared CLI helpers for volume loading, CSV I/O, notebooks, and synthetic fallback.""" from __future__ import annotations import csv import hashlib +import json from pathlib import Path from typing import Any @@ -59,3 +60,28 @@ def write_csv(path: Path, rows: list[dict[str, Any]]) -> None: writer = csv.DictWriter(fh, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) + + +def write_qc_notebook(path: Path, *, title: str, code_source: list[str]) -> None: + """Write a minimal QC notebook with a markdown header and one code cell.""" + nb = { + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [f"# {title}\n"], + }, + { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": code_source, + }, + ], + "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, + "nbformat": 4, + "nbformat_minor": 5, + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(nb, indent=2)) diff --git a/braggtrack/cli/embed_dataset.py b/braggtrack/cli/embed_dataset.py index d5b52dc..3e6d602 100644 --- a/braggtrack/cli/embed_dataset.py +++ b/braggtrack/cli/embed_dataset.py @@ -1,4 +1,4 @@ -"""Compute Week 4 multi-view MIP embeddings for segmented spots.""" +"""Compute multi-view MIP embeddings for segmented spots.""" from __future__ import annotations @@ -26,8 +26,10 @@ def build_parser() -> argparse.ArgumentParser: default=None, help="Dataset root with scan folders (default: data/sample_operando if present, else .)", ) - p.add_argument("--segdir", default="artifacts/week2", help="Segmentation output with features.csv + labels.npz") - p.add_argument("--outdir", default="artifacts/week4", help="Embedding output root") + p.add_argument( + "--segdir", default="artifacts/segmentation", help="Segmentation output with features.csv + labels.npz" + ) + p.add_argument("--outdir", default="artifacts/embedding", help="Embedding output root") p.add_argument("--margin", type=int, default=2, help="Voxel padding around each spot bbox") p.add_argument( "--backend", @@ -122,7 +124,7 @@ def main() -> int: "dim": dim, "backend": args.backend, "model": args.model if args.backend == "torch" else "mock-hash", - "schema_version": "week4.v1", + "schema_version": "embedding.v1", } (scan_out / "embedding_manifest.json").write_text(json.dumps(manifest, indent=2)) summaries.append({**manifest, "embeddings": str(scan_out / "embeddings.npz")}) diff --git a/braggtrack/cli/segment_dataset.py b/braggtrack/cli/segment_dataset.py index e21e36e..ec16efe 100644 --- a/braggtrack/cli/segment_dataset.py +++ b/braggtrack/cli/segment_dataset.py @@ -8,7 +8,7 @@ import numpy as np -from braggtrack.cli._utils import synth_volume_from_file, write_csv +from braggtrack.cli._utils import synth_volume_from_file, write_csv, write_qc_notebook from braggtrack.io import ( MissingH5DependencyError, discover_operando_scans, @@ -35,7 +35,7 @@ def build_parser() -> argparse.ArgumentParser: default=None, help="Dataset root with scan folders (default: data/sample_operando if present, else .)", ) - parser.add_argument("--outdir", default="artifacts/week2", help="Output artifact directory") + parser.add_argument("--outdir", default="artifacts/segmentation", help="Output artifact directory") parser.add_argument( "--method", choices=["classical", "dino"], @@ -94,49 +94,25 @@ def build_parser() -> argparse.ArgumentParser: return parser -def _write_notebook(path: Path) -> None: - nb = { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Week 2 Visual QC\n", - "Loads per-scan feature tables and overlays up to 20 representative objects.\n", - ], - }, - { - "cell_type": "code", - "execution_count": None, - "metadata": {}, - "outputs": [], - "source": [ - "import csv, json\n", - "from pathlib import Path\n", - "import matplotlib.pyplot as plt\n", - "root = Path('artifacts/week2')\n", - "for scan_dir in sorted(root.glob('scan*')):\n", - " table = scan_dir / 'features.csv'\n", - " if not table.exists():\n", - " continue\n", - " rows = list(csv.DictReader(table.open()))\n", - " rows = sorted(rows, key=lambda r: float(r['integrated_intensity']), reverse=True)[:20]\n", - " print(scan_dir.name, 'objects:', len(rows))\n", - " fig, ax = plt.subplots(figsize=(8, 4))\n", - " ax.set_title(f'{scan_dir.name} top-20 object intensities')\n", - " ax.bar(range(len(rows)), [float(r['integrated_intensity']) for r in rows])\n", - " ax.set_xlabel('Object rank')\n", - " ax.set_ylabel('Integrated intensity')\n", - " plt.show()\n", - ], - }, - ], - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "nbformat": 4, - "nbformat_minor": 5, - } - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(nb, indent=2)) +_SEGMENTATION_QC_CODE = [ + "import csv, json\n", + "from pathlib import Path\n", + "import matplotlib.pyplot as plt\n", + "root = Path('artifacts/segmentation')\n", + "for scan_dir in sorted(root.glob('scan*')):\n", + " table = scan_dir / 'features.csv'\n", + " if not table.exists():\n", + " continue\n", + " rows = list(csv.DictReader(table.open()))\n", + " rows = sorted(rows, key=lambda r: float(r['integrated_intensity']), reverse=True)[:20]\n", + " print(scan_dir.name, 'objects:', len(rows))\n", + " fig, ax = plt.subplots(figsize=(8, 4))\n", + " ax.set_title(f'{scan_dir.name} top-20 object intensities')\n", + " ax.bar(range(len(rows)), [float(r['integrated_intensity']) for r in rows])\n", + " ax.set_xlabel('Object rank')\n", + " ax.set_ylabel('Integrated intensity')\n", + " plt.show()\n", +] def main() -> int: @@ -208,7 +184,7 @@ def main() -> int: "merge_distance": args.merge_distance, "seed_count": result.seed_count, "component_count": len(table), - "schema_version": "week2.v1", + "schema_version": "segmentation.v1", "labels_archive": str(scan_out / "labels.npz"), }, indent=2, @@ -225,13 +201,17 @@ def main() -> int: "summary": str(scan_out / "summary.json"), "features": str(scan_out / "features.csv"), "labels_archive": str(scan_out / "labels.npz"), - "schema_version": "week2.v1", + "schema_version": "segmentation.v1", } ) (outdir / "segmentation_summary.json").write_text(json.dumps(summaries, indent=2)) write_csv(outdir / "segmentation_summary.csv", summaries) - _write_notebook(outdir / "qc" / "week2_visual_qc.ipynb") + write_qc_notebook( + outdir / "qc" / "segmentation_visual_qc.ipynb", + title="Segmentation Visual QC", + code_source=_SEGMENTATION_QC_CODE, + ) print(json.dumps(summaries, indent=2)) return 0 if summaries else 1 diff --git a/braggtrack/cli/segment_synthetic.py b/braggtrack/cli/segment_synthetic.py index e7315fa..af0a035 100644 --- a/braggtrack/cli/segment_synthetic.py +++ b/braggtrack/cli/segment_synthetic.py @@ -1,4 +1,4 @@ -"""Run a Week 2 segmentation smoke test on a synthetic 3D volume.""" +"""Run a segmentation smoke test on a synthetic 3D volume.""" from __future__ import annotations diff --git a/braggtrack/cli/track_dataset.py b/braggtrack/cli/track_dataset.py index 8af3ca6..53806d0 100644 --- a/braggtrack/cli/track_dataset.py +++ b/braggtrack/cli/track_dataset.py @@ -1,4 +1,4 @@ -"""Run physics-only tracking across segmented scan feature tables.""" +"""Run tracking across segmented scan feature tables.""" from __future__ import annotations @@ -9,7 +9,7 @@ import numpy as np -from braggtrack.cli._utils import load_feature_csv, write_csv +from braggtrack.cli._utils import load_feature_csv, write_csv, write_qc_notebook from braggtrack.tracking import ( GeometrySemanticCost, PositionShapeCost, @@ -22,9 +22,9 @@ def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "indir", nargs="?", default="artifacts/week2", help="Directory with per-scan feature CSVs (Week 2 output)" + "root", nargs="?", default="artifacts/segmentation", help="Directory with per-scan feature CSVs" ) - parser.add_argument("--outdir", default="artifacts/week3", help="Output artifact directory") + parser.add_argument("--outdir", default="artifacts/tracking", help="Output artifact directory") parser.add_argument("--position-weight", type=float, default=1.0) parser.add_argument("--shape-weight", type=float, default=0.5) parser.add_argument("--gate-mu", type=float, default=float("inf")) @@ -34,7 +34,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument( "--embedding-dir", default=None, - help="Week 4 root with scanXXXX/embeddings.npz (from embed_dataset)", + help="Embedding root with scanXXXX/embeddings.npz (from embed_dataset)", ) parser.add_argument( "--cost-alpha", @@ -68,71 +68,47 @@ def _merge_embeddings(rows: list[dict[str, Any]], emb: dict[int, np.ndarray]) -> row["embedding"] = emb[lid] -def _write_notebook(path: Path) -> None: - nb = { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Week 3 Tracking QC\n", - "Visualises spot trajectories across scans and highlights near-overlap cases.\n", - ], - }, - { - "cell_type": "code", - "execution_count": None, - "metadata": {}, - "outputs": [], - "source": [ - "import csv, json\n", - "from pathlib import Path\n", - "import matplotlib.pyplot as plt\n", - "\n", - "root = Path('artifacts/week3')\n", - "tracks = list(csv.DictReader((root / 'tracks.csv').open()))\n", - "\n", - "# Group by track_id\n", - "by_track = {}\n", - "for r in tracks:\n", - " tid = int(r['track_id'])\n", - " by_track.setdefault(tid, []).append(r)\n", - "\n", - "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", - "for tid, obs in sorted(by_track.items()):\n", - " scans = [int(r['scan_idx']) for r in obs]\n", - " mus = [float(r['centroid_mu']) for r in obs]\n", - " chis = [float(r['centroid_chi']) for r in obs]\n", - " ds = [float(r['centroid_d']) for r in obs]\n", - " axes[0].plot(scans, mus, 'o-', label=f'T{tid}')\n", - " axes[1].plot(scans, chis, 'o-', label=f'T{tid}')\n", - " axes[2].plot(scans, ds, 'o-', label=f'T{tid}')\n", - "\n", - "for ax, lbl in zip(axes, ['centroid_mu', 'centroid_chi', 'centroid_d']):\n", - " ax.set_xlabel('Scan index')\n", - " ax.set_ylabel(lbl)\n", - " ax.legend(fontsize=7)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "metrics = json.loads((root / 'tracking_metrics.json').read_text())\n", - "print('Tracking metrics:')\n", - "for k, v in metrics.items():\n", - " print(f' {k}: {v}')\n", - ], - }, - ], - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "nbformat": 4, - "nbformat_minor": 5, - } - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(nb, indent=2)) +_TRACKING_QC_CODE = [ + "import csv, json\n", + "from pathlib import Path\n", + "import matplotlib.pyplot as plt\n", + "\n", + "root = Path('artifacts/tracking')\n", + "tracks = list(csv.DictReader((root / 'tracks.csv').open()))\n", + "\n", + "# Group by track_id\n", + "by_track = {}\n", + "for r in tracks:\n", + " tid = int(r['track_id'])\n", + " by_track.setdefault(tid, []).append(r)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + "for tid, obs in sorted(by_track.items()):\n", + " scans = [int(r['scan_idx']) for r in obs]\n", + " mus = [float(r['centroid_mu']) for r in obs]\n", + " chis = [float(r['centroid_chi']) for r in obs]\n", + " ds = [float(r['centroid_d']) for r in obs]\n", + " axes[0].plot(scans, mus, 'o-', label=f'T{tid}')\n", + " axes[1].plot(scans, chis, 'o-', label=f'T{tid}')\n", + " axes[2].plot(scans, ds, 'o-', label=f'T{tid}')\n", + "\n", + "for ax, lbl in zip(axes, ['centroid_mu', 'centroid_chi', 'centroid_d']):\n", + " ax.set_xlabel('Scan index')\n", + " ax.set_ylabel(lbl)\n", + " ax.legend(fontsize=7)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "metrics = json.loads((root / 'tracking_metrics.json').read_text())\n", + "print('Tracking metrics:')\n", + "for k, v in metrics.items():\n", + " print(f' {k}: {v}')\n", +] def main() -> int: args = build_parser().parse_args() - indir = Path(args.indir) + indir = Path(args.root) outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) @@ -188,7 +164,7 @@ def main() -> int: write_csv(outdir / "tracks.csv", track_rows) (outdir / "tracking_metrics.json").write_text(json.dumps(metrics, indent=2)) - schema_version = "week4.v1" if args.cost_beta != 0.0 else "week3.v1" + schema_version = "tracking_semantic.v1" if args.cost_beta != 0.0 else "tracking.v1" summary = { "scan_names": scan_names, "n_scans": len(scan_tables), @@ -200,7 +176,7 @@ def main() -> int: "embedding_dir": str(emb_root) if emb_root else None, } (outdir / "tracking_summary.json").write_text(json.dumps(summary, indent=2)) - _write_notebook(outdir / "qc" / "week3_tracking_qc.ipynb") + write_qc_notebook(outdir / "qc" / "tracking_qc.ipynb", title="Tracking QC", code_source=_TRACKING_QC_CODE) print(json.dumps(summary, indent=2)) return 0 diff --git a/braggtrack/io/validation.py b/braggtrack/io/validation.py index 46f3de0..b9cd211 100644 --- a/braggtrack/io/validation.py +++ b/braggtrack/io/validation.py @@ -1,4 +1,4 @@ -"""Dataset validation for Week 1 data-contract checks.""" +"""Dataset validation and data-contract checks.""" from __future__ import annotations diff --git a/braggtrack/segmentation/__init__.py b/braggtrack/segmentation/__init__.py index 7c54b98..fad1cec 100644 --- a/braggtrack/segmentation/__init__.py +++ b/braggtrack/segmentation/__init__.py @@ -1,4 +1,4 @@ -"""Segmentation utilities for Week 2 baselines.""" +"""Classical and DINO-based 3D segmentation of diffraction spots.""" from .classical import ( ClassicalSegmentationResult, diff --git a/braggtrack/segmentation/classical.py b/braggtrack/segmentation/classical.py index b0821ba..878921b 100644 --- a/braggtrack/segmentation/classical.py +++ b/braggtrack/segmentation/classical.py @@ -1,4 +1,4 @@ -"""Classical 3D segmentation building blocks for Week 2.""" +"""Classical 3D segmentation building blocks (LoG + watershed).""" from __future__ import annotations diff --git a/braggtrack/segmentation/otsu.py b/braggtrack/segmentation/otsu.py index 48ae332..4cef583 100644 --- a/braggtrack/segmentation/otsu.py +++ b/braggtrack/segmentation/otsu.py @@ -49,11 +49,6 @@ def smooth_thresholds( smoothed = np.empty(n, dtype=np.float64) for i in range(n): smoothed[i] = float(np.median(padded[i : i + w])) - residual = np.abs(raw - smoothed) - mad = float(np.median(residual)) if n > 1 else 0.0 - if mad > 0: - outlier = residual > mad_scale * mad - smoothed[outlier] = smoothed[outlier] # already local median return smoothed diff --git a/braggtrack/segmentation/pipeline.py b/braggtrack/segmentation/pipeline.py index c31d6eb..183ca6c 100644 --- a/braggtrack/segmentation/pipeline.py +++ b/braggtrack/segmentation/pipeline.py @@ -1,4 +1,4 @@ -"""Simple segmentation pipeline with Otsu baseline for Week 2.""" +"""Simple segmentation pipeline with Otsu thresholding baseline.""" from __future__ import annotations diff --git a/braggtrack/semantic/__init__.py b/braggtrack/semantic/__init__.py index cce3e96..5fdd5ab 100644 --- a/braggtrack/semantic/__init__.py +++ b/braggtrack/semantic/__init__.py @@ -1,4 +1,4 @@ -"""Week 4 multi-view semantic features (orthogonal MIPs + frozen ViT embeddings).""" +"""Multi-view semantic features (orthogonal MIPs + frozen ViT embeddings).""" from .dino import embed_multiview_mips, make_multiview_encoder, make_patch_encoder from .mips import crop_spot_cube, orthogonal_mips diff --git a/braggtrack/tracking/__init__.py b/braggtrack/tracking/__init__.py index c4eada5..7a019b1 100644 --- a/braggtrack/tracking/__init__.py +++ b/braggtrack/tracking/__init__.py @@ -1,4 +1,4 @@ -"""Tracking utilities for Week 3 physics-only association.""" +"""Multi-scan grain tracking with physics and semantic cost functions.""" from .assignment import associate_frames from .cost import CostFunction, GeometrySemanticCost, PositionShapeCost diff --git a/braggtrack/tracking/cost.py b/braggtrack/tracking/cost.py index 974fe0f..45906b5 100644 --- a/braggtrack/tracking/cost.py +++ b/braggtrack/tracking/cost.py @@ -1,7 +1,7 @@ """Pluggable cost functions for frame-to-frame spot association. -Week 3 provides a physics-only baseline (position + shape). -Week 4 adds a semantic term (cosine on fused DINO-style embeddings). +PositionShapeCost provides a physics-only baseline (position + shape). +GeometrySemanticCost adds a semantic term (cosine on fused embeddings). """ from __future__ import annotations diff --git a/braggtrack/tracking/synthetic.py b/braggtrack/tracking/synthetic.py index 209596d..5a92a5e 100644 --- a/braggtrack/tracking/synthetic.py +++ b/braggtrack/tracking/synthetic.py @@ -1,7 +1,7 @@ """Synthetic spot-table generator for testing tracking with known ground truth. Produces multi-frame scenarios with deliberate crossing trajectories, -birth/death events, and near-overlap cases so that Week 3 metrics +birth/death events, and near-overlap cases so that tracking metrics (ID-switch rate, fragmentation) are meaningful even without real data. """ diff --git a/scripts/ablation_week4.py b/scripts/ablation_semantic.py similarity index 88% rename from scripts/ablation_week4.py rename to scripts/ablation_semantic.py index c5de617..7a2f371 100644 --- a/scripts/ablation_week4.py +++ b/scripts/ablation_semantic.py @@ -1,4 +1,4 @@ -"""Grid search over ``cost_alpha`` and ``cost_beta`` for Week 4 semantic tracking. +"""Grid search over ``cost_alpha`` and ``cost_beta`` for semantic tracking. Loads feature tables once (and caches ``embeddings.npz`` per scan in memory), deep-copies per grid point, optionally merges embeddings, runs @@ -10,11 +10,11 @@ .. code-block:: bash - python scripts/ablation_week4.py \\ - --indir artifacts/week2 \\ - --embedding-dir artifacts/week4 \\ + python scripts/ablation_semantic.py \\ + --indir artifacts/segmentation \\ + --embedding-dir artifacts/embedding \\ --betas 0,0.25,0.5,1.0 \\ - --output artifacts/week4_ablation/report.json + --output artifacts/ablation/report.json """ from __future__ import annotations @@ -26,7 +26,8 @@ from pathlib import Path from typing import Any -from braggtrack.cli.track_dataset import _load_embeddings_npz, _load_feature_csv, _merge_embeddings +from braggtrack.cli._utils import load_feature_csv +from braggtrack.cli.track_dataset import _load_embeddings_npz, _merge_embeddings from braggtrack.tracking import ( GeometrySemanticCost, PositionShapeCost, @@ -37,11 +38,11 @@ def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description=__doc__) - p.add_argument("--indir", default="artifacts/week2", help="Week 2 directory with scan*/features.csv") + p.add_argument("--indir", default="artifacts/segmentation", help="Directory with scan*/features.csv") p.add_argument( "--embedding-dir", default=None, - help="Week 4 embeddings root (required if any beta > 0)", + help="Embeddings root (required if any beta > 0)", ) p.add_argument( "--alphas", @@ -53,7 +54,7 @@ def build_parser() -> argparse.ArgumentParser: default="0,0.25,0.5,1.0", help="Comma-separated cost_beta values (0 = geometry-only)", ) - p.add_argument("--output", default="artifacts/week4_ablation/report.json", help="JSON output path") + p.add_argument("--output", default="artifacts/ablation/report.json", help="JSON output path") p.add_argument("--position-weight", type=float, default=1.0) p.add_argument("--shape-weight", type=float, default=0.5) p.add_argument("--gate-mu", type=float, default=math.inf) @@ -74,7 +75,7 @@ def _load_scan_tables(indir: Path) -> tuple[list[list[dict[str, Any]]], list[str for sd in scan_dirs: feat = sd / "features.csv" if feat.exists(): - tables.append(_load_feature_csv(feat)) + tables.append(load_feature_csv(feat)) names.append(sd.name) return tables, names @@ -167,7 +168,7 @@ def main() -> int: rows_out.append(row) report = { - "schema_version": "week4_ablation.v1", + "schema_version": "ablation.v1", "indir": str(indir.resolve()), "embedding_dir": str(emb_root.resolve()) if emb_root else None, "grid": {"alphas": alphas, "betas": betas}, diff --git a/scripts/check_dino_acceptance.py b/scripts/check_dino_acceptance.py index 37bb721..f8f3137 100644 --- a/scripts/check_dino_acceptance.py +++ b/scripts/check_dino_acceptance.py @@ -50,7 +50,7 @@ def main() -> int: scan = row.get("scan", "?") if row.get("component_count", 0) <= 0: failures.append(f"{scan}: component_count must be > 0") - if row.get("schema_version") != "week2.v1": + if row.get("schema_version") != "segmentation.v1": failures.append(f"{scan}: schema_version mismatch") summary_csv = OUTDIR / "segmentation_summary.csv" @@ -79,7 +79,7 @@ def main() -> int: "method": "dino", "scan_count": len(payload), "non_empty_components": sum(1 for r in payload if r.get("component_count", 0) > 0), - "schema_consistent": all(r.get("schema_version") == "week2.v1" for r in payload), + "schema_consistent": all(r.get("schema_version") == "segmentation.v1" for r in payload), "failures": failures, } print(json.dumps(report, indent=2)) diff --git a/scripts/check_week4_acceptance.py b/scripts/check_embedding_acceptance.py similarity index 88% rename from scripts/check_week4_acceptance.py rename to scripts/check_embedding_acceptance.py index 1aa1f01..bd70ce2 100644 --- a/scripts/check_week4_acceptance.py +++ b/scripts/check_embedding_acceptance.py @@ -1,4 +1,4 @@ -"""Week 4 acceptance: segmentation labels, mock embeddings, semantic tracking.""" +"""Embedding acceptance: segmentation labels, mock embeddings, semantic tracking.""" from __future__ import annotations @@ -14,10 +14,10 @@ from braggtrack.io import resolve_dataset_root -OUT_EMB = Path("artifacts/week4") +OUT_EMB = Path("artifacts/embedding") DATASET_ROOT = resolve_dataset_root(None) -OUT_TRACK = Path("artifacts/week4_track") -SEG = Path("artifacts/week2") +OUT_TRACK = Path("artifacts/tracking_semantic") +SEG = Path("artifacts/segmentation") def main() -> int: @@ -85,8 +85,8 @@ def main() -> int: if payload.get("n_scans") != 3: failures.append(f"Expected 3 scans, got {payload.get('n_scans')}") - if payload.get("schema_version") != "week4.v1": - failures.append(f"schema_version expected week4.v1, got {payload.get('schema_version')}") + if payload.get("schema_version") != "tracking_semantic.v1": + failures.append(f"schema_version expected tracking_semantic.v1, got {payload.get('schema_version')}") for fname in ("tracks.csv", "tracking_metrics.json", "tracking_summary.json"): if not (OUT_TRACK / fname).exists(): diff --git a/scripts/check_acceptance.py b/scripts/check_io_acceptance.py similarity index 96% rename from scripts/check_acceptance.py rename to scripts/check_io_acceptance.py index 0fde743..84e1ac8 100644 --- a/scripts/check_acceptance.py +++ b/scripts/check_io_acceptance.py @@ -1,4 +1,4 @@ -"""Week 1 acceptance checks for BraggTrack. +"""I/O and discovery acceptance checks for BraggTrack. Checks: 1. All three sample scans are discovered and ordered as 1,2,3. diff --git a/scripts/check_week2_acceptance.py b/scripts/check_segmentation_acceptance.py similarity index 87% rename from scripts/check_week2_acceptance.py rename to scripts/check_segmentation_acceptance.py index 1db6d5e..269eb8e 100644 --- a/scripts/check_week2_acceptance.py +++ b/scripts/check_segmentation_acceptance.py @@ -1,4 +1,4 @@ -"""Week 2 acceptance checks for segmentation artifacts.""" +"""Segmentation acceptance checks.""" from __future__ import annotations @@ -14,7 +14,7 @@ from braggtrack.io import resolve_dataset_root -OUTDIR = Path("artifacts/week2") +OUTDIR = Path("artifacts/segmentation") DATASET_ROOT = resolve_dataset_root(None) @@ -35,7 +35,7 @@ def main() -> int: for row in payload: if row.get("component_count", 0) <= 0: failures.append(f"{row.get('scan')}: component_count must be > 0") - if row.get("schema_version") != "week2.v1": + if row.get("schema_version") != "segmentation.v1": failures.append(f"{row.get('scan')}: schema_version mismatch") summary_csv = OUTDIR / "segmentation_summary.csv" @@ -50,7 +50,7 @@ def main() -> int: report = { "scan_count": len(payload), "non_empty_components": sum(1 for r in payload if r.get("component_count", 0) > 0), - "schema_consistent": all(r.get("schema_version") == "week2.v1" for r in payload), + "schema_consistent": all(r.get("schema_version") == "segmentation.v1" for r in payload), "failures": failures, } print(json.dumps(report, indent=2)) diff --git a/scripts/check_week3_acceptance.py b/scripts/check_tracking_acceptance.py similarity index 88% rename from scripts/check_week3_acceptance.py rename to scripts/check_tracking_acceptance.py index ef9b7e4..e98b82f 100644 --- a/scripts/check_week3_acceptance.py +++ b/scripts/check_tracking_acceptance.py @@ -1,4 +1,4 @@ -"""Week 3 acceptance checks for tracking artifacts. +"""Tracking acceptance checks. Acceptance criteria: 1. Tracks are generated across all three scans. @@ -19,13 +19,13 @@ from braggtrack.io import resolve_dataset_root -OUTDIR = Path("artifacts/week3") +OUTDIR = Path("artifacts/tracking") DATASET_ROOT = resolve_dataset_root(None) def main() -> int: - # Step 1 — ensure segmentation artifacts exist (run Week 2 if needed). - seg_dir = Path("artifacts/week2") + # Step 1 — ensure segmentation artifacts exist. + seg_dir = Path("artifacts/segmentation") if not (seg_dir / "segmentation_summary.json").exists(): subprocess.run( [sys.executable, "-m", "braggtrack.cli.segment_dataset", str(DATASET_ROOT), "--outdir", str(seg_dir)], @@ -70,14 +70,14 @@ def main() -> int: failures.append(f"Missing artifact: {fname}") # Check schema version. - if payload.get("schema_version") != "week3.v1": + if payload.get("schema_version") != "tracking.v1": failures.append(f"schema_version mismatch: {payload.get('schema_version')}") report = { "n_scans": n_scans, "total_tracks": total_tracks, "metrics_present": all(k in payload for k in ("fragmentation_ratio", "id_switch_rate")), - "schema_consistent": payload.get("schema_version") == "week3.v1", + "schema_consistent": payload.get("schema_version") == "tracking.v1", "failures": failures, } print(json.dumps(report, indent=2)) diff --git a/scripts/ci_report.py b/scripts/ci_report.py index 5292186..4a73ec8 100644 --- a/scripts/ci_report.py +++ b/scripts/ci_report.py @@ -100,40 +100,40 @@ def main() -> int: unit_ok = run_unit_tests() - acc_ok_cmd, acc_payload, _ = run_cmd_json([sys.executable, "scripts/check_acceptance.py"], "Week 1 Acceptance") + acc_ok_cmd, acc_payload, _ = run_cmd_json([sys.executable, "scripts/check_io_acceptance.py"], "IO Acceptance") acc_ok = acc_ok_cmd and evaluate_acceptance(acc_payload if isinstance(acc_payload, dict) else None) smoke_ok_cmd, smoke_payload, _ = run_cmd_json( - [sys.executable, "-m", "braggtrack.cli.segment_synthetic"], "Week 2 Smoke" + [sys.executable, "-m", "braggtrack.cli.segment_synthetic"], "Segmentation Smoke" ) smoke_ok = smoke_ok_cmd and evaluate_smoke(smoke_payload if isinstance(smoke_payload, dict) else None) - wk2_ok_cmd, wk2_payload, _ = run_cmd_json( - [sys.executable, "scripts/check_week2_acceptance.py"], "Week 2 Acceptance" + seg_ok_cmd, seg_payload, _ = run_cmd_json( + [sys.executable, "scripts/check_segmentation_acceptance.py"], "Segmentation Acceptance" ) - wk2_ok = wk2_ok_cmd and isinstance(wk2_payload, dict) and wk2_payload.get("failures") == [] + seg_ok = seg_ok_cmd and isinstance(seg_payload, dict) and seg_payload.get("failures") == [] - wk3_ok_cmd, wk3_payload, _ = run_cmd_json( - [sys.executable, "scripts/check_week3_acceptance.py"], "Week 3 Acceptance" + trk_ok_cmd, trk_payload, _ = run_cmd_json( + [sys.executable, "scripts/check_tracking_acceptance.py"], "Tracking Acceptance" ) - wk3_ok = wk3_ok_cmd and isinstance(wk3_payload, dict) and wk3_payload.get("failures") == [] + trk_ok = trk_ok_cmd and isinstance(trk_payload, dict) and trk_payload.get("failures") == [] - wk4_ok_cmd, wk4_payload, _ = run_cmd_json( - [sys.executable, "scripts/check_week4_acceptance.py"], "Week 4 Acceptance" + emb_ok_cmd, emb_payload, _ = run_cmd_json( + [sys.executable, "scripts/check_embedding_acceptance.py"], "Embedding Acceptance" ) - wk4_ok = wk4_ok_cmd and isinstance(wk4_payload, dict) and wk4_payload.get("failures") == [] + emb_ok = emb_ok_cmd and isinstance(emb_payload, dict) and emb_payload.get("failures") == [] dino_ok_cmd, dino_payload, _ = run_cmd_json([sys.executable, "scripts/check_dino_acceptance.py"], "DINO Acceptance") dino_ok = dino_ok_cmd and isinstance(dino_payload, dict) and dino_payload.get("failures") == [] - all_ok = unit_ok and acc_ok and smoke_ok and wk2_ok and wk3_ok and wk4_ok and dino_ok + all_ok = unit_ok and acc_ok and smoke_ok and seg_ok and trk_ok and emb_ok and dino_ok print("\n=== Summary ===") print(f"unit_tests={'PASS' if unit_ok else 'FAIL'}") - print(f"acceptance={'PASS' if acc_ok else 'FAIL'}") + print(f"io_acceptance={'PASS' if acc_ok else 'FAIL'}") print(f"smoke={'PASS' if smoke_ok else 'FAIL'}") - print(f"week2_acceptance={'PASS' if wk2_ok else 'FAIL'}") - print(f"week3_acceptance={'PASS' if wk3_ok else 'FAIL'}") - print(f"week4_acceptance={'PASS' if wk4_ok else 'FAIL'}") + print(f"segmentation_acceptance={'PASS' if seg_ok else 'FAIL'}") + print(f"tracking_acceptance={'PASS' if trk_ok else 'FAIL'}") + print(f"embedding_acceptance={'PASS' if emb_ok else 'FAIL'}") print(f"dino_acceptance={'PASS' if dino_ok else 'FAIL'}") print(f"overall={'PASS' if all_ok else 'FAIL'}") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6f5ae76 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,25 @@ +"""Shared test fixtures for BraggTrack.""" + +from __future__ import annotations + + +def make_spot( + mu: float, + chi: float, + d: float, + intensity: float = 100.0, + voxels: int = 10, + eig: tuple[float, float, float] = (0.5, 0.5, 0.5), +) -> dict: + """Create a synthetic spot dictionary for testing.""" + return { + "label": 1, + "voxel_count": voxels, + "integrated_intensity": intensity, + "centroid_mu": mu, + "centroid_chi": chi, + "centroid_d": d, + "eig_1": eig[0], + "eig_2": eig[1], + "eig_3": eig[2], + } diff --git a/tests/test_acceptance_script.py b/tests/test_acceptance_script.py index 8950691..2d37e36 100644 --- a/tests/test_acceptance_script.py +++ b/tests/test_acceptance_script.py @@ -10,7 +10,7 @@ def test_acceptance_script_runs_successfully(self) -> None: env["PYTHONPATH"] = f".{os.pathsep}{env.get('PYTHONPATH', '')}" proc = subprocess.run( - [sys.executable, "scripts/check_acceptance.py"], + [sys.executable, "scripts/check_io_acceptance.py"], check=False, capture_output=True, text=True, diff --git a/tests/test_week1_validation.py b/tests/test_io_validation.py similarity index 95% rename from tests/test_week1_validation.py rename to tests/test_io_validation.py index 0d61f00..1397823 100644 --- a/tests/test_week1_validation.py +++ b/tests/test_io_validation.py @@ -5,7 +5,7 @@ from braggtrack.io.models import ExperimentSequence, ScanVolumeMeta -class Week1ValidationTests(unittest.TestCase): +class TestIOValidation(unittest.TestCase): def test_beamline_adapter_builds_three_scan_sequence(self) -> None: adapter = BeamlineAdapter(sample_operando_root()) sequence = adapter.build_sequence() diff --git a/tests/test_kinematics.py b/tests/test_kinematics.py index 22b9539..6d215eb 100644 --- a/tests/test_kinematics.py +++ b/tests/test_kinematics.py @@ -3,6 +3,7 @@ import unittest import numpy as np +from conftest import make_spot as _spot from braggtrack.tracking.cost import PositionShapeCost from braggtrack.tracking.kinematics import ( @@ -13,27 +14,6 @@ from braggtrack.tracking.lifecycle import build_tracks, tracks_to_table -def _spot( - mu: float, - chi: float, - d: float, - intensity: float = 100.0, - voxels: int = 10, - eig: tuple[float, float, float] = (0.5, 0.5, 0.5), -) -> dict: - return { - "label": 1, - "voxel_count": voxels, - "integrated_intensity": intensity, - "centroid_mu": mu, - "centroid_chi": chi, - "centroid_d": d, - "eig_1": eig[0], - "eig_2": eig[1], - "eig_3": eig[2], - } - - def _build_table(scan_tables: list[list[dict]]) -> list[dict]: cost_fn = PositionShapeCost() G = build_tracks(scan_tables, cost_fn) diff --git a/tests/test_kinematics_groundtruth.py b/tests/test_kinematics_groundtruth.py index 2c4424a..74c9ee1 100644 --- a/tests/test_kinematics_groundtruth.py +++ b/tests/test_kinematics_groundtruth.py @@ -28,12 +28,12 @@ import unittest import numpy as np +from conftest import make_spot from braggtrack.tracking.cost import PositionShapeCost from braggtrack.tracking.kinematics import compute_grain_kinematics, summarize_kinematics from braggtrack.tracking.lifecycle import build_tracks, tracks_to_table - N_SCANS = 5 @@ -45,17 +45,7 @@ def _spot( voxels: int = 10, eig: tuple[float, float, float] = (1.0, 1.0, 1.0), ) -> dict: - return { - "label": 1, - "voxel_count": voxels, - "integrated_intensity": intensity, - "centroid_mu": mu, - "centroid_chi": chi, - "centroid_d": d, - "eig_1": eig[0], - "eig_2": eig[1], - "eig_3": eig[2], - } + return make_spot(mu, chi, d, intensity=intensity, voxels=voxels, eig=eig) def _build_scenario() -> list[list[dict]]: @@ -81,9 +71,7 @@ def _build_scenario() -> list[list[dict]]: if i >= 2: intensity_d = 50.0 + 50.0 * (i - 2) d_d = 25.0 * (1.0 + 0.002 * (i - 2)) - frame.append( - _spot(30.0, 70.0, d_d, intensity=intensity_d, voxels=30, eig=(3.0, 2.0, 1.0)) - ) + frame.append(_spot(30.0, 70.0, d_d, intensity=intensity_d, voxels=30, eig=(3.0, 2.0, 1.0))) scans.append(frame) diff --git a/tests/test_segment_dataset_cli.py b/tests/test_segment_dataset_cli.py index 8e47413..5827e08 100644 --- a/tests/test_segment_dataset_cli.py +++ b/tests/test_segment_dataset_cli.py @@ -11,9 +11,9 @@ class SegmentDatasetCliTests(unittest.TestCase): - def test_segment_dataset_writes_week2_artifacts(self) -> None: + def test_segment_dataset_writes_artifacts(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: - outdir = Path(tmpdir) / "week2" + outdir = Path(tmpdir) / "segmentation" proc = subprocess.run( [ sys.executable, diff --git a/tests/test_week2_acceptance.py b/tests/test_segmentation_acceptance.py similarity index 73% rename from tests/test_week2_acceptance.py rename to tests/test_segmentation_acceptance.py index 26fcf2d..445d53a 100644 --- a/tests/test_week2_acceptance.py +++ b/tests/test_segmentation_acceptance.py @@ -4,10 +4,10 @@ import unittest -class Week2AcceptanceTests(unittest.TestCase): - def test_week2_acceptance_script(self) -> None: +class TestSegmentationAcceptance(unittest.TestCase): + def test_segmentation_acceptance_script(self) -> None: proc = subprocess.run( - [sys.executable, "scripts/check_week2_acceptance.py"], + [sys.executable, "scripts/check_segmentation_acceptance.py"], check=False, capture_output=True, text=True, diff --git a/tests/test_semantic_week4.py b/tests/test_semantic.py similarity index 99% rename from tests/test_semantic_week4.py rename to tests/test_semantic.py index 6114d90..458239f 100644 --- a/tests/test_semantic_week4.py +++ b/tests/test_semantic.py @@ -1,4 +1,4 @@ -"""Week 4 semantic MIPs, encoder, and association cost tests.""" +"""Semantic MIPs, encoder, and association cost tests.""" import math import unittest diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 29a7fca..409cb10 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -1,9 +1,10 @@ -"""Unit tests for the Week 3 tracking module.""" +"""Unit tests for the tracking module.""" import math import unittest import numpy as np +from conftest import make_spot as _spot from braggtrack.tracking.assignment import associate_frames from braggtrack.tracking.cost import PositionShapeCost @@ -12,20 +13,6 @@ from braggtrack.tracking.synthetic import generate_crossing_scenario -def _spot(mu: float, chi: float, d: float, eig: tuple[float, float, float] = (0.5, 0.5, 0.5)) -> dict: - return { - "label": 1, - "voxel_count": 10, - "integrated_intensity": 100.0, - "centroid_mu": mu, - "centroid_chi": chi, - "centroid_d": d, - "eig_1": eig[0], - "eig_2": eig[1], - "eig_3": eig[2], - } - - class TestPositionShapeCost(unittest.TestCase): def test_pairwise_matrix_matches_scalar(self) -> None: rng = np.random.RandomState(7) diff --git a/tests/test_week3_acceptance.py b/tests/test_tracking_acceptance.py similarity index 76% rename from tests/test_week3_acceptance.py rename to tests/test_tracking_acceptance.py index b91fdd3..2291d11 100644 --- a/tests/test_week3_acceptance.py +++ b/tests/test_tracking_acceptance.py @@ -4,10 +4,10 @@ import unittest -class Week3AcceptanceTests(unittest.TestCase): - def test_week3_acceptance_script(self) -> None: +class TestTrackingAcceptance(unittest.TestCase): + def test_tracking_acceptance_script(self) -> None: proc = subprocess.run( - [sys.executable, "scripts/check_week3_acceptance.py"], + [sys.executable, "scripts/check_tracking_acceptance.py"], check=False, capture_output=True, text=True,