From 4efc4e9c679416f641d1400355c5667d8378167c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 19 Apr 2026 08:09:33 +0000 Subject: [PATCH 1/2] =?UTF-8?q?Enable=20view-mm-view=20=E2=86=92=20einsum?= =?UTF-8?q?=20fusion=20by=20default?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR enables the _APPLY_VIEW_MM_VIEW_PATTERN einsum fusion by default, tightens the pattern matchers to prevent false positives, and generalizes the einsum FLOP counter. The recommended review order is: graph_utils.py (pattern matcher changes) → compute_estimation.py (flop counter) → api.py (flag flip) → tests. Einsum fusion PyTorch decomposes 3D-input nn.Linear into view → mm → view, which folds the batch and sequence dimensions into a single axis. This prevents the ILP solver from discovering sequence-parallel strategies since the sequence dimension is invisible in the flattened 2D [B*S, D] representation. The einsum fusion restores the original [B, S, D] shape, giving the solver access to the sequence dimension as a sharding axis. With the default cost model, the einsum fusion produces the same FSDP+TP solution as the mm path (identical compute and comm costs). The benefit materializes when combined with the NCCL cost model, where the solver discovers sequence-parallel strategies for GQA attention that avoid expensive activation all-gathers. ILP overhead (LLaMA3-8B, NCCL cost model, repeated_subgraphs=True): ┌─────────────┬───────────────────┬────────────────────┐ │ │ 2 layers │ 32 layers │ ├─────────────┼───────────────────┼────────────────────┤ │ MM baseline │ 37.8s │ 80.1s │ ├─────────────┼───────────────────┼────────────────────┤ │ Einsum │ 34.4s (9% faster) │ 88.2s (10% slower) │ └─────────────┴───────────────────┴────────────────────┘ At 2 layers, einsum is faster due to fewer graph nodes. At 32 layers the solve is slightly slower (24 strategies per einsum vs 16 per mm), but the clustering algorithm handles the denser strategy space well. Pattern matcher tightening The matchers now verify: - Forward: input view is a canonical flatten [*batch, K] → [prod(batch), K], output view is the matching unflatten, weight shape is [K, N], input rank ≥ 3, and batch dims match between input and output - Backward: both permutes are exactly [1, 0], both views are canonical flattenings, batch dims match - Graph lint (gm.graph.lint()) runs after rewrite The matchers are intentionally conservative with view args — they compare integer shapes directly, which means symbolic shapes won't match. This will be addressed when dynamic=True becomes the default. einsum_flop generalization The FLOP counter now handles arbitrary batch rank (previously only 3D). Forward: (r+1)D × 2D computes prod(batch) * N * K * 2. Backward: (r+1)D × (r+1)D with matching batch dims. seq_nr metadata fix The einsum replacement now copies seq_nr from the mm node (the core compute) rather than the outer view/permute, so forward/backward einsum pairs remain correctly matched by autograd's sequence numbering. Authored with Claude. --- autoparallel/api.py | 2 +- .../cost_models/compute_estimation.py | 38 +++-- autoparallel/graph_passes/graph_utils.py | 85 +++++++++- tests/test_activation_checkpointing.py | 18 ++- tests/test_graph_utils.py | 152 ++++++++++++++++++ tests/test_optimize_placement.py | 45 ++++-- 6 files changed, 297 insertions(+), 43 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 76eb6a84..49854555 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -60,7 +60,7 @@ move_to_fake, ) -_APPLY_VIEW_MM_VIEW_PATTERN = False +_APPLY_VIEW_MM_VIEW_PATTERN = True logger = logging.getLogger(__name__) diff --git a/autoparallel/cost_models/compute_estimation.py b/autoparallel/cost_models/compute_estimation.py index eee031ee..247acaa3 100644 --- a/autoparallel/cost_models/compute_estimation.py +++ b/autoparallel/cost_models/compute_estimation.py @@ -15,29 +15,33 @@ @register_flop_formula(torch.ops.aten.einsum, get_raw=True) def einsum_flop(equation, tensors, out=None, **kwargs) -> int: - # from torch.distributed.tensor._ops._einsum_strategy import EinsumDims assert len(tensors) == 2 a_shape, b_shape = [x.shape for x in tensors] - # parse einop equation and extract dims - # TODO: generalize - # input_dims, output_dim = EinsumDims.parse_equation(equation) - # edims = EinsumDims.parse_dims(input_dims, output_dim) - - if len(a_shape) == 3 and len(b_shape) == 3: - b, m, k = a_shape - b1, n, k2 = b_shape - assert b == b1 - assert m == n - flop = (b * m) * k * k2 * 2 - elif len(a_shape) == 3 and len(b_shape) == 2: - b, m, k = a_shape + if len(b_shape) == 2 and len(a_shape) >= 3: + # Forward linear: "{batch}k,kn->{batch}n" + # a: [*batch, K], b: [K, N] + batch_size = 1 + for d in a_shape[:-1]: + batch_size *= d + k = a_shape[-1] k2, n = b_shape - assert k == k2 - flop = b * m * n * k * 2 + assert k == k2, f"Contracting dim mismatch: {k} vs {k2}" + return batch_size * n * k * 2 + elif len(a_shape) == len(b_shape) and len(a_shape) >= 3: + # Backward gradient-weight: "{batch}n,{batch}k->kn" + # a: [*batch, N], b: [*batch, K] + assert ( + a_shape[:-1] == b_shape[:-1] + ), f"Batch dims mismatch: {a_shape[:-1]} vs {b_shape[:-1]}" + batch_size = 1 + for d in a_shape[:-1]: + batch_size *= d + n = a_shape[-1] + k = b_shape[-1] + return batch_size * n * k * 2 else: raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}") - return flop @dataclass diff --git a/autoparallel/graph_passes/graph_utils.py b/autoparallel/graph_passes/graph_utils.py index 3b395f53..ed6f166e 100644 --- a/autoparallel/graph_passes/graph_utils.py +++ b/autoparallel/graph_passes/graph_utils.py @@ -256,33 +256,82 @@ def _batch_dims(n: int) -> str: return "".join(chr(97 + i) for i in range(n)) +def _is_canonical_flatten(input_shape, view_args): + """Check that view_args represent [*batch, K] -> [prod(batch), K].""" + if len(view_args) != 2: + return False + expected_flat = 1 + for d in input_shape[:-1]: + expected_flat *= d + return view_args[0] == expected_flat and view_args[1] == input_shape[-1] + + +def _is_canonical_unflatten(input_shape, view_args): + """Check that view_args represent [prod(batch), N] -> [*batch, N].""" + if len(view_args) < 3: + return False + batch_dims = view_args[:-1] + expected_flat = 1 + for d in batch_dims: + expected_flat *= d + return expected_flat == input_shape[0] and view_args[-1] == input_shape[-1] + + def _match_forward_linear(mm_node): """Match the forward pattern: view -> mm -> view. + Verifies canonical flatten/unflatten shapes: + input [*batch, K] -> view [prod(batch), K] -> mm [prod(batch), N] -> view [*batch, N] + Returns (inputs, replaced_node, equation) or None. """ first_input, second_input = mm_node.all_input_nodes if first_input.target != torch.ops.aten.view.default: return None view_input = first_input.all_input_nodes[0] + input_shape = view_input.meta["val"].shape + if input_shape.numel() == 0 or len(input_shape) < 3: + return None + # Verify the input view is a canonical flatten [*batch, K] -> [prod(batch), K] + flatten_args = first_input.args[1] + if not _is_canonical_flatten(input_shape, flatten_args): + return None users = list(mm_node.users) if not ( len(users) == 1 and users[0].target == torch.ops.aten.view.default - and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1] and second_input.meta["val"].ndim == 2 ): return None - ndim = view_input.meta["val"].ndim - assert 1 < ndim <= 26, "Only support up to 26D for now" + output_view = users[0] + output_shape = output_view.meta["val"].shape + # Verify the output view is a canonical unflatten [prod(batch), N] -> [*batch, N] + unflatten_args = output_view.args[1] + mm_shape = mm_node.meta["val"].shape + if not _is_canonical_unflatten(mm_shape, unflatten_args): + return None + # Verify batch dims match between input and output + if input_shape[:-1] != output_shape[:-1]: + return None + # Verify weight shape is [K, N] matching the flatten dimensions + weight_shape = second_input.meta["val"].shape + if weight_shape[0] != input_shape[-1] or weight_shape[1] != output_shape[-1]: + return None + ndim = len(input_shape) dims = _batch_dims(ndim - 1) equation = f"{dims}k,kn->{dims}n" - return [view_input, second_input], users[0], equation + return [view_input, second_input], output_view, equation def _match_backward_linear(mm_node): """Match the backward pattern: view -> permute -> mm -> permute. + The backward of einsum "{batch}k,kn->{batch}n" produces a gradient-weight + computation: permute(view(grad, [flat, N]), [1,0]) @ view(x, [flat, K]) -> [N, K], + followed by permute([N, K], [1, 0]) -> [K, N]. + + Verifies canonical flatten shapes and [1,0] permute orders. + Returns (inputs, replaced_node, equation) or None. """ first_input, second_input = mm_node.all_input_nodes @@ -290,20 +339,35 @@ def _match_backward_linear(mm_node): return None if first_input.target != torch.ops.aten.permute.default: return None - if first_input.all_input_nodes[0].target != torch.ops.aten.view.default: + first_view = first_input.all_input_nodes[0] + if first_view.target != torch.ops.aten.view.default: return None - orig_first = first_input.all_input_nodes[0].all_input_nodes[0] + # Verify the input permute is [1, 0] (transpose) + perm_dims = list(first_input.args[1]) + if perm_dims != [1, 0]: + return None + orig_first = first_view.all_input_nodes[0] orig_second = second_input.all_input_nodes[0] + # Verify both views are canonical flattenings + if not _is_canonical_flatten(orig_first.meta["val"].shape, first_view.args[1]): + return None + if not _is_canonical_flatten(orig_second.meta["val"].shape, second_input.args[1]): + return None users = list(mm_node.users) if not ( len(users) == 1 and users[0].target == torch.ops.aten.permute.default - and orig_first.meta["val"].shape[:-1] == orig_second.meta["val"].shape[:-1] and mm_node.meta["val"].ndim == 2 ): return None + # Verify the output permute is [1, 0] + out_perm_dims = list(users[0].args[1]) + if out_perm_dims != [1, 0]: + return None + # Verify batch dims match + if orig_first.meta["val"].shape[:-1] != orig_second.meta["val"].shape[:-1]: + return None ndim = orig_first.meta["val"].ndim - assert 1 < ndim <= 26, "Only support up to 26D for now" dims = _batch_dims(ndim - 1) equation = f"{dims}n,{dims}k->kn" return [orig_first, orig_second], users[0], equation @@ -322,8 +386,13 @@ def _replace_view_mm_view_with_einsum(gm): args=(equation, inputs), ) new_node.meta.update(replaced_node.meta) + # Preserve the mm node's seq_nr so that forward/backward + # einsum pairs remain matched by autograd's sequence numbering. + if "seq_nr" in node.meta: + new_node.meta["seq_nr"] = node.meta["seq_nr"] replaced_node.replace_all_uses_with(new_node) gm.graph.eliminate_dead_code() + gm.graph.lint() gm.recompile() diff --git a/tests/test_activation_checkpointing.py b/tests/test_activation_checkpointing.py index c3ce81c2..2c865ef6 100644 --- a/tests/test_activation_checkpointing.py +++ b/tests/test_activation_checkpointing.py @@ -125,8 +125,17 @@ def _is_inside_checkpointed_fn(node): # --------------------------------------------------------------------------- +def _find_linear_nodes(graph): + """Find mm or einsum nodes (depending on whether einsum fusion is enabled).""" + mm_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) + einsum_nodes = graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + return mm_nodes + einsum_nodes + + def test_user_ac_recompute_tags_on_targeted_ops(device_mesh_1d): - """SDPA ops get MUST_RECOMPUTE and mm ops get MUST_SAVE from user policy.""" + """SDPA ops get MUST_RECOMPUTE and mm/einsum ops get MUST_SAVE from user policy.""" context_fn = functools.partial( create_selective_checkpoint_contexts, _must_save_policy ) @@ -141,9 +150,9 @@ def input_fn(): gm = _build_joint_graph(model, input_fn, device_mesh_1d) - mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) - assert len(mm_nodes) > 0 - for n in mm_nodes: + linear_nodes = _find_linear_nodes(gm.graph) + assert len(linear_nodes) > 0 + for n in linear_nodes: if n.meta.get("partitioner_tag", "") == "is_backward": continue if _is_inside_checkpointed_fn(n): @@ -334,6 +343,7 @@ def test_ac_joint_pass_apply_ac_policy_saves_mm_and_sdpa(device_mesh_1d): save_list = { torch.ops.aten.mm.default, + torch.ops.aten.einsum.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._scaled_dot_product_cudnn_attention.default, diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index c44039a8..c8740012 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -112,3 +112,155 @@ def f(x, w): assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +# --- False-positive rejection tests --- + + +def test_no_match_non_canonical_flatten(): + """Input with fewer than 3 dims should not be matched.""" + K, N = 4, 8 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [6, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [2, 3, N]) + + x = torch.randn(6, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_unflatten_reorders_batch_dims(): + """Canonical flatten but unflatten reorders batch dims: must be rejected. + + Input [2, 3, 4] flattens to [6, 4] (canonical), but the output view + unflattens to [3, 2, N] instead of [2, 3, N]. This would produce a + semantically wrong einsum. + """ + B, S, K, N = 2, 3, 4, 8 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [S, B, N]) + + x = torch.randn(B, S, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_wrong_permute_order(): + """Backward pattern with non-[1,0] permute must be rejected.""" + B, S, K = 2, 8, 16 + + def f(grad_out, x): + flat_grad = torch.ops.aten.view.default(grad_out, [B * S, K]) + # Wrong permute -- [0, 1] is identity, not transpose + perm_grad = torch.ops.aten.permute.default(flat_grad, [0, 1]) + flat_x = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(perm_grad, flat_x) + return torch.ops.aten.permute.default(out, [1, 0]) + + grad_out = torch.randn(B, S, K) + x = torch.randn(B, S, K) + gm = make_fx(f, tracing_mode="fake")(grad_out, x) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_2d_input(): + """2D input (no batch dims) should not be matched.""" + K, N = 16, 32 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [K, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [K, N]) + + x = torch.randn(K, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +# --- Numerical equivalence tests --- + + +def test_forward_numerical_equivalence_3d(): + """Einsum replacement produces the same numerical result as view-mm-view.""" + B, S, K, N = 2, 8, 16, 32 + x = torch.randn(B, S, K) + w = torch.randn(K, N) + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [B, S, N]) + + expected = f(x, w) + + gm = make_fx(f, tracing_mode="fake")(x, w) + _replace_view_mm_view_with_einsum(gm) + actual = gm(x, w) + + torch.testing.assert_close(actual, expected) + + +def test_forward_numerical_equivalence_4d(): + """Einsum replacement produces the same result for 4D input.""" + B, S, T, K, N = 2, 4, 3, 16, 32 + x = torch.randn(B, S, T, K) + w = torch.randn(K, N) + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S * T, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [B, S, T, N]) + + expected = f(x, w) + + gm = make_fx(f, tracing_mode="fake")(x, w) + _replace_view_mm_view_with_einsum(gm) + actual = gm(x, w) + + torch.testing.assert_close(actual, expected) + + +def test_backward_numerical_equivalence(): + """Einsum replacement for the backward gradient-weight pattern.""" + B, S, K, N = 2, 8, 16, 32 + grad_out = torch.randn(B, S, N) + x = torch.randn(B, S, K) + + def f(grad_out, x): + flat_grad = torch.ops.aten.view.default(grad_out, [B * S, N]) + perm_grad = torch.ops.aten.permute.default(flat_grad, [1, 0]) + flat_x = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(perm_grad, flat_x) + return torch.ops.aten.permute.default(out, [1, 0]) + + expected = f(grad_out, x) + + gm = make_fx(f, tracing_mode="fake")(grad_out, x) + _replace_view_mm_view_with_einsum(gm) + actual = gm(grad_out, x) + + torch.testing.assert_close(actual, expected) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index d3131ec3..a2e80eb2 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -125,22 +125,41 @@ def test_optimization_finds_fsdp_and_ddp_1d(device_mesh_1d, high_mem, model_type mm_nodes = autop.gm.graph.find_nodes( op="call_function", target=torch.ops.aten.mm.default ) - len_mm_nodes = {"ffn_with_multiple_input_output": 5, "transformer_block": 18}[ - model_type - ] - len_fwd_mm_nodes = {"ffn_with_multiple_input_output": 2, "transformer_block": 6}[ - model_type - ] - assert len(mm_nodes) == len_mm_nodes - fwd_mm_nodes = mm_nodes[0:len_fwd_mm_nodes] - bwd_mm_grad_weight_nodes = mm_nodes[len_fwd_mm_nodes::2] - bwd_mm_grad_input_nodes = mm_nodes[(len_fwd_mm_nodes + 1) :: 2] + einsum_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + linear_nodes = mm_nodes + einsum_nodes + is_einsum = len(einsum_nodes) > 0 + + if is_einsum: + len_linear_nodes = { + "ffn_with_multiple_input_output": 5, + "transformer_block": 18, + }[model_type] + len_fwd_linear_nodes = { + "ffn_with_multiple_input_output": 2, + "transformer_block": 6, + }[model_type] + else: + len_linear_nodes = { + "ffn_with_multiple_input_output": 5, + "transformer_block": 18, + }[model_type] + len_fwd_linear_nodes = { + "ffn_with_multiple_input_output": 2, + "transformer_block": 6, + }[model_type] + + assert len(linear_nodes) == len_linear_nodes + fwd_linear_nodes = linear_nodes[0:len_fwd_linear_nodes] + bwd_linear_grad_weight_nodes = linear_nodes[len_fwd_linear_nodes::2] + bwd_linear_grad_input_nodes = linear_nodes[(len_fwd_linear_nodes + 1) :: 2] # and check that matmuls have full replication on weights during fwd, # which maps to DDP / FSDP # fwd - for node in fwd_mm_nodes: + for node in fwd_linear_nodes: p = sharding_placement[node] # input and output are sharded on batch assert p.input_specs[0].placements == (Shard(0),) @@ -149,14 +168,14 @@ def test_optimization_finds_fsdp_and_ddp_1d(device_mesh_1d, high_mem, model_type assert p.input_specs[1].placements == (Replicate(),) # bwd grad weight - for node in bwd_mm_grad_weight_nodes: + for node in bwd_linear_grad_weight_nodes: p = sharding_placement[node] assert p.input_specs[0].placements == (Shard(1),) assert p.output_specs.placements == (Partial("sum"),) assert p.input_specs[1].placements == (Shard(0),) # bwd grad inputs - for node in bwd_mm_grad_input_nodes: + for node in bwd_linear_grad_input_nodes: p = sharding_placement[node] assert p.input_specs[0].placements == (Shard(0),) assert p.output_specs.placements == (Shard(0),) From a299465bd636c4ab7edd93326e5578626aad76a4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 19 Apr 2026 09:18:13 +0000 Subject: [PATCH 2/2] Fix test Problem: The einsum fusion changed the backward grad weight tensor layout from 2D ([N, B*S] @ [B*S, K]) to 3D ([B,S,N] @ [B,S,K]), which shifts the batch dimension from position 1 to position 0 in input 0. The test expected Shard(1) unconditionally, but einsum correctly produces Shard(0). Fix: Updated the test at tests/test_optimize_placement.py:170-175 to use Shard(0) for einsum and Shard(1) for mm backward grad weight nodes, since both are semantically equivalent (sharding on the contracted batch dimension) but differ in dimension index due to the tensor rank. --- tests/test_optimize_placement.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index a2e80eb2..e9bdefff 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -168,9 +168,12 @@ def test_optimization_finds_fsdp_and_ddp_1d(device_mesh_1d, high_mem, model_type assert p.input_specs[1].placements == (Replicate(),) # bwd grad weight + # For mm: [N, B*S] @ [B*S, K] → batch dim is at position 1 for input 0 + # For einsum: bsn,bsk->nk → batch dim is at position 0 for both inputs + bwd_grad_weight_shard = (Shard(0),) if is_einsum else (Shard(1),) for node in bwd_linear_grad_weight_nodes: p = sharding_placement[node] - assert p.input_specs[0].placements == (Shard(1),) + assert p.input_specs[0].placements == bwd_grad_weight_shard assert p.output_specs.placements == (Partial("sum"),) assert p.input_specs[1].placements == (Shard(0),)