diff --git a/braggtrack/cli/embed_dataset.py b/braggtrack/cli/embed_dataset.py index 3e6d602..885d2d5 100644 --- a/braggtrack/cli/embed_dataset.py +++ b/braggtrack/cli/embed_dataset.py @@ -35,9 +35,9 @@ def build_parser() -> argparse.ArgumentParser: "--backend", choices=("auto", "mock", "torch"), default="auto", - help="Embedding backend (mock needs no PyTorch; torch uses Dinov2-small)", + help="Embedding backend (mock needs no PyTorch; torch uses DINOv3)", ) - p.add_argument("--model", default="facebook/dinov2-small", help="HF model id when backend=torch") + p.add_argument("--model", default="facebook/dinov3-vitb16-pretrain-lvd1689m", help="HF model id when backend=torch") return p diff --git a/braggtrack/cli/segment_dataset.py b/braggtrack/cli/segment_dataset.py index ec16efe..17fdd83 100644 --- a/braggtrack/cli/segment_dataset.py +++ b/braggtrack/cli/segment_dataset.py @@ -83,12 +83,12 @@ def build_parser() -> argparse.ArgumentParser: help="HuggingFace model ID for the DINO torch backend", ) parser.add_argument("--dino-pca-components", type=int, default=16, help="PCA components for DINO feature reduction") - parser.add_argument("--dino-min-cluster-size", type=int, default=3, help="HDBSCAN min_cluster_size") + parser.add_argument("--dino-min-cluster-size", type=int, default=5, help="HDBSCAN min_cluster_size") parser.add_argument("--dino-min-samples", type=int, default=2, help="HDBSCAN min_samples") parser.add_argument( "--dino-min-overlap", type=float, - default=0.3, + default=0.2, help="Min overlap fraction for 3D slice stitching", ) return parser diff --git a/braggtrack/segmentation/dino_segment.py b/braggtrack/segmentation/dino_segment.py index 885c402..4130631 100644 --- a/braggtrack/segmentation/dino_segment.py +++ b/braggtrack/segmentation/dino_segment.py @@ -55,46 +55,101 @@ def _extract_slice_features( return np.stack(feature_maps, axis=0), slice_hw +def _patch_foreground_masks( + volume: np.ndarray, + threshold: float, + patch_size: int, + axis: int = 0, +) -> list[np.ndarray]: + """Boolean foreground mask at patch resolution for each slice. + + A patch is foreground if any of its pixels exceed *threshold*. + """ + n_slices = volume.shape[axis] + masks: list[np.ndarray] = [] + for i in range(n_slices): + slc = np.take(volume, i, axis=axis) + fg = (slc >= threshold).astype(np.float64) + h, w = slc.shape + hp = max(1, h // patch_size) + wp = max(1, w // patch_size) + cropped = fg[: hp * patch_size, : wp * patch_size] + blocks = cropped.reshape(hp, patch_size, wp, patch_size) + masks.append(blocks.max(axis=(1, 3)) > 0) + return masks + + def _cluster_feature_map( features: np.ndarray, *, n_components_pca: int = 16, - min_cluster_size: int = 3, + min_cluster_size: int = 5, min_samples: int = 2, + foreground_mask: np.ndarray | None = None, + pca_model: object | None = None, ) -> np.ndarray: """Cluster a 2D feature map ``(H_p, W_p, D)`` into instance labels. Uses PCA dimensionality reduction followed by HDBSCAN. Returns ``(H_p, W_p)`` int array, 0 = background/noise. + + When *foreground_mask* is given only foreground patches participate + in clustering; background patches get label 0. When *pca_model* is + given it is used for the PCA transform instead of fitting per-slice. """ from sklearn.cluster import HDBSCAN from sklearn.decomposition import PCA h_p, w_p, d = features.shape - n_patches = h_p * w_p flat = features.reshape(-1, d) - n_comp = min(n_components_pca, d, n_patches) - if n_comp < 2: - # Too few patches to cluster — assign all to a single region. - return np.ones((h_p, w_p), dtype=np.int32) + if foreground_mask is not None: + fg_flat = foreground_mask.ravel() + n_fg = int(fg_flat.sum()) + if n_fg == 0: + return np.zeros((h_p, w_p), dtype=np.int32) + if n_fg < 2: + out = np.zeros((h_p, w_p), dtype=np.int32) + out[foreground_mask] = 1 + return out + else: + fg_flat = np.ones(h_p * w_p, dtype=bool) + n_fg = h_p * w_p - reduced = PCA(n_components=n_comp).fit_transform(flat) + fg_features = flat[fg_flat] - effective_min_cluster = max(2, min(min_cluster_size, n_patches // 2)) + if pca_model is not None: + reduced = pca_model.transform(fg_features) + else: + n_comp = min(n_components_pca, d, n_fg) + if n_comp < 2: + out = np.zeros((h_p, w_p), dtype=np.int32) + if foreground_mask is not None: + out[foreground_mask] = 1 + else: + out[:] = 1 + return out + reduced = PCA(n_components=n_comp).fit_transform(fg_features) + + effective_min_cluster = max(2, min(min_cluster_size, n_fg // 2)) clusterer = HDBSCAN( min_cluster_size=effective_min_cluster, min_samples=max(1, min(min_samples, effective_min_cluster - 1)), ) raw_labels = clusterer.fit_predict(reduced) - # HDBSCAN labels: -1 = noise, 0..K = clusters. Shift to 1-based. - labels = np.where(raw_labels >= 0, raw_labels + 1, 0) + cluster_labels = np.where(raw_labels >= 0, raw_labels + 1, 0) - # If HDBSCAN assigned everything to noise, treat all patches as one region. - if not np.any(labels > 0): - return np.ones((h_p, w_p), dtype=np.int32) + if not np.any(cluster_labels > 0): + out = np.zeros((h_p, w_p), dtype=np.int32) + if foreground_mask is not None: + out[foreground_mask] = 1 + else: + out[:] = 1 + return out - return labels.reshape(h_p, w_p).astype(np.int32) + out = np.zeros(h_p * w_p, dtype=np.int32) + out[fg_flat] = cluster_labels + return out.reshape(h_p, w_p) def _upsample_labels( @@ -205,10 +260,10 @@ def segment_dino( model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", torch_device: str | None = None, n_components_pca: int = 16, - min_cluster_size: int = 3, + min_cluster_size: int = 5, min_samples: int = 2, threshold_fraction: float = 1.0, - min_overlap_fraction: float = 0.3, + min_overlap_fraction: float = 0.2, axis: int = 0, ) -> DinoSegmentationResult: """Segment a 3D volume using DINOv3 patch-level features + HDBSCAN. @@ -232,6 +287,8 @@ def segment_dino( axis Axis to slice along (0 = mu/z, typically the narrowest). """ + from sklearn.decomposition import PCA + from braggtrack.segmentation.otsu import otsu_threshold from braggtrack.semantic.dino import make_patch_encoder @@ -243,13 +300,31 @@ def segment_dino( features, slice_hw = _extract_slice_features(volume, encoder, axis=axis) + # Foreground mask at patch resolution — exclude background from clustering. + fg_masks = _patch_foreground_masks(volume, threshold, encoder.patch_size, axis=axis) + + # Global PCA across all foreground patches for a consistent feature space + # (per-slice PCA causes wildly different cluster counts). + n_slices = features.shape[0] + d = features.shape[-1] + all_fg = [features[i][fg_masks[i]] for i in range(n_slices)] + all_fg_cat = np.concatenate(all_fg, axis=0) if any(f.size > 0 for f in all_fg) else np.empty((0, d)) + + n_comp = min(n_components_pca, d, all_fg_cat.shape[0]) + pca = None + if n_comp >= 2: + pca = PCA(n_components=n_comp) + pca.fit(all_fg_cat) + per_slice_labels: list[np.ndarray] = [] - for i in range(features.shape[0]): + for i in range(n_slices): patch_labels = _cluster_feature_map( features[i], n_components_pca=n_components_pca, min_cluster_size=min_cluster_size, min_samples=min_samples, + foreground_mask=fg_masks[i], + pca_model=pca, ) full_labels = _upsample_labels(patch_labels, slice_hw, encoder.patch_size) per_slice_labels.append(full_labels) @@ -261,7 +336,6 @@ def segment_dino( labels_3d = np.where(foreground, labels_3d, 0).astype(np.int32) component_count = len(np.unique(labels_3d[labels_3d > 0])) - # response is empty — DINO segmentation has no LoG-equivalent response surface. return DinoSegmentationResult( threshold=threshold, seed_count=component_count, diff --git a/braggtrack/semantic/dino.py b/braggtrack/semantic/dino.py index 0c8c27e..ef280d8 100644 --- a/braggtrack/semantic/dino.py +++ b/braggtrack/semantic/dino.py @@ -4,7 +4,7 @@ * ``mock`` — deterministic CPU-only vectors from image bytes (default when PyTorch is unavailable). -* ``torch`` — Hugging Face DINOv3/v2 model (CLS or patch tokens). +* ``torch`` — Hugging Face DINOv3 model (CLS or patch tokens). * ``auto`` — use ``torch`` if import succeeds, else ``mock``. """ @@ -73,7 +73,7 @@ def embed(self, mip_mu: np.ndarray, mip_chi: np.ndarray, mip_d: np.ndarray) -> n class TorchDinoMultiviewEncoder: - """Loads Dinov2 once; call :meth:`embed` per spot.""" + """Loads DINOv3 once; call :meth:`embed` per spot.""" def __init__(self, model_name: str, device: str | None = None) -> None: import torch @@ -119,7 +119,7 @@ def _requested_backend(explicit: BackendName | None) -> BackendName: def make_multiview_encoder( backend: BackendName | None = None, *, - model_name: str = "facebook/dinov2-small", + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", torch_device: str | None = None, ) -> MultiviewEncoder: """Construct a reusable encoder (loads torch weights at most once).""" @@ -154,7 +154,7 @@ class MockPatchEncoder: @property def patch_size(self) -> int: - return 14 + return 16 @property def feature_dim(self) -> int: @@ -177,7 +177,7 @@ def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray: class TorchDinoPatchEncoder: - """Extracts DINOv2/v3 patch tokens as a spatial feature map.""" + """Extracts DINOv3 patch tokens as a spatial feature map.""" def __init__(self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", device: str | None = None) -> None: import torch @@ -189,8 +189,9 @@ def __init__(self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", self._model = AutoModel.from_pretrained(model_name) self._model.eval() self._model.to(self._device) - self._patch_size = getattr(self._model.config, "patch_size", 14) + self._patch_size = getattr(self._model.config, "patch_size", 16) self._feature_dim = int(self._model.config.hidden_size) + self._num_register_tokens = getattr(self._model.config, "num_register_tokens", 0) @property def patch_size(self) -> int: @@ -206,7 +207,8 @@ def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray: inputs = self._proc(images=[rgb], return_tensors="pt") inputs = {k: v.to(self._device) for k, v in inputs.items()} out = self._model(**inputs) - patch_tokens = out.last_hidden_state[:, 1:, :].squeeze(0) + skip = 1 + self._num_register_tokens + patch_tokens = out.last_hidden_state[:, skip:, :].squeeze(0) h_img = inputs["pixel_values"].shape[2] w_img = inputs["pixel_values"].shape[3] h_p = h_img // self._patch_size @@ -238,7 +240,7 @@ def embed_multiview_mips( mip_d: np.ndarray, *, backend: BackendName | None = None, - model_name: str = "facebook/dinov2-small", + model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", torch_device: str | None = None, ) -> np.ndarray: """Return a single L2-normalised concatenated feature vector.""" diff --git a/notebooks/dino_segmentation_comparison.ipynb b/notebooks/dino_segmentation_comparison.ipynb index 5204d15..b46334e 100644 --- a/notebooks/dino_segmentation_comparison.ipynb +++ b/notebooks/dino_segmentation_comparison.ipynb @@ -3,7 +3,7 @@ { "cell_type": "markdown", "id": "ddb2fb4c", - "source": "# DINO vs Classical Segmentation Comparison\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BASE-Laboratory/BraggTrack/blob/main/notebooks/dino_segmentation_comparison.ipynb)\n\nThis notebook runs both segmentation backends on the bundled `data/sample_operando/` scans and compares their outputs side-by-side.\n\n| Method | How it works | Strengths |\n|--------|-------------|-----------|\n| **Classical** | Otsu threshold \u2192 LoG enhancement \u2192 h-maxima seeds \u2192 seeded watershed \u2192 merge nearby | Fast, interpretable, well-tuned for this beamline |\n| **DINO** | DINOv3 patch features \u2192 PCA \u2192 HDBSCAN clustering \u2192 3D slice stitching \u2192 Otsu foreground mask | Learns in feature space \u2014 should generalise across beamlines/detectors without re-tuning |\n\nUses the **mock** DINO backend by default (no GPU required). Set `BRAGGTRACK_DINO_BACKEND=torch` for real DINOv3 features.", + "source": "# DINO vs Classical Segmentation Comparison\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BASE-Laboratory/BraggTrack/blob/main/notebooks/dino_segmentation_comparison.ipynb)\n\nThis notebook runs both segmentation backends on the bundled `data/sample_operando/` scans and compares their outputs side-by-side.\n\n| Method | How it works | Strengths |\n|--------|-------------|-----------|\n| **Classical** | Otsu threshold → LoG enhancement → h-maxima seeds → seeded watershed → merge nearby | Fast, interpretable, well-tuned for this beamline |\n| **DINOv3** | Frozen DINOv3 ViT-B/16 patch features → PCA → HDBSCAN clustering → 3D slice stitching → Otsu foreground mask | Learns in feature space — should generalise across beamlines/detectors without re-tuning |\n\nOn **Colab** (GPU + internet) this notebook uses real DINOv3 weights (`facebook/dinov3-vitb16-pretrain-lvd1689m`). Locally without PyTorch it falls back to the mock backend automatically.", "metadata": {} }, { @@ -15,7 +15,7 @@ { "cell_type": "code", "id": "320c0082", - "source": "import os, subprocess, sys\n\n_ON_COLAB = \"google.colab\" in sys.modules or os.environ.get(\"COLAB_RELEASE_TAG\")\n\nif _ON_COLAB:\n print(\"Colab detected \u2014 installing BraggTrack + sample data...\")\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"braggtrack[notebook] @ git+https://github.com/BASE-Laboratory/BraggTrack.git\",\n ])\n if not os.path.isdir(\"data/sample_operando\"):\n subprocess.check_call([\n \"git\", \"clone\", \"--depth=1\", \"--filter=blob:none\", \"--sparse\",\n \"https://github.com/BASE-Laboratory/BraggTrack.git\", \"_braggtrack_repo\",\n ])\n subprocess.check_call(\n [\"git\", \"sparse-checkout\", \"set\", \"data/sample_operando\"],\n cwd=\"_braggtrack_repo\",\n )\n os.makedirs(\"data\", exist_ok=True)\n os.rename(\"_braggtrack_repo/data/sample_operando\", \"data/sample_operando\")\n subprocess.check_call([\"rm\", \"-rf\", \"_braggtrack_repo\"])\n os.environ.setdefault(\"BRAGGTRACK_DATA_ROOT\", os.path.abspath(\"data/sample_operando\"))\n print(\"Done.\")\nelse:\n print(\"Local environment \u2014 skipping Colab setup.\")", + "source": "import os, subprocess, sys\n\n_ON_COLAB = \"google.colab\" in sys.modules or os.environ.get(\"COLAB_RELEASE_TAG\")\n\nif _ON_COLAB:\n print(\"Colab detected — installing BraggTrack + dependencies...\")\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n \"braggtrack[notebook] @ git+https://github.com/BASE-Laboratory/BraggTrack.git\",\n ])\n subprocess.check_call([\n sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"torch\", \"torchvision\", \"transformers\",\n ])\n # DINOv3 is a gated model — you must accept the license at\n # https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m\n # then paste your HF token when prompted.\n from huggingface_hub import login\n login()\n\n if not os.path.isdir(\"data/sample_operando\"):\n subprocess.check_call([\n \"git\", \"clone\", \"--depth=1\", \"--filter=blob:none\", \"--sparse\",\n \"https://github.com/BASE-Laboratory/BraggTrack.git\", \"_braggtrack_repo\",\n ])\n subprocess.check_call(\n [\"git\", \"sparse-checkout\", \"set\", \"data/sample_operando\"],\n cwd=\"_braggtrack_repo\",\n )\n os.makedirs(\"data\", exist_ok=True)\n os.rename(\"_braggtrack_repo/data/sample_operando\", \"data/sample_operando\")\n subprocess.check_call([\"rm\", \"-rf\", \"_braggtrack_repo\"])\n os.environ.setdefault(\"BRAGGTRACK_DATA_ROOT\", os.path.abspath(\"data/sample_operando\"))\n print(\"Done.\")\nelse:\n print(\"Local environment — skipping Colab setup.\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -31,7 +31,7 @@ { "cell_type": "markdown", "id": "12606936", - "source": "## 1 \u2014 Load real data\n\nRead the largest 3D numeric dataset from each H5 file (bypasses the fixed NeXus path shortlist).", + "source": "## 1 — Load real data\n\nRead the largest 3D numeric dataset from each H5 file (bypasses the fixed NeXus path shortlist).", "metadata": {} }, { @@ -45,7 +45,7 @@ { "cell_type": "markdown", "id": "61e5cbd4", - "source": "## 2 \u2014 Run both segmentation methods\n\n### Classical pipeline\nOtsu \u2192 LoG \u2192 h-maxima \u2192 seeded watershed \u2192 remove small \u2192 fill holes \u2192 merge nearby \u2192 relabel.", + "source": "## 2 — Run both segmentation methods\n\n### Classical pipeline\nOtsu → LoG → h-maxima → seeded watershed → remove small → fill holes → merge nearby → relabel.", "metadata": {} }, { @@ -59,13 +59,13 @@ { "cell_type": "markdown", "id": "a5cfd4e0", - "source": "### DINO pipeline\nDINOv3 patch features \u2192 PCA \u2192 HDBSCAN \u2192 upsample \u2192 3D stitch \u2192 Otsu foreground mask \u2192 post-process.\n\nThe post-processing (remove small, fill holes, merge nearby, relabel) is identical to keep the comparison fair.", + "source": "### DINOv3 pipeline\nFrozen DINOv3 ViT-B/16 patch features → global PCA → HDBSCAN → upsample → 3D stitch → Otsu foreground mask → post-process.\n\n`backend=\"auto\"` uses real DINOv3 weights when PyTorch is available (e.g. Colab), otherwise falls back to the mock backend for CI. The post-processing (remove small, fill holes, merge nearby, relabel) is identical to keep the comparison fair.", "metadata": {} }, { "cell_type": "code", "id": "2ab096d1", - "source": "def run_dino(volume: np.ndarray) -> np.ndarray:\n res = segment_dino(volume, backend=\"mock\")\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = np.where(binary, labels, 0)\n labels = merge_nearby_labels(labels, volume, max_centroid_distance=MERGE_DISTANCE)\n return relabel_sequential(labels)\n\ndino_labels = []\nfor s, v in zip(scans, all_volumes):\n lab = run_dino(v)\n dino_labels.append(lab)\n print(f\"{s.scan_name} DINO: {int(lab.max())} spots\")", + "source": "from braggtrack.semantic.dino import _resolve_backend\n\nDINO_BACKEND = _resolve_backend(\"auto\")\nprint(f\"DINO backend: {DINO_BACKEND}\")\n\ndef run_dino(volume: np.ndarray) -> np.ndarray:\n res = segment_dino(volume, backend=\"auto\")\n labels = remove_small_objects(res.labeled_volume, min_size=8)\n binary = fill_holes_binary(labels > 0)\n labels = np.where(binary, labels, 0)\n labels = merge_nearby_labels(labels, volume, max_centroid_distance=MERGE_DISTANCE)\n return relabel_sequential(labels)\n\ndino_labels = []\nfor s, v in zip(scans, all_volumes):\n lab = run_dino(v)\n dino_labels.append(lab)\n print(f\"{s.scan_name} DINO ({DINO_BACKEND}): {int(lab.max())} spots\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -73,7 +73,7 @@ { "cell_type": "markdown", "id": "1b93c867", - "source": "## 3 \u2014 Spot count comparison", + "source": "## 3 — Spot count comparison", "metadata": {} }, { @@ -87,7 +87,7 @@ { "cell_type": "code", "id": "1de56eb0", - "source": "x = np.arange(len(scans))\nwidth = 0.35\n\nfig, ax = plt.subplots(figsize=(8, 4))\nbars1 = ax.bar(x - width / 2, classical_counts, width, label=\"Classical\", color=\"#1f77b4\")\nbars2 = ax.bar(x + width / 2, dino_counts, width, label=\"DINO (mock)\", color=\"#ff7f0e\")\nax.set_xlabel(\"Scan\")\nax.set_ylabel(\"Spot count\")\nax.set_title(\"Spot counts: Classical vs DINO\")\nax.set_xticks(x)\nax.set_xticklabels(scan_names)\nax.legend()\n\nfor bars in [bars1, bars2]:\n for bar in bars:\n ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3,\n str(int(bar.get_height())), ha=\"center\", va=\"bottom\", fontsize=10)\n\nplt.tight_layout()\nplt.show()", + "source": "x = np.arange(len(scans))\nwidth = 0.35\n\nfig, ax = plt.subplots(figsize=(8, 4))\nbars1 = ax.bar(x - width / 2, classical_counts, width, label=\"Classical\", color=\"#1f77b4\")\nbars2 = ax.bar(x + width / 2, dino_counts, width, label=f\"DINOv3 ({DINO_BACKEND})\", color=\"#ff7f0e\")\nax.set_xlabel(\"Scan\")\nax.set_ylabel(\"Spot count\")\nax.set_title(\"Spot counts: Classical vs DINOv3\")\nax.set_xticks(x)\nax.set_xticklabels(scan_names)\nax.legend()\n\nfor bars in [bars1, bars2]:\n for bar in bars:\n ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3,\n str(int(bar.get_height())), ha=\"center\", va=\"bottom\", fontsize=10)\n\nplt.tight_layout()\nplt.show()", "metadata": {}, "execution_count": null, "outputs": [] @@ -95,13 +95,13 @@ { "cell_type": "markdown", "id": "dfd3d99c", - "source": "## 4 \u2014 Visual comparison: tri-axis label projections\n\nSide-by-side label overlays for each scan, projected along all three physical axes (\u03bc, \u03c7, d). Each row is a scan; left column = classical, right column = DINO.", + "source": "## 4 — Visual comparison: tri-axis label projections\n\nSide-by-side label overlays for each scan, projected along all three physical axes (μ, χ, d). Each row is a scan; left column = classical, right column = DINO.", "metadata": {} }, { "cell_type": "code", "id": "569f7b23", - "source": "# Build a shared colormap large enough for both methods.\nmax_labels = max(\n max(int(l.max()) for l in classical_labels),\n max(int(l.max()) for l in dino_labels),\n) + 1\nrng_cm = np.random.RandomState(42)\nlabel_colors = np.zeros((max_labels, 4))\nlabel_colors[0] = [0, 0, 0, 0]\nfor i in range(1, max_labels):\n label_colors[i] = [*rng_cm.uniform(0.2, 0.95, 3), 0.65]\nlabel_cmap = ListedColormap(label_colors)\n\naxis_info = [\n (0, \"MIP along mu\", \"chi\", \"d\"),\n (1, \"MIP along chi\", \"d\", \"mu\"),\n (2, \"MIP along d\", \"chi\", \"mu\"),\n]\n\nfig, axes = plt.subplots(len(scans), 6, figsize=(22, len(scans) * 3.5))\n\nfor row, (s, v, c_lab, d_lab) in enumerate(zip(scans, all_volumes, classical_labels, dino_labels)):\n for col_offset, (method_name, labels) in enumerate([(\"Classical\", c_lab), (\"DINO\", d_lab)]):\n for ax_idx, (axis_id, title, xlabel, ylabel) in enumerate(axis_info):\n ax = axes[row, col_offset * 3 + ax_idx]\n mip = v.max(axis=axis_id)\n floor = otsu_floor_from_mip(v, axis=axis_id)\n proj_l = label_projection_by_intensity(v, labels, axis=axis_id, mip_floor=floor)\n\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n mask = np.ma.masked_where(proj_l == 0, proj_l)\n ax.imshow(mask, cmap=label_cmap, interpolation=\"nearest\", vmin=0, vmax=max_labels - 1)\n\n if row == 0:\n ax.set_title(f\"{method_name}\\n{title}\", fontsize=9)\n ax.tick_params(labelsize=6)\n if ax_idx == 0 and col_offset == 0:\n n_c = int(c_lab.max())\n n_d = int(d_lab.max())\n ax.set_ylabel(f\"{s.scan_name}\\nC={n_c} D={n_d}\", fontsize=9)\n\nplt.suptitle(\"Classical (left 3 cols) vs DINO (right 3 cols) \u2014 tri-axis label projection\", y=1.01, fontsize=13)\nplt.tight_layout()\nplt.show()", + "source": "# Build a shared colormap large enough for both methods.\nmax_labels = max(\n max(int(l.max()) for l in classical_labels),\n max(int(l.max()) for l in dino_labels),\n) + 1\nrng_cm = np.random.RandomState(42)\nlabel_colors = np.zeros((max_labels, 4))\nlabel_colors[0] = [0, 0, 0, 0]\nfor i in range(1, max_labels):\n label_colors[i] = [*rng_cm.uniform(0.2, 0.95, 3), 0.65]\nlabel_cmap = ListedColormap(label_colors)\n\naxis_info = [\n (0, \"MIP along mu\", \"chi\", \"d\"),\n (1, \"MIP along chi\", \"d\", \"mu\"),\n (2, \"MIP along d\", \"chi\", \"mu\"),\n]\n\nfig, axes = plt.subplots(len(scans), 6, figsize=(22, len(scans) * 3.5))\n\nfor row, (s, v, c_lab, d_lab) in enumerate(zip(scans, all_volumes, classical_labels, dino_labels)):\n for col_offset, (method_name, labels) in enumerate([(\"Classical\", c_lab), (\"DINO\", d_lab)]):\n for ax_idx, (axis_id, title, xlabel, ylabel) in enumerate(axis_info):\n ax = axes[row, col_offset * 3 + ax_idx]\n mip = v.max(axis=axis_id)\n floor = otsu_floor_from_mip(v, axis=axis_id)\n proj_l = label_projection_by_intensity(v, labels, axis=axis_id, mip_floor=floor)\n\n vlo, vhi = np.percentile(mip, [1, 99.9])\n ax.imshow(mip, cmap=\"gray\", vmin=vlo, vmax=vhi)\n mask = np.ma.masked_where(proj_l == 0, proj_l)\n ax.imshow(mask, cmap=label_cmap, interpolation=\"nearest\", vmin=0, vmax=max_labels - 1)\n\n if row == 0:\n ax.set_title(f\"{method_name}\\n{title}\", fontsize=9)\n ax.tick_params(labelsize=6)\n if ax_idx == 0 and col_offset == 0:\n n_c = int(c_lab.max())\n n_d = int(d_lab.max())\n ax.set_ylabel(f\"{s.scan_name}\\nC={n_c} D={n_d}\", fontsize=9)\n\nplt.suptitle(\"Classical (left 3 cols) vs DINO (right 3 cols) — tri-axis label projection\", y=1.01, fontsize=13)\nplt.tight_layout()\nplt.show()", "metadata": {}, "execution_count": null, "outputs": [] @@ -109,7 +109,7 @@ { "cell_type": "markdown", "id": "49b9b0e3", - "source": "## 5 \u2014 Instance feature comparison\n\nCompare the per-spot properties (voxel count, integrated intensity, centroid, eigenvalues) between the two methods.", + "source": "## 5 — Instance feature comparison\n\nCompare the per-spot properties (voxel count, integrated intensity, centroid, eigenvalues) between the two methods.", "metadata": {} }, { @@ -131,7 +131,7 @@ { "cell_type": "markdown", "id": "d3693a3f", - "source": "## 6 \u2014 Spatial overlap (Dice coefficient)\n\nFor each scan, compute the Dice coefficient between the binary foreground masks produced by the two methods. This measures how much the methods agree on *where* spots are, regardless of how they partition them into instances.", + "source": "## 6 — Spatial overlap (Dice coefficient)\n\nFor each scan, compute the Dice coefficient between the binary foreground masks produced by the two methods. This measures how much the methods agree on *where* spots are, regardless of how they partition them into instances.", "metadata": {} }, { @@ -145,7 +145,7 @@ { "cell_type": "markdown", "id": "4b7a8bbc", - "source": "## 7 \u2014 Centroid scatter: classical vs DINO\n\nPlot the centroids from both methods on the same axes. Matching centroids (spots found by both methods) will overlap; method-unique detections will stand alone.", + "source": "## 7 — Centroid scatter: classical vs DINO\n\nPlot the centroids from both methods on the same axes. Matching centroids (spots found by both methods) will overlap; method-unique detections will stand alone.", "metadata": {} }, { @@ -159,13 +159,13 @@ { "cell_type": "markdown", "id": "77d2e6e2", - "source": "## 8 \u2014 Consistency across scans\n\nA key motivation for DINO-based segmentation is consistency: the same physical spots should produce the same segmentation across consecutive scans. Compare the coefficient of variation (std/mean) of spot counts across scans for each method.", + "source": "## 8 — Consistency across scans\n\nA key motivation for DINO-based segmentation is consistency: the same physical spots should produce the same segmentation across consecutive scans. Compare the coefficient of variation (std/mean) of spot counts across scans for each method.", "metadata": {} }, { "cell_type": "code", "id": "82ec00b3", - "source": "c_arr = np.array(classical_counts, dtype=float)\nd_arr = np.array(dino_counts, dtype=float)\n\nc_cv = c_arr.std() / c_arr.mean() if c_arr.mean() > 0 else 0\nd_cv = d_arr.std() / d_arr.mean() if d_arr.mean() > 0 else 0\n\nconsistency = pd.DataFrame({\n \"Method\": [\"Classical\", \"DINO (mock)\"],\n \"Mean spots\": [c_arr.mean(), d_arr.mean()],\n \"Std spots\": [c_arr.std(), d_arr.std()],\n \"CV (std/mean)\": [c_cv, d_cv],\n \"Spread (max-min)\": [c_arr.max() - c_arr.min(), d_arr.max() - d_arr.min()],\n})\nprint(consistency.to_string(index=False))", + "source": "c_arr = np.array(classical_counts, dtype=float)\nd_arr = np.array(dino_counts, dtype=float)\n\nc_cv = c_arr.std() / c_arr.mean() if c_arr.mean() > 0 else 0\nd_cv = d_arr.std() / d_arr.mean() if d_arr.mean() > 0 else 0\n\nconsistency = pd.DataFrame({\n \"Method\": [\"Classical\", f\"DINOv3 ({DINO_BACKEND})\"],\n \"Mean spots\": [c_arr.mean(), d_arr.mean()],\n \"Std spots\": [c_arr.std(), d_arr.std()],\n \"CV (std/mean)\": [c_cv, d_cv],\n \"Spread (max-min)\": [c_arr.max() - c_arr.min(), d_arr.max() - d_arr.min()],\n})\nprint(consistency.to_string(index=False))", "metadata": {}, "execution_count": null, "outputs": [] @@ -173,13 +173,13 @@ { "cell_type": "markdown", "id": "ec6c4ef9", - "source": "## 9 \u2014 Per-scan feature tables\n\nFull feature tables for both methods on scan 1, for detailed inspection.", + "source": "## 9 — Per-scan feature tables\n\nFull feature tables for both methods on scan 1, for detailed inspection.", "metadata": {} }, { "cell_type": "code", "id": "49e37ae2", - "source": "cols = [\"label\", \"voxel_count\", \"integrated_intensity\",\n \"centroid_mu\", \"centroid_chi\", \"centroid_d\",\n \"eig_1\", \"eig_2\", \"eig_3\"]\n\nprint(\"=== Classical \u2014 scan0001 ===\")\ndisplay(pd.DataFrame(classical_features[0])[cols]) if classical_features[0] else print(\"(no spots)\")\n\nprint(\"\\n=== DINO \u2014 scan0001 ===\")\ndisplay(pd.DataFrame(dino_features[0])[cols]) if dino_features[0] else print(\"(no spots)\")", + "source": "cols = [\"label\", \"voxel_count\", \"integrated_intensity\",\n \"centroid_mu\", \"centroid_chi\", \"centroid_d\",\n \"eig_1\", \"eig_2\", \"eig_3\"]\n\nprint(\"=== Classical — scan0001 ===\")\ndisplay(pd.DataFrame(classical_features[0])[cols]) if classical_features[0] else print(\"(no spots)\")\n\nprint(\"\\n=== DINO — scan0001 ===\")\ndisplay(pd.DataFrame(dino_features[0])[cols]) if dino_features[0] else print(\"(no spots)\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -187,7 +187,7 @@ { "cell_type": "markdown", "id": "a64fefce", - "source": "## 10 \u2014 Notes and next steps\n\n**What the mock backend shows:** The mock DINO backend produces hash-based random features, so the clustering is not semantically meaningful \u2014 it tests the *pipeline plumbing* (slice extraction \u2192 PCA \u2192 HDBSCAN \u2192 stitching \u2192 foreground mask) but not the quality of learned representations.\n\n**What changes with real DINOv3 weights:**\n- Set `BRAGGTRACK_DINO_BACKEND=torch` (requires `torch` + `transformers` + GPU)\n- The encoder extracts genuine patch-level features where similar textures cluster together\n- Expect better instance separation without hand-tuned LoG/watershed parameters\n- The same model should work across different beamlines and detectors\n\n**CLI equivalent:**\n```bash\n# Classical\nbraggtrack-segment-dataset --method classical --outdir artifacts/classical\n\n# DINO (mock)\nbraggtrack-segment-dataset --method dino --dino-backend mock --outdir artifacts/dino\n\n# DINO (real weights)\nbraggtrack-segment-dataset --method dino --dino-backend torch --outdir artifacts/dino_real\n```", + "source": "## 10 — Notes\n\n**Backend selection:** This notebook uses `backend=\"auto\"` — on Colab (or any environment with PyTorch + `transformers`) it loads the real **DINOv3 ViT-B/16** weights (`facebook/dinov3-vitb16-pretrain-lvd1689m`, 86M params). Without PyTorch it falls back to a deterministic mock for CI.\n\n**DINOv3** ([arXiv 2508.10104](https://arxiv.org/abs/2508.10104)) is a family of vision foundation models from Meta AI that produce high-quality dense patch features without fine-tuning. The frozen patch tokens are clustered per-slice via HDBSCAN using a shared PCA space across all slices, then stitched into 3D labels.\n\n**Gated model:** DINOv3 requires accepting a license on HuggingFace before weights can be downloaded. The setup cell handles the login prompt on Colab.\n\n**CLI equivalent:**\n```bash\n# Classical\npython -m braggtrack.cli.segment_dataset --method classical --outdir artifacts/classical\n\n# DINOv3 (real weights — requires torch + transformers + HF token)\npython -m braggtrack.cli.segment_dataset --method dino --dino-backend torch --outdir artifacts/dino\n\n# DINOv3 (mock — for CI / no GPU)\npython -m braggtrack.cli.segment_dataset --method dino --dino-backend mock --outdir artifacts/dino_mock\n```", "metadata": {} } ], diff --git a/tests/test_segmentation_dino.py b/tests/test_segmentation_dino.py index cabc289..7a65857 100644 --- a/tests/test_segmentation_dino.py +++ b/tests/test_segmentation_dino.py @@ -19,7 +19,7 @@ def test_output_shape(self) -> None: enc = MockPatchEncoder() img = np.random.default_rng(42).standard_normal((56, 56)) features = enc.extract_patch_features(img) - self.assertEqual(features.shape, (4, 4, 384)) + self.assertEqual(features.shape, (3, 3, 384)) def test_deterministic(self) -> None: enc = MockPatchEncoder() @@ -40,7 +40,7 @@ class TestMakePatchEncoder(unittest.TestCase): def test_mock_backend(self) -> None: enc = make_patch_encoder("mock") self.assertIsInstance(enc, MockPatchEncoder) - self.assertEqual(enc.patch_size, 14) + self.assertEqual(enc.patch_size, 16) self.assertEqual(enc.feature_dim, 384)