diff --git a/crates/larql-inference/src/forward/ple.rs b/crates/larql-inference/src/forward/ple.rs index 97782277a..dd2e59c53 100644 --- a/crates/larql-inference/src/forward/ple.rs +++ b/crates/larql-inference/src/forward/ple.rs @@ -33,15 +33,21 @@ pub fn precompute_per_layer_inputs( let seq_len = token_ids.len(); let hidden = weights.hidden_size; - // Stream 1: model projection from main embeddings - let model_proj_key = match arch.per_layer_model_projection_key() { - Some(k) => k, - None => return Vec::new(), - }; - let w_model_proj = match weights.tensors.get(&model_proj_key) { - Some(w) => w, - None => return Vec::new(), - }; + // Stream 1: model projection from main embeddings. + // `arch.has_per_layer_embeddings()` is true here, so the key must + // exist and the loader must have populated it — both invariants are + // enforced at vindex-load time by `load_model_weights*` (see + // chrishayuk/larql#49). A panic here means the vindex was loaded + // through a path that skipped validation, which is a bug. + let model_proj_key = arch + .per_layer_model_projection_key() + .expect("PLE arch must expose per_layer_model_projection_key"); + let w_model_proj = weights.tensors.get(&model_proj_key).unwrap_or_else(|| { + panic!( + "PLE tensor `{model_proj_key}` missing from weights — rebuild this vindex \ + (chrishayuk/larql#49)" + ) + }); let projected = dot_proj(main_embeds, w_model_proj); let model_proj_scale = (hidden as f32).powf(-0.5); @@ -120,27 +126,35 @@ pub(crate) fn apply_per_layer_embedding( per_layer_input: Option<&Array2>, ) -> Array2 { let arch = &*weights.arch; + // `per_layer_input == None` is the only legitimate "skip PLE for + // this layer" path — it's how callers without precomputed PLE + // inputs opt out (e.g. unit tests, non-PLE archs that route + // through here defensively). The tensor-missing branches that + // used to live here have been collapsed into expects because the + // loader now refuses to load a PLE arch without them (#49). let per_layer_input = match per_layer_input { Some(p) => p, None => return h.clone(), }; - let gate_key = match arch.per_layer_input_gate_key(layer) { - Some(k) => k, - None => return h.clone(), - }; - let proj_key = match arch.per_layer_projection_key(layer) { - Some(k) => k, - None => return h.clone(), - }; - let w_gate = match weights.tensors.get(&gate_key) { - Some(w) => w, - None => return h.clone(), - }; - let w_proj = match weights.tensors.get(&proj_key) { - Some(w) => w, - None => return h.clone(), - }; + let gate_key = arch + .per_layer_input_gate_key(layer) + .unwrap_or_else(|| panic!("PLE arch missing per_layer_input_gate_key for layer {layer}")); + let proj_key = arch + .per_layer_projection_key(layer) + .unwrap_or_else(|| panic!("PLE arch missing per_layer_projection_key for layer {layer}")); + let w_gate = weights.tensors.get(&gate_key).unwrap_or_else(|| { + panic!( + "PLE tensor `{gate_key}` missing from weights — rebuild this vindex \ + (chrishayuk/larql#49)" + ) + }); + let w_proj = weights.tensors.get(&proj_key).unwrap_or_else(|| { + panic!( + "PLE tensor `{proj_key}` missing from weights — rebuild this vindex \ + (chrishayuk/larql#49)" + ) + }); // gate = h @ w_gate.T → [seq, ple_dim] let mut gate = dot_proj(h, w_gate); @@ -197,9 +211,12 @@ mod tests { } #[test] - fn precompute_returns_empty_when_projection_weight_missing() { - // Even if arch claims PLE support, missing weight → empty return. - // TinyModel arch doesn't enable PLE so this exercises the same early exit. + fn precompute_returns_empty_on_non_ple_arch() { + // Re-asserts the `!arch.has_per_layer_embeddings()` early exit on a + // different input shape — TinyModel doesn't enable PLE so we never + // reach the tensor lookups. The old fallback ("PLE arch but tensor + // missing → silent empty") was removed; that case now panics at + // load time before the forward path ever runs (#49). let weights = make_test_weights(); let embeds = Array2::zeros((1, weights.hidden_size)); let result = precompute_per_layer_inputs(&weights, &embeds, &[0u32]); @@ -218,14 +235,19 @@ mod tests { } #[test] - fn apply_ple_missing_gate_weight_returns_h_unchanged() { + #[should_panic(expected = "PLE arch missing per_layer_input_gate_key")] + fn apply_ple_with_input_on_non_ple_arch_panics() { + // Previously the function silently returned h unchanged when the + // gate key was missing — that's the exact failure mode behind + // #49 (extract drops PLE → forward silently no-ops → garbage). + // The contract now: if a caller hands us a precomputed PLE input, + // the arch must actually be PLE-capable, otherwise it's a bug. + // `precompute_per_layer_inputs` enforces this naturally because + // it returns an empty Vec on non-PLE archs. let weights = make_test_weights(); let h = input(1, weights.hidden_size); - // Provide a per_layer_input, but TinyModel has no per_layer gate tensors let dummy_input = Array2::zeros((1, 4)); - let result = apply_per_layer_embedding(&weights, &h, 0, Some(&dummy_input)); - // Gate key doesn't exist in TinyModel → returns h unchanged - assert_eq!(result, h, "missing gate weight should return h unchanged"); + let _ = apply_per_layer_embedding(&weights, &h, 0, Some(&dummy_input)); } #[test] diff --git a/crates/larql-vindex/src/format/weights/load/f32.rs b/crates/larql-vindex/src/format/weights/load/f32.rs index 4a5807b99..64b8a123a 100644 --- a/crates/larql-vindex/src/format/weights/load/f32.rs +++ b/crates/larql-vindex/src/format/weights/load/f32.rs @@ -141,6 +141,19 @@ pub fn load_model_weights_with_opts( tensors.insert(entry.key.clone(), arr.into_shared()); } } + kind::TENSOR_F16 => { + // Gemma 4 PLE sidecars are always written as f16 (see + // `weights::ple_sidecar`). The byte-count vs expected + // floats already picked F16 in `actual_dtype` above, so + // `floats` is already decoded — same handling as TENSOR + // from here, just routed via a distinct manifest kind. + if entry.shape.len() != 2 { + continue; + } + let arr = Array2::from_shape_vec((entry.shape[0], entry.shape[1]), floats) + .map_err(|e| VindexError::Parse(e.to_string()))?; + tensors.insert(entry.key.clone(), arr.into_shared()); + } kind::VECTOR => { vectors.insert(entry.key.clone(), floats); } @@ -148,6 +161,71 @@ pub fn load_model_weights_with_opts( } } + // ── Gemma-4 PLE invariant ── + // + // The PLE forward path (crates/larql-inference/src/forward/ple.rs) + // looks tensors up unconditionally once `arch.has_per_layer_embeddings()` + // is true. Pre-#49, the writer dropped them on --quant none and the + // forward path silently fell through to `return h.clone()`, producing + // garbage tokens with no diagnostic. Validate at load time instead so + // a stale (pre-fix) vindex fails loudly with actionable guidance. + if arch.has_per_layer_embeddings() { + let mut missing: Vec = Vec::new(); + let require_tensor = |key: &str, missing: &mut Vec| { + if !tensors.contains_key(key) { + missing.push(key.to_string()); + } + }; + require_tensor("per_layer_model_projection.weight", &mut missing); + if let Some(k) = arch.per_layer_embed_key() { + require_tensor(&k, &mut missing); + } + for layer in 0..config.num_layers { + if let Some(k) = arch.per_layer_input_gate_key(layer) { + require_tensor(&k, &mut missing); + } + if let Some(k) = arch.per_layer_projection_key(layer) { + require_tensor(&k, &mut missing); + } + if let Some(k) = arch.post_per_layer_input_norm_key(layer) { + if !vectors.contains_key(&k) { + missing.push(k); + } + } + } + if let Some(k) = arch.per_layer_projection_norm_key() { + if !vectors.contains_key(&k) { + missing.push(k); + } + } + // `layer_scalar_key` returns Some only on Gemma-4; treat it as a + // per-layer requirement on those models. Out-of-the-box Llama / + // Gemma-2/3 leave this None, so the loop below is a no-op there. + if arch.layer_scalar_key(0).is_some() { + for layer in 0..config.num_layers { + if let Some(k) = arch.layer_scalar_key(layer) { + if !vectors.contains_key(&k) { + missing.push(k); + } + } + } + } + if !missing.is_empty() { + let sample = missing + .iter() + .take(6) + .cloned() + .collect::>() + .join(", "); + return Err(VindexError::Parse(format!( + "vindex is missing Gemma-4 PLE sidecar tensors ({} entries, e.g. {sample}). \ + Rebuild with current larql — older extracts dropped these on --quant none / f16 \ + and silently produced garbage INFER (chrishayuk/larql#49).", + missing.len(), + ))); + } + } + // Gate vectors from gate_vectors.bin — only when running in non-Q4 mode. // // In Q4 vindexes (quant=q4k) the forward pass reads FFN weights straight diff --git a/crates/larql-vindex/src/format/weights/mod.rs b/crates/larql-vindex/src/format/weights/mod.rs index e191bb84f..d2c9be187 100644 --- a/crates/larql-vindex/src/format/weights/mod.rs +++ b/crates/larql-vindex/src/format/weights/mod.rs @@ -18,6 +18,7 @@ mod capabilities; pub mod load; pub mod manifest; +mod ple_sidecar; pub mod write_f32; pub mod write_kquant; pub mod write_layers; diff --git a/crates/larql-vindex/src/format/weights/write_kquant/ple.rs b/crates/larql-vindex/src/format/weights/ple_sidecar.rs similarity index 60% rename from crates/larql-vindex/src/format/weights/write_kquant/ple.rs rename to crates/larql-vindex/src/format/weights/ple_sidecar.rs index 8245118b2..b185dcdf8 100644 --- a/crates/larql-vindex/src/format/weights/write_kquant/ple.rs +++ b/crates/larql-vindex/src/format/weights/ple_sidecar.rs @@ -1,20 +1,22 @@ -//! Stage 5 — `ple_weights.bin` (Gemma 4 E2B Per-Layer Embeddings). +//! Shared writer for Gemma-4 Per-Layer Embedding (PLE) sidecars. //! -//! Stored as f16 — NOT Q4_K. The two globals -//! (`per_layer_model_projection`, `embed_tokens_per_layer`) and the -//! per-layer input_gate/projection matrices behave like embedding -//! tables: each super-block of 256 values spans a wide dynamic range -//! with a handful of outliers, and Q4_K's per-super-block (d, dmin) -//! calibration zeros out the majority of cells to accommodate those -//! outliers. PLE contributions are additive into every layer's -//! residual, so the cell-level noise compounds across 35 layers — the -//! observable result was "arrays" / "amphibians" instead of "Paris" on -//! Gemma 4 E2B. f16 halves the BF16 footprint (~4.7 GB for the big -//! lookup on E2B) and preserves enough precision for accurate -//! per-token PLE retrieval. +//! PLE tensors are large (~4.7 GB for E2B's `embed_tokens_per_layer`) +//! but behave like embedding tables: each super-block of 256 values +//! spans a wide dynamic range with a handful of outliers. Q4_K's +//! per-super-block calibration zeros out the majority of cells to +//! accommodate those outliers, and the cell-level noise compounds +//! over 35+ layers of additive contribution — the observable result +//! was garbage tokens on Gemma 4 E2B / E4B. f16 halves the BF16 +//! footprint and preserves enough precision for accurate per-token +//! retrieval, so PLE is stored as `kind::TENSOR_F16` in +//! `ple_weights.bin` regardless of the rest of the vindex's quant +//! mode. //! -//! Manifest entries are appended to the running norms manifest so -//! `weight_manifest.json` references everything in one list. +//! Both writers (`write_f32` and `write_q4k`) call into this helper +//! so the on-disk layout — and the manifest entries the loader +//! validates — stay byte-identical. Keeps Gemma-4 inference correct +//! across `--quant none`, `--quant q4k`, and any future quant modes. +//! Regression context: chrishayuk/larql#49. use std::io::{BufWriter, Write}; use std::path::Path; @@ -22,13 +24,20 @@ use std::path::Path; use crate::error::VindexError; use crate::format::filenames::*; -use super::super::write_f32::{kind, WeightEntry, WeightSource}; +use super::write_f32::{kind, WeightEntry, WeightSource}; +/// Write `ple_weights.bin` and append `tensor_f16` manifest entries +/// for every Gemma-4 PLE tensor. No-op when the architecture has no +/// PLE (i.e. `!arch.has_per_layer_embeddings()`). +/// +/// `manifest_entries` is the running `Vec` that the +/// caller threads through every weight-writing stage; this function +/// appends to it in place. pub(super) fn write_ple_weights( source: &dyn WeightSource, dir: &Path, num_layers: usize, - norm_entries: &mut Vec, + manifest_entries: &mut Vec, ) -> Result<(), VindexError> { let arch = source.arch(); if !arch.has_per_layer_embeddings() { @@ -65,7 +74,7 @@ pub(super) fn write_ple_weights( // Global: model projection [ple_dim·num_layers, hidden] write_tensor( &mut ple_file, - norm_entries, + manifest_entries, &mut ple_offset, "per_layer_model_projection.weight".into(), source.get_tensor("per_layer_model_projection.weight"), @@ -75,7 +84,7 @@ pub(super) fn write_ple_weights( if let Some(key) = arch.per_layer_embed_key() { write_tensor( &mut ple_file, - norm_entries, + manifest_entries, &mut ple_offset, key.clone(), source.get_tensor(&key), @@ -87,7 +96,7 @@ pub(super) fn write_ple_weights( if let Some(k) = arch.per_layer_input_gate_key(layer) { write_tensor( &mut ple_file, - norm_entries, + manifest_entries, &mut ple_offset, k.clone(), source.get_tensor(&k), @@ -96,7 +105,7 @@ pub(super) fn write_ple_weights( if let Some(k) = arch.per_layer_projection_key(layer) { write_tensor( &mut ple_file, - norm_entries, + manifest_entries, &mut ple_offset, k.clone(), source.get_tensor(&k), diff --git a/crates/larql-vindex/src/format/weights/write_f32.rs b/crates/larql-vindex/src/format/weights/write_f32.rs index 54c9565fa..ab741a82b 100644 --- a/crates/larql-vindex/src/format/weights/write_f32.rs +++ b/crates/larql-vindex/src/format/weights/write_f32.rs @@ -468,6 +468,19 @@ pub fn write_model_weights_with_opts( Some(arch.post_attention_layernorm_key(layer)), arch.pre_feedforward_layernorm_key(layer), arch.post_feedforward_layernorm_key(layer), + // Gemma 4 per-layer scalar multiplier. Returned by + // arch as Option; None on non-Gemma-4 archs. + // Omitting it on the float writer silently broke + // Gemma-4 inference the same way #49 broke E4B PLE. + arch.layer_scalar_key(layer), + // Gemma 4 per-layer embedding post-norm. Lives with + // the per-layer norms so the loader sees it as a + // standard `kind::VECTOR` entry. + if arch.has_per_layer_embeddings() { + arch.post_per_layer_input_norm_key(layer) + } else { + None + }, ] .into_iter() .flatten() @@ -520,10 +533,37 @@ pub fn write_model_weights_with_opts( length: bytes.len() as u64, file: NORMS_BIN.into(), }); + norms_offset += bytes.len() as u64; + } + + // Gemma 4 PLE global projection norm (small vector). Pairs + // with the per-layer post_per_layer_input_norm entries above. + // Same layout as the Q4_K writer (see + // `write_q4k/norms.rs::write_norms_and_router`). + if arch.has_per_layer_embeddings() { + if let Some(data) = source.get_vector("per_layer_projection_norm.weight") { + let bytes = crate::config::dtype::encode_floats(&data, dtype); + norms_file.write_all(&bytes)?; + entries.push(WeightEntry { + key: "per_layer_projection_norm.weight".into(), + kind: kind::VECTOR.into(), + shape: vec![data.len()], + offset: norms_offset, + length: bytes.len() as u64, + file: NORMS_BIN.into(), + }); + } } norms_file.flush()?; } + // ── PLE sidecar ── (Gemma 4 only — helper no-ops otherwise). + // Mirrors the Q4_K writer so non-Q4 extracts capture the same + // tensors the inference path expects in weights.tensors. Before + // this, `larql extract --quant none` against E4B silently produced + // a vindex with zero PLE entries → garbage INFER output (#49). + super::ple_sidecar::write_ple_weights(source, dir, num_layers, &mut entries)?; + // ── LM Head ── (skipped when level < Inference) if write_lm_head { if let Some((data, rows, cols)) = source.lm_head() { diff --git a/crates/larql-vindex/src/format/weights/write_kquant/mod.rs b/crates/larql-vindex/src/format/weights/write_kquant/mod.rs index 044ee72fa..d6b4be80c 100644 --- a/crates/larql-vindex/src/format/weights/write_kquant/mod.rs +++ b/crates/larql-vindex/src/format/weights/write_kquant/mod.rs @@ -10,7 +10,9 @@ //! - [`ffn`] — `interleaved_kquant.bin` (+ opt `down_features_q4k.bin`) //! - [`moe_layers`] — `layers/layer_{L:02}.weights` (hybrid MoE) //! - [`norms`] — `norms.bin` (norms + MoE router/scales) -//! - [`ple`] — `ple_weights.bin` (Gemma 4 E2B PLE, f16) +//! - `super::ple_sidecar` — `ple_weights.bin` (Gemma 4 PLE, f16, +//! shared with the `write_f32` writer so non-Q4 extracts capture +//! the same sidecars; see chrishayuk/larql#49) //! - [`lm_head`] — `lm_head_q4.bin` //! //! The orchestrator below threads the running `Vec` @@ -35,7 +37,6 @@ mod ffn; mod lm_head; mod moe_layers; mod norms; -mod ple; pub mod feature_major_down; @@ -249,7 +250,7 @@ pub fn write_model_weights_kquant_with_opts( ffn::write_interleaved_ffn_kquant(source, dir, num_layers, opts, callbacks)?; moe_layers::write_per_layer_moe_kquant(source, dir, num_layers)?; let mut entries = norms::write_norms_and_router(source, dir, num_layers)?; - ple::write_ple_weights(source, dir, num_layers, &mut entries)?; + super::ple_sidecar::write_ple_weights(source, dir, num_layers, &mut entries)?; lm_head::write_lm_head_kquant(source, dir, &mut entries)?; let manifest_json = diff --git a/crates/larql-vindex/tests/test_vindex.rs b/crates/larql-vindex/tests/test_vindex.rs index 64eb17f95..c3efccb9d 100644 --- a/crates/larql-vindex/tests/test_vindex.rs +++ b/crates/larql-vindex/tests/test_vindex.rs @@ -3547,37 +3547,31 @@ fn adaptive_gate_knn_uses_pinned() { ); } -// ─── PLE tensors survive Q4_K extract → load round-trip ───────── +// ── Shared Gemma-4 PLE fixture ────────────────────────────────── // -// Regression test for the Gemma 4 E2B "predict returns garbage on -// Q4K vindex" bug: the extractor used to drop the six Per-Layer -// Embedding tensors, so `precompute_per_layer_inputs` silently -// returned an empty Vec and PLE was never applied. Extraction now -// writes `ple_weights.bin` (Q4_K-packed tensors) plus the two small -// PLE norms into norms.bin. This test builds a Gemma 4-shaped -// synthetic safetensors, runs the real extract pipeline, loads via -// `load_model_weights_kquant`, and asserts every PLE tensor is back in -// `weights.tensors` / `weights.vectors` with the right shape. -#[test] -fn streaming_extract_q4k_carries_ple_tensors() { - use larql_vindex::QuantFormat; +// Writes a Gemma-4-shaped HuggingFace model dir on disk: config.json +// with `hidden_size_per_layer_input` set (the knob +// `has_per_layer_embeddings()` keys off), tokenizer.json stub, and a +// safetensors with every tensor the extractor expects — including the +// six PLE tensors per layer plus the three globals AND the per-layer +// `layer_scalar` (Gemma-4-only, 0-D scalar surfaced as a 1-element +// vector). Shared between the Q4_K and the `--quant none` regression +// tests so they exercise the exact same fixture. +// +// Returns the populated model dir path; caller decides where to write +// the vindex output and what quant to extract with. +#[allow(clippy::type_complexity)] +fn write_gemma4_ple_fixture( + model_dir: &std::path::Path, + num_layers: usize, + hidden: usize, + intermediate: usize, + vocab: usize, + ple_dim: usize, +) { use std::collections::HashMap; - let model_dir = std::env::temp_dir().join("larql_test_streaming_q4k_ple_model"); - let output_dir = std::env::temp_dir().join("larql_test_streaming_q4k_ple_output"); - let _ = std::fs::remove_dir_all(&model_dir); - let _ = std::fs::remove_dir_all(&output_dir); - std::fs::create_dir_all(&model_dir).unwrap(); - - // E2B-shaped config at a test-friendly scale. `hidden_size_per_layer_input` - // is the knob `has_per_layer_embeddings()` keys off, so it must be present - // AND non-zero for the extractor to hit the PLE path. Gemma 4 uses the - // text_config wrapper; detect_from_json handles that. - let hidden = 256usize; // multiple of 256 so Q/K/V/O skip the padder - let intermediate = 256usize; - let num_layers = 2usize; - let vocab = 256usize; - let ple_dim = 256usize; + std::fs::create_dir_all(model_dir).unwrap(); let config = serde_json::json!({ "model_type": "gemma4", @@ -3699,6 +3693,17 @@ fn streaming_extract_q4k_carries_ple_tensors() { &format!("{lp}.self_attn.k_norm.weight"), vec![hidden], ); + // Gemma-4 per-layer scalar multiplier (`layer_scalar`). Real + // models ship it as a 0-D scalar; we use a 1-element vector + // because that's how `WeightSource::get_vector` surfaces it. + // Omitting it on the writer side silently broke Gemma-4 + // inference (same family of bug as #49 for PLE). + push( + &mut tensors, + &mut metadata, + &format!("{lp}.layer_scalar"), + vec![1], + ); // ── PLE per-layer tensors (the regression surface) ── push( @@ -3766,6 +3771,41 @@ fn streaming_extract_q4k_carries_ple_tensors() { let tok_json = r#"{"version":"1.0","model":{"type":"BPE","vocab":{},"merges":[]},"added_tokens":[]}"#; std::fs::write(model_dir.join("tokenizer.json"), tok_json).unwrap(); +} + +// ─── PLE tensors survive Q4_K extract → load round-trip ───────── +// +// Regression test for the Gemma 4 E2B "predict returns garbage on +// Q4K vindex" bug: the extractor used to drop the six Per-Layer +// Embedding tensors, so `precompute_per_layer_inputs` silently +// returned an empty Vec and PLE was never applied. Extraction now +// writes `ple_weights.bin` (Q4_K-packed tensors) plus the two small +// PLE norms into norms.bin. This test builds a Gemma 4-shaped +// synthetic safetensors, runs the real extract pipeline, loads via +// `load_model_weights_kquant`, and asserts every PLE tensor is back in +// `weights.tensors` / `weights.vectors` with the right shape. +#[test] +fn streaming_extract_q4k_carries_ple_tensors() { + use larql_vindex::QuantFormat; + + let model_dir = std::env::temp_dir().join("larql_test_streaming_q4k_ple_model"); + let output_dir = std::env::temp_dir().join("larql_test_streaming_q4k_ple_output"); + let _ = std::fs::remove_dir_all(&model_dir); + let _ = std::fs::remove_dir_all(&output_dir); + + // E2B-shaped config at a test-friendly scale. `hidden_size_per_layer_input` + // is the knob `has_per_layer_embeddings()` keys off, so it must be present + // AND non-zero for the extractor to hit the PLE path. Gemma 4 uses the + // text_config wrapper; detect_from_json handles that. + let hidden = 256usize; // multiple of 256 so Q/K/V/O skip the padder + let intermediate = 256usize; + let num_layers = 2usize; + let vocab = 256usize; + let ple_dim = 256usize; + + write_gemma4_ple_fixture(&model_dir, num_layers, hidden, intermediate, vocab, ple_dim); + let tok_json = + r#"{"version":"1.0","model":{"type":"BPE","vocab":{},"merges":[]},"added_tokens":[]}"#; let tokenizer = larql_vindex::tokenizers::Tokenizer::from_bytes(tok_json.as_bytes()).unwrap(); let mut cb = larql_vindex::SilentBuildCallbacks; @@ -3882,6 +3922,19 @@ fn streaming_extract_q4k_carries_ple_tensors() { "global PLE norm missing from loaded weights.vectors" ); + // Gemma-4 layer_scalar: a per-layer 0-D scalar surfaced as a 1-element + // vector. The forward path multiplies h by this value after FFN; omitting + // it silently produced garbage on the 31B model. Previously uncovered + // even on the Q4_K path — same failure mode shape as the PLE drop in #49, + // so we pin it here too. + for layer in 0..num_layers { + let key = format!("layers.{layer}.layer_scalar"); + assert!( + weights.vectors.contains_key(&key), + "layer {layer} layer_scalar missing from loaded weights.vectors" + ); + } + // final_logit_softcapping must survive the round-trip. Missing it // lets predict_kquant peak the softmax on the wrong token. let cfg = larql_vindex::load_vindex_config(&output_dir).unwrap(); @@ -3902,6 +3955,267 @@ fn streaming_extract_q4k_carries_ple_tensors() { let _ = std::fs::remove_dir_all(&output_dir); } +// ─── PLE tensors survive --quant none extract → load round-trip ───── +// +// The companion regression test to `streaming_extract_q4k_carries_ple_tensors` +// for the `--quant none` / `--quant f16` extract path. Q4_K already wrote +// the PLE sidecar correctly; the float writer silently dropped every PLE +// tensor (the bug from chrishayuk/larql#49). After the fix, both writers +// route through `weights::ple_sidecar::write_ple_weights` and the float +// loader validates the invariant on load, refusing to surface a vindex that +// would silently produce garbage INFER output. ExtractLevel must be +// Inference (not Browse) because non-Q4 extracts only write model weights +// when the level demands attn — see `maybe_write_model_weights`. +#[test] +fn streaming_extract_noquant_carries_ple_tensors() { + use larql_vindex::QuantFormat; + + let model_dir = std::env::temp_dir().join("larql_test_streaming_noquant_ple_model"); + let output_dir = std::env::temp_dir().join("larql_test_streaming_noquant_ple_output"); + let _ = std::fs::remove_dir_all(&model_dir); + let _ = std::fs::remove_dir_all(&output_dir); + + let hidden = 256usize; + let intermediate = 256usize; + let num_layers = 2usize; + let vocab = 256usize; + let ple_dim = 256usize; + + write_gemma4_ple_fixture(&model_dir, num_layers, hidden, intermediate, vocab, ple_dim); + let tok_json = + r#"{"version":"1.0","model":{"type":"BPE","vocab":{},"merges":[]},"added_tokens":[]}"#; + let tokenizer = larql_vindex::tokenizers::Tokenizer::from_bytes(tok_json.as_bytes()).unwrap(); + + let mut cb = larql_vindex::SilentBuildCallbacks; + larql_vindex::build_vindex_streaming( + &model_dir, + &tokenizer, + "test/streaming-noquant-ple", + &output_dir, + 5, + // Inference (not Browse): non-Q4 only writes model weights when + // the level includes attn. + larql_vindex::ExtractLevel::Inference, + larql_vindex::StorageDtype::F32, + QuantFormat::None, + larql_vindex::WriteWeightsOptions::default(), + larql_vindex::KquantWriteOptions::default(), + false, + &mut cb, + ) + .unwrap(); + + // ── ple_weights.bin must exist and the manifest must list all 3 + // global + (2 per-layer) PLE tensor_f16 entries. Identical + // layout to the Q4_K case because both writers call into the + // shared `weights::ple_sidecar`. ── + assert!( + output_dir.join("ple_weights.bin").exists(), + "non-Q4 extract should emit ple_weights.bin when the arch has PLE (#49)" + ); + + let manifest_json = std::fs::read_to_string(output_dir.join("weight_manifest.json")).unwrap(); + let manifest: Vec = serde_json::from_str(&manifest_json).unwrap(); + let ple_tensor_keys: Vec<&str> = manifest + .iter() + .filter(|e| e["kind"] == "tensor_f16") + .filter_map(|e| e["key"].as_str()) + .collect(); + + assert_eq!( + ple_tensor_keys.len(), + 2 + 2 * num_layers, + "expected {} PLE tensor_f16 entries, got: {:?}", + 2 + 2 * num_layers, + ple_tensor_keys + ); + assert!( + ple_tensor_keys.contains(&"per_layer_model_projection.weight"), + "global model projection missing from manifest" + ); + assert!( + ple_tensor_keys.contains(&"embed_tokens_per_layer.weight"), + "global per-layer embed missing from manifest" + ); + + // ── PLE norms (per-layer + global) must land in norms.bin as + // vector entries. ── + let ple_vector_keys: Vec<&str> = manifest + .iter() + .filter(|e| e["kind"] == "vector") + .filter_map(|e| e["key"].as_str()) + .filter(|k| k.contains("per_layer")) + .collect(); + assert!( + ple_vector_keys.contains(&"per_layer_projection_norm.weight"), + "global PLE norm missing from norms.bin manifest: {ple_vector_keys:?}" + ); + for layer in 0..num_layers { + let k = format!("layers.{layer}.post_per_layer_input_norm.weight"); + assert!( + ple_vector_keys.iter().any(|v| *v == k), + "layer {layer} post-PLE norm missing: {ple_vector_keys:?}" + ); + } + + // ── layer_scalar (Gemma-4 only) must also land in norms.bin. ── + let scalar_keys: Vec<&str> = manifest + .iter() + .filter(|e| e["kind"] == "vector") + .filter_map(|e| e["key"].as_str()) + .filter(|k| k.contains("layer_scalar")) + .collect(); + assert_eq!( + scalar_keys.len(), + num_layers, + "expected one layer_scalar entry per layer, got: {scalar_keys:?}" + ); + + // ── Load back via the float loader and verify every PLE tensor + + // vector is hydrated. The loader now validates this invariant + // and would refuse to return Ok if anything is missing. ── + let mut lcb = larql_vindex::SilentLoadCallbacks; + let weights = larql_vindex::load_model_weights(&output_dir, &mut lcb).unwrap(); + + let proj = weights + .tensors + .get("per_layer_model_projection.weight") + .expect("per_layer_model_projection missing after load"); + assert_eq!(proj.shape(), &[ple_dim * num_layers, hidden]); + + let embed_ple = weights + .tensors + .get("embed_tokens_per_layer.weight") + .expect("embed_tokens_per_layer missing after load"); + assert_eq!(embed_ple.shape(), &[vocab, ple_dim * num_layers]); + + for layer in 0..num_layers { + let gate_key = format!("layers.{layer}.per_layer_input_gate.weight"); + let proj_key = format!("layers.{layer}.per_layer_projection.weight"); + let gate = weights + .tensors + .get(&gate_key) + .unwrap_or_else(|| panic!("{gate_key} missing")); + assert_eq!(gate.shape(), &[ple_dim, hidden]); + let proj = weights + .tensors + .get(&proj_key) + .unwrap_or_else(|| panic!("{proj_key} missing")); + assert_eq!(proj.shape(), &[hidden, ple_dim]); + + let norm_key = format!("layers.{layer}.post_per_layer_input_norm.weight"); + assert!( + weights.vectors.contains_key(&norm_key), + "{norm_key} missing from weights.vectors" + ); + + let scalar_key = format!("layers.{layer}.layer_scalar"); + assert!( + weights.vectors.contains_key(&scalar_key), + "{scalar_key} missing from weights.vectors" + ); + } + + assert!( + weights + .vectors + .contains_key("per_layer_projection_norm.weight"), + "global PLE norm missing from loaded weights.vectors" + ); + + let _ = std::fs::remove_dir_all(&model_dir); + let _ = std::fs::remove_dir_all(&output_dir); +} + +// ─── load_model_weights rejects PLE-arch vindexes with missing sidecars ── +// +// Pre-#49, dropping PLE tensors on the writer side was paired with a +// silent "missing tensor → return empty" in the forward path, so an +// affected vindex loaded fine but produced garbage INFER. The fix moves +// the failure into the load path: any PLE-active vindex whose +// `weights.tensors` / `weights.vectors` don't carry the required PLE +// keys is rejected with an actionable rebuild hint. This test stages +// that failure by writing the vindex normally, then nuking the manifest +// entries for the PLE tensors before calling `load_model_weights` — +// exercises the error branch in `load/f32.rs` end-to-end. +#[test] +fn load_model_weights_rejects_ple_arch_with_missing_sidecars() { + use larql_vindex::QuantFormat; + + let model_dir = std::env::temp_dir().join("larql_test_ple_missing_sidecar_model"); + let output_dir = std::env::temp_dir().join("larql_test_ple_missing_sidecar_output"); + let _ = std::fs::remove_dir_all(&model_dir); + let _ = std::fs::remove_dir_all(&output_dir); + + let hidden = 256usize; + let intermediate = 256usize; + let num_layers = 2usize; + let vocab = 256usize; + let ple_dim = 256usize; + + write_gemma4_ple_fixture(&model_dir, num_layers, hidden, intermediate, vocab, ple_dim); + let tok_json = + r#"{"version":"1.0","model":{"type":"BPE","vocab":{},"merges":[]},"added_tokens":[]}"#; + let tokenizer = larql_vindex::tokenizers::Tokenizer::from_bytes(tok_json.as_bytes()).unwrap(); + + let mut cb = larql_vindex::SilentBuildCallbacks; + larql_vindex::build_vindex_streaming( + &model_dir, + &tokenizer, + "test/ple-missing-sidecar", + &output_dir, + 5, + larql_vindex::ExtractLevel::Inference, + larql_vindex::StorageDtype::F32, + QuantFormat::None, + larql_vindex::WriteWeightsOptions::default(), + larql_vindex::KquantWriteOptions::default(), + false, + &mut cb, + ) + .unwrap(); + + // Simulate a stale (pre-fix) vindex: drop the manifest entries that + // would otherwise hydrate the PLE tensors at load time, then re-write + // the manifest. The bytes in `ple_weights.bin` / `norms.bin` are left + // alone — only the manifest table changes, which is exactly what the + // pre-fix writer did (it never wrote the entries in the first place). + let manifest_path = output_dir.join("weight_manifest.json"); + let manifest_text = std::fs::read_to_string(&manifest_path).unwrap(); + let mut entries: Vec = serde_json::from_str(&manifest_text).unwrap(); + entries.retain(|e| { + let key = e["key"].as_str().unwrap_or(""); + !key.contains("per_layer") + && !key.contains("embed_tokens_per_layer") + && !key.contains("layer_scalar") + }); + std::fs::write( + &manifest_path, + serde_json::to_string_pretty(&entries).unwrap(), + ) + .unwrap(); + + let mut lcb = larql_vindex::SilentLoadCallbacks; + // `ModelWeights` doesn't implement Debug, so use a match here + // instead of `Result::expect_err` (which requires T: Debug). + let err = match larql_vindex::load_model_weights(&output_dir, &mut lcb) { + Ok(_) => panic!("load must reject PLE-arch vindex with missing sidecars"), + Err(e) => e, + }; + let msg = format!("{err}"); + assert!( + msg.contains("Gemma-4 PLE sidecar"), + "error must call out the PLE sidecars — got: {msg}" + ); + assert!( + msg.contains("chrishayuk/larql#49"), + "error must point at the issue for the rebuild hint — got: {msg}" + ); + + let _ = std::fs::remove_dir_all(&model_dir); + let _ = std::fs::remove_dir_all(&output_dir); +} + // ─── Variable per-layer intermediate size (Gemma 4 E2B double-wide MLP) ── // // E2B's `use_double_wide_mlp=True` gives half the layers a 2× intermediate