From d0b915b9b1ea50677f94cb1c4529d68b1dee1d50 Mon Sep 17 00:00:00 2001 From: Mykhailo Korobkov Date: Thu, 14 May 2026 10:49:37 +0300 Subject: [PATCH 1/5] fix(gguf): map deepseek_v4/deepseekv4 arch string to DeepSeekV4Arch --- crates/larql-models/src/loading/gguf.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/larql-models/src/loading/gguf.rs b/crates/larql-models/src/loading/gguf.rs index 8903cec3..f7ffc9b1 100644 --- a/crates/larql-models/src/loading/gguf.rs +++ b/crates/larql-models/src/loading/gguf.rs @@ -369,6 +369,7 @@ impl GgufFile { "phi" | "phi2" | "phi3" => "phi", "gpt2" => "gpt2", "deepseek" | "deepseek2" => "deepseek_v2", + "deepseek_v4" | "deepseekv4" => "deepseek_v4", other => other, }; From d93797fe374374e23c0b1d695efdb5db8b0fd22a Mon Sep 17 00:00:00 2001 From: Mykhailo Korobkov Date: Thu, 14 May 2026 11:55:37 +0300 Subject: [PATCH 2/5] feat(gqa): add gqa_attention_asym for MLA-absorbed asymmetric head dims DS-V3 absorbed attention has qk_head_dim=192 (nope=128+rope=64) but v_head_dim=128. The existing gqa_attention uses a single head_dim for all projections, which would corrupt V slicing and output shape. gqa_attention_asym accepts separate qk_head_dim and v_head_dim: - Q/K sliced with qk_head_dim (dot-product stays in the larger space) - V sliced and output written with v_head_dim - Returns (seq, num_q * v_head_dim) When qk_head_dim == v_head_dim the function is numerically identical to gqa_attention (verified by asym_sym_equivalence_when_dims_equal test). 4 tests added: shape, finiteness, sym-equivalence, seq=1 causal. Note: gqa kernels live in larql-compute (post-ADR-0022 Step 2d); this commit places the asym variant alongside the existing gqa_attention there. --- crates/larql-compute/src/attention/gqa.rs | 170 ++++++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/crates/larql-compute/src/attention/gqa.rs b/crates/larql-compute/src/attention/gqa.rs index ae062845..b86d9409 100644 --- a/crates/larql-compute/src/attention/gqa.rs +++ b/crates/larql-compute/src/attention/gqa.rs @@ -254,6 +254,73 @@ fn gqa_attention_capture( (out, weights, all_weights) } +/// GQA with asymmetric Q/K vs V head dimensions — required for MLA-absorbed attention. +/// +/// `qk_head_dim`: head dimension for Q and K (e.g. 192 for DS-V3: nope=128 + rope=64). +/// `v_head_dim`: head dimension for V and the output (e.g. 128 for DS-V3). +/// +/// q: (seq, num_q * qk_head_dim), k: (seq, num_kv * qk_head_dim), v: (seq, num_kv * v_head_dim) +/// Returns: (seq, num_q * v_head_dim) +#[allow(clippy::too_many_arguments)] +pub fn gqa_attention_asym( + q: &Array2, + k: &Array2, + v: &Array2, + num_q: usize, + qk_head_dim: usize, + v_head_dim: usize, + reps: usize, + scale: f64, + seq_len: usize, +) -> Array2 { + let mut out = Array2::::zeros((seq_len, num_q * v_head_dim)); + let scale_f32 = scale as f32; + let mut scores_buf = vec![0.0f32; seq_len]; + + for h in 0..num_q { + let kv_h = h / reps; + let q_off = h * qk_head_dim; + let kv_qk_off = kv_h * qk_head_dim; + let kv_v_off = kv_h * v_head_dim; + let out_off = h * v_head_dim; + + for qi in 0..seq_len { + let causal_len = qi + 1; + let q_row = q.slice(ndarray::s![qi, q_off..q_off + qk_head_dim]); + let k_block = k.slice(ndarray::s![0..causal_len, kv_qk_off..kv_qk_off + qk_head_dim]); + let raw_scores = k_block.dot(&q_row); + + for i in 0..causal_len { + scores_buf[i] = raw_scores[i] * scale_f32; + } + let max_val = scores_buf[..causal_len] + .iter() + .copied() + .fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f64; + for score in scores_buf.iter_mut().take(causal_len) { + let e = ((*score - max_val) as f64).exp(); + *score = e as f32; + sum += e; + } + let inv_sum = (1.0 / sum) as f32; + for score in scores_buf.iter_mut().take(causal_len) { + *score *= inv_sum; + } + + let v_block = v.slice(ndarray::s![0..causal_len, kv_v_off..kv_v_off + v_head_dim]); + for d in 0..v_head_dim { + let mut acc = 0.0f32; + for i in 0..causal_len { + acc += scores_buf[i] * v_block[(i, d)]; + } + out[(qi, out_off + d)] = acc; + } + } + } + out +} + #[cfg(test)] mod tests { use super::*; @@ -657,4 +724,107 @@ mod tests { ); } } + + // ── gqa_attention_asym — MLA-absorbed asymmetric head dims ────────────── + + #[test] + fn asym_output_shape() { + // DS-V3 style: qk_head_dim=6, v_head_dim=4, num_q=2, reps=2 (num_kv=1) + let seq = 3usize; + let qk_hd = 6usize; + let v_hd = 4usize; + let num_q = 2usize; + let reps = 2usize; + let q = small(seq, num_q * qk_hd, 0.01); + let k = small(seq, (num_q / reps) * qk_hd, 0.01); + let v = small(seq, (num_q / reps) * v_hd, 0.01); + let out = gqa_attention_asym( + &q, + &k, + &v, + num_q, + qk_hd, + v_hd, + reps, + 1.0 / (qk_hd as f64).sqrt(), + seq, + ); + assert_eq!( + out.shape(), + &[seq, num_q * v_hd], + "asym output shape should be [seq, num_q * v_head_dim]" + ); + } + + #[test] + fn asym_output_finite() { + let seq = 4usize; + let qk_hd = 8usize; + let v_hd = 6usize; + let num_q = 4usize; + let reps = 2usize; + let num_kv = num_q / reps; + let q = small(seq, num_q * qk_hd, 0.01); + let k = small(seq, num_kv * qk_hd, 0.01); + let v = small(seq, num_kv * v_hd, 0.01); + let out = gqa_attention_asym( + &q, + &k, + &v, + num_q, + qk_hd, + v_hd, + reps, + 1.0 / (qk_hd as f64).sqrt(), + seq, + ); + assert!( + out.iter().all(|x| x.is_finite()), + "asym GQA output has non-finite values" + ); + } + + #[test] + fn asym_sym_equivalence_when_dims_equal() { + // When qk_head_dim == v_head_dim, asym must match sym exactly. + let seq = 3usize; + let hd = 4usize; + let num_q = 2usize; + let reps = 2usize; + let num_kv = num_q / reps; + let q = small(seq, num_q * hd, 0.05); + let k = small(seq, num_kv * hd, 0.05); + let v = small(seq, num_kv * hd, 0.05); + let scale = 1.0 / (hd as f64).sqrt(); + let sym = gqa_attention(&q, &k, &v, num_q, hd, reps, scale, seq); + let asym = gqa_attention_asym(&q, &k, &v, num_q, hd, hd, reps, scale, seq); + for (a, b) in sym.iter().zip(asym.iter()) { + assert!( + (a - b).abs() < 1e-5, + "asym must match sym when dims are equal: {a} vs {b}" + ); + } + } + + #[test] + fn asym_single_token_causal() { + // seq=1 → causal_len=1 everywhere; trivial softmax (weight=1.0) + let seq = 1usize; + let qk_hd = 4usize; + let v_hd = 2usize; + let q = small(seq, qk_hd, 0.1); + let k = small(seq, qk_hd, 0.1); + let v = small(seq, v_hd, 0.1); + let out = + gqa_attention_asym(&q, &k, &v, 1, qk_hd, v_hd, 1, 1.0 / (qk_hd as f64).sqrt(), seq); + // Output must equal V exactly (weight=1 on single token). + let v_row: Vec = v.row(0).to_vec(); + let out_row: Vec = out.row(0).to_vec(); + for (vv, ov) in v_row.iter().zip(out_row.iter()) { + assert!( + (vv - ov).abs() < 1e-5, + "seq=1 output must equal V: {vv} vs {ov}" + ); + } + } } From 9b56cd2ec877feacaf69af3c05f29fb09bbb495c Mon Sep 17 00:00:00 2001 From: Mykhailo Korobkov Date: Thu, 14 May 2026 11:57:48 +0300 Subject: [PATCH 3/5] feat(mla): add qk_nope/rope/v_head_dim fields for DS-V3 MLA absorption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three new optional fields on ModelConfig: qk_nope_head_dim — non-RoPE part of Q/K head dim (DS-V3: 128) qk_rope_head_dim — RoPE-rotated part of Q/K head dim (DS-V3: 64) v_head_dim — V projection head dim (DS-V3: 128) Parsed from config.json (qk_nope_head_dim / qk_rope_head_dim / v_head_dim). Trait accessors added to ModelArchitecture with None defaults. DeepSeekArch overrides to read from config. DS-V3 detection test extended to verify all three fields round-trip. Two GGUF test-only ModelConfig literals updated to include None stubs. --- .../src/architectures/deepseek.rs | 12 +++++++++++ crates/larql-models/src/config.rs | 21 +++++++++++++++++++ crates/larql-models/src/detect/parser.rs | 6 ++++++ crates/larql-models/src/detect/tests.rs | 10 ++++++++- crates/larql-models/src/loading/gguf.rs | 6 ++++++ 5 files changed, 54 insertions(+), 1 deletion(-) diff --git a/crates/larql-models/src/architectures/deepseek.rs b/crates/larql-models/src/architectures/deepseek.rs index 55430e3a..bc6ac8b7 100644 --- a/crates/larql-models/src/architectures/deepseek.rs +++ b/crates/larql-models/src/architectures/deepseek.rs @@ -105,6 +105,18 @@ impl ModelArchitecture for DeepSeekArch { self.config.q_lora_rank.unwrap_or(1536) } + fn mla_qk_nope_head_dim(&self) -> Option { + self.config.qk_nope_head_dim + } + + fn mla_qk_rope_head_dim(&self) -> Option { + self.config.qk_rope_head_dim + } + + fn mla_v_head_dim(&self) -> Option { + self.config.v_head_dim + } + fn mla_kv_a_key(&self, layer: usize) -> Option { Some(format!( "{}self_attn.kv_a_proj_with_mqa.weight", diff --git a/crates/larql-models/src/config.rs b/crates/larql-models/src/config.rs index 3393c4f2..685c5564 100644 --- a/crates/larql-models/src/config.rs +++ b/crates/larql-models/src/config.rs @@ -117,6 +117,12 @@ pub struct ModelConfig { // MLA fields pub kv_lora_rank: Option, pub q_lora_rank: Option, + /// DS-V3 MLA: non-RoPE part of head dim (nope). qk_head_dim = qk_nope_head_dim + qk_rope_head_dim. + pub qk_nope_head_dim: Option, + /// DS-V3 MLA: RoPE part of head dim. + pub qk_rope_head_dim: Option, + /// DS-V3 MLA: V head dim (may differ from qk_nope+rope total). + pub v_head_dim: Option, // RoPE scaling pub rope_scaling: Option, // Softcapping (Gemma2) @@ -791,6 +797,21 @@ pub trait ModelArchitecture: Send + Sync { None } + /// DS-V3 MLA: non-RoPE head dim (nope). Combined qk_head_dim = nope + rope. + fn mla_qk_nope_head_dim(&self) -> Option { + None + } + + /// DS-V3 MLA: RoPE head dim portion. + fn mla_qk_rope_head_dim(&self) -> Option { + None + } + + /// DS-V3 MLA: V head dim (after absorption may differ from qk dims). + fn mla_v_head_dim(&self) -> Option { + None + } + // ── RoPE scaling ── /// RoPE scaling type (None, "linear", "yarn", "dynamic", "llama3"). diff --git a/crates/larql-models/src/detect/parser.rs b/crates/larql-models/src/detect/parser.rs index c8613c36..3a263712 100644 --- a/crates/larql-models/src/detect/parser.rs +++ b/crates/larql-models/src/detect/parser.rs @@ -166,6 +166,9 @@ pub(super) fn parse_model_config(config: &serde_json::Value) -> ModelConfig { // MLA fields let kv_lora_rank = text_config["kv_lora_rank"].as_u64().map(|v| v as usize); let q_lora_rank = text_config["q_lora_rank"].as_u64().map(|v| v as usize); + let qk_nope_head_dim = text_config["qk_nope_head_dim"].as_u64().map(|v| v as usize); + let qk_rope_head_dim = text_config["qk_rope_head_dim"].as_u64().map(|v| v as usize); + let v_head_dim = text_config["v_head_dim"].as_u64().map(|v| v as usize); // RoPE scaling. Four shapes appear in the wild: // @@ -295,6 +298,9 @@ pub(super) fn parse_model_config(config: &serde_json::Value) -> ModelConfig { num_shared_experts, kv_lora_rank, q_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, rope_scaling, attn_logit_softcapping, final_logit_softcapping, diff --git a/crates/larql-models/src/detect/tests.rs b/crates/larql-models/src/detect/tests.rs index fa266547..06f2e12f 100644 --- a/crates/larql-models/src/detect/tests.rs +++ b/crates/larql-models/src/detect/tests.rs @@ -564,7 +564,10 @@ fn test_detect_deepseek_v3() { "num_experts_per_tok": 8, "n_shared_experts": 1, "kv_lora_rank": 512, - "q_lora_rank": 1536 + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128 }); let arch = detect_from_json(&config); @@ -573,6 +576,11 @@ fn test_detect_deepseek_v3() { assert_eq!(arch.num_experts(), 256); assert_eq!(arch.num_experts_per_token(), 8); assert_eq!(arch.num_shared_experts(), 1); + + // MLA geometry fields + assert_eq!(arch.mla_qk_nope_head_dim(), Some(128)); + assert_eq!(arch.mla_qk_rope_head_dim(), Some(64)); + assert_eq!(arch.mla_v_head_dim(), Some(128)); } #[test] diff --git a/crates/larql-models/src/loading/gguf.rs b/crates/larql-models/src/loading/gguf.rs index f7ffc9b1..cfcd124b 100644 --- a/crates/larql-models/src/loading/gguf.rs +++ b/crates/larql-models/src/loading/gguf.rs @@ -1019,6 +1019,9 @@ mod tests { moe_intermediate_size: None, kv_lora_rank: None, q_lora_rank: None, + qk_nope_head_dim: None, + qk_rope_head_dim: None, + v_head_dim: None, rope_scaling: None, attn_logit_softcapping: None, final_logit_softcapping: None, @@ -1158,6 +1161,9 @@ mod tests { moe_intermediate_size: None, kv_lora_rank: None, q_lora_rank: None, + qk_nope_head_dim: None, + qk_rope_head_dim: None, + v_head_dim: None, rope_scaling: None, attn_logit_softcapping: None, final_logit_softcapping: None, From 2a1fc079a2e1da08764e62c8f543811b6be1cc1b Mon Sep 17 00:00:00 2001 From: Mykhailo Korobkov Date: Thu, 14 May 2026 12:05:00 +0300 Subject: [PATCH 4/5] =?UTF-8?q?feat(vindex):=20MLA=20absorption=20?= =?UTF-8?q?=E2=80=94=20fuse=20DS-V3=20low-rank=20Q/K/V=20into=20dense=20we?= =?UTF-8?q?ight=20matrices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements `mla_absorb::absorb()` which converts the four MLA weight matrices (kv_a, kv_b, q_a, q_b) into standard dense Q/K/V tensors compatible with `gqa_attention_asym`. Key correctness points: - rope-K is MQA: single row in kv_a[kv_lora..] replicated num_kv times in absorbed K (not per-head in the input tensor) - DS-V3 native per-head layout [nope|rope] → LARQL convention [rope|nope] applied symmetrically to Q and K during absorption - V: straightforward kv_b[nope+v_hd slice] @ kv_compress Three tests (3 passed): - absorbed_forward_matches_reference: reference MLA forward vs absorbed path through gqa_attention_asym must match within 1e-4 - absorbed_shapes: output tensor dimensions - rope_k_is_broadcast_not_zero: single rope-K correctly replicated across heads --- crates/larql-vindex/Cargo.toml | 1 + .../src/format/weights/mla_absorb.rs | 327 ++++++++++++++++++ crates/larql-vindex/src/format/weights/mod.rs | 1 + 3 files changed, 329 insertions(+) create mode 100644 crates/larql-vindex/src/format/weights/mla_absorb.rs diff --git a/crates/larql-vindex/Cargo.toml b/crates/larql-vindex/Cargo.toml index a5957e67..7823eb19 100644 --- a/crates/larql-vindex/Cargo.toml +++ b/crates/larql-vindex/Cargo.toml @@ -65,6 +65,7 @@ default = [] gpu = ["dep:larql-compute-metal"] [dev-dependencies] +larql-inference = { path = "../larql-inference" } criterion = "0.5" tempfile = "3" # HTTP mocking for the publish trio (lfs / remote / upload). diff --git a/crates/larql-vindex/src/format/weights/mla_absorb.rs b/crates/larql-vindex/src/format/weights/mla_absorb.rs new file mode 100644 index 00000000..c2cb5524 --- /dev/null +++ b/crates/larql-vindex/src/format/weights/mla_absorb.rs @@ -0,0 +1,327 @@ +/// MLA absorption — fuse DS-V3 low-rank attention projections into standard Q/K/V tensors. +/// +/// DS-V3 stores attention as four weight matrices: +/// +/// kv_a shape (kv_lora_rank + qk_rope, hidden) — KV compressor + shared RoPE-K +/// kv_b shape (num_kv*(qk_nope+v_hd), kv_lora) — KV decompressor (K_nope interleaved with V) +/// q_a shape (q_lora, hidden) — Q compressor +/// q_b shape (num_q*qk_head_dim, q_lora) — Q decompressor +/// +/// After absorption the caller obtains three dense tensors in LARQL convention: +/// +/// Q shape (num_q * qk_head_dim, hidden) per-head layout [rope | nope] +/// K shape (num_kv * qk_head_dim, hidden) per-head layout [rope | nope] (rope replicated) +/// V shape (num_kv * v_head_dim, hidden) +/// +/// The absorbed tensors feed directly into `gqa_attention_asym` because they have the +/// asymmetric qk_head_dim / v_head_dim that function expects. +use ndarray::{Array2, ArrayView2, s}; + +pub struct MlaGeometry { + pub num_q: usize, + pub num_kv: usize, + pub qk_nope: usize, + pub qk_rope: usize, + pub v_hd: usize, + pub kv_lora: usize, + pub q_lora: usize, +} + +impl MlaGeometry { + pub fn qk_head_dim(&self) -> usize { + self.qk_nope + self.qk_rope + } +} + +/// Absorb MLA projections into standard dense Q/K/V weight matrices. +/// +/// Returns `(Q, K, V)` with shapes as documented above. +/// +/// # Panics +/// Panics on shape mismatch (programming error, not runtime input error). +pub fn absorb( + kv_a: &Array2, + kv_b: &Array2, + q_a: &Array2, + q_b: &Array2, + g: &MlaGeometry, +) -> (Array2, Array2, Array2) { + let MlaGeometry { + num_q, + num_kv, + qk_nope, + qk_rope, + v_hd, + kv_lora, + q_lora, + } = *g; + let qk_head_dim = qk_nope + qk_rope; + let hidden = kv_a.ncols(); + + // Dimension assertions + assert_eq!( + kv_a.nrows(), + kv_lora + qk_rope, + "kv_a rows = kv_lora + qk_rope (MQA: single rope-K)" + ); + assert_eq!( + kv_b.nrows(), + num_kv * (qk_nope + v_hd), + "kv_b rows = num_kv * (qk_nope + v_hd)" + ); + assert_eq!(kv_b.ncols(), kv_lora); + assert_eq!(q_a.nrows(), q_lora); + assert_eq!(q_a.ncols(), hidden); + assert_eq!(q_b.nrows(), num_q * qk_head_dim); + assert_eq!(q_b.ncols(), q_lora); + + let kv_compress: ArrayView2 = kv_a.slice(s![..kv_lora, ..]); + // MQA: single rope-K row shared across all KV heads + let k_rope_row: ArrayView2 = kv_a.slice(s![kv_lora.., ..]); + assert_eq!(k_rope_row.nrows(), qk_rope); + + // ── Q ────────────────────────────────────────────────────────────────── + // Absorbed Q = q_b @ q_a (shape: num_q*qk_head_dim × hidden) + // DS-V3 native per-head layout: [nope_dims | rope_dims] + // LARQL convention: [rope_dims | nope_dims] — swap within each head + let q_native = q_b.dot(q_a); // (num_q*qk_head_dim, hidden) + let mut q_out = Array2::::zeros((num_q * qk_head_dim, hidden)); + for h in 0..num_q { + let src_base = h * qk_head_dim; + let dst_base = h * qk_head_dim; + // rope part: native[nope..qk_head_dim] → dst[0..qk_rope] + q_out + .slice_mut(s![dst_base..dst_base + qk_rope, ..]) + .assign(&q_native.slice(s![src_base + qk_nope..src_base + qk_head_dim, ..])); + // nope part: native[0..qk_nope] → dst[qk_rope..qk_head_dim] + q_out + .slice_mut(s![dst_base + qk_rope..dst_base + qk_head_dim, ..]) + .assign(&q_native.slice(s![src_base..src_base + qk_nope, ..])); + } + + // ── K ────────────────────────────────────────────────────────────────── + // K_nope[h] = kv_b[h*(nope+v_hd) .. h*(nope+v_hd)+nope, :] @ kv_compress → (qk_nope, hidden) + // K_rope = k_rope_row @ identity → (qk_rope, hidden) shared + // Per head, LARQL layout: [rope_dims | nope_dims] + let k_rope_dense = k_rope_row.dot(&Array2::eye(hidden)); // (qk_rope, hidden) + let mut k_out = Array2::::zeros((num_kv * qk_head_dim, hidden)); + for h in 0..num_kv { + let kv_base = h * (qk_nope + v_hd); + let dst_base = h * qk_head_dim; + // rope first (broadcast single MQA rope-K) + k_out + .slice_mut(s![dst_base..dst_base + qk_rope, ..]) + .assign(&k_rope_dense); + // nope: absorb + let k_nope_h = kv_b + .slice(s![kv_base..kv_base + qk_nope, ..]) + .dot(&kv_compress); + k_out + .slice_mut(s![dst_base + qk_rope..dst_base + qk_head_dim, ..]) + .assign(&k_nope_h); + } + + // ── V ────────────────────────────────────────────────────────────────── + let mut v_out = Array2::::zeros((num_kv * v_hd, hidden)); + for h in 0..num_kv { + let kv_base = h * (qk_nope + v_hd); + let dst_base = h * v_hd; + let v_h = kv_b + .slice(s![kv_base + qk_nope..kv_base + qk_nope + v_hd, ..]) + .dot(&kv_compress); + v_out + .slice_mut(s![dst_base..dst_base + v_hd, ..]) + .assign(&v_h); + } + + (q_out, k_out, v_out) +} + +#[cfg(test)] +mod tests { + use super::*; + use larql_inference::attention::gqa::gqa_attention_asym; + use ndarray::Array2; + + fn randn(rows: usize, cols: usize, seed: u64) -> Array2 { + // Simple deterministic "random" via LCG + let mut state = seed; + let data: Vec = (0..rows * cols) + .map(|_| { + state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + let bits = (state >> 33) as u32; + (bits as f32 / u32::MAX as f32) * 2.0 - 1.0 + }) + .collect(); + Array2::from_shape_vec((rows, cols), data).unwrap() + } + + /// Reference MLA forward pass (matches DS-V3 math, rope-first in output). + /// + /// Returns (q, k, v) projected activations for a single token x (shape 1×hidden). + fn mla_reference_forward( + x: &Array2, + kv_a: &Array2, + kv_b: &Array2, + q_a: &Array2, + q_b: &Array2, + g: &MlaGeometry, + ) -> (Array2, Array2, Array2) { + let MlaGeometry { + num_q, + num_kv, + qk_nope, + qk_rope, + v_hd, + kv_lora, + .. + } = *g; + let qk_head_dim = qk_nope + qk_rope; + let seq = x.nrows(); + let _hidden = x.ncols(); + + // KV latent and shared rope-K + let kv_latent = x.dot(&kv_a.slice(s![..kv_lora, ..]).t()); // (seq, kv_lora) + let k_rope_global = x.dot(&kv_a.slice(s![kv_lora.., ..]).t()); // (seq, qk_rope) + + // Q: compress → decompress → reorder rope-first + let q_latent = x.dot(&q_a.t()); // (seq, q_lora) + let q_native = q_latent.dot(&q_b.t()); // (seq, num_q*qk_head_dim) + let mut q_out = Array2::::zeros((seq, num_q * qk_head_dim)); + for h in 0..num_q { + let src_base = h * qk_head_dim; + let dst_base = h * qk_head_dim; + // rope-first + q_out + .slice_mut(s![.., dst_base..dst_base + qk_rope]) + .assign(&q_native.slice(s![.., src_base + qk_nope..src_base + qk_head_dim])); + q_out + .slice_mut(s![.., dst_base + qk_rope..dst_base + qk_head_dim]) + .assign(&q_native.slice(s![.., src_base..src_base + qk_nope])); + } + + // K: nope absorbed, rope replicated, rope-first + let mut k_out = Array2::::zeros((seq, num_kv * qk_head_dim)); + for h in 0..num_kv { + let kv_base = h * (qk_nope + v_hd); + let dst_base = h * qk_head_dim; + // rope (broadcast single shared K_rope) + k_out + .slice_mut(s![.., dst_base..dst_base + qk_rope]) + .assign(&k_rope_global); + // nope + let k_nope_h = kv_latent + .dot(&kv_b.slice(s![kv_base..kv_base + qk_nope, ..]).t()); + k_out + .slice_mut(s![.., dst_base + qk_rope..dst_base + qk_head_dim]) + .assign(&k_nope_h); + } + + // V + let mut v_out = Array2::::zeros((seq, num_kv * v_hd)); + for h in 0..num_kv { + let kv_base = h * (qk_nope + v_hd); + let dst_base = h * v_hd; + let v_h = kv_latent.dot(&kv_b.slice(s![kv_base + qk_nope..kv_base + qk_nope + v_hd, ..]).t()); + v_out + .slice_mut(s![.., dst_base..dst_base + v_hd]) + .assign(&v_h); + } + + (q_out, k_out, v_out) + } + + fn geometry() -> MlaGeometry { + MlaGeometry { + num_q: 4, + num_kv: 2, + qk_nope: 4, + qk_rope: 2, + v_hd: 4, + kv_lora: 8, + q_lora: 8, + } + } + + fn weights(g: &MlaGeometry) -> (Array2, Array2, Array2, Array2) { + let hidden = 16; + let qk_head_dim = g.qk_head_dim(); + // kv_a: (kv_lora + qk_rope, hidden) — MQA: one shared rope-K + let kv_a = randn(g.kv_lora + g.qk_rope, hidden, 1); + // kv_b: (num_kv*(qk_nope+v_hd), kv_lora) + let kv_b = randn(g.num_kv * (g.qk_nope + g.v_hd), g.kv_lora, 2); + let q_a = randn(g.q_lora, hidden, 3); + let q_b = randn(g.num_q * qk_head_dim, g.q_lora, 4); + (kv_a, kv_b, q_a, q_b) + } + + #[test] + fn absorbed_forward_matches_reference() { + let g = geometry(); + let (kv_a, kv_b, q_a, q_b) = weights(&g); + let hidden = 16usize; + let seq = 3usize; + + // Compute absorbed weight matrices + let (w_q, w_k, w_v) = absorb(&kv_a, &kv_b, &q_a, &q_b, &g); + + // Random input sequence + let x = randn(seq, hidden, 99); + + // Reference path: project each token through MLA, then run gqa_attention_asym + let (q_ref, k_ref, v_ref) = mla_reference_forward(&x, &kv_a, &kv_b, &q_a, &q_b, &g); + let qk_head_dim = g.qk_head_dim(); + let reps = g.num_q / g.num_kv; + let scale = 1.0 / (qk_head_dim as f64).sqrt(); + let ref_out = gqa_attention_asym(&q_ref, &k_ref, &v_ref, g.num_q, qk_head_dim, g.v_hd, reps, scale, seq); + + // Absorbed path: project through absorbed weight matrices, then run gqa_attention_asym + let q_abs = x.dot(&w_q.t()); + let k_abs = x.dot(&w_k.t()); + let v_abs = x.dot(&w_v.t()); + let abs_out = gqa_attention_asym(&q_abs, &k_abs, &v_abs, g.num_q, qk_head_dim, g.v_hd, reps, scale, seq); + + // Must match numerically (within float precision) + let max_diff = ref_out + .iter() + .zip(abs_out.iter()) + .map(|(&a, &b)| (a - b).abs()) + .fold(0.0f32, f32::max); + assert!( + max_diff < 1e-4, + "absorbed forward must match reference, max_diff={max_diff}" + ); + } + + #[test] + fn absorbed_shapes() { + let g = geometry(); + let (kv_a, kv_b, q_a, q_b) = weights(&g); + let hidden = 16usize; + let qk_head_dim = g.qk_head_dim(); + let (w_q, w_k, w_v) = absorb(&kv_a, &kv_b, &q_a, &q_b, &g); + assert_eq!(w_q.shape(), &[g.num_q * qk_head_dim, hidden]); + assert_eq!(w_k.shape(), &[g.num_kv * qk_head_dim, hidden]); + assert_eq!(w_v.shape(), &[g.num_kv * g.v_hd, hidden]); + } + + #[test] + fn rope_k_is_broadcast_not_zero() { + // The absorbed K rope section for each KV head must be non-zero + // and identical across heads (proving the broadcast replicated correctly). + let g = geometry(); + let (kv_a, kv_b, q_a, q_b) = weights(&g); + let (_, w_k, _) = absorb(&kv_a, &kv_b, &q_a, &q_b, &g); + let qk_head_dim = g.qk_head_dim(); + let head0_rope: Vec = w_k.slice(s![..g.qk_rope, ..]).iter().copied().collect(); + let head1_rope: Vec = w_k + .slice(s![qk_head_dim..qk_head_dim + g.qk_rope, ..]) + .iter() + .copied() + .collect(); + assert!(head0_rope.iter().any(|v| v.abs() > 1e-6), "rope-K must be non-zero"); + for (a, b) in head0_rope.iter().zip(head1_rope.iter()) { + assert!((a - b).abs() < 1e-6, "rope-K must be identical across heads: {a} vs {b}"); + } + } +} diff --git a/crates/larql-vindex/src/format/weights/mod.rs b/crates/larql-vindex/src/format/weights/mod.rs index d2c9be18..41299490 100644 --- a/crates/larql-vindex/src/format/weights/mod.rs +++ b/crates/larql-vindex/src/format/weights/mod.rs @@ -18,6 +18,7 @@ mod capabilities; pub mod load; pub mod manifest; +pub mod mla_absorb; mod ple_sidecar; pub mod write_f32; pub mod write_kquant; From 2d10daa8213d99ecb8105faa8120c37e01b62664 Mon Sep 17 00:00:00 2001 From: Mykhailo Korobkov Date: Thu, 14 May 2026 12:08:49 +0300 Subject: [PATCH 5/5] feat(vindex): wire MLA absorption into f32 weight writer write_model_weights_with_opts now accepts DS-V3 / MLA architectures when all three geometry fields (qk_nope_head_dim, qk_rope_head_dim, v_head_dim) are present in config.json. When detected: - skips the standard-attention guard - per layer: fetches kv_a/kv_b/q_a/q_b projections, calls mla_absorb::absorb, writes the resulting dense Q/K/V under the standard attn_q/k/v key names - O projection is passed through unchanged (no absorption needed) The loader remains MLA-unaware: it reads standard Q/K/V tensors just as for any Llama/Mistral model. The extra storage cost (absorbed K replicates the MQA rope-K row num_kv times) is acceptable for DS-V3 full scale (~3.5 GB extra per 61 layers on num_kv=128). All 971 larql-vindex unit + integration tests pass. --- .github/workflows/larql-vindex.yml | 9 ++ Cargo.lock | 1 + crates/larql-compute/src/attention/gqa.rs | 18 ++- .../larql-models/src/architectures/gemma3.rs | 3 + .../tests/test_expert_endpoint.rs | 3 + .../src/format/weights/mla_absorb.rs | 49 +++++-- .../src/format/weights/write_f32.rs | 133 ++++++++++++++++-- 7 files changed, 195 insertions(+), 21 deletions(-) diff --git a/.github/workflows/larql-vindex.yml b/.github/workflows/larql-vindex.yml index 3c5fd2e7..0b105105 100644 --- a/.github/workflows/larql-vindex.yml +++ b/.github/workflows/larql-vindex.yml @@ -67,6 +67,15 @@ jobs: "VCPKG_ROOT=$vcpkgRoot" | Out-File -FilePath $env:GITHUB_ENV -Append & "$vcpkgRoot\vcpkg.exe" install openblas:x64-windows + # protoc on PATH satisfies the `cfg(not(windows))` skip in + # `larql-router-protocol`'s build.rs. Required by larql-vindex's + # dev-dep on `larql-inference` → transitive on `larql-router-protocol` + # (added in #96 for MLA absorption integration tests). + - name: Install protoc (Windows) + if: runner.os == 'Windows' + shell: pwsh + run: choco install protoc -y --no-progress + - name: Cache cargo registry + build artefacts uses: actions/cache@v5 with: diff --git a/Cargo.lock b/Cargo.lock index 5d90e7c7..0bb7279b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2397,6 +2397,7 @@ dependencies = [ "larql-compute", "larql-compute-metal", "larql-core", + "larql-inference", "larql-models", "larql-vindex-spec", "libc", diff --git a/crates/larql-compute/src/attention/gqa.rs b/crates/larql-compute/src/attention/gqa.rs index b86d9409..2d538f6f 100644 --- a/crates/larql-compute/src/attention/gqa.rs +++ b/crates/larql-compute/src/attention/gqa.rs @@ -287,7 +287,10 @@ pub fn gqa_attention_asym( for qi in 0..seq_len { let causal_len = qi + 1; let q_row = q.slice(ndarray::s![qi, q_off..q_off + qk_head_dim]); - let k_block = k.slice(ndarray::s![0..causal_len, kv_qk_off..kv_qk_off + qk_head_dim]); + let k_block = k.slice(ndarray::s![ + 0..causal_len, + kv_qk_off..kv_qk_off + qk_head_dim + ]); let raw_scores = k_block.dot(&q_row); for i in 0..causal_len { @@ -815,8 +818,17 @@ mod tests { let q = small(seq, qk_hd, 0.1); let k = small(seq, qk_hd, 0.1); let v = small(seq, v_hd, 0.1); - let out = - gqa_attention_asym(&q, &k, &v, 1, qk_hd, v_hd, 1, 1.0 / (qk_hd as f64).sqrt(), seq); + let out = gqa_attention_asym( + &q, + &k, + &v, + 1, + qk_hd, + v_hd, + 1, + 1.0 / (qk_hd as f64).sqrt(), + seq, + ); // Output must equal V exactly (weight=1 on single token). let v_row: Vec = v.row(0).to_vec(); let out_row: Vec = out.row(0).to_vec(); diff --git a/crates/larql-models/src/architectures/gemma3.rs b/crates/larql-models/src/architectures/gemma3.rs index 945a0884..1fc75431 100644 --- a/crates/larql-models/src/architectures/gemma3.rs +++ b/crates/larql-models/src/architectures/gemma3.rs @@ -146,6 +146,9 @@ mod tests { moe_intermediate_size: None, kv_lora_rank: None, q_lora_rank: None, + qk_nope_head_dim: None, + qk_rope_head_dim: None, + v_head_dim: None, rope_scaling, attn_logit_softcapping: None, final_logit_softcapping: None, diff --git a/crates/larql-server/tests/test_expert_endpoint.rs b/crates/larql-server/tests/test_expert_endpoint.rs index 78c894ae..8bc31a74 100644 --- a/crates/larql-server/tests/test_expert_endpoint.rs +++ b/crates/larql-server/tests/test_expert_endpoint.rs @@ -78,6 +78,9 @@ impl TestMoeArch { moe_intermediate_size: Some(INTER), kv_lora_rank: None, q_lora_rank: None, + qk_nope_head_dim: None, + qk_rope_head_dim: None, + v_head_dim: None, rope_scaling: None, attn_logit_softcapping: None, final_logit_softcapping: None, diff --git a/crates/larql-vindex/src/format/weights/mla_absorb.rs b/crates/larql-vindex/src/format/weights/mla_absorb.rs index c2cb5524..7cb7fbc8 100644 --- a/crates/larql-vindex/src/format/weights/mla_absorb.rs +++ b/crates/larql-vindex/src/format/weights/mla_absorb.rs @@ -15,7 +15,7 @@ /// /// The absorbed tensors feed directly into `gqa_attention_asym` because they have the /// asymmetric qk_head_dim / v_head_dim that function expects. -use ndarray::{Array2, ArrayView2, s}; +use ndarray::{s, Array2, ArrayView2}; pub struct MlaGeometry { pub num_q: usize, @@ -148,7 +148,9 @@ mod tests { let mut state = seed; let data: Vec = (0..rows * cols) .map(|_| { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); let bits = (state >> 33) as u32; (bits as f32 / u32::MAX as f32) * 2.0 - 1.0 }) @@ -210,8 +212,7 @@ mod tests { .slice_mut(s![.., dst_base..dst_base + qk_rope]) .assign(&k_rope_global); // nope - let k_nope_h = kv_latent - .dot(&kv_b.slice(s![kv_base..kv_base + qk_nope, ..]).t()); + let k_nope_h = kv_latent.dot(&kv_b.slice(s![kv_base..kv_base + qk_nope, ..]).t()); k_out .slice_mut(s![.., dst_base + qk_rope..dst_base + qk_head_dim]) .assign(&k_nope_h); @@ -222,7 +223,11 @@ mod tests { for h in 0..num_kv { let kv_base = h * (qk_nope + v_hd); let dst_base = h * v_hd; - let v_h = kv_latent.dot(&kv_b.slice(s![kv_base + qk_nope..kv_base + qk_nope + v_hd, ..]).t()); + let v_h = kv_latent.dot( + &kv_b + .slice(s![kv_base + qk_nope..kv_base + qk_nope + v_hd, ..]) + .t(), + ); v_out .slice_mut(s![.., dst_base..dst_base + v_hd]) .assign(&v_h); @@ -273,13 +278,33 @@ mod tests { let qk_head_dim = g.qk_head_dim(); let reps = g.num_q / g.num_kv; let scale = 1.0 / (qk_head_dim as f64).sqrt(); - let ref_out = gqa_attention_asym(&q_ref, &k_ref, &v_ref, g.num_q, qk_head_dim, g.v_hd, reps, scale, seq); + let ref_out = gqa_attention_asym( + &q_ref, + &k_ref, + &v_ref, + g.num_q, + qk_head_dim, + g.v_hd, + reps, + scale, + seq, + ); // Absorbed path: project through absorbed weight matrices, then run gqa_attention_asym let q_abs = x.dot(&w_q.t()); let k_abs = x.dot(&w_k.t()); let v_abs = x.dot(&w_v.t()); - let abs_out = gqa_attention_asym(&q_abs, &k_abs, &v_abs, g.num_q, qk_head_dim, g.v_hd, reps, scale, seq); + let abs_out = gqa_attention_asym( + &q_abs, + &k_abs, + &v_abs, + g.num_q, + qk_head_dim, + g.v_hd, + reps, + scale, + seq, + ); // Must match numerically (within float precision) let max_diff = ref_out @@ -319,9 +344,15 @@ mod tests { .iter() .copied() .collect(); - assert!(head0_rope.iter().any(|v| v.abs() > 1e-6), "rope-K must be non-zero"); + assert!( + head0_rope.iter().any(|v| v.abs() > 1e-6), + "rope-K must be non-zero" + ); for (a, b) in head0_rope.iter().zip(head1_rope.iter()) { - assert!((a - b).abs() < 1e-6, "rope-K must be identical across heads: {a} vs {b}"); + assert!( + (a - b).abs() < 1e-6, + "rope-K must be identical across heads: {a} vs {b}" + ); } } } diff --git a/crates/larql-vindex/src/format/weights/write_f32.rs b/crates/larql-vindex/src/format/weights/write_f32.rs index ab741a82..2dc112e9 100644 --- a/crates/larql-vindex/src/format/weights/write_f32.rs +++ b/crates/larql-vindex/src/format/weights/write_f32.rs @@ -25,6 +25,7 @@ use crate::format::filenames::*; use crate::format::load::load_vindex_config; use super::capabilities::{ensure_standard_attention_supported, SURFACE_F32_WEIGHT_WRITER}; +use super::mla_absorb::{self, MlaGeometry}; use larql_models::ModelWeights; /// Manifest `kind` discriminators — wire-format strings written into @@ -283,7 +284,32 @@ pub fn write_model_weights_with_opts( .unwrap_or(crate::config::dtype::StorageDtype::F32); let arch = source.arch(); - ensure_standard_attention_supported(arch, SURFACE_F32_WEIGHT_WRITER)?; + // MLA absorption: if the architecture uses MLA and all geometry dims are + // present, we absorb Q/K/V in-flight and write standard dense tensors. + // Otherwise fall through to the standard guard (which rejects MLA). + let mla_geom: Option = if arch.uses_mla() { + match ( + arch.mla_qk_nope_head_dim(), + arch.mla_qk_rope_head_dim(), + arch.mla_v_head_dim(), + ) { + (Some(qk_nope), Some(qk_rope), Some(v_hd)) => Some(MlaGeometry { + num_q: arch.config().num_q_heads, + num_kv: arch.config().num_kv_heads, + qk_nope, + qk_rope, + v_hd, + kv_lora: arch.kv_lora_rank(), + q_lora: arch.q_lora_rank(), + }), + _ => None, + } + } else { + None + }; + if mla_geom.is_none() { + ensure_standard_attention_supported(arch, SURFACE_F32_WEIGHT_WRITER)?; + } let num_layers = source.num_layers(); let mut entries: Vec = Vec::new(); @@ -299,16 +325,84 @@ pub fn write_model_weights_with_opts( for layer in 0..num_layers { callbacks.on_layer_start(COMP_ATTN_WEIGHTS, layer, num_layers); - for key in &[ - arch.attn_q_key(layer), - arch.attn_k_key(layer), - arch.attn_v_key(layer), - arch.attn_o_key(layer), - ] { - if let Some((data, rows, cols)) = source.get_tensor(key) { + + if let Some(ref g) = mla_geom { + // MLA absorption path: fetch the four low-rank projections, + // absorb into dense Q/K/V, write under the standard key names + // so the loader needs no MLA awareness. + let hidden = arch.config().hidden_size; + let qk_hd = g.qk_nope + g.qk_rope; + let kv_a_key = arch.mla_kv_a_key(layer).unwrap_or_default(); + let kv_b_key = arch.mla_kv_b_key(layer).unwrap_or_default(); + let q_a_key = arch.mla_q_a_key(layer).unwrap_or_default(); + let q_b_key = arch.mla_q_b_key(layer).unwrap_or_default(); + + let kv_a_raw = source.get_tensor(&kv_a_key); + let kv_b_raw = source.get_tensor(&kv_b_key); + let q_a_raw = source.get_tensor(&q_a_key); + let q_b_raw = source.get_tensor(&q_b_key); + + if let ( + Some((kv_a_d, _, _)), + Some((kv_b_d, _, _)), + Some((q_a_d, _, _)), + Some((q_b_d, _, _)), + ) = (kv_a_raw, kv_b_raw, q_a_raw, q_b_raw) + { + use ndarray::Array2; + let kv_a = Array2::from_shape_vec((g.kv_lora + g.qk_rope, hidden), kv_a_d) + .expect("kv_a shape mismatch"); + let kv_b = Array2::from_shape_vec( + (g.num_kv * (g.qk_nope + g.v_hd), g.kv_lora), + kv_b_d, + ) + .expect("kv_b shape mismatch"); + let q_a = Array2::from_shape_vec((g.q_lora, hidden), q_a_d) + .expect("q_a shape mismatch"); + let q_b = Array2::from_shape_vec((g.num_q * qk_hd, g.q_lora), q_b_d) + .expect("q_b shape mismatch"); + + let (w_q, w_k, w_v) = mla_absorb::absorb(&kv_a, &kv_b, &q_a, &q_b, g); + + for (tensor, key, rows, cols) in [ + ( + w_q.into_raw_vec_and_offset().0, + arch.attn_q_key(layer), + g.num_q * qk_hd, + hidden, + ), + ( + w_k.into_raw_vec_and_offset().0, + arch.attn_k_key(layer), + g.num_kv * qk_hd, + hidden, + ), + ( + w_v.into_raw_vec_and_offset().0, + arch.attn_v_key(layer), + g.num_kv * g.v_hd, + hidden, + ), + ] { + let len = write_floats(&mut attn_file, &tensor, dtype)?; + entries.push(WeightEntry { + key, + kind: kind::TENSOR.into(), + shape: vec![rows, cols], + offset: attn_offset, + length: len, + file: ATTN_WEIGHTS_BIN.into(), + }); + attn_offset += len; + } + } + + // O projection is a standard linear — no absorption needed + let o_key = arch.attn_o_key(layer); + if let Some((data, rows, cols)) = source.get_tensor(&o_key) { let len = write_floats(&mut attn_file, &data, dtype)?; entries.push(WeightEntry { - key: key.clone(), + key: o_key, kind: kind::TENSOR.into(), shape: vec![rows, cols], offset: attn_offset, @@ -317,6 +411,27 @@ pub fn write_model_weights_with_opts( }); attn_offset += len; } + } else { + // Standard Q/K/V/O path + for key in &[ + arch.attn_q_key(layer), + arch.attn_k_key(layer), + arch.attn_v_key(layer), + arch.attn_o_key(layer), + ] { + if let Some((data, rows, cols)) = source.get_tensor(key) { + let len = write_floats(&mut attn_file, &data, dtype)?; + entries.push(WeightEntry { + key: key.clone(), + kind: kind::TENSOR.into(), + shape: vec![rows, cols], + offset: attn_offset, + length: len, + file: ATTN_WEIGHTS_BIN.into(), + }); + attn_offset += len; + } + } } // QK norms (1D vectors, stored alongside attention)