Fix pytree input validation and use solved placements for forward validation#423
Merged
Fix pytree input validation and use solved placements for forward validation#423
Conversation
…idation The simple API (auto_parallel) advertised pytree support (dict inputs, nested structures) but had two bugs that broke it at runtime: 1. _compute_expected_inputs iterated over traced_inputs directly, so a dict element hit the non-tensor branch and sharding was never applied to its tensor leaves. The forward function similarly passed the raw pytree structure to the compiled graph, which expects flat tensor args. 2. Forward validation used self.input_constraints (the user's request) rather than the solver's solution to determine expected input shapes. If the user set no constraints and the solver chose something other than the hardcoded default Shard(0), validation would reject correct inputs. This PR fixes both by pytree-flattening inputs in _compute_expected_inputs and forward(), and by extracting the actual solved placements from the solution dict via get_plain_input_and_grad_nodes. This also removes the duplicated default placement logic from _compute_expected_inputs, making it a pure shape calculator. Other cleanups along the way: - _extract_input_info now returns per-tensor devices (previously _make_input_fn hardcoded device="cuda") - apply_placement takes sharding_placement as a required argument (every call site already passed it explicitly) Authored with Claude.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The simple API (
auto_parallel) advertisedpytreesupport (dict inputs, nested structures) but had two bugs that broke it at runtime:_compute_expected_inputsiterated overtraced_inputsdirectly, so a dict element hit the non-tensor branch and sharding was never applied to its tensor leaves. The forward function similarly passed the rawpytreestructure to the compiled graph, which expects flat tensor args.self.input_constraints(the user's request) rather than the solver's solution to determine expected input shapes. If the user set no constraints and the solver chose something other than the hardcoded defaultShard(0), validation would reject correct inputs.This PR fixes both by
pytree-flattening inputs in_compute_expected_inputs and forward(), and by extracting the actual solved placements from the solution dict viaget_plain_input_and_grad_nodes. This also removes the duplicated default placement logic from_compute_expected_inputs, making it a pure shape calculator.Other cleanups along the way:
_extract_input_infonow returns per-tensor devices (previously_make_input_fnhardcodeddevice="cuda")apply_placementtakessharding_placementas a required argument (every call site already passed it explicitly)Authored with Claude.