Skip to content

Enable view-mm-view → einsum fusion by default#424

Open
fmassa wants to merge 2 commits intomainfrom
fmassa/improve_einsum_pattern
Open

Enable view-mm-view → einsum fusion by default#424
fmassa wants to merge 2 commits intomainfrom
fmassa/improve_einsum_pattern

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Apr 19, 2026

This PR enables the _APPLY_VIEW_MM_VIEW_PATTERN einsum fusion by default, tightens the pattern matchers to prevent false positives, and generalizes the einsum FLOP counter.

The recommended review order is: graph_utils.py (pattern matcher changes) → compute_estimation.py (flop counter) → api.py (flag flip) → tests.

Einsum fusion

PyTorch decomposes 3D-input nn.Linear into view → mm → view, which folds the batch and sequence dimensions into a single axis. This prevents the ILP solver from discovering sequence-parallel strategies since the sequence dimension is invisible in the flattened 2D [B*S, D] representation. The einsum fusion restores the original [B, S, D] shape, giving the solver access to the sequence dimension as a sharding axis.

With the default cost model, the einsum fusion produces the same FSDP+TP solution as the mm path (identical compute and comm costs). The benefit materializes when combined with the NCCL cost model, where the solver discovers sequence-parallel strategies for GQA attention that avoid expensive activation all-gathers.

ILP overhead (LLaMA3-8B, NCCL cost model, repeated_subgraphs=True):

┌─────────────┬───────────────────┬────────────────────┐
│             │     2 layers      │     32 layers      │
├─────────────┼───────────────────┼────────────────────┤
│ MM baseline │ 37.8s             │ 80.1s              │
├─────────────┼───────────────────┼────────────────────┤
│ Einsum      │ 34.4s (9% faster) │ 88.2s (10% slower) │
└─────────────┴───────────────────┴────────────────────┘

At 2 layers, einsum is faster due to fewer graph nodes. At 32 layers the solve is slightly slower (24 strategies per einsum vs 16 per mm), but the clustering algorithm handles the denser strategy space well.

Pattern matcher tightening

The matchers now verify:

  • Forward: input view is a canonical flatten [*batch, K] → [prod(batch), K], output view is the matching unflatten, weight shape is [K, N], input rank ≥ 3, and batch dims match between input and output
  • Backward: both permutes are exactly [1, 0], both views are canonical flattenings, batch dims match
  • Graph lint (gm.graph.lint()) runs after rewrite

The matchers are intentionally conservative with view args — they compare integer shapes directly, which means symbolic shapes won't match. This will be addressed when dynamic=True becomes the default.

einsum_flop generalization

The FLOP counter now handles arbitrary batch rank (previously only 3D). Forward: (r+1)D × 2D computes prod(batch) * N * K * 2. Backward: (r+1)D × (r+1)D with matching batch dims.

seq_nr metadata fix

The einsum replacement now copies seq_nr from the mm node (the core compute) rather than the outer view/permute, so forward/backward einsum pairs remain correctly matched by autograd's sequence numbering.

Authored with Claude.

This PR enables the _APPLY_VIEW_MM_VIEW_PATTERN einsum fusion by default, tightens the pattern matchers to prevent false positives, and generalizes the einsum FLOP counter.

The recommended review order is: graph_utils.py (pattern matcher changes) → compute_estimation.py (flop counter) → api.py (flag flip) → tests.

Einsum fusion

PyTorch decomposes 3D-input nn.Linear into view → mm → view, which folds the batch and sequence dimensions into a single axis. This prevents the ILP solver from discovering sequence-parallel strategies since the sequence dimension is
invisible in the flattened 2D [B*S, D] representation. The einsum fusion restores the original [B, S, D] shape, giving the solver access to the sequence dimension as a sharding axis.

With the default cost model, the einsum fusion produces the same FSDP+TP solution as the mm path (identical compute and comm costs). The benefit materializes when combined with the NCCL cost model, where the solver discovers
sequence-parallel strategies for GQA attention that avoid expensive activation all-gathers.

ILP overhead (LLaMA3-8B, NCCL cost model, repeated_subgraphs=True):

┌─────────────┬───────────────────┬────────────────────┐
│             │     2 layers      │     32 layers      │
├─────────────┼───────────────────┼────────────────────┤
│ MM baseline │ 37.8s             │ 80.1s              │
├─────────────┼───────────────────┼────────────────────┤
│ Einsum      │ 34.4s (9% faster) │ 88.2s (10% slower) │
└─────────────┴───────────────────┴────────────────────┘

At 2 layers, einsum is faster due to fewer graph nodes. At 32 layers the solve is slightly slower (24 strategies per einsum vs 16 per mm), but the clustering algorithm handles the denser strategy space well.

Pattern matcher tightening

The matchers now verify:
- Forward: input view is a canonical flatten [*batch, K] → [prod(batch), K], output view is the matching unflatten, weight shape is [K, N], input rank ≥ 3, and batch dims match between input and output
- Backward: both permutes are exactly [1, 0], both views are canonical flattenings, batch dims match
- Graph lint (gm.graph.lint()) runs after rewrite

The matchers are intentionally conservative with view args — they compare integer shapes directly, which means symbolic shapes won't match. This will be addressed when dynamic=True becomes the default.

einsum_flop generalization

The FLOP counter now handles arbitrary batch rank (previously only 3D). Forward: (r+1)D × 2D computes prod(batch) * N * K * 2. Backward: (r+1)D × (r+1)D with matching batch dims.

seq_nr metadata fix

The einsum replacement now copies seq_nr from the mm node (the core compute) rather than the outer view/permute, so forward/backward einsum pairs remain correctly matched by autograd's sequence numbering.

Authored with Claude.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 19, 2026
Problem: The einsum fusion changed the backward grad weight tensor layout from 2D ([N, B*S] @ [B*S, K]) to 3D ([B,S,N] @ [B,S,K]), which shifts the batch dimension from position 1 to position 0 in input 0. The test expected Shard(1)
  unconditionally, but einsum correctly produces Shard(0).

  Fix: Updated the test at tests/test_optimize_placement.py:170-175 to use Shard(0) for einsum and Shard(1) for mm backward grad weight nodes, since both are semantically equivalent (sharding on the contracted batch dimension) but differ
  in dimension index due to the tensor rank.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant