Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5734,19 +5734,45 @@ def bitwise_not(context, node):
context.add(x)


@register_torch_op(torch_alias=["and"])
def bitwise_and(context, node):
def _bitwise_as_logical_if_boolean(context, node, op_name, logical_handler):
"""Shared body for bitwise_and/or/xor.

Core ML has no true bitwise op on integers, so we lower to the logical
counterpart whenever at least one operand is bool — which covers the common
"combine boolean masks" pattern in attention/transformer code (where
torch.export may produce a mixed bool/float pair). Pure non-bool inputs are
still rejected so we don't silently change semantics for genuine integer
bitwise math.
"""
inputs = _get_inputs(context, node)

input_dtypes = [i.dtype for i in inputs]
if all(types.is_bool(input_dtype) for input_dtype in input_dtypes):
logical_and(context, node)
if any(types.is_bool(d) for d in input_dtypes):
logical_handler(context, node)
else:
raise NotImplementedError(
f"The `bitwise_and` op only supports boolean input, but get {input_dtypes}."
f"The `{op_name}` op only supports boolean input, but get {input_dtypes}."
)


@register_torch_op(torch_alias=["and"])
def bitwise_and(context, node):
_bitwise_as_logical_if_boolean(context, node, "bitwise_and", logical_and)


# "or" and "xor" cover the post-sanitize form of "aten::__or__" / "aten::__xor__"
# which torch.export emits for `tensor | tensor` / `tensor ^ tensor`. These are
# common when building boolean attention masks (e.g. Gemma combines a causal
# mask with a padding mask via __or__).
@register_torch_op(torch_alias=["or"])
def bitwise_or(context, node):
_bitwise_as_logical_if_boolean(context, node, "bitwise_or", logical_or)


@register_torch_op(torch_alias=["xor"])
def bitwise_xor(context, node):
_bitwise_as_logical_if_boolean(context, node, "bitwise_xor", logical_xor)


@register_torch_op
def logical_not(context, node):
# There is an optional `out` parameter in torch.logical_not.
Expand Down Expand Up @@ -6663,6 +6689,16 @@ def new_zeros(context, node):
context.add(mb.fill(shape=shape, value=0., name=node.name))


@register_torch_op
def new_ones(context, node):
# tensor.new_ones(size) — same shape semantics as new_zeros, value is 1.
# Use _make_fill_op so float-typed shape inputs (which torch.export sometimes
# produces) are coerced to int32 automatically.
inputs = _get_inputs(context, node)
result = _make_fill_op(inputs[1], 1.0, node.name)
context.add(result)


@register_torch_op
def scalar_tensor(context, node):
x = _get_inputs(context, node, expected=[1, 5])[0]
Expand Down Expand Up @@ -7443,7 +7479,7 @@ def _nonzero_as_tuple(context, node, x):
context.add(result, node.name)


@register_torch_op(torch_alias=["where.self"])
@register_torch_op(torch_alias=["where.self", "where.scalarother"])
def where(context, node):
inputs = _get_inputs(context, node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,39 @@
from .. import utils
from ..converter import TranscriptionContext
from ..internal_graph import InternalTorchIRNode
from ..utils import sanitize_op_kind


class TestSanitizeOpKind:
"""Unit tests for the op-name canonicalizer used by both TorchScript and EXIR
frontends. The trickiest case is op overloads whose canonical name only
contains a "__name__" wrapper after the namespace prefix is stripped — e.g.
aten::__or__.Tensor must canonicalize to "or" so it resolves against the
same registry entry as the legacy "__or__" form.
"""

@pytest.mark.parametrize(
"raw, expected",
[
# Already-canonical names round-trip.
("add", "add"),
("logical_or", "logical_or"),
# Legacy double-underscore form (single token).
("__add__", "add"),
("__or__", "or"),
# ATen / overload-suffixed forms.
("aten::add.Tensor", "add"),
("aten::bmm.default", "bmm"),
("aten::pow.Tensor_Scalar", "pow"),
# Dunder hidden behind a namespace + overload suffix — the case that
# used to slip through and produce e.g. "__or__" as the lookup key.
("aten::__or__.Tensor", "or"),
("aten::__and__.Tensor", "and"),
("aten::__xor__.Tensor", "xor"),
],
)
def test_sanitize_op_kind(self, raw, expected):
assert sanitize_op_kind(raw) == expected


class TestTorchOps:
Expand Down
97 changes: 97 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4501,6 +4501,34 @@ def forward(self, x):
)


class TestNewOnes(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, shape",
itertools.product(
compute_units,
backends,
frontends,
[
(1,),
(2, 3),
(1, 1, 2, 5, 1),
],
),
)
def test_new_ones_static(self, compute_unit, backend, frontend, shape):
class OnesStaticModel(nn.Module):
def forward(self, x):
return x.new_ones(x.shape)

self.run_compare_torch(
shape,
OnesStaticModel().eval(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)


class TestNewFull(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, rank",
Expand Down Expand Up @@ -13316,6 +13344,75 @@ def forward(self, x, y):
)


class TestBitwiseOr(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_bitwise_or(self, compute_unit, backend, frontend):
class TestModel(torch.nn.Module):
def forward(self, x, y):
return torch.bitwise_or(x, y)

input_shape = (2, 3)
input_data_x = torch.rand(*input_shape) > 0.2
input_data_y = torch.rand(*input_shape) < 0.8
self.run_compare_torch(
[input_data_x, input_data_y],
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_or_operator(self, compute_unit, backend, frontend):
# Exercises tensor.__or__ (i.e. `x | y`) which sanitizes to "or" and
# is the form torch.export emits when building boolean attention masks.
class TestModel(torch.nn.Module):
def forward(self, x, y):
return x | y

input_shape = (2, 3)
input_data_x = torch.rand(*input_shape) > 0.2
input_data_y = torch.rand(*input_shape) < 0.8
self.run_compare_torch(
[input_data_x, input_data_y],
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
)


class TestBitwiseXor(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_bitwise_xor(self, compute_unit, backend, frontend):
class TestModel(torch.nn.Module):
def forward(self, x, y):
return torch.bitwise_xor(x, y)

input_shape = (2, 3)
input_data_x = torch.rand(*input_shape) > 0.2
input_data_y = torch.rand(*input_shape) < 0.8
self.run_compare_torch(
[input_data_x, input_data_y],
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
input_as_shape=False,
)


class TestUnfold(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, input_shape, is_dynamic_hw, kernel_size, dilation, padding, stride",
Expand Down
6 changes: 6 additions & 0 deletions coremltools/converters/mil/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def skip_default_prefix_and_suffix_with_deliminator(
op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, "::")
op_kind = skip_default_prefix_and_suffix_with_deliminator(op_kind, ".")

# 4. Strip the "__name__" wrapper again. The dunder may only become visible
# after stripping the namespace and overload suffix above, e.g.
# "aten::__or__.Tensor" -> "__or__" -> "or".
if op_kind.startswith("__") and op_kind.endswith("__"):
op_kind = op_kind[2:-2]

return op_kind


Expand Down