diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c8d5a31cb4..3d0deeb992 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1614,6 +1614,17 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto mode = quantization_mode_to_string(mode_); bool w_quantized = (inputs[1].dtype() == uint32); + // Tensor-scale nvfp4 (global_scale_x / global_scale_w) is packed into + // inputs by ops.cpp but no Metal qqmm kernel currently consumes the + // global scales. Reject the request rather than silently dropping them + // in the gemv path below. + int base_size = w_quantized ? 3 : 2; + if (mode_ == QuantizationMode::Nvfp4 && + static_cast(inputs.size()) > base_size) { + throw std::runtime_error( + "[QQMatmul] Global scale (tensor-scale nvfp4) is not supported " + "on the Metal backend."); + } if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(allocator::malloc(out.nbytes())); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b036979044..d73927889e 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -205,6 +205,32 @@ def test_qqmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qqmm_metal_global_scale_rejected(self): + # Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not + # implemented in the Metal qqmm kernels. mx.qqmm must reject the + # request on Metal rather than silently dropping the global scales + # in the gemv path and producing incorrect results. + if not mx.metal.is_available(): + return + + w = mx.random.normal(shape=(64, 64)) + w_q, scales = mx.quantize(w, mode="nvfp4") + x = mx.random.normal(shape=(1, 64)) + gx = mx.array(1.0, dtype=mx.float32) + gw = mx.array(1.0, dtype=mx.float32) + + with self.assertRaises(RuntimeError): + y = mx.qqmm( + x, + w_q, + scales, + mode="nvfp4", + global_scale_x=gx, + global_scale_w=gw, + stream=mx.gpu, + ) + mx.eval(y) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)