Conversation
|
Hum, but is that the problem? @NolanKoblischke issue was that things didn't work if he had a batch size of 1, or no batch size |
|
Can we instead add a piece of code in the codecs encode functions to ensure a batch dimension is there? |
EiffL
left a comment
There was a problem hiding this comment.
Code Review: Massive update to subsampler
Overview
This PR makes a targeted fix to the subsampler.py file, changing a generic squeeze() call to explicitly specify squeeze(-1) when calculating the scales tensor. Despite the "massive update" title, this is a minimal but important change.
Analysis
Code Quality & Correctness ✅
- Good fix: The change from
squeeze()tosqueeze(-1)is a sound improvement - Explicit dimension handling: Specifying the dimension to squeeze makes the code more predictable and safer
- Maintains functionality: The mathematical operation remains the same but with better dimension control
Following Project Conventions ✅
- Consistent with PyTorch best practices: Explicit dimension specification is recommended
- Aligns with codebase style: The change follows the project's tensor operation patterns
- Type safety: Helps maintain tensor shape consistency with the jaxtyping annotations used throughout the project
Technical Implications
Benefits:
- Prevents silent errors: Generic
squeeze()could remove unintended dimensions if tensor shapes change - Better debugging: Explicit dimension specification makes errors more traceable
- Shape consistency: Ensures the
scalestensor always has the expected shape for broadcasting
Potential considerations:
- Runtime behavior: If the last dimension is not size 1, this will now raise an error instead of silently doing nothing
- Backward compatibility: Could break code that relied on the previous behavior with unexpected tensor shapes
Context Analysis
Looking at the code:
scales = ((self.dim_in / label_sizes) ** 0.5).squeeze(-1)label_sizeshas shape[b, 1](fromkeepdim=True)- The division and power operations maintain this shape
squeeze(-1)removes the last dimension, resulting in shape[b]- This is then used for broadcasting:
scales[:, None, None, None]
Verdict
LGTM - This is a good defensive programming practice that improves code reliability without changing functionality. The change is solid and should be merged.
Minor suggestion: Consider adding a comment explaining why the last dimension is squeezed, especially since this relates to the broadcasting pattern used later.
This pull request includes a minor fix to the
subsampler.pyfile. The change ensures that thesqueezemethod explicitly removes the last dimension (dim=-1) when calculating thescalestensor.