From 5496e6482ec3f3088e72f5f1dc900c2716d80b01 Mon Sep 17 00:00:00 2001 From: Atlas-BountyHunter Date: Fri, 15 May 2026 01:34:35 +0000 Subject: [PATCH] Add MLX handler for aten.flip --- backends/mlx/ops.py | 35 +++++++++++++++++++++++++ backends/mlx/test/test_ops.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 204e45ba341..71e565c14b1 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -1687,6 +1687,41 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.flip.default]) +def _flip_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.flip") + require_kwargs(P.kwargs(n), set(), "aten.flip") + x, dims_arg = args + + dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg) + require_static_ints(dims, "dims", "aten.flip") + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise RuntimeError("aten.flip: missing tensor metadata") + ndim = len(x_meta.shape) + if len(set(d % ndim for d in dims)) != len(dims): + raise ValueError(f"aten.flip: dims must be unique, got {dims}") + + out = P.make_or_get_slot(n) + current = x + for i, dim in enumerate(dims): + reverse_out = out if i == len(dims) - 1 else P.make_tmp_slot()[1] + P.emit( + SliceNode( + x=P.slot_to_tid(current), + out=P.slot_to_tid(reverse_out), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(-1), + stop=P.to_int_or_vid(-(x_meta.shape[dim % ndim] + 1)), + step=-1, + ) + ) + current = reverse_out + return out + + @REGISTRY.register(target=[torch.ops.aten.roll.default]) def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot: args = P.args(n) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index afc45adcc93..9e60e8b0ece 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -855,6 +855,55 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x,) +class FlipModel(nn.Module): + """Model that flips a tensor along specified dimensions.""" + + def __init__(self, dims: Tuple[int, ...]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.flip(x, dims=self.dims) + + +@register_test +class FlipTest(OpTestCase): + """Test case for torch.flip().""" + + name = "flip" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 5), + dims: Tuple[int, ...] = (0,), + ): + self.input_shape = input_shape + self.dims = dims + dim_str = ",".join(str(d) for d in dims) + self.name = f"flip_dim({dim_str})" + + @classmethod + def get_test_configs(cls) -> List["FlipTest"]: + return [ + cls(input_shape=(8,), dims=(0,)), + cls(input_shape=(4, 5), dims=(0,)), + cls(input_shape=(4, 5), dims=(1,)), + cls(input_shape=(3, 4, 5), dims=(2,)), + cls(input_shape=(3, 4, 5), dims=(0, 2)), + cls(input_shape=(3, 4, 5), dims=(0, 1, 2)), + cls(input_shape=(3, 4, 5), dims=(-1,)), + ] + + def create_model(self) -> nn.Module: + return FlipModel(self.dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + class RollModel(nn.Module): """Model that rolls a tensor along specified dimensions."""