diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index f59dec2f6..285540cb6 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1805,9 +1805,18 @@ def mish(context, node): inputs = _get_inputs(context, node, expected=1) x = inputs[0] - softplus = mb.softplus(x=x) - tanh = mb.tanh(x=softplus) - res = mb.mul(x=x, y=tanh, name=node.name) + # e = exp(x) + # mish = x / (1 + 2 / (e * (e + 2))) + # Clamp x to avoid -inf producing NaN (exp(-inf)=0 causes 0/0, and -inf/finite=-inf). + # mish(-inf) is mathematically 0; mish(-100) ≈ 0 to full precision. + x = mb.clip(x=x, alpha=-100.0, beta=float("inf")) + e = mb.exp(x=x) + ep2 = mb.add(x=e, y=2.0) + emep2 = mb.mul(x=e, y=ep2) + tdemep2 = mb.real_div(x=2.0, y=emep2) + optdemep2 = mb.add(x=1.0, y=tdemep2) + res = mb.real_div(x=x, y=optdemep2, name=node.name) + context.add(res) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 432a04dc1..267669434 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -6628,6 +6628,51 @@ def test_mish(self, compute_unit, backend, frontend, shape): shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit ) + @pytest.mark.parametrize( + "compute_unit, backend, frontend, scale", + itertools.product(compute_units, backends, frontends, [0.1, 3.5, 11.0]), + ) + def test_mish_stability(self, compute_unit, backend, frontend, scale): + class MishModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same") + self.act = nn.Mish() + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(28 * 28 * 16, 10) + + def forward(self, x): + x = self.act(self.conv1(x)) + x = self.flatten(x) + x = self.fc1(x) + return x + + model = MishModel().eval() + + # Fixed weights: conv weight=1.0, bias=0.0 + # Each interior conv output pixel = sum of 9 input values ≈ 9 * local_value + # Mish input interval ≈ [-9*scale, 9*scale] + # scale=0.1 → mish interval ≈ [-0.9, 0.9] (small values) + # scale=3.5 → mish interval ≈ [-31.5, 31.5] (covers x=-30 regime) + # scale=11.0 → mish interval ≈ [-99, 99] (covers x=-100 regime) + with torch.no_grad(): + model.conv1.weight.fill_(1.0) + model.conv1.bias.fill_(0.0) + model.fc1.weight.fill_(0.01) + model.fc1.bias.fill_(0.0) + + # Fixed input: 28x28 values from -scale to +scale + x = torch.linspace(-scale, scale, 28 * 28).reshape(1, 1, 28, 28) + + TorchBaseTest.run_compare_torch( + x, + model, + input_as_shape=False, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + @pytest.mark.parametrize( "compute_unit, backend, frontend, shape", itertools.product(compute_units, backends, frontends, COMMON_SHAPES_ALL),