diff --git a/autoparallel/api.py b/autoparallel/api.py index 76eb6a84..49854555 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -60,7 +60,7 @@ move_to_fake, ) -_APPLY_VIEW_MM_VIEW_PATTERN = False +_APPLY_VIEW_MM_VIEW_PATTERN = True logger = logging.getLogger(__name__) diff --git a/autoparallel/cost_models/compute_estimation.py b/autoparallel/cost_models/compute_estimation.py index eee031ee..247acaa3 100644 --- a/autoparallel/cost_models/compute_estimation.py +++ b/autoparallel/cost_models/compute_estimation.py @@ -15,29 +15,33 @@ @register_flop_formula(torch.ops.aten.einsum, get_raw=True) def einsum_flop(equation, tensors, out=None, **kwargs) -> int: - # from torch.distributed.tensor._ops._einsum_strategy import EinsumDims assert len(tensors) == 2 a_shape, b_shape = [x.shape for x in tensors] - # parse einop equation and extract dims - # TODO: generalize - # input_dims, output_dim = EinsumDims.parse_equation(equation) - # edims = EinsumDims.parse_dims(input_dims, output_dim) - - if len(a_shape) == 3 and len(b_shape) == 3: - b, m, k = a_shape - b1, n, k2 = b_shape - assert b == b1 - assert m == n - flop = (b * m) * k * k2 * 2 - elif len(a_shape) == 3 and len(b_shape) == 2: - b, m, k = a_shape + if len(b_shape) == 2 and len(a_shape) >= 3: + # Forward linear: "{batch}k,kn->{batch}n" + # a: [*batch, K], b: [K, N] + batch_size = 1 + for d in a_shape[:-1]: + batch_size *= d + k = a_shape[-1] k2, n = b_shape - assert k == k2 - flop = b * m * n * k * 2 + assert k == k2, f"Contracting dim mismatch: {k} vs {k2}" + return batch_size * n * k * 2 + elif len(a_shape) == len(b_shape) and len(a_shape) >= 3: + # Backward gradient-weight: "{batch}n,{batch}k->kn" + # a: [*batch, N], b: [*batch, K] + assert ( + a_shape[:-1] == b_shape[:-1] + ), f"Batch dims mismatch: {a_shape[:-1]} vs {b_shape[:-1]}" + batch_size = 1 + for d in a_shape[:-1]: + batch_size *= d + n = a_shape[-1] + k = b_shape[-1] + return batch_size * n * k * 2 else: raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}") - return flop @dataclass diff --git a/autoparallel/graph_passes/graph_utils.py b/autoparallel/graph_passes/graph_utils.py index 3b395f53..ed6f166e 100644 --- a/autoparallel/graph_passes/graph_utils.py +++ b/autoparallel/graph_passes/graph_utils.py @@ -256,33 +256,82 @@ def _batch_dims(n: int) -> str: return "".join(chr(97 + i) for i in range(n)) +def _is_canonical_flatten(input_shape, view_args): + """Check that view_args represent [*batch, K] -> [prod(batch), K].""" + if len(view_args) != 2: + return False + expected_flat = 1 + for d in input_shape[:-1]: + expected_flat *= d + return view_args[0] == expected_flat and view_args[1] == input_shape[-1] + + +def _is_canonical_unflatten(input_shape, view_args): + """Check that view_args represent [prod(batch), N] -> [*batch, N].""" + if len(view_args) < 3: + return False + batch_dims = view_args[:-1] + expected_flat = 1 + for d in batch_dims: + expected_flat *= d + return expected_flat == input_shape[0] and view_args[-1] == input_shape[-1] + + def _match_forward_linear(mm_node): """Match the forward pattern: view -> mm -> view. + Verifies canonical flatten/unflatten shapes: + input [*batch, K] -> view [prod(batch), K] -> mm [prod(batch), N] -> view [*batch, N] + Returns (inputs, replaced_node, equation) or None. """ first_input, second_input = mm_node.all_input_nodes if first_input.target != torch.ops.aten.view.default: return None view_input = first_input.all_input_nodes[0] + input_shape = view_input.meta["val"].shape + if input_shape.numel() == 0 or len(input_shape) < 3: + return None + # Verify the input view is a canonical flatten [*batch, K] -> [prod(batch), K] + flatten_args = first_input.args[1] + if not _is_canonical_flatten(input_shape, flatten_args): + return None users = list(mm_node.users) if not ( len(users) == 1 and users[0].target == torch.ops.aten.view.default - and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1] and second_input.meta["val"].ndim == 2 ): return None - ndim = view_input.meta["val"].ndim - assert 1 < ndim <= 26, "Only support up to 26D for now" + output_view = users[0] + output_shape = output_view.meta["val"].shape + # Verify the output view is a canonical unflatten [prod(batch), N] -> [*batch, N] + unflatten_args = output_view.args[1] + mm_shape = mm_node.meta["val"].shape + if not _is_canonical_unflatten(mm_shape, unflatten_args): + return None + # Verify batch dims match between input and output + if input_shape[:-1] != output_shape[:-1]: + return None + # Verify weight shape is [K, N] matching the flatten dimensions + weight_shape = second_input.meta["val"].shape + if weight_shape[0] != input_shape[-1] or weight_shape[1] != output_shape[-1]: + return None + ndim = len(input_shape) dims = _batch_dims(ndim - 1) equation = f"{dims}k,kn->{dims}n" - return [view_input, second_input], users[0], equation + return [view_input, second_input], output_view, equation def _match_backward_linear(mm_node): """Match the backward pattern: view -> permute -> mm -> permute. + The backward of einsum "{batch}k,kn->{batch}n" produces a gradient-weight + computation: permute(view(grad, [flat, N]), [1,0]) @ view(x, [flat, K]) -> [N, K], + followed by permute([N, K], [1, 0]) -> [K, N]. + + Verifies canonical flatten shapes and [1,0] permute orders. + Returns (inputs, replaced_node, equation) or None. """ first_input, second_input = mm_node.all_input_nodes @@ -290,20 +339,35 @@ def _match_backward_linear(mm_node): return None if first_input.target != torch.ops.aten.permute.default: return None - if first_input.all_input_nodes[0].target != torch.ops.aten.view.default: + first_view = first_input.all_input_nodes[0] + if first_view.target != torch.ops.aten.view.default: return None - orig_first = first_input.all_input_nodes[0].all_input_nodes[0] + # Verify the input permute is [1, 0] (transpose) + perm_dims = list(first_input.args[1]) + if perm_dims != [1, 0]: + return None + orig_first = first_view.all_input_nodes[0] orig_second = second_input.all_input_nodes[0] + # Verify both views are canonical flattenings + if not _is_canonical_flatten(orig_first.meta["val"].shape, first_view.args[1]): + return None + if not _is_canonical_flatten(orig_second.meta["val"].shape, second_input.args[1]): + return None users = list(mm_node.users) if not ( len(users) == 1 and users[0].target == torch.ops.aten.permute.default - and orig_first.meta["val"].shape[:-1] == orig_second.meta["val"].shape[:-1] and mm_node.meta["val"].ndim == 2 ): return None + # Verify the output permute is [1, 0] + out_perm_dims = list(users[0].args[1]) + if out_perm_dims != [1, 0]: + return None + # Verify batch dims match + if orig_first.meta["val"].shape[:-1] != orig_second.meta["val"].shape[:-1]: + return None ndim = orig_first.meta["val"].ndim - assert 1 < ndim <= 26, "Only support up to 26D for now" dims = _batch_dims(ndim - 1) equation = f"{dims}n,{dims}k->kn" return [orig_first, orig_second], users[0], equation @@ -322,8 +386,13 @@ def _replace_view_mm_view_with_einsum(gm): args=(equation, inputs), ) new_node.meta.update(replaced_node.meta) + # Preserve the mm node's seq_nr so that forward/backward + # einsum pairs remain matched by autograd's sequence numbering. + if "seq_nr" in node.meta: + new_node.meta["seq_nr"] = node.meta["seq_nr"] replaced_node.replace_all_uses_with(new_node) gm.graph.eliminate_dead_code() + gm.graph.lint() gm.recompile() diff --git a/tests/test_activation_checkpointing.py b/tests/test_activation_checkpointing.py index c3ce81c2..2c865ef6 100644 --- a/tests/test_activation_checkpointing.py +++ b/tests/test_activation_checkpointing.py @@ -125,8 +125,17 @@ def _is_inside_checkpointed_fn(node): # --------------------------------------------------------------------------- +def _find_linear_nodes(graph): + """Find mm or einsum nodes (depending on whether einsum fusion is enabled).""" + mm_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) + einsum_nodes = graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + return mm_nodes + einsum_nodes + + def test_user_ac_recompute_tags_on_targeted_ops(device_mesh_1d): - """SDPA ops get MUST_RECOMPUTE and mm ops get MUST_SAVE from user policy.""" + """SDPA ops get MUST_RECOMPUTE and mm/einsum ops get MUST_SAVE from user policy.""" context_fn = functools.partial( create_selective_checkpoint_contexts, _must_save_policy ) @@ -141,9 +150,9 @@ def input_fn(): gm = _build_joint_graph(model, input_fn, device_mesh_1d) - mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) - assert len(mm_nodes) > 0 - for n in mm_nodes: + linear_nodes = _find_linear_nodes(gm.graph) + assert len(linear_nodes) > 0 + for n in linear_nodes: if n.meta.get("partitioner_tag", "") == "is_backward": continue if _is_inside_checkpointed_fn(n): @@ -334,6 +343,7 @@ def test_ac_joint_pass_apply_ac_policy_saves_mm_and_sdpa(device_mesh_1d): save_list = { torch.ops.aten.mm.default, + torch.ops.aten.einsum.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._scaled_dot_product_cudnn_attention.default, diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index c44039a8..c8740012 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -112,3 +112,155 @@ def f(x, w): assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +# --- False-positive rejection tests --- + + +def test_no_match_non_canonical_flatten(): + """Input with fewer than 3 dims should not be matched.""" + K, N = 4, 8 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [6, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [2, 3, N]) + + x = torch.randn(6, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_unflatten_reorders_batch_dims(): + """Canonical flatten but unflatten reorders batch dims: must be rejected. + + Input [2, 3, 4] flattens to [6, 4] (canonical), but the output view + unflattens to [3, 2, N] instead of [2, 3, N]. This would produce a + semantically wrong einsum. + """ + B, S, K, N = 2, 3, 4, 8 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [S, B, N]) + + x = torch.randn(B, S, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_wrong_permute_order(): + """Backward pattern with non-[1,0] permute must be rejected.""" + B, S, K = 2, 8, 16 + + def f(grad_out, x): + flat_grad = torch.ops.aten.view.default(grad_out, [B * S, K]) + # Wrong permute -- [0, 1] is identity, not transpose + perm_grad = torch.ops.aten.permute.default(flat_grad, [0, 1]) + flat_x = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(perm_grad, flat_x) + return torch.ops.aten.permute.default(out, [1, 0]) + + grad_out = torch.randn(B, S, K) + x = torch.randn(B, S, K) + gm = make_fx(f, tracing_mode="fake")(grad_out, x) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +def test_no_match_2d_input(): + """2D input (no batch dims) should not be matched.""" + K, N = 16, 32 + + def f(x, w): + flat = torch.ops.aten.view.default(x, [K, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [K, N]) + + x = torch.randn(K, K) + w = torch.randn(K, N) + gm = make_fx(f, tracing_mode="fake")(x, w) + + _replace_view_mm_view_with_einsum(gm) + + assert _count_ops(gm, torch.ops.aten.einsum.default) == 0 + assert _count_ops(gm, torch.ops.aten.mm.default) == 1 + + +# --- Numerical equivalence tests --- + + +def test_forward_numerical_equivalence_3d(): + """Einsum replacement produces the same numerical result as view-mm-view.""" + B, S, K, N = 2, 8, 16, 32 + x = torch.randn(B, S, K) + w = torch.randn(K, N) + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [B, S, N]) + + expected = f(x, w) + + gm = make_fx(f, tracing_mode="fake")(x, w) + _replace_view_mm_view_with_einsum(gm) + actual = gm(x, w) + + torch.testing.assert_close(actual, expected) + + +def test_forward_numerical_equivalence_4d(): + """Einsum replacement produces the same result for 4D input.""" + B, S, T, K, N = 2, 4, 3, 16, 32 + x = torch.randn(B, S, T, K) + w = torch.randn(K, N) + + def f(x, w): + flat = torch.ops.aten.view.default(x, [B * S * T, K]) + out = torch.ops.aten.mm.default(flat, w) + return torch.ops.aten.view.default(out, [B, S, T, N]) + + expected = f(x, w) + + gm = make_fx(f, tracing_mode="fake")(x, w) + _replace_view_mm_view_with_einsum(gm) + actual = gm(x, w) + + torch.testing.assert_close(actual, expected) + + +def test_backward_numerical_equivalence(): + """Einsum replacement for the backward gradient-weight pattern.""" + B, S, K, N = 2, 8, 16, 32 + grad_out = torch.randn(B, S, N) + x = torch.randn(B, S, K) + + def f(grad_out, x): + flat_grad = torch.ops.aten.view.default(grad_out, [B * S, N]) + perm_grad = torch.ops.aten.permute.default(flat_grad, [1, 0]) + flat_x = torch.ops.aten.view.default(x, [B * S, K]) + out = torch.ops.aten.mm.default(perm_grad, flat_x) + return torch.ops.aten.permute.default(out, [1, 0]) + + expected = f(grad_out, x) + + gm = make_fx(f, tracing_mode="fake")(grad_out, x) + _replace_view_mm_view_with_einsum(gm) + actual = gm(grad_out, x) + + torch.testing.assert_close(actual, expected) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index d3131ec3..e9bdefff 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -125,22 +125,41 @@ def test_optimization_finds_fsdp_and_ddp_1d(device_mesh_1d, high_mem, model_type mm_nodes = autop.gm.graph.find_nodes( op="call_function", target=torch.ops.aten.mm.default ) - len_mm_nodes = {"ffn_with_multiple_input_output": 5, "transformer_block": 18}[ - model_type - ] - len_fwd_mm_nodes = {"ffn_with_multiple_input_output": 2, "transformer_block": 6}[ - model_type - ] - assert len(mm_nodes) == len_mm_nodes - fwd_mm_nodes = mm_nodes[0:len_fwd_mm_nodes] - bwd_mm_grad_weight_nodes = mm_nodes[len_fwd_mm_nodes::2] - bwd_mm_grad_input_nodes = mm_nodes[(len_fwd_mm_nodes + 1) :: 2] + einsum_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + linear_nodes = mm_nodes + einsum_nodes + is_einsum = len(einsum_nodes) > 0 + + if is_einsum: + len_linear_nodes = { + "ffn_with_multiple_input_output": 5, + "transformer_block": 18, + }[model_type] + len_fwd_linear_nodes = { + "ffn_with_multiple_input_output": 2, + "transformer_block": 6, + }[model_type] + else: + len_linear_nodes = { + "ffn_with_multiple_input_output": 5, + "transformer_block": 18, + }[model_type] + len_fwd_linear_nodes = { + "ffn_with_multiple_input_output": 2, + "transformer_block": 6, + }[model_type] + + assert len(linear_nodes) == len_linear_nodes + fwd_linear_nodes = linear_nodes[0:len_fwd_linear_nodes] + bwd_linear_grad_weight_nodes = linear_nodes[len_fwd_linear_nodes::2] + bwd_linear_grad_input_nodes = linear_nodes[(len_fwd_linear_nodes + 1) :: 2] # and check that matmuls have full replication on weights during fwd, # which maps to DDP / FSDP # fwd - for node in fwd_mm_nodes: + for node in fwd_linear_nodes: p = sharding_placement[node] # input and output are sharded on batch assert p.input_specs[0].placements == (Shard(0),) @@ -149,14 +168,17 @@ def test_optimization_finds_fsdp_and_ddp_1d(device_mesh_1d, high_mem, model_type assert p.input_specs[1].placements == (Replicate(),) # bwd grad weight - for node in bwd_mm_grad_weight_nodes: + # For mm: [N, B*S] @ [B*S, K] → batch dim is at position 1 for input 0 + # For einsum: bsn,bsk->nk → batch dim is at position 0 for both inputs + bwd_grad_weight_shard = (Shard(0),) if is_einsum else (Shard(1),) + for node in bwd_linear_grad_weight_nodes: p = sharding_placement[node] - assert p.input_specs[0].placements == (Shard(1),) + assert p.input_specs[0].placements == bwd_grad_weight_shard assert p.output_specs.placements == (Partial("sum"),) assert p.input_specs[1].placements == (Shard(0),) # bwd grad inputs - for node in bwd_mm_grad_input_nodes: + for node in bwd_linear_grad_input_nodes: p = sharding_placement[node] assert p.input_specs[0].placements == (Shard(0),) assert p.output_specs.placements == (Shard(0),)