From eb9a4d0675e016ea704c84b3f18fbb04db555803 Mon Sep 17 00:00:00 2001 From: LongYinan Date: Thu, 14 May 2026 15:24:40 +0800 Subject: [PATCH] [Metal] Reject tensor-scale nvfp4 in qqmm QQMatmul::eval_gpu on Metal silently dropped global_scale_x / global_scale_w in the gemv special case (pre-quantized w, M==1), producing numerically incorrect results when tensor-scale nvfp4 weights were in use. The general case already throws NYI. Add a backend-level guard at the top of QQMatmul::eval_gpu that rejects nvfp4 with global scales packed into inputs, matching the local throw style and keeping the check out of the backend-agnostic ops.cpp. Fixes #3550. --- mlx/backend/metal/quantized.cpp | 11 +++++++++++ python/tests/test_quantized.py | 26 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c8d5a31cb4..3d0deeb992 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1614,6 +1614,17 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto mode = quantization_mode_to_string(mode_); bool w_quantized = (inputs[1].dtype() == uint32); + // Tensor-scale nvfp4 (global_scale_x / global_scale_w) is packed into + // inputs by ops.cpp but no Metal qqmm kernel currently consumes the + // global scales. Reject the request rather than silently dropping them + // in the gemv path below. + int base_size = w_quantized ? 3 : 2; + if (mode_ == QuantizationMode::Nvfp4 && + static_cast(inputs.size()) > base_size) { + throw std::runtime_error( + "[QQMatmul] Global scale (tensor-scale nvfp4) is not supported " + "on the Metal backend."); + } if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(allocator::malloc(out.nbytes())); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index b036979044..d73927889e 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -205,6 +205,32 @@ def test_qqmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qqmm_metal_global_scale_rejected(self): + # Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not + # implemented in the Metal qqmm kernels. mx.qqmm must reject the + # request on Metal rather than silently dropping the global scales + # in the gemv path and producing incorrect results. + if not mx.metal.is_available(): + return + + w = mx.random.normal(shape=(64, 64)) + w_q, scales = mx.quantize(w, mode="nvfp4") + x = mx.random.normal(shape=(1, 64)) + gx = mx.array(1.0, dtype=mx.float32) + gw = mx.array(1.0, dtype=mx.float32) + + with self.assertRaises(RuntimeError): + y = mx.qqmm( + x, + w_q, + scales, + mode="nvfp4", + global_scale_x=gx, + global_scale_w=gw, + stream=mx.gpu, + ) + mx.eval(y) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)