Add PyTorch DPT decoder module with shared head and task-specific decoders#18
Add PyTorch DPT decoder module with shared head and task-specific decoders#18
Conversation
Add pytorch/decoders.py implementing the DPT dense prediction head hierarchy, mirroring the Scenic/Flax decoder structure: - DPTHead: shared backbone (ReassembleBlocks + FeatureFusion) - Decoder: base class wrapping a DPTHead - SegmentationDecoder: per-pixel class logits (num_classes output) - DepthDecoder: monocular depth (1-channel output) - NormalsDecoder: surface normals (3-channel output) - load_decoder_weights(): loads pre-converted .npz checkpoints with backward-compatible key remapping from the legacy flat format
…line DPT definition - Remove ~200 lines of redundant inline DPTSegmentationHead code - Import SegmentationDecoder and load_decoder_weights from tips.pytorch.decoders (PR #18) - Clone add-decoders-module branch for the decoders module - Use per-variant INTERMEDIATE_LAYERS_MAP instead of hardcoded layer indices - Download ADE20K images at runtime instead of hosting in repo
pytorch/decoders.py
Outdated
| state_dict parameter names and values are NumPy arrays in PyTorch layout | ||
| (already transposed from Flax/JAX format). | ||
|
|
||
| Use ``scripts/convert_segmentation_checkpoint.py`` to produce these |
There was a problem hiding this comment.
please remove this internal detail
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Task-specific decoders (Refactored as Thin Wrappers) |
There was a problem hiding this comment.
nit. remove "(Refactored as Thin Wrappers)"
pytorch/decoders.py
Outdated
|
|
||
| This module provides a shared DPT backbone (ReassembleBlocks + fusion) and | ||
| task-specific decoder subclasses for segmentation, depth, and surface normals, | ||
| mirroring the Scenic/Flax decoder hierarchy. |
There was a problem hiding this comment.
Remove mention of our scenic/flax internals.
pytorch/decoders.py
Outdated
| residual = self.residual_unit(residual) | ||
| x = x + residual | ||
| x = self.main_unit(x) | ||
| # Upsample 2x with align_corners=True (matches Scenic reference). |
| super().__init__(out_channels=num_classes, **kwargs) | ||
|
|
||
|
|
||
| class DepthDecoder(Decoder): |
There was a problem hiding this comment.
DepthDecoder actually is classification based, not regression:
class DepthDecoder(Decoder):
"""Decoder for monocular depth prediction using classification bins."""
def init(
self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs
) -> None:
super().init(out_channels=256, **kwargs)
self.min_depth = min_depth
self.max_depth = max_depth
self.register_buffer(
"bin_centers", torch.linspace(min_depth, max_depth, 256)
)
def forward(
self,
intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]],
image_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
logits = super().forward(intermediate_features, image_size=image_size)
# Apply ReLU and shift
logits = torch.relu(logits) + self.min_depth
# Normalize to probabilities along the channel dimension
probs = logits / torch.sum(logits, dim=1, keepdim=True)
# Compute expectation: sum(prob * bin_center)
depth_map = torch.einsum(
"bchw,c->bhw", probs, self.bin_centers.to(logits.device)
)
return depth_map.unsqueeze(1)
Summary
Add
pytorch/decoders.pyimplementing the DPT dense prediction decoder hierarchy in PyTorch, mirroring the Scenic/Flax decoder structure.Architecture
The module follows a shared-head pattern where a common DPT backbone is subclassed for each task:
DPTHead: Shared backbone performing feature reassembly at multiple scales (ReassembleBlocks) and bottom-up fusion (FeatureFusionBlocks).Decoder: Base class wrapping a DPTHead with a common interface.SegmentationDecoder(Decoder): Per-pixel class logits (num_classesoutput channels).DepthDecoder(Decoder): Monocular depth prediction (1-channel output).NormalsDecoder(Decoder): Surface normals prediction (3-channel output).Weight Loading
load_decoder_weights()loads pre-converted.npzcheckpoints (produced byscripts/convert_segmentation_checkpoint.py) with backward-compatible key remapping from the legacy flat format.Usage