Skip to content

Support dynamic batch sizes via symbolic shapes#418

Open
fmassa wants to merge 19 commits intomainfrom
fmassa/dynamic_shapes
Open

Support dynamic batch sizes via symbolic shapes#418
fmassa wants to merge 19 commits intomainfrom
fmassa/dynamic_shapes

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Apr 13, 2026

When dynamic=True, AutoParallel traces the model with symbolic batch dimensions so the parallel model accepts arbitrary batch sizes at runtime, without re-running the ILP optimizer.

The implementation has three phases, best reviewed in this order:

  1. Joint graph tracing with symbolic inputs (api.py, input_validation.py): Input tensors get all dimensions marked as dynamic via StatelessSymbolicContext. PyTorch's ShapeEnv automatically concretizes non-batch dimensions (hidden dim, nheads, etc.) through guards when they interact with concrete parameter shapes, so only truly dynamic dims survive as symbolic. _check_forward_args is relaxed to accept any size for SymInt dimensions.
  2. ILP solver isolation (optimize_sharding.py): Shape-computation nodes (sym_size, operator.mul) are skipped in build_sharding_metadata — they produce scalars, not tensors. All user_args and user_strats are concretized (both bare SymInts and FakeTensors with symbolic shapes) at the boundary so the entire solver sees the same concrete values as the non-dynamic case. self.nodes and self.node_map are derived from strats to keep indices consistent.
  3. Parallel graph lowering (apply_sharding.py): View ops bypass DTensor dispatch (which can't handle symbolic expressions) and instead compute local shapes via view_groups dim mapping applied to the local input tensor's shape. After the two-pass lowering, _re_symbolize_graph re-traces the parallel graph outside the FakeTensorMode to replace the joint graph's SymInts with fresh symbols that Inductor can codegen.

Also adds apply_cuda_patches to conftest.py so that tests work on machines with fewer GPUs than the fake world size (256 devices). Previously, tests failed on 1-GPU machines due to deferred CUDA capability checks trying to access non-existent devices.

Authored with Claude.

fmassa added 4 commits April 12, 2026 15:17
When dynamic=True, AutoParallel now traces the model with symbolic batch dimensions so the parallel model accepts arbitrary batch sizes at runtime, without re-running the ILP optimizer.

The implementation has three phases, best reviewed in this order:

1. Joint graph tracing with symbolic inputs (api.py, input_validation.py): Input tensors get all dimensions marked as dynamic via StatelessSymbolicContext. PyTorch's ShapeEnv automatically concretizes non-batch
dimensions (hidden dim, nheads, etc.) through guards when they interact with concrete parameter shapes, so only truly dynamic dims survive as symbolic.
2. ILP solver isolation (optimize_sharding.py): Shape-computation nodes (sym_size, operator.mul) that appear in the symbolic graph are skipped in the sharding metadata — they produce scalars, not tensors. All
user_args and user_strats are concretized at the build_sharding_metadata boundary so the ILP solver sees the same concrete values as the non-dynamic case.
3. Parallel graph lowering (apply_sharding.py): View ops bypass DTensor dispatch (which can't handle symbolic expressions) and instead compute local shapes via view_groups dim mapping applied to the local input
tensor's symbolic shape. After lowering, _re_symbolize_graph re-traces the parallel graph outside the FakeTensorMode to replace the joint graph's SymInts with fresh symbols derivable from the parallel graph's own
inputs, which Inductor can codegen.

Authored with Claude.
29 tests pass in test_dynamic_shapes.py:

  Unit tests (23, no GPU/mesh required):
  - TestCheckForwardArgs (5 tests) — concrete match, mismatch, SymInt accepts any batch, rejects wrong static dim, ndim mismatch
  - TestMakeInputsDynamic (4 tests) — all dims symbolic, non-tensor passthrough, concretization after mm, meta tensors no allocation
  - TestConcretizeArgs (2 tests) — SymInt/FakeTensor concretization, concrete passthrough
  - TestProducesTensor (4 tests) — tensor, tuple of tensors, nested with None, scalars
  - TestComputeLocalViewShape (4 tests) — flatten BS, unflatten BS, split heads with TP, merge heads
  - TestReSymbolizeGraph (1 test) — fresh symbols from different ShapeEnv
  - TestConcretizeShape (2 tests) — concrete passthrough, SymInt to hint
  - TestArgsHaveSymints (1 test, inside other tests)

  Integration tests (6, need mesh fixtures):
  - ILP placement consistency (static vs dynamic) for 1D and 2D meshes
  - apply_placement success for FFN and TransformerBlock
  - Joint graph has symbolic shapes
  - _check_forward_args accepts different batch sizes
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 13, 2026
fmassa added 2 commits April 13, 2026 09:44
  Description:

  Replace the two separate code paths for creating local args (shard_placeholder_inputs with DTensor redistribution for static, _make_symbolic_local_args for dynamic) with a single _make_local_args that computes local shapes by directly dividing
   sharded dims. Both paths produce identical results since the ILP validates even sharding, making DTensor's ceiling division a no-op.

  Also removes the meta=False path from shard_node_given_placements since its only caller was the now-removed shard_placeholder_inputs.

  Authored with Claude.
@fmassa fmassa requested a review from xmfan April 13, 2026 12:48
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did add better dynamic shape support for views in pytorch/pytorch#174629, I wonder if the issues you hit might be solved in pytorch/pytorch#175287

but also it seems this is not using unbacked symbols?

@aditvenk
Copy link
Copy Markdown
Contributor

We did add better dynamic shape support for views in pytorch/pytorch#174629, I wonder if the issues you hit might be solved in pytorch/pytorch#175287

but also it seems this is not using unbacked symbols?

Should this use unbacked symbols?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 14, 2026

@laithsakka

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 14, 2026

I am not a huge fan of the written strategy for (1), it seems to me that it would be simpler and less error prone to just force explicit annotation of dynamic inputs

@fmassa
Copy link
Copy Markdown
Contributor Author

fmassa commented Apr 14, 2026

@ezyang can you clarify your point about

I am not a huge fan of the written strategy for (1), it seems to me that it would be simpler and less error prone to just force explicit annotation of dynamic inputs

Given that we assume full-graph, and that we are only assuming all shapes are dynamic for input arguments (i.e., parameters and buffers are static), only really the batch / sequence dimension end up being dynamic through the lifetime of the program, which I think is pretty reasonable?

What are the concerns you have?

That being said, I'm still having issues with this PR in conjunction with the dynamic shapes introduced by DeepSeekV3 MOE within the local_map, so I'd appreciate some help there

fmassa added 10 commits April 15, 2026 16:04
… local args

  Description:

  This rewrites the dynamic shapes lowering pipeline to be simpler and more correct, removing ~170 lines of code in the process.

  The key insight: by swapping a fresh ShapeEnv onto the FakeTensorMode before creating local args, make_fx with tracing_mode="symbolic" naturally propagates fresh symbols through the entire parallel graph. This eliminates the need for post-hoc
  _re_symbolize_graph (which changed graph structure by decomposing HOP tuple outputs) and _compute_local_view_shape (which reimplemented DTensor's view_groups dim mapping).

  Changes:

  1. ShapeEnv swap in apply_sharding_to_model: Swaps a fresh ShapeEnv onto the FakeTensorMode before lowering. The new ShapeEnv is kept permanently (not restored) since update_joint_with_descriptors copies the parallel graph's metadata.
  2. _make_local_args rewritten: Uses DTensor from_local().redistribute() for correct local shapes/strides (restoring the original approach from main), then re-creates tensors with fresh SymInts via from_tensor + StatelessSymbolicContext.
  Concretizes any old SymInts before DTensor to avoid ShapeEnv conflicts.
  3. Unified view/factory shape handling: In dynamic mode, shape args computed from local tensors (SymInts) are already local — kept as-is. Concrete shape args (global constants baked in the graph) are adjusted by dividing sharded dims by mesh
  size. Same logic for both factory ops and view ops. In static mode, view ops use DTensor from_local wrapping as before.
  4. Removed _re_symbolize_graph: No longer needed — fresh symbols propagate naturally from _make_local_args through make_fx.
  5. Removed _compute_local_view_shape and view_groups machinery: ~80 lines of dim mapping code (_compute_local_dim, _compute_local_view_shape, _RESHAPE_OPS, Flatten/Split/InputDim/Singleton imports) removed. Dynamic view ops just execute
  directly since shape args are already local.
  6. ILP solver fixes for local_map HOPs: concretize_symint/concretize_args consolidated as two primitives. _concretize_tensor_meta returns None for tensors with unbacked SymInts (HOP-internal activations). get_local_map_placement_option
  robustly handles non-tensor inputs, non-OpStrategy specs, and filtered input_specs.
  7. Cost estimation: _shard_args_for_node concretizes SymInts, _concretize_val concretizes placeholder FakeTensors, preventing SymFloat costs from leaking into the ILP.
  8. DSV3 model fix: Reordered local_mapped_region arguments to match Dynamo's freevar reordering.

  Authored with Claude.
  Description:

  Fixes two correctness issues in the dynamic shapes support:

  1. requires_grad preservation: _make_inputs_dynamic (api.py) and _make_local_args (apply_sharding.py) were creating meta tensors without requires_grad, silently dropping input gradient tracking. This changed the joint graph's autograd
  semantics — output descriptors that should be GradAOTOutput(grad_of=PlainAOTInput(...)) became None, meaning the backward pass no longer computed gradients w.r.t. inputs.
  2. Input validation for concretized SymInt dims: _check_forward_args was treating all SymInt dimensions as wildcards (isinstance(exp, torch.SymInt): continue), but some SymInts get concretized by guards during tracing (e.g., hidden dim after
  mm with a concrete weight). These have expr.is_number == True and should be validated against their concrete value. Added _get_expected_dim_value which returns None (skip) only for genuinely dynamic dims, and the concrete int(expr) for
  concretized ones.

  Authored with Claude.
  Description:

  Code review cleanup addressing three issues:

  1. Remove dead _filter_specs_for_local_map: This helper was defined and unit-tested but never called — the actual filtering logic in _redistribute_and_adjust_args does its own inline handling. Removed the function and its 4 tests.
  2. Tighten _get_input_nodes / _all_input_nodes: Both ApplyShardingInterpreter._get_input_nodes and ShardingOptimizer._all_input_nodes were silently dropping any node not in sharding_placement/strats. Now they assert that non-get_attr nodes
  missing from the placement map must be non-tensor-producing (scalar shape-computation nodes like sym_size, operator.mul). This catches accidental omissions of tensor-producing nodes that would lead to hard-to-debug downstream shape mismatches.
  3. Assert shared FakeTensorMode in ShapeEnv swap: apply_sharding_to_model now verifies that all placeholder FakeTensors share the same FakeTensorMode before swapping the ShapeEnv, making the implicit assumption explicit.

  Authored with Claude.
  Description:

  Refactors the ad-hoc shape localization logic in _redistribute_and_adjust_args into a single _localize_shape_arg helper. Factory ops and dynamic-mode view ops both do the same thing: compute a local shape from the global shape in
  node.meta["val"], dividing sharded dims by mesh size, while preserving SymInt values that are already local. The helper unifies this, and the interpreter now reads as a clear sequence of steps:

  1. Factory ops → localize args[0]
  2. Dynamic view ops → localize args[1]
  3. Static view ops → DTensor from_local wrapping

  Authored with Claude.
Here's a summary of the new tests added:

  Unit tests (9 new):

  - TestLocalizeShapeArg (3 tests):
    - test_fully_concrete_shape — all concrete shape args divided by mesh size
    - test_mixed_symint_and_concrete — SymInt preserved, concrete divided
    - test_multi_dim_sharding — same dim sharded on two mesh dims (dp + ep)
  - TestCheckForwardArgsMultiInput (3 tests):
    - test_two_inputs_different_batch_accepted — dynamic batch dims accepted on both inputs
    - test_two_inputs_wrong_feature_rejected — wrong static dim on second input rejected
    - test_arg_count_mismatch_rejected — wrong number of args rejected

  Integration tests (3 new):

  - test_dynamic_apply_placement_view_heavy — view/reshape-heavy model (unflatten → permute → view) with dynamic=True, exercises _localize_shape_arg for view ops
  - test_dynamic_apply_placement_factory_op — model with torch.zeros(x.shape[0], ...) factory op with dynamic=True, exercises _localize_shape_arg for factory ops
  - test_dynamic_vs_static_parity_view_heavy — verifies dynamic=True and dynamic=False produce same ILP placements AND both succeed in apply_placement for the view-heavy model

  Test models (2 new):

  - ViewHeavyModel — linear → unflatten to heads → permute → view back, exercising view/reshape shape localization
  - FactoryOpModel — linear + torch.zeros(batch, seq) factory op with input-dependent shape
The parity test now:
  - Param placements: exact match required (these determine weight distribution)
  - Intermediate nodes: verifies the same set of ops exist in both graphs (by target), but allows the ILP to choose different equivalent-cost strategies for intermediate nodes

  This is the right tradeoff — it catches structural differences (missing ops) while accepting that the ILP is a cost-optimization problem with potentially multiple optimal solutions for intermediate nodes.
  Description:

  With dynamic shapes, the joint graph contains sym_size_int and operator.mul nodes that are not in strats (they produce scalars, not tensors). get_identical_regions was iterating all graph nodes and looking up strategies[node], which raised
  KeyError for these nodes. Same issue in _hash_node which accessed strategies[s] for node.all_input_nodes.

  Fixed by skipping nodes not in strategies during both the clustering loop and the input-node hashing. This is the same filtering pattern used in build_sharding_metadata, _all_input_nodes, and _get_input_nodes.

  Only triggered when repeated_subgraphs=True (graph clustering enabled), which example_llama3.py uses.

  Authored with Claude.
  Description:

  Instead of scattering concretize_symint / concretize_args calls across the
  optimization pipeline (placement options, cost estimation, build_sharding_metadata),
  create an ephemeral concretized copy of the graph module upfront and let the optimizer
  work on it exclusively.

  The change is best read starting with concretize_gm() in optimize_sharding.py,
  then the ShardingOptimizer.__init__ changes that wire it in, then the simplifications
  it enables in build_sharding_metadata, compute_estimation.py, and
  placement_options.py.

  The key design rule: any public ShardingOptimizer method that accepts or returns
  nodes normalizes between original and concrete graphs internally via
  _normalize_node(). Solutions returned by get_solution() / resolve() are
  translated back to original-graph nodes so apply_sharding works unchanged.

  Authored with Claude.
Comment on lines +959 to +960
# Dynamo reorders captured variables (lifted freevars) before explicit
# arguments, so x must come first in the input order and placements.
Copy link
Copy Markdown
Member

@xmfan xmfan Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look, this is only lifted freevars under dynamic shapes, so this order works ONLY for dynamic shapes, and will error with static shapes. I'm looking at a Dynamo fix

fmassa added 3 commits April 18, 2026 15:33
Both fixes were the same root cause — mixing original-graph nodes with concrete-graph-keyed data:

  - test_clustering_high_coverage: _clustering_stats was called with autop.gm.graph (original) but autop.sharding_optimizer.strats (concrete). No original nodes were in strats, so per_layer_total was empty → division by zero. Fixed by using
  autop.sharding_optimizer.graph.
  - test_chosen_param_placement_matches_grad: param_node came from opt.graph (concrete) but solution[param_node] used original-node-keyed solution. Fixed by returning original-node counterparts from the setup function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants