Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/larql-vindex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ jobs:
"VCPKG_ROOT=$vcpkgRoot" | Out-File -FilePath $env:GITHUB_ENV -Append
& "$vcpkgRoot\vcpkg.exe" install openblas:x64-windows

# protoc on PATH satisfies the `cfg(not(windows))` skip in
# `larql-router-protocol`'s build.rs. Required by larql-vindex's
# dev-dep on `larql-inference` → transitive on `larql-router-protocol`
# (added in #96 for MLA absorption integration tests).
- name: Install protoc (Windows)
if: runner.os == 'Windows'
shell: pwsh
run: choco install protoc -y --no-progress

- name: Cache cargo registry + build artefacts
uses: actions/cache@v5
with:
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

182 changes: 182 additions & 0 deletions crates/larql-compute/src/attention/gqa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,76 @@ fn gqa_attention_capture(
(out, weights, all_weights)
}

/// GQA with asymmetric Q/K vs V head dimensions — required for MLA-absorbed attention.
///
/// `qk_head_dim`: head dimension for Q and K (e.g. 192 for DS-V3: nope=128 + rope=64).
/// `v_head_dim`: head dimension for V and the output (e.g. 128 for DS-V3).
///
/// q: (seq, num_q * qk_head_dim), k: (seq, num_kv * qk_head_dim), v: (seq, num_kv * v_head_dim)
/// Returns: (seq, num_q * v_head_dim)
#[allow(clippy::too_many_arguments)]
pub fn gqa_attention_asym(
q: &Array2<f32>,
k: &Array2<f32>,
v: &Array2<f32>,
num_q: usize,
qk_head_dim: usize,
v_head_dim: usize,
reps: usize,
scale: f64,
seq_len: usize,
) -> Array2<f32> {
let mut out = Array2::<f32>::zeros((seq_len, num_q * v_head_dim));
let scale_f32 = scale as f32;
let mut scores_buf = vec![0.0f32; seq_len];

for h in 0..num_q {
let kv_h = h / reps;
let q_off = h * qk_head_dim;
let kv_qk_off = kv_h * qk_head_dim;
let kv_v_off = kv_h * v_head_dim;
let out_off = h * v_head_dim;

for qi in 0..seq_len {
let causal_len = qi + 1;
let q_row = q.slice(ndarray::s![qi, q_off..q_off + qk_head_dim]);
let k_block = k.slice(ndarray::s![
0..causal_len,
kv_qk_off..kv_qk_off + qk_head_dim
]);
let raw_scores = k_block.dot(&q_row);

for i in 0..causal_len {
scores_buf[i] = raw_scores[i] * scale_f32;
}
let max_val = scores_buf[..causal_len]
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f64;
for score in scores_buf.iter_mut().take(causal_len) {
let e = ((*score - max_val) as f64).exp();
*score = e as f32;
sum += e;
}
let inv_sum = (1.0 / sum) as f32;
for score in scores_buf.iter_mut().take(causal_len) {
*score *= inv_sum;
}

let v_block = v.slice(ndarray::s![0..causal_len, kv_v_off..kv_v_off + v_head_dim]);
for d in 0..v_head_dim {
let mut acc = 0.0f32;
for i in 0..causal_len {
acc += scores_buf[i] * v_block[(i, d)];
}
out[(qi, out_off + d)] = acc;
}
}
}
out
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -657,4 +727,116 @@ mod tests {
);
}
}

// ── gqa_attention_asym — MLA-absorbed asymmetric head dims ──────────────

#[test]
fn asym_output_shape() {
// DS-V3 style: qk_head_dim=6, v_head_dim=4, num_q=2, reps=2 (num_kv=1)
let seq = 3usize;
let qk_hd = 6usize;
let v_hd = 4usize;
let num_q = 2usize;
let reps = 2usize;
let q = small(seq, num_q * qk_hd, 0.01);
let k = small(seq, (num_q / reps) * qk_hd, 0.01);
let v = small(seq, (num_q / reps) * v_hd, 0.01);
let out = gqa_attention_asym(
&q,
&k,
&v,
num_q,
qk_hd,
v_hd,
reps,
1.0 / (qk_hd as f64).sqrt(),
seq,
);
assert_eq!(
out.shape(),
&[seq, num_q * v_hd],
"asym output shape should be [seq, num_q * v_head_dim]"
);
}

#[test]
fn asym_output_finite() {
let seq = 4usize;
let qk_hd = 8usize;
let v_hd = 6usize;
let num_q = 4usize;
let reps = 2usize;
let num_kv = num_q / reps;
let q = small(seq, num_q * qk_hd, 0.01);
let k = small(seq, num_kv * qk_hd, 0.01);
let v = small(seq, num_kv * v_hd, 0.01);
let out = gqa_attention_asym(
&q,
&k,
&v,
num_q,
qk_hd,
v_hd,
reps,
1.0 / (qk_hd as f64).sqrt(),
seq,
);
assert!(
out.iter().all(|x| x.is_finite()),
"asym GQA output has non-finite values"
);
}

#[test]
fn asym_sym_equivalence_when_dims_equal() {
// When qk_head_dim == v_head_dim, asym must match sym exactly.
let seq = 3usize;
let hd = 4usize;
let num_q = 2usize;
let reps = 2usize;
let num_kv = num_q / reps;
let q = small(seq, num_q * hd, 0.05);
let k = small(seq, num_kv * hd, 0.05);
let v = small(seq, num_kv * hd, 0.05);
let scale = 1.0 / (hd as f64).sqrt();
let sym = gqa_attention(&q, &k, &v, num_q, hd, reps, scale, seq);
let asym = gqa_attention_asym(&q, &k, &v, num_q, hd, hd, reps, scale, seq);
for (a, b) in sym.iter().zip(asym.iter()) {
assert!(
(a - b).abs() < 1e-5,
"asym must match sym when dims are equal: {a} vs {b}"
);
}
}

#[test]
fn asym_single_token_causal() {
// seq=1 → causal_len=1 everywhere; trivial softmax (weight=1.0)
let seq = 1usize;
let qk_hd = 4usize;
let v_hd = 2usize;
let q = small(seq, qk_hd, 0.1);
let k = small(seq, qk_hd, 0.1);
let v = small(seq, v_hd, 0.1);
let out = gqa_attention_asym(
&q,
&k,
&v,
1,
qk_hd,
v_hd,
1,
1.0 / (qk_hd as f64).sqrt(),
seq,
);
// Output must equal V exactly (weight=1 on single token).
let v_row: Vec<f32> = v.row(0).to_vec();
let out_row: Vec<f32> = out.row(0).to_vec();
for (vv, ov) in v_row.iter().zip(out_row.iter()) {
assert!(
(vv - ov).abs() < 1e-5,
"seq=1 output must equal V: {vv} vs {ov}"
);
}
}
}
12 changes: 12 additions & 0 deletions crates/larql-models/src/architectures/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ impl ModelArchitecture for DeepSeekArch {
self.config.q_lora_rank.unwrap_or(1536)
}

fn mla_qk_nope_head_dim(&self) -> Option<usize> {
self.config.qk_nope_head_dim
}

fn mla_qk_rope_head_dim(&self) -> Option<usize> {
self.config.qk_rope_head_dim
}

fn mla_v_head_dim(&self) -> Option<usize> {
self.config.v_head_dim
}

fn mla_kv_a_key(&self, layer: usize) -> Option<String> {
Some(format!(
"{}self_attn.kv_a_proj_with_mqa.weight",
Expand Down
3 changes: 3 additions & 0 deletions crates/larql-models/src/architectures/gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ mod tests {
moe_intermediate_size: None,
kv_lora_rank: None,
q_lora_rank: None,
qk_nope_head_dim: None,
qk_rope_head_dim: None,
v_head_dim: None,
rope_scaling,
attn_logit_softcapping: None,
final_logit_softcapping: None,
Expand Down
21 changes: 21 additions & 0 deletions crates/larql-models/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ pub struct ModelConfig {
// MLA fields
pub kv_lora_rank: Option<usize>,
pub q_lora_rank: Option<usize>,
/// DS-V3 MLA: non-RoPE part of head dim (nope). qk_head_dim = qk_nope_head_dim + qk_rope_head_dim.
pub qk_nope_head_dim: Option<usize>,
/// DS-V3 MLA: RoPE part of head dim.
pub qk_rope_head_dim: Option<usize>,
/// DS-V3 MLA: V head dim (may differ from qk_nope+rope total).
pub v_head_dim: Option<usize>,
// RoPE scaling
pub rope_scaling: Option<RopeScaling>,
// Softcapping (Gemma2)
Expand Down Expand Up @@ -791,6 +797,21 @@ pub trait ModelArchitecture: Send + Sync {
None
}

/// DS-V3 MLA: non-RoPE head dim (nope). Combined qk_head_dim = nope + rope.
fn mla_qk_nope_head_dim(&self) -> Option<usize> {
None
}

/// DS-V3 MLA: RoPE head dim portion.
fn mla_qk_rope_head_dim(&self) -> Option<usize> {
None
}

/// DS-V3 MLA: V head dim (after absorption may differ from qk dims).
fn mla_v_head_dim(&self) -> Option<usize> {
None
}

// ── RoPE scaling ──

/// RoPE scaling type (None, "linear", "yarn", "dynamic", "llama3").
Expand Down
6 changes: 6 additions & 0 deletions crates/larql-models/src/detect/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ pub(super) fn parse_model_config(config: &serde_json::Value) -> ModelConfig {
// MLA fields
let kv_lora_rank = text_config["kv_lora_rank"].as_u64().map(|v| v as usize);
let q_lora_rank = text_config["q_lora_rank"].as_u64().map(|v| v as usize);
let qk_nope_head_dim = text_config["qk_nope_head_dim"].as_u64().map(|v| v as usize);
let qk_rope_head_dim = text_config["qk_rope_head_dim"].as_u64().map(|v| v as usize);
let v_head_dim = text_config["v_head_dim"].as_u64().map(|v| v as usize);

// RoPE scaling. Four shapes appear in the wild:
//
Expand Down Expand Up @@ -295,6 +298,9 @@ pub(super) fn parse_model_config(config: &serde_json::Value) -> ModelConfig {
num_shared_experts,
kv_lora_rank,
q_lora_rank,
qk_nope_head_dim,
qk_rope_head_dim,
v_head_dim,
rope_scaling,
attn_logit_softcapping,
final_logit_softcapping,
Expand Down
10 changes: 9 additions & 1 deletion crates/larql-models/src/detect/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,10 @@ fn test_detect_deepseek_v3() {
"num_experts_per_tok": 8,
"n_shared_experts": 1,
"kv_lora_rank": 512,
"q_lora_rank": 1536
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128
});

let arch = detect_from_json(&config);
Expand All @@ -573,6 +576,11 @@ fn test_detect_deepseek_v3() {
assert_eq!(arch.num_experts(), 256);
assert_eq!(arch.num_experts_per_token(), 8);
assert_eq!(arch.num_shared_experts(), 1);

// MLA geometry fields
assert_eq!(arch.mla_qk_nope_head_dim(), Some(128));
assert_eq!(arch.mla_qk_rope_head_dim(), Some(64));
assert_eq!(arch.mla_v_head_dim(), Some(128));
}

#[test]
Expand Down
7 changes: 7 additions & 0 deletions crates/larql-models/src/loading/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ impl GgufFile {
"phi" | "phi2" | "phi3" => "phi",
"gpt2" => "gpt2",
"deepseek" | "deepseek2" => "deepseek_v2",
"deepseek_v4" | "deepseekv4" => "deepseek_v4",
other => other,
};

Expand Down Expand Up @@ -1018,6 +1019,9 @@ mod tests {
moe_intermediate_size: None,
kv_lora_rank: None,
q_lora_rank: None,
qk_nope_head_dim: None,
qk_rope_head_dim: None,
v_head_dim: None,
rope_scaling: None,
attn_logit_softcapping: None,
final_logit_softcapping: None,
Expand Down Expand Up @@ -1157,6 +1161,9 @@ mod tests {
moe_intermediate_size: None,
kv_lora_rank: None,
q_lora_rank: None,
qk_nope_head_dim: None,
qk_rope_head_dim: None,
v_head_dim: None,
rope_scaling: None,
attn_logit_softcapping: None,
final_logit_softcapping: None,
Expand Down
3 changes: 3 additions & 0 deletions crates/larql-server/tests/test_expert_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ impl TestMoeArch {
moe_intermediate_size: Some(INTER),
kv_lora_rank: None,
q_lora_rank: None,
qk_nope_head_dim: None,
qk_rope_head_dim: None,
v_head_dim: None,
rope_scaling: None,
attn_logit_softcapping: None,
final_logit_softcapping: None,
Expand Down
1 change: 1 addition & 0 deletions crates/larql-vindex/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ default = []
gpu = ["dep:larql-compute-metal"]

[dev-dependencies]
larql-inference = { path = "../larql-inference" }
criterion = "0.5"
tempfile = "3"
# HTTP mocking for the publish trio (lfs / remote / upload).
Expand Down
Loading
Loading