Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions braggtrack/cli/embed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def build_parser() -> argparse.ArgumentParser:
"--backend",
choices=("auto", "mock", "torch"),
default="auto",
help="Embedding backend (mock needs no PyTorch; torch uses Dinov2-small)",
help="Embedding backend (mock needs no PyTorch; torch uses DINOv3)",
)
p.add_argument("--model", default="facebook/dinov2-small", help="HF model id when backend=torch")
p.add_argument("--model", default="facebook/dinov3-vitb16-pretrain-lvd1689m", help="HF model id when backend=torch")
return p


Expand Down
4 changes: 2 additions & 2 deletions braggtrack/cli/segment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def build_parser() -> argparse.ArgumentParser:
help="HuggingFace model ID for the DINO torch backend",
)
parser.add_argument("--dino-pca-components", type=int, default=16, help="PCA components for DINO feature reduction")
parser.add_argument("--dino-min-cluster-size", type=int, default=3, help="HDBSCAN min_cluster_size")
parser.add_argument("--dino-min-cluster-size", type=int, default=5, help="HDBSCAN min_cluster_size")
parser.add_argument("--dino-min-samples", type=int, default=2, help="HDBSCAN min_samples")
parser.add_argument(
"--dino-min-overlap",
type=float,
default=0.3,
default=0.2,
help="Min overlap fraction for 3D slice stitching",
)
return parser
Expand Down
110 changes: 92 additions & 18 deletions braggtrack/segmentation/dino_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,46 +55,101 @@ def _extract_slice_features(
return np.stack(feature_maps, axis=0), slice_hw


def _patch_foreground_masks(
volume: np.ndarray,
threshold: float,
patch_size: int,
axis: int = 0,
) -> list[np.ndarray]:
"""Boolean foreground mask at patch resolution for each slice.

A patch is foreground if any of its pixels exceed *threshold*.
"""
n_slices = volume.shape[axis]
masks: list[np.ndarray] = []
for i in range(n_slices):
slc = np.take(volume, i, axis=axis)
fg = (slc >= threshold).astype(np.float64)
h, w = slc.shape
hp = max(1, h // patch_size)
wp = max(1, w // patch_size)
cropped = fg[: hp * patch_size, : wp * patch_size]
blocks = cropped.reshape(hp, patch_size, wp, patch_size)
masks.append(blocks.max(axis=(1, 3)) > 0)
return masks


def _cluster_feature_map(
features: np.ndarray,
*,
n_components_pca: int = 16,
min_cluster_size: int = 3,
min_cluster_size: int = 5,
min_samples: int = 2,
foreground_mask: np.ndarray | None = None,
pca_model: object | None = None,
) -> np.ndarray:
"""Cluster a 2D feature map ``(H_p, W_p, D)`` into instance labels.

Uses PCA dimensionality reduction followed by HDBSCAN.
Returns ``(H_p, W_p)`` int array, 0 = background/noise.

When *foreground_mask* is given only foreground patches participate
in clustering; background patches get label 0. When *pca_model* is
given it is used for the PCA transform instead of fitting per-slice.
"""
from sklearn.cluster import HDBSCAN
from sklearn.decomposition import PCA

h_p, w_p, d = features.shape
n_patches = h_p * w_p
flat = features.reshape(-1, d)

n_comp = min(n_components_pca, d, n_patches)
if n_comp < 2:
# Too few patches to cluster — assign all to a single region.
return np.ones((h_p, w_p), dtype=np.int32)
if foreground_mask is not None:
fg_flat = foreground_mask.ravel()
n_fg = int(fg_flat.sum())
if n_fg == 0:
return np.zeros((h_p, w_p), dtype=np.int32)
if n_fg < 2:
out = np.zeros((h_p, w_p), dtype=np.int32)
out[foreground_mask] = 1
return out
else:
fg_flat = np.ones(h_p * w_p, dtype=bool)
n_fg = h_p * w_p

reduced = PCA(n_components=n_comp).fit_transform(flat)
fg_features = flat[fg_flat]

effective_min_cluster = max(2, min(min_cluster_size, n_patches // 2))
if pca_model is not None:
reduced = pca_model.transform(fg_features)
else:
n_comp = min(n_components_pca, d, n_fg)
if n_comp < 2:
out = np.zeros((h_p, w_p), dtype=np.int32)
if foreground_mask is not None:
out[foreground_mask] = 1
else:
out[:] = 1
return out
reduced = PCA(n_components=n_comp).fit_transform(fg_features)

effective_min_cluster = max(2, min(min_cluster_size, n_fg // 2))
clusterer = HDBSCAN(
min_cluster_size=effective_min_cluster,
min_samples=max(1, min(min_samples, effective_min_cluster - 1)),
)
raw_labels = clusterer.fit_predict(reduced)
# HDBSCAN labels: -1 = noise, 0..K = clusters. Shift to 1-based.
labels = np.where(raw_labels >= 0, raw_labels + 1, 0)
cluster_labels = np.where(raw_labels >= 0, raw_labels + 1, 0)

# If HDBSCAN assigned everything to noise, treat all patches as one region.
if not np.any(labels > 0):
return np.ones((h_p, w_p), dtype=np.int32)
if not np.any(cluster_labels > 0):
out = np.zeros((h_p, w_p), dtype=np.int32)
if foreground_mask is not None:
out[foreground_mask] = 1
else:
out[:] = 1
return out

return labels.reshape(h_p, w_p).astype(np.int32)
out = np.zeros(h_p * w_p, dtype=np.int32)
out[fg_flat] = cluster_labels
return out.reshape(h_p, w_p)


def _upsample_labels(
Expand Down Expand Up @@ -205,10 +260,10 @@ def segment_dino(
model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m",
torch_device: str | None = None,
n_components_pca: int = 16,
min_cluster_size: int = 3,
min_cluster_size: int = 5,
min_samples: int = 2,
threshold_fraction: float = 1.0,
min_overlap_fraction: float = 0.3,
min_overlap_fraction: float = 0.2,
axis: int = 0,
) -> DinoSegmentationResult:
"""Segment a 3D volume using DINOv3 patch-level features + HDBSCAN.
Expand All @@ -232,6 +287,8 @@ def segment_dino(
axis
Axis to slice along (0 = mu/z, typically the narrowest).
"""
from sklearn.decomposition import PCA

from braggtrack.segmentation.otsu import otsu_threshold
from braggtrack.semantic.dino import make_patch_encoder

Expand All @@ -243,13 +300,31 @@ def segment_dino(

features, slice_hw = _extract_slice_features(volume, encoder, axis=axis)

# Foreground mask at patch resolution — exclude background from clustering.
fg_masks = _patch_foreground_masks(volume, threshold, encoder.patch_size, axis=axis)

# Global PCA across all foreground patches for a consistent feature space
# (per-slice PCA causes wildly different cluster counts).
n_slices = features.shape[0]
d = features.shape[-1]
all_fg = [features[i][fg_masks[i]] for i in range(n_slices)]
all_fg_cat = np.concatenate(all_fg, axis=0) if any(f.size > 0 for f in all_fg) else np.empty((0, d))

n_comp = min(n_components_pca, d, all_fg_cat.shape[0])
pca = None
if n_comp >= 2:
pca = PCA(n_components=n_comp)
pca.fit(all_fg_cat)

per_slice_labels: list[np.ndarray] = []
for i in range(features.shape[0]):
for i in range(n_slices):
patch_labels = _cluster_feature_map(
features[i],
n_components_pca=n_components_pca,
min_cluster_size=min_cluster_size,
min_samples=min_samples,
foreground_mask=fg_masks[i],
pca_model=pca,
)
full_labels = _upsample_labels(patch_labels, slice_hw, encoder.patch_size)
per_slice_labels.append(full_labels)
Expand All @@ -261,7 +336,6 @@ def segment_dino(
labels_3d = np.where(foreground, labels_3d, 0).astype(np.int32)

component_count = len(np.unique(labels_3d[labels_3d > 0]))
# response is empty — DINO segmentation has no LoG-equivalent response surface.
return DinoSegmentationResult(
threshold=threshold,
seed_count=component_count,
Expand Down
18 changes: 10 additions & 8 deletions braggtrack/semantic/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

* ``mock`` — deterministic CPU-only vectors from image bytes (default when
PyTorch is unavailable).
* ``torch`` — Hugging Face DINOv3/v2 model (CLS or patch tokens).
* ``torch`` — Hugging Face DINOv3 model (CLS or patch tokens).
* ``auto`` — use ``torch`` if import succeeds, else ``mock``.
"""

Expand Down Expand Up @@ -73,7 +73,7 @@ def embed(self, mip_mu: np.ndarray, mip_chi: np.ndarray, mip_d: np.ndarray) -> n


class TorchDinoMultiviewEncoder:
"""Loads Dinov2 once; call :meth:`embed` per spot."""
"""Loads DINOv3 once; call :meth:`embed` per spot."""

def __init__(self, model_name: str, device: str | None = None) -> None:
import torch
Expand Down Expand Up @@ -119,7 +119,7 @@ def _requested_backend(explicit: BackendName | None) -> BackendName:
def make_multiview_encoder(
backend: BackendName | None = None,
*,
model_name: str = "facebook/dinov2-small",
model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m",
torch_device: str | None = None,
) -> MultiviewEncoder:
"""Construct a reusable encoder (loads torch weights at most once)."""
Expand Down Expand Up @@ -154,7 +154,7 @@ class MockPatchEncoder:

@property
def patch_size(self) -> int:
return 14
return 16

@property
def feature_dim(self) -> int:
Expand All @@ -177,7 +177,7 @@ def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray:


class TorchDinoPatchEncoder:
"""Extracts DINOv2/v3 patch tokens as a spatial feature map."""
"""Extracts DINOv3 patch tokens as a spatial feature map."""

def __init__(self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m", device: str | None = None) -> None:
import torch
Expand All @@ -189,8 +189,9 @@ def __init__(self, model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m",
self._model = AutoModel.from_pretrained(model_name)
self._model.eval()
self._model.to(self._device)
self._patch_size = getattr(self._model.config, "patch_size", 14)
self._patch_size = getattr(self._model.config, "patch_size", 16)
self._feature_dim = int(self._model.config.hidden_size)
self._num_register_tokens = getattr(self._model.config, "num_register_tokens", 0)

@property
def patch_size(self) -> int:
Expand All @@ -206,7 +207,8 @@ def extract_patch_features(self, image_2d: np.ndarray) -> np.ndarray:
inputs = self._proc(images=[rgb], return_tensors="pt")
inputs = {k: v.to(self._device) for k, v in inputs.items()}
out = self._model(**inputs)
patch_tokens = out.last_hidden_state[:, 1:, :].squeeze(0)
skip = 1 + self._num_register_tokens
patch_tokens = out.last_hidden_state[:, skip:, :].squeeze(0)
h_img = inputs["pixel_values"].shape[2]
w_img = inputs["pixel_values"].shape[3]
h_p = h_img // self._patch_size
Expand Down Expand Up @@ -238,7 +240,7 @@ def embed_multiview_mips(
mip_d: np.ndarray,
*,
backend: BackendName | None = None,
model_name: str = "facebook/dinov2-small",
model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m",
torch_device: str | None = None,
) -> np.ndarray:
"""Return a single L2-normalised concatenated feature vector."""
Expand Down
Loading
Loading