Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
move_to_fake,
)

_APPLY_VIEW_MM_VIEW_PATTERN = False
_APPLY_VIEW_MM_VIEW_PATTERN = True

logger = logging.getLogger(__name__)

Expand Down
38 changes: 21 additions & 17 deletions autoparallel/cost_models/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,33 @@

@register_flop_formula(torch.ops.aten.einsum, get_raw=True)
def einsum_flop(equation, tensors, out=None, **kwargs) -> int:
# from torch.distributed.tensor._ops._einsum_strategy import EinsumDims
assert len(tensors) == 2
a_shape, b_shape = [x.shape for x in tensors]

# parse einop equation and extract dims
# TODO: generalize
# input_dims, output_dim = EinsumDims.parse_equation(equation)
# edims = EinsumDims.parse_dims(input_dims, output_dim)

if len(a_shape) == 3 and len(b_shape) == 3:
b, m, k = a_shape
b1, n, k2 = b_shape
assert b == b1
assert m == n
flop = (b * m) * k * k2 * 2
elif len(a_shape) == 3 and len(b_shape) == 2:
b, m, k = a_shape
if len(b_shape) == 2 and len(a_shape) >= 3:
# Forward linear: "{batch}k,kn->{batch}n"
# a: [*batch, K], b: [K, N]
batch_size = 1
for d in a_shape[:-1]:
batch_size *= d
k = a_shape[-1]
k2, n = b_shape
assert k == k2
flop = b * m * n * k * 2
assert k == k2, f"Contracting dim mismatch: {k} vs {k2}"
return batch_size * n * k * 2
elif len(a_shape) == len(b_shape) and len(a_shape) >= 3:
# Backward gradient-weight: "{batch}n,{batch}k->kn"
# a: [*batch, N], b: [*batch, K]
assert (
a_shape[:-1] == b_shape[:-1]
), f"Batch dims mismatch: {a_shape[:-1]} vs {b_shape[:-1]}"
batch_size = 1
for d in a_shape[:-1]:
batch_size *= d
n = a_shape[-1]
k = b_shape[-1]
return batch_size * n * k * 2
else:
raise NotImplementedError(f"Unsupported einsum shapes: {a_shape} {b_shape}")
return flop


@dataclass
Expand Down
85 changes: 77 additions & 8 deletions autoparallel/graph_passes/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,54 +256,118 @@ def _batch_dims(n: int) -> str:
return "".join(chr(97 + i) for i in range(n))


def _is_canonical_flatten(input_shape, view_args):
"""Check that view_args represent [*batch, K] -> [prod(batch), K]."""
if len(view_args) != 2:
return False
expected_flat = 1
for d in input_shape[:-1]:
expected_flat *= d
return view_args[0] == expected_flat and view_args[1] == input_shape[-1]


def _is_canonical_unflatten(input_shape, view_args):
"""Check that view_args represent [prod(batch), N] -> [*batch, N]."""
if len(view_args) < 3:
return False
batch_dims = view_args[:-1]
expected_flat = 1
for d in batch_dims:
expected_flat *= d
return expected_flat == input_shape[0] and view_args[-1] == input_shape[-1]


def _match_forward_linear(mm_node):
"""Match the forward pattern: view -> mm -> view.

Verifies canonical flatten/unflatten shapes:
input [*batch, K] -> view [prod(batch), K] -> mm [prod(batch), N] -> view [*batch, N]

Returns (inputs, replaced_node, equation) or None.
"""
first_input, second_input = mm_node.all_input_nodes
if first_input.target != torch.ops.aten.view.default:
return None
view_input = first_input.all_input_nodes[0]
input_shape = view_input.meta["val"].shape
if input_shape.numel() == 0 or len(input_shape) < 3:
return None
# Verify the input view is a canonical flatten [*batch, K] -> [prod(batch), K]
flatten_args = first_input.args[1]
if not _is_canonical_flatten(input_shape, flatten_args):
return None
users = list(mm_node.users)
if not (
len(users) == 1
and users[0].target == torch.ops.aten.view.default
and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1]
and second_input.meta["val"].ndim == 2
):
return None
ndim = view_input.meta["val"].ndim
assert 1 < ndim <= 26, "Only support up to 26D for now"
output_view = users[0]
output_shape = output_view.meta["val"].shape
# Verify the output view is a canonical unflatten [prod(batch), N] -> [*batch, N]
unflatten_args = output_view.args[1]
mm_shape = mm_node.meta["val"].shape
if not _is_canonical_unflatten(mm_shape, unflatten_args):
return None
# Verify batch dims match between input and output
if input_shape[:-1] != output_shape[:-1]:
return None
# Verify weight shape is [K, N] matching the flatten dimensions
weight_shape = second_input.meta["val"].shape
if weight_shape[0] != input_shape[-1] or weight_shape[1] != output_shape[-1]:
return None
ndim = len(input_shape)
dims = _batch_dims(ndim - 1)
equation = f"{dims}k,kn->{dims}n"
return [view_input, second_input], users[0], equation
return [view_input, second_input], output_view, equation


def _match_backward_linear(mm_node):
"""Match the backward pattern: view -> permute -> mm -> permute.

The backward of einsum "{batch}k,kn->{batch}n" produces a gradient-weight
computation: permute(view(grad, [flat, N]), [1,0]) @ view(x, [flat, K]) -> [N, K],
followed by permute([N, K], [1, 0]) -> [K, N].

Verifies canonical flatten shapes and [1,0] permute orders.

Returns (inputs, replaced_node, equation) or None.
"""
first_input, second_input = mm_node.all_input_nodes
if second_input.target != torch.ops.aten.view.default:
return None
if first_input.target != torch.ops.aten.permute.default:
return None
if first_input.all_input_nodes[0].target != torch.ops.aten.view.default:
first_view = first_input.all_input_nodes[0]
if first_view.target != torch.ops.aten.view.default:
return None
orig_first = first_input.all_input_nodes[0].all_input_nodes[0]
# Verify the input permute is [1, 0] (transpose)
perm_dims = list(first_input.args[1])
if perm_dims != [1, 0]:
return None
orig_first = first_view.all_input_nodes[0]
orig_second = second_input.all_input_nodes[0]
# Verify both views are canonical flattenings
if not _is_canonical_flatten(orig_first.meta["val"].shape, first_view.args[1]):
return None
if not _is_canonical_flatten(orig_second.meta["val"].shape, second_input.args[1]):
return None
users = list(mm_node.users)
if not (
len(users) == 1
and users[0].target == torch.ops.aten.permute.default
and orig_first.meta["val"].shape[:-1] == orig_second.meta["val"].shape[:-1]
and mm_node.meta["val"].ndim == 2
):
return None
# Verify the output permute is [1, 0]
out_perm_dims = list(users[0].args[1])
if out_perm_dims != [1, 0]:
return None
# Verify batch dims match
if orig_first.meta["val"].shape[:-1] != orig_second.meta["val"].shape[:-1]:
return None
ndim = orig_first.meta["val"].ndim
assert 1 < ndim <= 26, "Only support up to 26D for now"
dims = _batch_dims(ndim - 1)
equation = f"{dims}n,{dims}k->kn"
return [orig_first, orig_second], users[0], equation
Expand All @@ -322,8 +386,13 @@ def _replace_view_mm_view_with_einsum(gm):
args=(equation, inputs),
)
new_node.meta.update(replaced_node.meta)
# Preserve the mm node's seq_nr so that forward/backward
# einsum pairs remain matched by autograd's sequence numbering.
if "seq_nr" in node.meta:
new_node.meta["seq_nr"] = node.meta["seq_nr"]
replaced_node.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


Expand Down
18 changes: 14 additions & 4 deletions tests/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,17 @@ def _is_inside_checkpointed_fn(node):
# ---------------------------------------------------------------------------


def _find_linear_nodes(graph):
"""Find mm or einsum nodes (depending on whether einsum fusion is enabled)."""
mm_nodes = graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default)
einsum_nodes = graph.find_nodes(
op="call_function", target=torch.ops.aten.einsum.default
)
return mm_nodes + einsum_nodes


def test_user_ac_recompute_tags_on_targeted_ops(device_mesh_1d):
"""SDPA ops get MUST_RECOMPUTE and mm ops get MUST_SAVE from user policy."""
"""SDPA ops get MUST_RECOMPUTE and mm/einsum ops get MUST_SAVE from user policy."""
context_fn = functools.partial(
create_selective_checkpoint_contexts, _must_save_policy
)
Expand All @@ -141,9 +150,9 @@ def input_fn():

gm = _build_joint_graph(model, input_fn, device_mesh_1d)

mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default)
assert len(mm_nodes) > 0
for n in mm_nodes:
linear_nodes = _find_linear_nodes(gm.graph)
assert len(linear_nodes) > 0
for n in linear_nodes:
if n.meta.get("partitioner_tag", "") == "is_backward":
continue
if _is_inside_checkpointed_fn(n):
Expand Down Expand Up @@ -334,6 +343,7 @@ def test_ac_joint_pass_apply_ac_policy_saves_mm_and_sdpa(device_mesh_1d):

save_list = {
torch.ops.aten.mm.default,
torch.ops.aten.einsum.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
Expand Down
152 changes: 152 additions & 0 deletions tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,155 @@ def f(x, w):

assert _count_ops(gm, torch.ops.aten.einsum.default) == 0
assert _count_ops(gm, torch.ops.aten.mm.default) == 1


# --- False-positive rejection tests ---


def test_no_match_non_canonical_flatten():
"""Input with fewer than 3 dims should not be matched."""
K, N = 4, 8

def f(x, w):
flat = torch.ops.aten.view.default(x, [6, K])
out = torch.ops.aten.mm.default(flat, w)
return torch.ops.aten.view.default(out, [2, 3, N])

x = torch.randn(6, K)
w = torch.randn(K, N)
gm = make_fx(f, tracing_mode="fake")(x, w)

_replace_view_mm_view_with_einsum(gm)

assert _count_ops(gm, torch.ops.aten.einsum.default) == 0
assert _count_ops(gm, torch.ops.aten.mm.default) == 1


def test_no_match_unflatten_reorders_batch_dims():
"""Canonical flatten but unflatten reorders batch dims: must be rejected.

Input [2, 3, 4] flattens to [6, 4] (canonical), but the output view
unflattens to [3, 2, N] instead of [2, 3, N]. This would produce a
semantically wrong einsum.
"""
B, S, K, N = 2, 3, 4, 8

def f(x, w):
flat = torch.ops.aten.view.default(x, [B * S, K])
out = torch.ops.aten.mm.default(flat, w)
return torch.ops.aten.view.default(out, [S, B, N])

x = torch.randn(B, S, K)
w = torch.randn(K, N)
gm = make_fx(f, tracing_mode="fake")(x, w)

_replace_view_mm_view_with_einsum(gm)

assert _count_ops(gm, torch.ops.aten.einsum.default) == 0
assert _count_ops(gm, torch.ops.aten.mm.default) == 1


def test_no_match_wrong_permute_order():
"""Backward pattern with non-[1,0] permute must be rejected."""
B, S, K = 2, 8, 16

def f(grad_out, x):
flat_grad = torch.ops.aten.view.default(grad_out, [B * S, K])
# Wrong permute -- [0, 1] is identity, not transpose
perm_grad = torch.ops.aten.permute.default(flat_grad, [0, 1])
flat_x = torch.ops.aten.view.default(x, [B * S, K])
out = torch.ops.aten.mm.default(perm_grad, flat_x)
return torch.ops.aten.permute.default(out, [1, 0])

grad_out = torch.randn(B, S, K)
x = torch.randn(B, S, K)
gm = make_fx(f, tracing_mode="fake")(grad_out, x)

_replace_view_mm_view_with_einsum(gm)

assert _count_ops(gm, torch.ops.aten.einsum.default) == 0
assert _count_ops(gm, torch.ops.aten.mm.default) == 1


def test_no_match_2d_input():
"""2D input (no batch dims) should not be matched."""
K, N = 16, 32

def f(x, w):
flat = torch.ops.aten.view.default(x, [K, K])
out = torch.ops.aten.mm.default(flat, w)
return torch.ops.aten.view.default(out, [K, N])

x = torch.randn(K, K)
w = torch.randn(K, N)
gm = make_fx(f, tracing_mode="fake")(x, w)

_replace_view_mm_view_with_einsum(gm)

assert _count_ops(gm, torch.ops.aten.einsum.default) == 0
assert _count_ops(gm, torch.ops.aten.mm.default) == 1


# --- Numerical equivalence tests ---


def test_forward_numerical_equivalence_3d():
"""Einsum replacement produces the same numerical result as view-mm-view."""
B, S, K, N = 2, 8, 16, 32
x = torch.randn(B, S, K)
w = torch.randn(K, N)

def f(x, w):
flat = torch.ops.aten.view.default(x, [B * S, K])
out = torch.ops.aten.mm.default(flat, w)
return torch.ops.aten.view.default(out, [B, S, N])

expected = f(x, w)

gm = make_fx(f, tracing_mode="fake")(x, w)
_replace_view_mm_view_with_einsum(gm)
actual = gm(x, w)

torch.testing.assert_close(actual, expected)


def test_forward_numerical_equivalence_4d():
"""Einsum replacement produces the same result for 4D input."""
B, S, T, K, N = 2, 4, 3, 16, 32
x = torch.randn(B, S, T, K)
w = torch.randn(K, N)

def f(x, w):
flat = torch.ops.aten.view.default(x, [B * S * T, K])
out = torch.ops.aten.mm.default(flat, w)
return torch.ops.aten.view.default(out, [B, S, T, N])

expected = f(x, w)

gm = make_fx(f, tracing_mode="fake")(x, w)
_replace_view_mm_view_with_einsum(gm)
actual = gm(x, w)

torch.testing.assert_close(actual, expected)


def test_backward_numerical_equivalence():
"""Einsum replacement for the backward gradient-weight pattern."""
B, S, K, N = 2, 8, 16, 32
grad_out = torch.randn(B, S, N)
x = torch.randn(B, S, K)

def f(grad_out, x):
flat_grad = torch.ops.aten.view.default(grad_out, [B * S, N])
perm_grad = torch.ops.aten.permute.default(flat_grad, [1, 0])
flat_x = torch.ops.aten.view.default(x, [B * S, K])
out = torch.ops.aten.mm.default(perm_grad, flat_x)
return torch.ops.aten.permute.default(out, [1, 0])

expected = f(grad_out, x)

gm = make_fx(f, tracing_mode="fake")(grad_out, x)
_replace_view_mm_view_with_einsum(gm)
actual = gm(grad_out, x)

torch.testing.assert_close(actual, expected)
Loading
Loading