diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 7985935bc..f59dec2f6 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -5734,19 +5734,45 @@ def bitwise_not(context, node): context.add(x) -@register_torch_op(torch_alias=["and"]) -def bitwise_and(context, node): +def _bitwise_as_logical_if_boolean(context, node, op_name, logical_handler): + """Shared body for bitwise_and/or/xor. + + Core ML has no true bitwise op on integers, so we lower to the logical + counterpart whenever at least one operand is bool — which covers the common + "combine boolean masks" pattern in attention/transformer code (where + torch.export may produce a mixed bool/float pair). Pure non-bool inputs are + still rejected so we don't silently change semantics for genuine integer + bitwise math. + """ inputs = _get_inputs(context, node) - input_dtypes = [i.dtype for i in inputs] - if all(types.is_bool(input_dtype) for input_dtype in input_dtypes): - logical_and(context, node) + if any(types.is_bool(d) for d in input_dtypes): + logical_handler(context, node) else: raise NotImplementedError( - f"The `bitwise_and` op only supports boolean input, but get {input_dtypes}." + f"The `{op_name}` op only supports boolean input, but get {input_dtypes}." ) +@register_torch_op(torch_alias=["and"]) +def bitwise_and(context, node): + _bitwise_as_logical_if_boolean(context, node, "bitwise_and", logical_and) + + +# "or" and "xor" cover the post-sanitize form of "aten::__or__" / "aten::__xor__" +# which torch.export emits for `tensor | tensor` / `tensor ^ tensor`. These are +# common when building boolean attention masks (e.g. Gemma combines a causal +# mask with a padding mask via __or__). +@register_torch_op(torch_alias=["or"]) +def bitwise_or(context, node): + _bitwise_as_logical_if_boolean(context, node, "bitwise_or", logical_or) + + +@register_torch_op(torch_alias=["xor"]) +def bitwise_xor(context, node): + _bitwise_as_logical_if_boolean(context, node, "bitwise_xor", logical_xor) + + @register_torch_op def logical_not(context, node): # There is an optional `out` parameter in torch.logical_not. @@ -6663,6 +6689,16 @@ def new_zeros(context, node): context.add(mb.fill(shape=shape, value=0., name=node.name)) +@register_torch_op +def new_ones(context, node): + # tensor.new_ones(size) — same shape semantics as new_zeros, value is 1. + # Use _make_fill_op so float-typed shape inputs (which torch.export sometimes + # produces) are coerced to int32 automatically. + inputs = _get_inputs(context, node) + result = _make_fill_op(inputs[1], 1.0, node.name) + context.add(result) + + @register_torch_op def scalar_tensor(context, node): x = _get_inputs(context, node, expected=[1, 5])[0] @@ -7443,7 +7479,7 @@ def _nonzero_as_tuple(context, node, x): context.add(result, node.name) -@register_torch_op(torch_alias=["where.self"]) +@register_torch_op(torch_alias=["where.self", "where.scalarother"]) def where(context, node): inputs = _get_inputs(context, node) diff --git a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py index f16663864..a8458f6b1 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py +++ b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py @@ -21,6 +21,39 @@ from .. import utils from ..converter import TranscriptionContext from ..internal_graph import InternalTorchIRNode +from ..utils import sanitize_op_kind + + +class TestSanitizeOpKind: + """Unit tests for the op-name canonicalizer used by both TorchScript and EXIR + frontends. The trickiest case is op overloads whose canonical name only + contains a "__name__" wrapper after the namespace prefix is stripped — e.g. + aten::__or__.Tensor must canonicalize to "or" so it resolves against the + same registry entry as the legacy "__or__" form. + """ + + @pytest.mark.parametrize( + "raw, expected", + [ + # Already-canonical names round-trip. + ("add", "add"), + ("logical_or", "logical_or"), + # Legacy double-underscore form (single token). + ("__add__", "add"), + ("__or__", "or"), + # ATen / overload-suffixed forms. + ("aten::add.Tensor", "add"), + ("aten::bmm.default", "bmm"), + ("aten::pow.Tensor_Scalar", "pow"), + # Dunder hidden behind a namespace + overload suffix — the case that + # used to slip through and produce e.g. "__or__" as the lookup key. + ("aten::__or__.Tensor", "or"), + ("aten::__and__.Tensor", "and"), + ("aten::__xor__.Tensor", "xor"), + ], + ) + def test_sanitize_op_kind(self, raw, expected): + assert sanitize_op_kind(raw) == expected class TestTorchOps: diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 83fae65a4..432a04dc1 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -4501,6 +4501,34 @@ def forward(self, x): ) +class TestNewOnes(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, frontend, shape", + itertools.product( + compute_units, + backends, + frontends, + [ + (1,), + (2, 3), + (1, 1, 2, 5, 1), + ], + ), + ) + def test_new_ones_static(self, compute_unit, backend, frontend, shape): + class OnesStaticModel(nn.Module): + def forward(self, x): + return x.new_ones(x.shape) + + self.run_compare_torch( + shape, + OnesStaticModel().eval(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + class TestNewFull(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, rank", @@ -13316,6 +13344,75 @@ def forward(self, x, y): ) +class TestBitwiseOr(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_bitwise_or(self, compute_unit, backend, frontend): + class TestModel(torch.nn.Module): + def forward(self, x, y): + return torch.bitwise_or(x, y) + + input_shape = (2, 3) + input_data_x = torch.rand(*input_shape) > 0.2 + input_data_y = torch.rand(*input_shape) < 0.8 + self.run_compare_torch( + [input_data_x, input_data_y], + TestModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + input_as_shape=False, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_or_operator(self, compute_unit, backend, frontend): + # Exercises tensor.__or__ (i.e. `x | y`) which sanitizes to "or" and + # is the form torch.export emits when building boolean attention masks. + class TestModel(torch.nn.Module): + def forward(self, x, y): + return x | y + + input_shape = (2, 3) + input_data_x = torch.rand(*input_shape) > 0.2 + input_data_y = torch.rand(*input_shape) < 0.8 + self.run_compare_torch( + [input_data_x, input_data_y], + TestModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + input_as_shape=False, + ) + + +class TestBitwiseXor(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_bitwise_xor(self, compute_unit, backend, frontend): + class TestModel(torch.nn.Module): + def forward(self, x, y): + return torch.bitwise_xor(x, y) + + input_shape = (2, 3) + input_data_x = torch.rand(*input_shape) > 0.2 + input_data_y = torch.rand(*input_shape) < 0.8 + self.run_compare_torch( + [input_data_x, input_data_y], + TestModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + input_as_shape=False, + ) + + class TestUnfold(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, input_shape, is_dynamic_hw, kernel_size, dilation, padding, stride", diff --git a/coremltools/converters/mil/frontend/torch/utils.py b/coremltools/converters/mil/frontend/torch/utils.py index ffda8623d..e8bee727d 100644 --- a/coremltools/converters/mil/frontend/torch/utils.py +++ b/coremltools/converters/mil/frontend/torch/utils.py @@ -163,6 +163,12 @@ def skip_default_prefix_and_suffix_with_deliminator( op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, "::") op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, ".") + # 4. Strip the "__name__" wrapper again. The dunder may only become visible + # after stripping the namespace and overload suffix above, e.g. + # "aten::__or__.Tensor" -> "__or__" -> "or". + if op_kind.startswith("__") and op_kind.endswith("__"): + op_kind = op_kind[2:-2] + return op_kind