Skip to content

Fix pytree input validation and use solved placements for forward validation#423

Merged
fmassa merged 3 commits intomainfrom
fmassa/fix_pytree_support
Apr 19, 2026
Merged

Fix pytree input validation and use solved placements for forward validation#423
fmassa merged 3 commits intomainfrom
fmassa/fix_pytree_support

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Apr 18, 2026

The simple API (auto_parallel) advertised pytree support (dict inputs, nested structures) but had two bugs that broke it at runtime:

  1. _compute_expected_inputs iterated over traced_inputs directly, so a dict element hit the non-tensor branch and sharding was never applied to its tensor leaves. The forward function similarly passed the raw pytree structure to the compiled graph, which expects flat tensor args.
  2. Forward validation used self.input_constraints (the user's request) rather than the solver's solution to determine expected input shapes. If the user set no constraints and the solver chose something other than the hardcoded default Shard(0), validation would reject correct inputs.

This PR fixes both by pytree-flattening inputs in _compute_expected_inputs and forward(), and by extracting the actual solved placements from the solution dict via get_plain_input_and_grad_nodes. This also removes the duplicated default placement logic from _compute_expected_inputs, making it a pure shape calculator.

Other cleanups along the way:

  • _extract_input_info now returns per-tensor devices (previously _make_input_fn hardcoded device="cuda")
  • apply_placement takes sharding_placement as a required argument (every call site already passed it explicitly)

Authored with Claude.

…idation

The simple API (auto_parallel) advertised pytree support (dict inputs, nested
structures) but had two bugs that broke it at runtime:

1. _compute_expected_inputs iterated over traced_inputs directly, so a dict
element hit the non-tensor branch and sharding was never applied to its
tensor leaves. The forward function similarly passed the raw pytree structure
to the compiled graph, which expects flat tensor args.
2. Forward validation used self.input_constraints (the user's request) rather
than the solver's solution to determine expected input shapes. If the user
set no constraints and the solver chose something other than the hardcoded
default Shard(0), validation would reject correct inputs.

This PR fixes both by pytree-flattening inputs in _compute_expected_inputs and
forward(), and by extracting the actual solved placements from the solution
dict via get_plain_input_and_grad_nodes. This also removes the duplicated
default placement logic from _compute_expected_inputs, making it a pure shape
calculator.

Other cleanups along the way:
- _extract_input_info now returns per-tensor devices (previously _make_input_fn
hardcoded device="cuda")
- apply_placement takes sharding_placement as a required argument (every call
site already passed it explicitly)

Authored with Claude.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 18, 2026
@fmassa fmassa merged commit 75ecc6a into main Apr 19, 2026
9 of 11 checks passed
@fmassa fmassa deleted the fmassa/fix_pytree_support branch April 19, 2026 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant