diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index bf7b0bafd6..a823915123 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -653,6 +653,9 @@ template QUANT_SDPA_DISPATCH(Affine, 32, 4) QUANT_SDPA_DISPATCH(Affine, 32, 6) QUANT_SDPA_DISPATCH(Affine, 32, 8) + QUANT_SDPA_DISPATCH(Affine, 64, 4) + QUANT_SDPA_DISPATCH(Affine, 64, 6) + QUANT_SDPA_DISPATCH(Affine, 64, 8) QUANT_SDPA_DISPATCH(Mxfp4, 32, 4) QUANT_SDPA_DISPATCH(Nvfp4, 16, 4) QUANT_SDPA_DISPATCH(Mxfp8, 32, 8) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index aa3f88fd36..bddc569b55 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -901,9 +901,9 @@ array quantized_scaled_dot_product_attention( // Validate mode-specific group_size and bits if (is_affine) { - if (group_size != 32) { + if (group_size != 32 && group_size != 64) { std::ostringstream msg; - msg << "[" << tag << "] Affine mode supports group_size 32 " + msg << "[" << tag << "] Affine mode supports group_size 32 or 64 " << "but received " << group_size << "."; throw std::invalid_argument(msg.str()); }