Skip to content

Add PyTorch DPT decoder module with shared head and task-specific decoders#18

Merged
koertchen merged 3 commits intomainfrom
add-decoders-module
Apr 14, 2026
Merged

Add PyTorch DPT decoder module with shared head and task-specific decoders#18
koertchen merged 3 commits intomainfrom
add-decoders-module

Conversation

@bingyic
Copy link
Copy Markdown
Collaborator

@bingyic bingyic commented Apr 14, 2026

Summary

Add pytorch/decoders.py implementing the DPT dense prediction decoder hierarchy in PyTorch, mirroring the Scenic/Flax decoder structure.

Architecture

The module follows a shared-head pattern where a common DPT backbone is subclassed for each task:

  • DPTHead: Shared backbone performing feature reassembly at multiple scales (ReassembleBlocks) and bottom-up fusion (FeatureFusionBlocks).
  • Decoder: Base class wrapping a DPTHead with a common interface.
  • SegmentationDecoder(Decoder): Per-pixel class logits (num_classes output channels).
  • DepthDecoder(Decoder): Monocular depth prediction (1-channel output).
  • NormalsDecoder(Decoder): Surface normals prediction (3-channel output).

Weight Loading

load_decoder_weights() loads pre-converted .npz checkpoints (produced by scripts/convert_segmentation_checkpoint.py) with backward-compatible key remapping from the legacy flat format.

Usage

from tips.pytorch.decoders import SegmentationDecoder, load_decoder_weights

decoder = SegmentationDecoder(num_classes=150, input_embed_dim=1024)
load_decoder_weights(decoder, 'path/to/checkpoint.npz')
logits = decoder(intermediate_features, image_size=(480, 640))

Add pytorch/decoders.py implementing the DPT dense prediction head
hierarchy, mirroring the Scenic/Flax decoder structure:

- DPTHead: shared backbone (ReassembleBlocks + FeatureFusion)
- Decoder: base class wrapping a DPTHead
- SegmentationDecoder: per-pixel class logits (num_classes output)
- DepthDecoder: monocular depth (1-channel output)
- NormalsDecoder: surface normals (3-channel output)
- load_decoder_weights(): loads pre-converted .npz checkpoints with
  backward-compatible key remapping from the legacy flat format
bingyic added a commit that referenced this pull request Apr 14, 2026
…line DPT definition

- Remove ~200 lines of redundant inline DPTSegmentationHead code
- Import SegmentationDecoder and load_decoder_weights from tips.pytorch.decoders (PR #18)
- Clone add-decoders-module branch for the decoders module
- Use per-variant INTERMEDIATE_LAYERS_MAP instead of hardcoded layer indices
- Download ADE20K images at runtime instead of hosting in repo
state_dict parameter names and values are NumPy arrays in PyTorch layout
(already transposed from Flax/JAX format).

Use ``scripts/convert_segmentation_checkpoint.py`` to produce these
Copy link
Copy Markdown
Collaborator

@koertchen koertchen Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this internal detail



# ---------------------------------------------------------------------------
# Task-specific decoders (Refactored as Thin Wrappers)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. remove "(Refactored as Thin Wrappers)"


This module provides a shared DPT backbone (ReassembleBlocks + fusion) and
task-specific decoder subclasses for segmentation, depth, and surface normals,
mirroring the Scenic/Flax decoder hierarchy.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove mention of our scenic/flax internals.

residual = self.residual_unit(residual)
x = x + residual
x = self.main_unit(x)
# Upsample 2x with align_corners=True (matches Scenic reference).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

super().__init__(out_channels=num_classes, **kwargs)


class DepthDecoder(Decoder):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DepthDecoder actually is classification based, not regression:

class DepthDecoder(Decoder):
"""Decoder for monocular depth prediction using classification bins."""
def init(
self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs
) -> None:
super().init(out_channels=256, **kwargs)
self.min_depth = min_depth
self.max_depth = max_depth
self.register_buffer(
"bin_centers", torch.linspace(min_depth, max_depth, 256)
)
def forward(
self,
intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]],
image_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
logits = super().forward(intermediate_features, image_size=image_size)
# Apply ReLU and shift
logits = torch.relu(logits) + self.min_depth
# Normalize to probabilities along the channel dimension
probs = logits / torch.sum(logits, dim=1, keepdim=True)
# Compute expectation: sum(prob * bin_center)
depth_map = torch.einsum(
"bchw,c->bhw", probs, self.bin_centers.to(logits.device)
)
return depth_map.unsqueeze(1)

@koertchen koertchen merged commit cf7d8d6 into main Apr 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants