Cast connector inputs to the live weight dtype#61
Merged
Conversation
The connectors cast incoming signals to self.input_dtype, the dtype captured at construction time. Casting the model afterwards (e.g. .float() for fp64/fp32 evaluation or debugging) then crashes with a dtype mismatch, because inputs still arrive as bfloat16 while the weights have moved on. Cast to the projection weights' current dtype instead; input_dtype remains as the construction-time module dtype. Default bf16 path is unchanged (verified identical outputs for all four connectors); after .float() each connector now runs and returns float32.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Second piece of splitting up #12 into simpler PRs (the dtype-robustness fix that the full-determinism experiments needed), extended to all four connectors.
What
LinearProjection,MLPProjection,PatchProjection, andCNNPatchProjectioncast incoming signals toself.input_dtype— the dtype captured at construction. If the model's dtype is changed afterwards (.float()for high-precision evaluation/debugging, fp64 determinism experiments), the forward crashes:Cast to the projection weights' current dtype instead.
input_dtypekeeps its construction-time role.Verification
For all four connectors: default bf16 outputs are bit-identical before/after this change; after
.float()the forward previously raised the RuntimeError above and now runs, returning float32.