diff --git a/eltwise_support_probe.py b/eltwise_support_probe.py index e977240..21569cf 100644 --- a/eltwise_support_probe.py +++ b/eltwise_support_probe.py @@ -424,7 +424,11 @@ def build_args(base, dtype, layout, mem, device, bcast="none", variant="tt"): if variant == "ts" and base in BINARY: s = _ts_scalar(dtype) p = BPARAMS.get(base, ()) - return [a, s, *p], {}, [a_t, s, *p] + # the device op takes a python scalar, but several torch goldens (maximum, + # minimum, logical_*, ldexp, logaddexp*) reject a python scalar 2nd operand + # -> feed the golden a 0-dim tensor instead (broadcasts identically). + s_ref = torch.tensor(s, dtype=(torch.int32 if dtype in _INT_DTYPES else torch.float32)) + return [a, s, *p], {}, [a_t, s_ref, *p] bsh = _bcast_shape(base, bcast) fi = bcast != "none" # broadcast operands go interleaved (can't shard a sub-tile) if base in TERNARY: