Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion mlx/backend/metal/kernels/quantized_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp8.h"

enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
enum class QuantMode { Affine, Mxfp4, Mxfp8, Nvfp4, TurboQuant3, TurboQuant4 };

template <typename OutT, typename EncodedT>
struct DecodeValue {
Expand Down Expand Up @@ -72,6 +72,64 @@ struct QuantConfig<QuantMode::Mxfp8> {
using scale_storage_t = uint8_t;
};

// TurboQuant: codebook-based quantization with per-vector float scales.
// Keys/values are packed bit indices; scales are per-vector L2 norms / sqrt(D).
// Codebooks are Lloyd-Max optimal for N(0,1) (distribution of rotated,
// norm-normalized key coordinates scaled by sqrt(D)).
template <>
struct QuantConfig<QuantMode::TurboQuant3> {
static constant constexpr bool has_bias = false;

using value_type = void;
using scale_type = void;

template <typename T>
using scale_storage_t = T;
};

template <>
struct QuantConfig<QuantMode::TurboQuant4> {
static constant constexpr bool has_bias = false;

using value_type = void;
using scale_type = void;

template <typename T>
using scale_storage_t = T;
};

// N(0,1) Lloyd-Max 3-bit codebook (8 reconstruction levels).
// Boundaries: 0, ±0.332, ±0.776, ±1.399 (midpoints of adjacent centroids).
constant float turbo3_codebook[8] = {
-1.7481f,
-1.0498f,
-0.5012f,
-0.1624f,
0.1624f,
0.5012f,
1.0498f,
1.7481f};

// N(0,1) equal-probability 4-bit codebook (16 reconstruction levels).
// Computed as E[X | X in (b_i, b_{i+1})] for N(0,1) with equiprobable bins.
constant float turbo4_codebook[16] = {
-1.9672f,
-1.3305f,
-1.0130f,
-0.7811f,
-0.5714f,
-0.4053f,
-0.2382f,
-0.0784f,
0.0784f,
0.2382f,
0.4053f,
0.5714f,
0.7811f,
1.0130f,
1.3305f,
1.9672f};

template <QuantMode mode, typename T>
struct Dequant {
using Cfg = QuantConfig<mode>;
Expand All @@ -98,6 +156,32 @@ struct Dequant {
}
};

template <typename T>
struct Dequant<QuantMode::TurboQuant3, T> {
[[clang::always_inline]] T raw(uint8_t v) const {
return T(turbo3_codebook[v & 7u]);
}
[[clang::always_inline]] T scale(T s) const {
return s;
}
[[clang::always_inline]] T operator()(uint8_t v, T s, T) const {
return s * raw(v);
}
};

template <typename T>
struct Dequant<QuantMode::TurboQuant4, T> {
[[clang::always_inline]] T raw(uint8_t v) const {
return T(turbo4_codebook[v & 15u]);
}
[[clang::always_inline]] T scale(T s) const {
return s;
}
[[clang::always_inline]] T operator()(uint8_t v, T s, T) const {
return s * raw(v);
}
};

// Pack metadata and unpackers for arbitrary bit-widths (wsize fixed at 32 bits)
template <int bits>
struct PackInfo {
Expand Down
18 changes: 17 additions & 1 deletion mlx/backend/metal/kernels/sdpa_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ struct QuantOps {
static constant constexpr int granularity = is_fast_path ? 4 : pack_factor;
using fast_load_t = metal::conditional_t<bits == 4, uint16_t, uint32_t>;
static constant constexpr uint32_t fast_mask = (1u << bits) - 1;
static_assert(bits == 4 || bits == 6 || bits == 8, "unsupported quant bits");
static_assert(
bits == 3 || bits == 4 || bits == 6 || bits == 8,
"unsupported quant bits");
static_assert(
!is_fast_path || (group_size % 4) == 0,
"group_size must be divisible by 4 for 4/8-bit fast path");
Expand Down Expand Up @@ -659,6 +661,20 @@ template <typename T, int D>
QUANT_SDPA_DISPATCH(Mxfp4, 32, 4)
QUANT_SDPA_DISPATCH(Nvfp4, 16, 4)
QUANT_SDPA_DISPATCH(Mxfp8, 32, 8)
// TurboQuant requires group_size == head_dim (one norm per vector).
// Gate on D to avoid the (D % group_size) == 0 static_assert failing.
if constexpr (D >= 64) {
QUANT_SDPA_DISPATCH(TurboQuant3, 64, 3)
QUANT_SDPA_DISPATCH(TurboQuant4, 64, 4)
}
if constexpr (D >= 128) {
QUANT_SDPA_DISPATCH(TurboQuant3, 128, 3)
QUANT_SDPA_DISPATCH(TurboQuant4, 128, 4)
}
if constexpr (D >= 256) {
QUANT_SDPA_DISPATCH(TurboQuant3, 256, 3)
QUANT_SDPA_DISPATCH(TurboQuant4, 256, 4)
}
#undef QUANT_SDPA_DISPATCH
}

Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,10 @@ int quant_mode_to_int(QuantizationMode mode) {
return 2;
case QuantizationMode::Nvfp4:
return 3;
case QuantizationMode::TurboQuant3:
return 4;
case QuantizationMode::TurboQuant4:
return 5;
default:
throw std::invalid_argument(
"[quant_sdpa_vector_2pass] Unsupported quantization mode.");
Expand Down
61 changes: 58 additions & 3 deletions mlx/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ array quantized_scaled_dot_product_attention(
// Parse mode and get parameters
auto qmode = string_to_quantization_mode(mode, tag);
bool is_affine = qmode == QuantizationMode::Affine;
bool is_turbo = qmode == QuantizationMode::TurboQuant3 ||
qmode == QuantizationMode::TurboQuant4;
auto [group_size, bits] =
quantization_params_from_mode(qmode, group_size_, bits_);

Expand All @@ -919,6 +921,19 @@ array quantized_scaled_dot_product_attention(
"[quantized_scaled_dot_product_attention] Affine mode requires "
"key_biases and value_biases.");
}
} else if (is_turbo) {
int expected_bits = (qmode == QuantizationMode::TurboQuant3) ? 3 : 4;
if (bits != expected_bits) {
std::ostringstream msg;
msg << "[" << tag << "] Mode '" << mode << "' requires bits "
<< expected_bits << ".";
throw std::invalid_argument(msg.str());
}
if (key_biases.has_value() || value_biases.has_value()) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Biases are not supported "
"for TurboQuant modes.");
}
} else {
// FP modes have fixed params - verify if user overrode them incorrectly
auto [expected_gs, expected_bits] =
Expand Down Expand Up @@ -972,12 +987,25 @@ array quantized_scaled_dot_product_attention(
"[quantized_scaled_dot_product_attention] Keys and values must be "
"uint32.");
}
if (!is_affine &&
if (!is_affine && !is_turbo &&
(key_scales.dtype() != uint8 || value_scales.dtype() != uint8)) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Scales must be uint8 for fp "
"quantization.");
}
if (is_turbo) {
auto check_turbo_scale_dtype = [&](const array& s, const char* name) {
if (s.dtype() != float16 && s.dtype() != bfloat16 &&
s.dtype() != float32) {
std::ostringstream msg;
msg << "[" << tag << "] TurboQuant " << name
<< " scales must be float16, bfloat16, or float32.";
throw std::invalid_argument(msg.str());
}
};
check_turbo_scale_dtype(key_scales, "key");
check_turbo_scale_dtype(value_scales, "value");
}

// Compute and validate dimensions
auto key_head_dim = (keys.shape(-1) * 32) / bits;
Expand Down Expand Up @@ -1015,6 +1043,22 @@ array quantized_scaled_dot_product_attention(
<< " must be divisible by group_size " << group_size << ".";
throw std::invalid_argument(msg.str());
}
if (is_turbo) {
if (group_size != queries.shape(-1)) {
std::ostringstream msg;
msg << "[" << tag << "] TurboQuant requires group_size == head_dim,"
<< " got group_size=" << group_size
<< " head_dim=" << queries.shape(-1) << ".";
throw std::invalid_argument(msg.str());
}
auto head_dim = queries.shape(-1);
if (head_dim != 64 && head_dim != 128 && head_dim != 256) {
std::ostringstream msg;
msg << "[" << tag << "] TurboQuant only supports head_dim in {64, 128, 256},"
<< " got " << head_dim << ".";
throw std::invalid_argument(msg.str());
}
}

// Validate scale/bias shapes
auto expected_scale_dim = queries.shape(-1) / group_size;
Expand Down Expand Up @@ -1063,10 +1107,16 @@ array quantized_scaled_dot_product_attention(
has_arr_mask,
has_sinks,
is_affine,
is_turbo,
group_size,
bits,
mode,
s](const std::vector<array>& inputs) {
if (is_turbo) {
throw std::runtime_error(
"[quantized_scaled_dot_product_attention] TurboQuant mode requires "
"a Metal GPU device.");
}
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;

Expand Down Expand Up @@ -1181,12 +1231,17 @@ array quantized_scaled_dot_product_attention(
Shape full_mask_shape{
queries.shape(0), queries.shape(1), queries.shape(2), keys.shape(-2)};

std::vector<array> inputs = {q, keys, key_scales};
// For TurboQuant, scales are per-vector norms stored as floats (not uint8).
// Cast them to final_type so the Metal kernel sees the expected dtype.
auto ks = is_turbo ? astype(key_scales, final_type, stream) : key_scales;
auto vs = is_turbo ? astype(value_scales, final_type, stream) : value_scales;

std::vector<array> inputs = {q, keys, ks};
if (is_affine) {
inputs.push_back(*key_biases);
}
inputs.push_back(values);
inputs.push_back(value_scales);
inputs.push_back(vs);
if (is_affine) {
inputs.push_back(*value_biases);
}
Expand Down
18 changes: 17 additions & 1 deletion mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3425,9 +3425,13 @@ std::string quantization_mode_to_string(QuantizationMode mode) {
case QuantizationMode::Mxfp8:
return "mxfp8";
case QuantizationMode::Nvfp4:
default:
return "nvfp4";
case QuantizationMode::TurboQuant3:
return "turbo3";
case QuantizationMode::TurboQuant4:
return "turbo4";
}
throw std::runtime_error("Unknown quantization mode");
}

QuantizationMode string_to_quantization_mode(
Expand All @@ -3441,6 +3445,10 @@ QuantizationMode string_to_quantization_mode(
return QuantizationMode::Mxfp8;
} else if (mode == "nvfp4") {
return QuantizationMode::Nvfp4;
} else if (mode == "turbo3") {
return QuantizationMode::TurboQuant3;
} else if (mode == "turbo4") {
return QuantizationMode::TurboQuant4;
}
std::string msg;
if (!tag.empty()) {
Expand Down Expand Up @@ -3473,6 +3481,14 @@ std::pair<int, int> quantization_params_from_mode(
default_group_size = 32;
default_bits = 8;
break;
case QuantizationMode::TurboQuant3:
default_group_size = 64;
default_bits = 3;
break;
case QuantizationMode::TurboQuant4:
default_group_size = 64;
default_bits = 4;
break;
}
return {
group_size_.value_or(default_group_size), bits_.value_or(default_bits)};
Expand Down
9 changes: 8 additions & 1 deletion mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,14 @@ class MLX_API UnaryPrimitive : public Primitive {
UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
};

enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
enum class QuantizationMode {
Affine,
Mxfp4,
Mxfp8,
Nvfp4,
TurboQuant3,
TurboQuant4
};

std::string quantization_mode_to_string(QuantizationMode mode);
QuantizationMode string_to_quantization_mode(
Expand Down
2 changes: 1 addition & 1 deletion python/src/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ void init_fast(nb::module_& parent_module) {
sinks (array, optional): An optional array of attention sinks with shape ``[N_q]``.
group_size (int, optional): The group size used in the KV quantization. Defaults follow the quantization ``mode``.
bits (int, optional): The bits used in the KV quantization. Defaults follow the quantization ``mode``.
mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, or ``"affine"``.
mode (str, optional): The quantization mode: ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, ``"affine"``, ``"turbo3"``, or ``"turbo4"``. TurboQuant modes use a WHT rotation + Lloyd-Max codebook and require ``group_size == head_dim``; ``k_biases`` and ``v_biases`` must be ``None``.
causal (bool, optional): Whether to apply lower-right aligned causal masking.
Cannot be used together with ``mask``.
Returns:
Expand Down
Loading