Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 55 additions & 33 deletions crates/larql-inference/src/forward/ple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,21 @@ pub fn precompute_per_layer_inputs(
let seq_len = token_ids.len();
let hidden = weights.hidden_size;

// Stream 1: model projection from main embeddings
let model_proj_key = match arch.per_layer_model_projection_key() {
Some(k) => k,
None => return Vec::new(),
};
let w_model_proj = match weights.tensors.get(&model_proj_key) {
Some(w) => w,
None => return Vec::new(),
};
// Stream 1: model projection from main embeddings.
// `arch.has_per_layer_embeddings()` is true here, so the key must
// exist and the loader must have populated it — both invariants are
// enforced at vindex-load time by `load_model_weights*` (see
// chrishayuk/larql#49). A panic here means the vindex was loaded
// through a path that skipped validation, which is a bug.
let model_proj_key = arch
.per_layer_model_projection_key()
.expect("PLE arch must expose per_layer_model_projection_key");
let w_model_proj = weights.tensors.get(&model_proj_key).unwrap_or_else(|| {
panic!(
"PLE tensor `{model_proj_key}` missing from weights — rebuild this vindex \
(chrishayuk/larql#49)"
)
});
let projected = dot_proj(main_embeds, w_model_proj);
let model_proj_scale = (hidden as f32).powf(-0.5);

Expand Down Expand Up @@ -120,27 +126,35 @@ pub(crate) fn apply_per_layer_embedding(
per_layer_input: Option<&Array2<f32>>,
) -> Array2<f32> {
let arch = &*weights.arch;
// `per_layer_input == None` is the only legitimate "skip PLE for
// this layer" path — it's how callers without precomputed PLE
// inputs opt out (e.g. unit tests, non-PLE archs that route
// through here defensively). The tensor-missing branches that
// used to live here have been collapsed into expects because the
// loader now refuses to load a PLE arch without them (#49).
let per_layer_input = match per_layer_input {
Some(p) => p,
None => return h.clone(),
};

let gate_key = match arch.per_layer_input_gate_key(layer) {
Some(k) => k,
None => return h.clone(),
};
let proj_key = match arch.per_layer_projection_key(layer) {
Some(k) => k,
None => return h.clone(),
};
let w_gate = match weights.tensors.get(&gate_key) {
Some(w) => w,
None => return h.clone(),
};
let w_proj = match weights.tensors.get(&proj_key) {
Some(w) => w,
None => return h.clone(),
};
let gate_key = arch
.per_layer_input_gate_key(layer)
.unwrap_or_else(|| panic!("PLE arch missing per_layer_input_gate_key for layer {layer}"));
let proj_key = arch
.per_layer_projection_key(layer)
.unwrap_or_else(|| panic!("PLE arch missing per_layer_projection_key for layer {layer}"));
let w_gate = weights.tensors.get(&gate_key).unwrap_or_else(|| {
panic!(
"PLE tensor `{gate_key}` missing from weights — rebuild this vindex \
(chrishayuk/larql#49)"
)
});
let w_proj = weights.tensors.get(&proj_key).unwrap_or_else(|| {
panic!(
"PLE tensor `{proj_key}` missing from weights — rebuild this vindex \
(chrishayuk/larql#49)"
)
});

// gate = h @ w_gate.T → [seq, ple_dim]
let mut gate = dot_proj(h, w_gate);
Expand Down Expand Up @@ -197,9 +211,12 @@ mod tests {
}

#[test]
fn precompute_returns_empty_when_projection_weight_missing() {
// Even if arch claims PLE support, missing weight → empty return.
// TinyModel arch doesn't enable PLE so this exercises the same early exit.
fn precompute_returns_empty_on_non_ple_arch() {
// Re-asserts the `!arch.has_per_layer_embeddings()` early exit on a
// different input shape — TinyModel doesn't enable PLE so we never
// reach the tensor lookups. The old fallback ("PLE arch but tensor
// missing → silent empty") was removed; that case now panics at
// load time before the forward path ever runs (#49).
let weights = make_test_weights();
let embeds = Array2::zeros((1, weights.hidden_size));
let result = precompute_per_layer_inputs(&weights, &embeds, &[0u32]);
Expand All @@ -218,14 +235,19 @@ mod tests {
}

#[test]
fn apply_ple_missing_gate_weight_returns_h_unchanged() {
#[should_panic(expected = "PLE arch missing per_layer_input_gate_key")]
fn apply_ple_with_input_on_non_ple_arch_panics() {
// Previously the function silently returned h unchanged when the
// gate key was missing — that's the exact failure mode behind
// #49 (extract drops PLE → forward silently no-ops → garbage).
// The contract now: if a caller hands us a precomputed PLE input,
// the arch must actually be PLE-capable, otherwise it's a bug.
// `precompute_per_layer_inputs` enforces this naturally because
// it returns an empty Vec on non-PLE archs.
let weights = make_test_weights();
let h = input(1, weights.hidden_size);
// Provide a per_layer_input, but TinyModel has no per_layer gate tensors
let dummy_input = Array2::zeros((1, 4));
let result = apply_per_layer_embedding(&weights, &h, 0, Some(&dummy_input));
// Gate key doesn't exist in TinyModel → returns h unchanged
assert_eq!(result, h, "missing gate weight should return h unchanged");
let _ = apply_per_layer_embedding(&weights, &h, 0, Some(&dummy_input));
}

#[test]
Expand Down
78 changes: 78 additions & 0 deletions crates/larql-vindex/src/format/weights/load/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,91 @@ pub fn load_model_weights_with_opts(
tensors.insert(entry.key.clone(), arr.into_shared());
}
}
kind::TENSOR_F16 => {
// Gemma 4 PLE sidecars are always written as f16 (see
// `weights::ple_sidecar`). The byte-count vs expected
// floats already picked F16 in `actual_dtype` above, so
// `floats` is already decoded — same handling as TENSOR
// from here, just routed via a distinct manifest kind.
if entry.shape.len() != 2 {
continue;
}
let arr = Array2::from_shape_vec((entry.shape[0], entry.shape[1]), floats)
.map_err(|e| VindexError::Parse(e.to_string()))?;
tensors.insert(entry.key.clone(), arr.into_shared());
}
kind::VECTOR => {
vectors.insert(entry.key.clone(), floats);
}
_ => {}
}
}

// ── Gemma-4 PLE invariant ──
//
// The PLE forward path (crates/larql-inference/src/forward/ple.rs)
// looks tensors up unconditionally once `arch.has_per_layer_embeddings()`
// is true. Pre-#49, the writer dropped them on --quant none and the
// forward path silently fell through to `return h.clone()`, producing
// garbage tokens with no diagnostic. Validate at load time instead so
// a stale (pre-fix) vindex fails loudly with actionable guidance.
if arch.has_per_layer_embeddings() {
let mut missing: Vec<String> = Vec::new();
let require_tensor = |key: &str, missing: &mut Vec<String>| {
if !tensors.contains_key(key) {
missing.push(key.to_string());
}
};
require_tensor("per_layer_model_projection.weight", &mut missing);
if let Some(k) = arch.per_layer_embed_key() {
require_tensor(&k, &mut missing);
}
for layer in 0..config.num_layers {
if let Some(k) = arch.per_layer_input_gate_key(layer) {
require_tensor(&k, &mut missing);
}
if let Some(k) = arch.per_layer_projection_key(layer) {
require_tensor(&k, &mut missing);
}
if let Some(k) = arch.post_per_layer_input_norm_key(layer) {
if !vectors.contains_key(&k) {
missing.push(k);
}
}
}
if let Some(k) = arch.per_layer_projection_norm_key() {
if !vectors.contains_key(&k) {
missing.push(k);
}
}
// `layer_scalar_key` returns Some only on Gemma-4; treat it as a
// per-layer requirement on those models. Out-of-the-box Llama /
// Gemma-2/3 leave this None, so the loop below is a no-op there.
if arch.layer_scalar_key(0).is_some() {
for layer in 0..config.num_layers {
if let Some(k) = arch.layer_scalar_key(layer) {
if !vectors.contains_key(&k) {
missing.push(k);
}
}
}
}
if !missing.is_empty() {
let sample = missing
.iter()
.take(6)
.cloned()
.collect::<Vec<_>>()
.join(", ");
return Err(VindexError::Parse(format!(
"vindex is missing Gemma-4 PLE sidecar tensors ({} entries, e.g. {sample}). \
Rebuild with current larql — older extracts dropped these on --quant none / f16 \
and silently produced garbage INFER (chrishayuk/larql#49).",
missing.len(),
)));
}
}

// Gate vectors from gate_vectors.bin — only when running in non-Q4 mode.
//
// In Q4 vindexes (quant=q4k) the forward pass reads FFN weights straight
Expand Down
1 change: 1 addition & 0 deletions crates/larql-vindex/src/format/weights/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod capabilities;
pub mod load;
pub mod manifest;
mod ple_sidecar;
pub mod write_f32;
pub mod write_kquant;
pub mod write_layers;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
//! Stage 5 — `ple_weights.bin` (Gemma 4 E2B Per-Layer Embeddings).
//! Shared writer for Gemma-4 Per-Layer Embedding (PLE) sidecars.
//!
//! Stored as f16 — NOT Q4_K. The two globals
//! (`per_layer_model_projection`, `embed_tokens_per_layer`) and the
//! per-layer input_gate/projection matrices behave like embedding
//! tables: each super-block of 256 values spans a wide dynamic range
//! with a handful of outliers, and Q4_K's per-super-block (d, dmin)
//! calibration zeros out the majority of cells to accommodate those
//! outliers. PLE contributions are additive into every layer's
//! residual, so the cell-level noise compounds across 35 layers — the
//! observable result was "arrays" / "amphibians" instead of "Paris" on
//! Gemma 4 E2B. f16 halves the BF16 footprint (~4.7 GB for the big
//! lookup on E2B) and preserves enough precision for accurate
//! per-token PLE retrieval.
//! PLE tensors are large (~4.7 GB for E2B's `embed_tokens_per_layer`)
//! but behave like embedding tables: each super-block of 256 values
//! spans a wide dynamic range with a handful of outliers. Q4_K's
//! per-super-block calibration zeros out the majority of cells to
//! accommodate those outliers, and the cell-level noise compounds
//! over 35+ layers of additive contribution — the observable result
//! was garbage tokens on Gemma 4 E2B / E4B. f16 halves the BF16
//! footprint and preserves enough precision for accurate per-token
//! retrieval, so PLE is stored as `kind::TENSOR_F16` in
//! `ple_weights.bin` regardless of the rest of the vindex's quant
//! mode.
//!
//! Manifest entries are appended to the running norms manifest so
//! `weight_manifest.json` references everything in one list.
//! Both writers (`write_f32` and `write_q4k`) call into this helper
//! so the on-disk layout — and the manifest entries the loader
//! validates — stay byte-identical. Keeps Gemma-4 inference correct
//! across `--quant none`, `--quant q4k`, and any future quant modes.
//! Regression context: chrishayuk/larql#49.

use std::io::{BufWriter, Write};
use std::path::Path;

use crate::error::VindexError;
use crate::format::filenames::*;

use super::super::write_f32::{kind, WeightEntry, WeightSource};
use super::write_f32::{kind, WeightEntry, WeightSource};

/// Write `ple_weights.bin` and append `tensor_f16` manifest entries
/// for every Gemma-4 PLE tensor. No-op when the architecture has no
/// PLE (i.e. `!arch.has_per_layer_embeddings()`).
///
/// `manifest_entries` is the running `Vec<WeightEntry>` that the
/// caller threads through every weight-writing stage; this function
/// appends to it in place.
pub(super) fn write_ple_weights(
source: &dyn WeightSource,
dir: &Path,
num_layers: usize,
norm_entries: &mut Vec<WeightEntry>,
manifest_entries: &mut Vec<WeightEntry>,
) -> Result<(), VindexError> {
let arch = source.arch();
if !arch.has_per_layer_embeddings() {
Expand Down Expand Up @@ -65,7 +74,7 @@ pub(super) fn write_ple_weights(
// Global: model projection [ple_dim·num_layers, hidden]
write_tensor(
&mut ple_file,
norm_entries,
manifest_entries,
&mut ple_offset,
"per_layer_model_projection.weight".into(),
source.get_tensor("per_layer_model_projection.weight"),
Expand All @@ -75,7 +84,7 @@ pub(super) fn write_ple_weights(
if let Some(key) = arch.per_layer_embed_key() {
write_tensor(
&mut ple_file,
norm_entries,
manifest_entries,
&mut ple_offset,
key.clone(),
source.get_tensor(&key),
Expand All @@ -87,7 +96,7 @@ pub(super) fn write_ple_weights(
if let Some(k) = arch.per_layer_input_gate_key(layer) {
write_tensor(
&mut ple_file,
norm_entries,
manifest_entries,
&mut ple_offset,
k.clone(),
source.get_tensor(&k),
Expand All @@ -96,7 +105,7 @@ pub(super) fn write_ple_weights(
if let Some(k) = arch.per_layer_projection_key(layer) {
write_tensor(
&mut ple_file,
norm_entries,
manifest_entries,
&mut ple_offset,
k.clone(),
source.get_tensor(&k),
Expand Down
40 changes: 40 additions & 0 deletions crates/larql-vindex/src/format/weights/write_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,19 @@ pub fn write_model_weights_with_opts(
Some(arch.post_attention_layernorm_key(layer)),
arch.pre_feedforward_layernorm_key(layer),
arch.post_feedforward_layernorm_key(layer),
// Gemma 4 per-layer scalar multiplier. Returned by
// arch as Option<String>; None on non-Gemma-4 archs.
// Omitting it on the float writer silently broke
// Gemma-4 inference the same way #49 broke E4B PLE.
arch.layer_scalar_key(layer),
// Gemma 4 per-layer embedding post-norm. Lives with
// the per-layer norms so the loader sees it as a
// standard `kind::VECTOR` entry.
if arch.has_per_layer_embeddings() {
arch.post_per_layer_input_norm_key(layer)
} else {
None
},
]
.into_iter()
.flatten()
Expand Down Expand Up @@ -520,10 +533,37 @@ pub fn write_model_weights_with_opts(
length: bytes.len() as u64,
file: NORMS_BIN.into(),
});
norms_offset += bytes.len() as u64;
}

// Gemma 4 PLE global projection norm (small vector). Pairs
// with the per-layer post_per_layer_input_norm entries above.
// Same layout as the Q4_K writer (see
// `write_q4k/norms.rs::write_norms_and_router`).
if arch.has_per_layer_embeddings() {
if let Some(data) = source.get_vector("per_layer_projection_norm.weight") {
let bytes = crate::config::dtype::encode_floats(&data, dtype);
norms_file.write_all(&bytes)?;
entries.push(WeightEntry {
key: "per_layer_projection_norm.weight".into(),
kind: kind::VECTOR.into(),
shape: vec![data.len()],
offset: norms_offset,
length: bytes.len() as u64,
file: NORMS_BIN.into(),
});
}
}
norms_file.flush()?;
}

// ── PLE sidecar ── (Gemma 4 only — helper no-ops otherwise).
// Mirrors the Q4_K writer so non-Q4 extracts capture the same
// tensors the inference path expects in weights.tensors. Before
// this, `larql extract --quant none` against E4B silently produced
// a vindex with zero PLE entries → garbage INFER output (#49).
super::ple_sidecar::write_ple_weights(source, dir, num_layers, &mut entries)?;

// ── LM Head ── (skipped when level < Inference)
if write_lm_head {
if let Some((data, rows, cols)) = source.lm_head() {
Expand Down
Loading
Loading