Open
Conversation
This PR enables the _APPLY_VIEW_MM_VIEW_PATTERN einsum fusion by default, tightens the pattern matchers to prevent false positives, and generalizes the einsum FLOP counter. The recommended review order is: graph_utils.py (pattern matcher changes) → compute_estimation.py (flop counter) → api.py (flag flip) → tests. Einsum fusion PyTorch decomposes 3D-input nn.Linear into view → mm → view, which folds the batch and sequence dimensions into a single axis. This prevents the ILP solver from discovering sequence-parallel strategies since the sequence dimension is invisible in the flattened 2D [B*S, D] representation. The einsum fusion restores the original [B, S, D] shape, giving the solver access to the sequence dimension as a sharding axis. With the default cost model, the einsum fusion produces the same FSDP+TP solution as the mm path (identical compute and comm costs). The benefit materializes when combined with the NCCL cost model, where the solver discovers sequence-parallel strategies for GQA attention that avoid expensive activation all-gathers. ILP overhead (LLaMA3-8B, NCCL cost model, repeated_subgraphs=True): ┌─────────────┬───────────────────┬────────────────────┐ │ │ 2 layers │ 32 layers │ ├─────────────┼───────────────────┼────────────────────┤ │ MM baseline │ 37.8s │ 80.1s │ ├─────────────┼───────────────────┼────────────────────┤ │ Einsum │ 34.4s (9% faster) │ 88.2s (10% slower) │ └─────────────┴───────────────────┴────────────────────┘ At 2 layers, einsum is faster due to fewer graph nodes. At 32 layers the solve is slightly slower (24 strategies per einsum vs 16 per mm), but the clustering algorithm handles the denser strategy space well. Pattern matcher tightening The matchers now verify: - Forward: input view is a canonical flatten [*batch, K] → [prod(batch), K], output view is the matching unflatten, weight shape is [K, N], input rank ≥ 3, and batch dims match between input and output - Backward: both permutes are exactly [1, 0], both views are canonical flattenings, batch dims match - Graph lint (gm.graph.lint()) runs after rewrite The matchers are intentionally conservative with view args — they compare integer shapes directly, which means symbolic shapes won't match. This will be addressed when dynamic=True becomes the default. einsum_flop generalization The FLOP counter now handles arbitrary batch rank (previously only 3D). Forward: (r+1)D × 2D computes prod(batch) * N * K * 2. Backward: (r+1)D × (r+1)D with matching batch dims. seq_nr metadata fix The einsum replacement now copies seq_nr from the mm node (the core compute) rather than the outer view/permute, so forward/backward einsum pairs remain correctly matched by autograd's sequence numbering. Authored with Claude.
Problem: The einsum fusion changed the backward grad weight tensor layout from 2D ([N, B*S] @ [B*S, K]) to 3D ([B,S,N] @ [B,S,K]), which shifts the batch dimension from position 1 to position 0 in input 0. The test expected Shard(1) unconditionally, but einsum correctly produces Shard(0). Fix: Updated the test at tests/test_optimize_placement.py:170-175 to use Shard(0) for einsum and Shard(1) for mm backward grad weight nodes, since both are semantically equivalent (sharding on the contracted batch dimension) but differ in dimension index due to the tensor rank.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR enables the
_APPLY_VIEW_MM_VIEW_PATTERNeinsum fusion by default, tightens the pattern matchers to prevent false positives, and generalizes the einsum FLOP counter.The recommended review order is:
graph_utils.py(pattern matcher changes) →compute_estimation.py(flop counter) →api.py(flag flip) → tests.Einsum fusion
PyTorch decomposes 3D-input
nn.Linearintoview → mm → view, which folds the batch and sequence dimensions into a single axis. This prevents the ILP solver from discovering sequence-parallel strategies since the sequence dimension is invisible in the flattened 2D[B*S, D]representation. The einsum fusion restores the original[B, S, D]shape, giving the solver access to the sequence dimension as a sharding axis.With the default cost model, the einsum fusion produces the same FSDP+TP solution as the mm path (identical compute and comm costs). The benefit materializes when combined with the NCCL cost model, where the solver discovers sequence-parallel strategies for GQA attention that avoid expensive activation all-gathers.
ILP overhead (LLaMA3-8B, NCCL cost model, repeated_subgraphs=True):
At 2 layers, einsum is faster due to fewer graph nodes. At 32 layers the solve is slightly slower (24 strategies per einsum vs 16 per mm), but the clustering algorithm handles the denser strategy space well.
Pattern matcher tightening
The matchers now verify:
[*batch, K] → [prod(batch), K], output view is the matching unflatten, weight shape is[K, N], input rank ≥ 3, and batch dims match between input and outputgm.graph.lint()) runs after rewriteThe matchers are intentionally conservative with view args — they compare integer shapes directly, which means symbolic shapes won't match. This will be addressed when
dynamic=Truebecomes the default.einsum_flop generalization
The FLOP counter now handles arbitrary batch rank (previously only 3D). Forward:
(r+1)D × 2Dcomputesprod(batch) * N * K * 2. Backward:(r+1)D × (r+1)Dwith matching batch dims.seq_nr metadata fix
The einsum replacement now copies
seq_nrfrom themmnode (the core compute) rather than the outer view/permute, so forward/backward einsum pairs remain correctly matched by autograd's sequence numbering.Authored with Claude.