diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2cf818ef..cae64399 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,7 +120,7 @@ jobs: toolchain: stable - name: Install cargo-llvm-cov - uses: taiki-e/install-action@51cd0b8c0499559d9a4d75c0f5c67bec3a894ec8 # v2 + uses: taiki-e/install-action@cca35edeb1d01366c2843b68fc3ca441446d73d3 # v2 with: tool: cargo-llvm-cov diff --git a/Cargo.lock b/Cargo.lock index 81fb968e..942dd4b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,7 +88,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -99,7 +99,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -1060,7 +1060,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -1141,7 +1141,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -1196,7 +1196,7 @@ dependencies = [ "atomic", "pear", "serde", - "toml", + "toml 0.8.23", "uncased", "version_check", ] @@ -1617,7 +1617,7 @@ dependencies = [ "serde_json", "sysinfo", "tokio", - "toml", + "toml 1.1.2+spec-1.1.0", ] [[package]] @@ -1645,10 +1645,12 @@ name = "higgs-models" version = "1.1.1" dependencies = [ "criterion", + "half", "image", "mlx-rs", "mlx-sys", "rand 0.10.1", + "safetensors", "serde", "serde_json", "tempfile", @@ -1990,16 +1992,6 @@ version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is-terminal" version = "0.4.17" @@ -2008,7 +2000,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -2536,7 +2528,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3378,7 +3370,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -3437,7 +3429,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -3470,6 +3462,16 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "safetensors" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3587,6 +3589,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" +dependencies = [ + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -3687,7 +3698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -3851,10 +3862,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.2", + "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -4095,9 +4106,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" dependencies = [ "bytes", "libc", @@ -4162,11 +4173,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", - "serde_spanned", + "serde_spanned 0.6.9", "toml_datetime 0.6.11", "toml_edit 0.22.27", ] +[[package]] +name = "toml" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned 1.1.1", + "toml_datetime 1.1.1+spec-1.1.0", + "toml_parser", + "toml_writer", + "winnow 1.0.2", +] + [[package]] name = "toml_datetime" version = "0.6.11" @@ -4193,7 +4219,7 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ "indexmap", "serde", - "serde_spanned", + "serde_spanned 0.6.9", "toml_datetime 0.6.11", "toml_write", "winnow 0.7.15", @@ -4251,9 +4277,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "a28f0d049ccfaa566e14e9663d304d8577427b368cb4710a20528690287a738b" dependencies = [ "base64 0.22.1", "bitflags 2.11.1", @@ -4261,7 +4287,6 @@ dependencies = [ "futures-util", "http", "http-body", - "iri-string", "mime", "pin-project-lite", "tokio", @@ -4269,6 +4294,7 @@ dependencies = [ "tower-layer", "tower-service", "tracing", + "url", "uuid", ] @@ -4828,7 +4854,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] diff --git a/crates/higgs-bench/Cargo.toml b/crates/higgs-bench/Cargo.toml index 9fffc06f..7c6e8d18 100644 --- a/crates/higgs-bench/Cargo.toml +++ b/crates/higgs-bench/Cargo.toml @@ -21,7 +21,7 @@ serde = { workspace = true } serde_json = { workspace = true } sysinfo = "0.32" tokio = { workspace = true } -toml = "0.8" +toml = "1.0" [build-dependencies] built = { version = "0.8", features = ["git2"] } diff --git a/crates/higgs-engine/src/model_loader.rs b/crates/higgs-engine/src/model_loader.rs index aa8119fa..bd2f91d1 100644 --- a/crates/higgs-engine/src/model_loader.rs +++ b/crates/higgs-engine/src/model_loader.rs @@ -1,6 +1,8 @@ use std::path::{Path, PathBuf}; -use higgs_models::{AnyModel, load_tokenizer as shared_load_tokenizer, registry, transformer}; +use higgs_models::{ + AnyModel, error::ModelError, load_tokenizer as shared_load_tokenizer, registry, transformer, +}; use crate::error::EngineError; @@ -36,6 +38,17 @@ pub fn load_model>(model_dir: P) -> Result match config.model_type.as_str() { "qwen2" | "qwen3" | "llama" | "mistral" => { + // Packed 1.25-bpw Bonsai-Q1 checkpoints declare model_type="qwen3" + // but the weights are quantized to bits=1. Keep detection ahead of + // the fp16/Q4 transformer loader so users get an explicit error + // while the workspace remains on upstream oxideai/mlx-rs. + if is_bonsai_q1(&config.model_dir)? { + return Err(EngineError::Model(ModelError::UnsupportedModel( + "Bonsai-Q1 requires MLX bits=1 affine quantization support; \ + the workspace stays on upstream oxideai/mlx-rs until that support lands" + .to_owned(), + ))); + } let model = transformer::load_model(&config.model_dir).map_err(EngineError::Model)?; Ok(AnyModel::Transformer(model)) } @@ -90,6 +103,39 @@ pub fn load_model>(model_dir: P) -> Result } } +/// Peek into `config.json` to detect packed 1-bit Bonsai-Q1 checkpoints. +/// +/// Returns `true` for Qwen3-shaped `quantization.bits == 1` checkpoints using +/// the expected group size. Returns `false` for any other model type or +/// quantization config. A missing / malformed `config.json` propagates as an +/// IO / JSON error — we never mask it. +fn is_bonsai_q1(dir: &Path) -> Result { + let cfg_path = dir.join("config.json"); + let txt = std::fs::read_to_string(&cfg_path).map_err(|e| { + EngineError::Model(higgs_models::error::ModelError::Io(std::io::Error::new( + e.kind(), + format!("{}: {e}", cfg_path.display()), + ))) + })?; + let cfg: serde_json::Value = serde_json::from_str(&txt) + .map_err(|e| EngineError::Model(higgs_models::error::ModelError::Json(e)))?; + let bonsai_group_size = u64::try_from(higgs_models::bonsai_q1::GROUP_SIZE) + .map_err(|e| EngineError::Model(ModelError::ShapeMismatch(e.to_string())))?; + Ok( + cfg.get("model_type").and_then(serde_json::Value::as_str) == Some("qwen3") + && cfg + .get("quantization") + .and_then(|q| q.get("bits")) + .and_then(serde_json::Value::as_u64) + == Some(1) + && cfg + .get("quantization") + .and_then(|q| q.get("group_size")) + .and_then(serde_json::Value::as_u64) + == Some(bonsai_group_size), + ) +} + /// Load a tokenizer from a model directory. pub fn load_tokenizer>(model_dir: P) -> Result { shared_load_tokenizer(model_dir).map_err(|e| EngineError::Tokenization(e.to_string())) @@ -230,6 +276,58 @@ mod tests { )); } + #[test] + fn is_bonsai_q1_requires_qwen3_model_type_and_group_size() { + let (qwen3_dir, _qwen3_result) = config_from_raw( + r#"{ + "model_type": "qwen3", + "quantization": {"bits": 1, "group_size": 128} + }"#, + ); + assert!(is_bonsai_q1(qwen3_dir.path()).unwrap()); + + let (llama_dir, _llama_result) = config_from_raw( + r#"{ + "model_type": "llama", + "quantization": {"bits": 1, "group_size": 128} + }"#, + ); + assert!(!is_bonsai_q1(llama_dir.path()).unwrap()); + + let (wrong_group_dir, _wrong_group_result) = config_from_raw( + r#"{ + "model_type": "qwen3", + "quantization": {"bits": 1, "group_size": 64} + }"#, + ); + assert!(!is_bonsai_q1(wrong_group_dir.path()).unwrap()); + + let (q4_dir, _q4_result) = config_from_raw( + r#"{ + "model_type": "qwen3", + "quantization": {"bits": 4, "group_size": 128} + }"#, + ); + assert!( + !is_bonsai_q1(q4_dir.path()).unwrap(), + "regular Q4 Qwen3 must not be misclassified as Bonsai-Q1" + ); + } + + #[test] + fn load_model_rejects_bonsai_q1_without_runtime_support() { + let (dir, _result) = config_from_raw( + r#"{ + "model_type": "qwen3", + "quantization": {"bits": 1, "group_size": 128} + }"#, + ); + match load_model(dir.path()) { + Err(err) => assert!(err.to_string().contains("Bonsai-Q1 requires MLX bits=1")), + Ok(_) => panic!("Expected unsupported Bonsai-Q1 runtime error"), + } + } + #[test] fn load_tokenizer_missing_tokenizer_json() { let dir = tempfile::tempdir().unwrap(); diff --git a/crates/higgs-models/Cargo.toml b/crates/higgs-models/Cargo.toml index b2007b0a..47a370aa 100644 --- a/crates/higgs-models/Cargo.toml +++ b/crates/higgs-models/Cargo.toml @@ -11,8 +11,10 @@ homepage.workspace = true workspace = true [dependencies] +half = "2.4" mlx-rs.workspace = true mlx-sys.workspace = true +safetensors = "0.4" serde.workspace = true serde_json.workspace = true tokenizers.workspace = true diff --git a/crates/higgs-models/src/bonsai_q1.rs b/crates/higgs-models/src/bonsai_q1.rs new file mode 100644 index 00000000..93fe10e2 --- /dev/null +++ b/crates/higgs-models/src/bonsai_q1.rs @@ -0,0 +1,1204 @@ +//! Bonsai-Q1 target-capable engine: packed 1.25-bpw weight storage. +//! +//! Unlike `DiffusionEngine::load_q1` which dequantizes to fp32 at load (32 GB +//! residency on 8B), this engine holds MLX's `Q1_0_g128` affine encoding +//! verbatim: `w[row, col] = scales[row, col/128] * bit(col) + biases[row, +//! col/128]`. Dequant happens inline inside the MLX quantized matmul kernel +//! once upstream MLX provides bits=1 affine support. +//! +//! Residency: ~1.25 GB for Bonsai-8B-mlx-1bit, ~260 MB for Bonsai-1.7B-mlx-1bit. +//! +//! Scope: Rust-side loader and engine implementation. Runtime routing is held +//! back in `higgs-engine` until the upstream MLX dependency supports bits=1 +//! affine quantization. + +#![allow( + clippy::too_many_arguments, + clippy::too_many_lines, + // Quantization math uses small bounded dims (head_dim, GROUP_SIZE=128, vocab) and + // bit-packed u32→f32 conversions where precision/sign loss is intentional. + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_precision_loss, + clippy::cast_sign_loss, + clippy::as_conversions, + // Dequant kernel + safetensors loader index into manually-bounds-checked slices. + clippy::indexing_slicing, + // Decode loop reuses names (q, k, v, t0) across rope/sdpa/o_proj stages by design. + clippy::shadow_unrelated, + clippy::shadow_reuse, + clippy::shadow_same, + // Loader unwraps on safetensors slices after explicit shape validation; load failure + // paths return ShapeMismatch via map_err elsewhere. + clippy::unwrap_used, + clippy::map_unwrap_or, + // YarnRoPE / Q1 / KV abbreviations are domain terms, not items to backtick. + clippy::doc_markdown, + clippy::doc_lazy_continuation, + clippy::missing_const_for_fn, + clippy::manual_flatten, + clippy::if_then_some_else_none, + clippy::suboptimal_flops, +)] + +use half::f16; +use std::path::Path; + +use mlx_rs::{Array, Dtype, error::Exception, fast, ops, ops::indexing::IndexOp}; +use safetensors::SafeTensors; + +use crate::{ + cache::{KeyValueCache, SteppingKeyValueCache}, + error::ModelError, + utils::{cached_scaled_dot_product_attention, create_attention_mask}, + yarn::{apply_yarn_rope, compute_yarn_freqs, yarn_get_mscale}, +}; + +/// Load and materialize a Bonsai-Q1 model from `model_dir` onto the GPU. +/// +/// Adapts [`BonsaiQ1Engine::load`]'s `Result<_, String>` into [`ModelError`] so +/// the engine surface in `higgs-engine::model_loader` can route it through the +/// same `EngineError::Model` path used by all other architectures. +pub fn load_bonsai_q1>(model_dir: P) -> Result { + let engine = BonsaiQ1Engine::load(model_dir).map_err(ModelError::ShapeMismatch)?; + engine.to_gpu().map_err(ModelError::Mlx) +} + +pub const GROUP_SIZE: usize = 128; +const BITS: i32 = 1; +const GROUP_SIZE_I32: i32 = GROUP_SIZE as i32; + +/// Packed 1-bit linear layer with affine per-group dequant. +/// +/// Layout (matches MLX 1-bit `QuantizedLinear`, `PrismML` fork): +/// - `w_packed`: `[out_features, in_features/32]` u32, bit `col%32` of word +/// `col/32` is the raw 1-bit weight for column `col`. +/// - `scales`, `biases`: `[out_features, in_features/128]` f16, one per group +/// of 128 input columns. +/// +/// Effective: 1 bit/weight + 32 bits/group / 128 weights = **1.25 bpw**. +pub struct PackedQ1Linear { + pub w_packed: Vec, + pub scales: Vec, + pub biases: Vec, + pub out_features: usize, + pub in_features: usize, +} + +impl PackedQ1Linear { + pub const fn resident_bytes(&self) -> usize { + self.w_packed.len() * 4 + self.scales.len() * 2 + self.biases.len() * 2 + } + + /// Dequantize a single row to fp32 (reference path for correctness tests). + /// + /// Not used on the hot path — P2 replaces this with a Metal kernel that + /// fuses dequant into the matmul. + pub fn dequant_row_to_fp32(&self, row: usize, out: &mut [f32]) { + debug_assert_eq!(out.len(), self.in_features); + let n_groups = self.in_features / GROUP_SIZE; + let packed_cols = self.in_features / 32; + let w_row = &self.w_packed[row * packed_cols..(row + 1) * packed_cols]; + let s_row = &self.scales[row * n_groups..(row + 1) * n_groups]; + let b_row = &self.biases[row * n_groups..(row + 1) * n_groups]; + for col in 0..self.in_features { + let word = w_row[col / 32]; + let bit = ((word >> (col % 32)) & 1) as f32; + let group = col / GROUP_SIZE; + out[col] = s_row[group].to_f32().mul_add(bit, b_row[group].to_f32()); + } + } +} + +pub struct BonsaiQ1LayerWeights { + pub q_proj: PackedQ1Linear, + pub k_proj: PackedQ1Linear, + pub v_proj: PackedQ1Linear, + pub o_proj: PackedQ1Linear, + pub gate_proj: PackedQ1Linear, + pub up_proj: PackedQ1Linear, + pub down_proj: PackedQ1Linear, + pub q_norm: Vec, + pub k_norm: Vec, + pub input_norm: Vec, + pub post_attn_norm: Vec, +} + +impl BonsaiQ1LayerWeights { + pub fn resident_bytes(&self) -> usize { + self.q_proj.resident_bytes() + + self.k_proj.resident_bytes() + + self.v_proj.resident_bytes() + + self.o_proj.resident_bytes() + + self.gate_proj.resident_bytes() + + self.up_proj.resident_bytes() + + self.down_proj.resident_bytes() + + (self.q_norm.len() + + self.k_norm.len() + + self.input_norm.len() + + self.post_attn_norm.len()) + * 2 + } +} + +#[derive(Debug, Clone)] +pub struct BonsaiQ1Config { + pub hidden: usize, + pub layers: usize, + pub heads: usize, + pub kv_heads: usize, + pub head_dim: usize, + pub inter: usize, + pub vocab: usize, + pub rms_norm_eps: f32, + pub rope_theta: f64, + /// YARN scaling factor if present (Bonsai-8B uses `factor=4.0, original=16384`). + pub rope_yarn_factor: Option, + pub rope_original_max_seq: Option, + pub tie_word_embeddings: bool, +} + +pub struct BonsaiQ1Engine { + pub config: BonsaiQ1Config, + pub layers: Vec, + /// Token embedding stored packed (dequants inline at embed lookup time). + pub embed: PackedQ1Linear, + /// Untied LM head for 8B (`tie_word_embeddings: false`). None for 1.7B. + pub lm_head: Option, + pub final_norm: Vec, +} + +impl BonsaiQ1Engine { + pub const fn num_layers(&self) -> usize { + self.layers.len() + } + + pub fn resident_bytes(&self) -> usize { + let layer_bytes: usize = self + .layers + .iter() + .map(BonsaiQ1LayerWeights::resident_bytes) + .sum(); + let lm_head_bytes = self + .lm_head + .as_ref() + .map_or(0, PackedQ1Linear::resident_bytes); + layer_bytes + self.embed.resident_bytes() + lm_head_bytes + self.final_norm.len() * 2 + } + + /// Load from a `HuggingFace` directory containing `config.json` + + /// `model.safetensors` in MLX 1-bit affine-quant format. + #[allow(clippy::too_many_lines)] + pub fn load>(model_dir: P) -> Result { + let dir = model_dir.as_ref(); + + let cfg_txt = std::fs::read_to_string(dir.join("config.json")) + .map_err(|e| format!("config.json: {e}"))?; + let cfg: serde_json::Value = + serde_json::from_str(&cfg_txt).map_err(|e| format!("config.json parse: {e}"))?; + + let u64_of = |k: &str| -> Result { + cfg[k] + .as_u64() + .ok_or_else(|| format!("config.json missing u64 '{k}'")) + }; + let hidden = u64_of("hidden_size")? as usize; + let heads = u64_of("num_attention_heads")? as usize; + let kv_heads = u64_of("num_key_value_heads")? as usize; + let head_dim = cfg["head_dim"].as_u64().map_or(128, |v| v as usize); + let inter = u64_of("intermediate_size")? as usize; + let layers_n = u64_of("num_hidden_layers")? as usize; + let vocab = u64_of("vocab_size")? as usize; + + let rms_norm_eps = cfg["rms_norm_eps"].as_f64().unwrap_or(1e-6) as f32; + let rope_theta = cfg["rope_theta"].as_f64().unwrap_or(1_000_000.0); + let tie_word_embeddings = cfg["tie_word_embeddings"].as_bool().unwrap_or(false); + + let (rope_yarn_factor, rope_original_max_seq) = cfg + .get("rope_scaling") + .and_then(|rs| { + (rs.get("rope_type").and_then(|v| v.as_str()) == Some("yarn")).then(|| { + let f = rs.get("factor").and_then(serde_json::Value::as_f64); + let o = rs + .get("original_max_position_embeddings") + .and_then(serde_json::Value::as_u64) + .map(|v| v as usize); + (f, o) + }) + }) + .unwrap_or((None, None)); + + let quant = cfg + .get("quantization") + .ok_or("missing quantization block")?; + let q_bits = quant.get("bits").and_then(serde_json::Value::as_u64); + let q_group = quant.get("group_size").and_then(serde_json::Value::as_u64); + if q_bits != Some(1) || q_group != Some(GROUP_SIZE as u64) { + return Err(format!( + "expected quantization {{bits:1, group_size:{GROUP_SIZE}}}, got bits={q_bits:?} \ + group_size={q_group:?}" + )); + } + + let st_path = dir.join("model.safetensors"); + let st_data = std::fs::read(&st_path).map_err(|e| format!("read safetensors: {e}"))?; + let tensors = SafeTensors::deserialize(&st_data) + .map_err(|e| format!("deserialize safetensors: {e}"))?; + + let config = BonsaiQ1Config { + hidden, + layers: layers_n, + heads, + kv_heads, + head_dim, + inter, + vocab, + rms_norm_eps, + rope_theta, + rope_yarn_factor, + rope_original_max_seq, + tie_word_embeddings, + }; + + let q_dim = heads * head_dim; + let kv_dim = kv_heads * head_dim; + + let embed = load_packed( + &tensors, + "model.embed_tokens", + vocab, + hidden, + "embed_tokens", + )?; + let lm_head = if tie_word_embeddings { + None + } else { + Some(load_packed(&tensors, "lm_head", vocab, hidden, "lm_head")?) + }; + let final_norm = load_f16(&tensors, "model.norm.weight")?; + if final_norm.len() != hidden { + return Err(format!( + "final_norm len {} != hidden {hidden}", + final_norm.len() + )); + } + + let mut layers = Vec::with_capacity(layers_n); + for i in 0..layers_n { + let p = format!("model.layers.{i}"); + let attn = format!("{p}.self_attn"); + let mlp = format!("{p}.mlp"); + + let layer = BonsaiQ1LayerWeights { + q_proj: load_packed(&tensors, &format!("{attn}.q_proj"), q_dim, hidden, "q_proj")?, + k_proj: load_packed( + &tensors, + &format!("{attn}.k_proj"), + kv_dim, + hidden, + "k_proj", + )?, + v_proj: load_packed( + &tensors, + &format!("{attn}.v_proj"), + kv_dim, + hidden, + "v_proj", + )?, + o_proj: load_packed(&tensors, &format!("{attn}.o_proj"), hidden, q_dim, "o_proj")?, + gate_proj: load_packed( + &tensors, + &format!("{mlp}.gate_proj"), + inter, + hidden, + "gate_proj", + )?, + up_proj: load_packed( + &tensors, + &format!("{mlp}.up_proj"), + inter, + hidden, + "up_proj", + )?, + down_proj: load_packed( + &tensors, + &format!("{mlp}.down_proj"), + hidden, + inter, + "down_proj", + )?, + q_norm: load_f16(&tensors, &format!("{attn}.q_norm.weight"))?, + k_norm: load_f16(&tensors, &format!("{attn}.k_norm.weight"))?, + input_norm: load_f16(&tensors, &format!("{p}.input_layernorm.weight"))?, + post_attn_norm: load_f16( + &tensors, + &format!("{p}.post_attention_layernorm.weight"), + )?, + }; + layers.push(layer); + } + + let engine = Self { + config, + layers, + embed, + lm_head, + final_norm, + }; + let resident_mb = engine.resident_bytes() as f64 / (1024.0 * 1024.0); + tracing::info!( + layers = engine.config.layers, + hidden = engine.config.hidden, + heads = engine.config.heads, + kv_heads = engine.config.kv_heads, + head_dim = engine.config.head_dim, + inter = engine.config.inter, + vocab = engine.config.vocab, + tied_embed = engine.config.tie_word_embeddings, + packed_resident_mb = format!("{resident_mb:.1}"), + "BonsaiQ1Engine::load", + ); + Ok(engine) + } +} + +// --------------------------------------------------------------------------- +// GPU-ready mirror — built once from the packed engine. +// --------------------------------------------------------------------------- + +/// MLX-resident 1-bit linear: weight as uint32 packed, scales/biases as f16, +/// same shape as `PackedQ1Linear` but ready for `ops::quantized_matmul`. +pub struct BonsaiQ1GpuLinear { + pub w: Array, + pub scales: Array, + pub biases: Array, + pub out_features: i32, + pub in_features: i32, +} + +impl BonsaiQ1GpuLinear { + fn from_packed(p: &PackedQ1Linear) -> Result { + let out = i32::try_from(p.out_features) + .map_err(|_| Exception::custom("out_features overflows i32"))?; + let inf = i32::try_from(p.in_features) + .map_err(|_| Exception::custom("in_features overflows i32"))?; + let packed_cols = inf / 32; + let n_groups = inf / GROUP_SIZE_I32; + + let w = Array::from_slice(&p.w_packed, &[out, packed_cols]); + let scales_f32: Vec = p.scales.iter().map(|h| h.to_f32()).collect(); + let biases_f32: Vec = p.biases.iter().map(|h| h.to_f32()).collect(); + let scales = Array::from_slice(&scales_f32, &[out, n_groups]).as_dtype(Dtype::Float16)?; + let biases = Array::from_slice(&biases_f32, &[out, n_groups]).as_dtype(Dtype::Float16)?; + + Ok(Self { + w, + scales, + biases, + out_features: out, + in_features: inf, + }) + } + + /// `y = x @ dequant(w, scales, biases).T` via fused bits=1 qmm. + pub fn forward(&self, x: &Array) -> Result { + ops::quantized_matmul( + x, + &self.w, + &self.scales, + &self.biases, + true, + GROUP_SIZE_I32, + BITS, + ) + } +} + +pub struct BonsaiQ1GpuLayer { + pub q_proj: BonsaiQ1GpuLinear, + pub k_proj: BonsaiQ1GpuLinear, + pub v_proj: BonsaiQ1GpuLinear, + pub o_proj: BonsaiQ1GpuLinear, + pub gate_proj: BonsaiQ1GpuLinear, + pub up_proj: BonsaiQ1GpuLinear, + pub down_proj: BonsaiQ1GpuLinear, + pub q_norm: Array, + pub k_norm: Array, + pub input_norm: Array, + pub post_attn_norm: Array, +} + +pub struct BonsaiQ1Gpu { + pub config: BonsaiQ1Config, + pub layers: Vec, + pub embed: BonsaiQ1GpuLinear, + pub lm_head: Option, + pub final_norm: Array, + /// YARN-scaled `RoPE` frequencies (per `head_dim/2`). None if no YARN. + pub yarn_freqs: Option, + pub yarn_mscale: f32, + pub attention_scale: f32, +} + +fn f16_vec_to_array(weights: &[f16]) -> Result { + let f32s: Vec = weights.iter().map(|h| h.to_f32()).collect(); + let len = + i32::try_from(weights.len()).map_err(|_| Exception::custom("norm len overflows i32"))?; + Array::from_slice(&f32s, &[len]).as_dtype(Dtype::Float16) +} + +impl BonsaiQ1Engine { + /// Consume the packed engine and materialize MLX arrays. + /// + /// Frees the `Vec` / `Vec` residency once copied to MLX. + pub fn to_gpu(self) -> Result { + let mut gpu_layers = Vec::with_capacity(self.layers.len()); + for layer in &self.layers { + gpu_layers.push(BonsaiQ1GpuLayer { + q_proj: BonsaiQ1GpuLinear::from_packed(&layer.q_proj)?, + k_proj: BonsaiQ1GpuLinear::from_packed(&layer.k_proj)?, + v_proj: BonsaiQ1GpuLinear::from_packed(&layer.v_proj)?, + o_proj: BonsaiQ1GpuLinear::from_packed(&layer.o_proj)?, + gate_proj: BonsaiQ1GpuLinear::from_packed(&layer.gate_proj)?, + up_proj: BonsaiQ1GpuLinear::from_packed(&layer.up_proj)?, + down_proj: BonsaiQ1GpuLinear::from_packed(&layer.down_proj)?, + q_norm: f16_vec_to_array(&layer.q_norm)?, + k_norm: f16_vec_to_array(&layer.k_norm)?, + input_norm: f16_vec_to_array(&layer.input_norm)?, + post_attn_norm: f16_vec_to_array(&layer.post_attn_norm)?, + }); + } + + let embed = BonsaiQ1GpuLinear::from_packed(&self.embed)?; + let lm_head = self + .lm_head + .as_ref() + .map(BonsaiQ1GpuLinear::from_packed) + .transpose()?; + let final_norm = f16_vec_to_array(&self.final_norm)?; + + // YARN precompute. + let head_dim_i = i32::try_from(self.config.head_dim) + .map_err(|_| Exception::custom("head_dim overflows i32"))?; + let base = self.config.rope_theta as f32; + let (yarn_freqs, yarn_mscale) = match self.config.rope_yarn_factor { + Some(factor) if factor > 1.0 => { + let orig_seq = self.config.rope_original_max_seq.ok_or_else(|| { + Exception::custom( + "rope_yarn_factor > 1.0 requires \ + rope_scaling.original_max_position_embeddings", + ) + })?; + let orig = i32::try_from(orig_seq) + .map_err(|_| Exception::custom("orig_max_seq overflows i32"))?; + let factor_f = factor as f32; + let freqs = compute_yarn_freqs(head_dim_i, base, factor_f, orig, 32.0, 1.0); + (Some(freqs), yarn_get_mscale(factor_f, 1.0)) + } + _ => (None, 1.0), + }; + + let head_dim_f = head_dim_i as f32; + let attention_scale = head_dim_f.sqrt().recip(); + + Ok(BonsaiQ1Gpu { + config: self.config, + layers: gpu_layers, + embed, + lm_head, + final_norm, + yarn_freqs, + yarn_mscale, + attention_scale, + }) + } +} + +impl BonsaiQ1Gpu { + pub fn num_layers(&self) -> usize { + self.layers.len() + } + + /// Gather embedding rows for a token-ID tensor. + /// + /// Uses MLX dequantize after gathering the selected packed rows. This path + /// requires bits=1 affine support in the active MLX runtime. + fn embed_rows(&self, ids: &Array) -> Result { + let shape = ids.shape().to_vec(); + let flat = ids.flatten(None, None)?; + let w = self.embed.w.take_axis(&flat, 0)?; + let s = self.embed.scales.take_axis(&flat, 0)?; + let b = self.embed.biases.take_axis(&flat, 0)?; + let out = ops::dequantize(&w, &s, &b, GROUP_SIZE_I32, BITS)?; + let mut ret_shape: Vec = shape; + ret_shape.push(-1); + out.reshape(&ret_shape) + } + + fn apply_rope(&self, x: &Array, offset: i32) -> Result { + let head_dim = i32::try_from(self.config.head_dim) + .map_err(|_| Exception::custom("head_dim overflows i32"))?; + let offset_array = Array::from_int(offset); + apply_yarn_rope( + x, + head_dim, + self.config.rope_theta as f32, + self.yarn_freqs.as_ref(), + self.yarn_mscale, + &offset_array, + false, // Qwen3 layout + ) + } + + /// Run the decoder trunk and return final-normed hidden `[B, T, hidden]`. + /// Shared body for `forward` (last-position logits) and + /// `forward_all_logits` (all-position logits, used by spec-decode verify). + /// + /// Body lives in [`forward_trunk_free`] so `compile_with_state` can wrap + /// it via a free-fn pointer (the `Copy + 'static` closure constraint + /// forbids capturing `&self`). + fn forward_trunk( + &self, + inputs: &Array, + cache: &mut Vec>, + ) -> Result { + forward_trunk_free(self, cache, inputs) + } + + /// Apply LM head (or tied embed) to `[B, T, hidden]` → `[B, T, vocab]`. + fn project_logits(&self, h: &Array) -> Result { + let logits = match &self.lm_head { + Some(head) => head.forward(h)?, + None => self.embed.forward(h)?, + }; + // Logits are returned as f32 by API contract (callers do as_slice:: + // for argmax / softmax). The trunk now stays in fp16 throughout (after + // the apply_yarn_rope dtype fix), so we cast here at the boundary. + logits.as_dtype(Dtype::Float32) + } + + /// Causal forward. Returns logits `[B, 1, vocab]` for the last position + /// (mlx_lm convention). + pub fn forward( + &self, + inputs: &Array, + cache: &mut Vec>, + ) -> Result { + let h = self.forward_trunk(inputs, cache)?; + let t = *h + .shape() + .get(1) + .ok_or_else(|| Exception::custom("trunk hidden missing T dim"))?; + let last = if t > 1 { h.index((.., -1.., ..)) } else { h }; + self.project_logits(&last) + } + + /// Causal forward returning logits at **every** position `[B, T, vocab]`. + /// Used by speculative-decode target verify: given the draft prefix, + /// obtain one logits row per proposed token in a single forward pass. + pub fn forward_all_logits( + &self, + inputs: &Array, + cache: &mut Vec>, + ) -> Result { + let h = self.forward_trunk(inputs, cache)?; + self.project_logits(&h) + } + + /// Profiled variant of `forward`: same result, but attributes per-section + /// wall time into `times`. Forces `.eval()` after every section (kills + /// lazy batching — that's the point: ratios matter, absolutes don't). + /// + /// Used by `bench_bonsai_q1_decode_breakdown` to answer the + /// dispatch-bound-vs-matmul-bound question for Bonsai-8B AR parity. + pub fn forward_profiled( + &self, + inputs: &Array, + cache: &mut Vec>, + times: &mut SectionTimes, + ) -> Result { + let h = self.forward_trunk_profiled(inputs, cache, times)?; + let t0 = std::time::Instant::now(); + let t = *h + .shape() + .get(1) + .ok_or_else(|| Exception::custom("trunk hidden missing T dim"))?; + let last = if t > 1 { h.index((.., -1.., ..)) } else { h }; + let logits = self.project_logits(&last)?; + logits.eval()?; + times.add("lm_head", t0.elapsed().as_nanos()); + Ok(logits) + } + + /// Profiled mirror of `forward_trunk`. Inserts `eval + record` at each + /// semantic section boundary. Sections are grouped by operation type + /// (qkv projections together, mlp up+gate together, etc.) — per-layer + /// noise is collapsed into section totals across all layers. + #[allow(non_snake_case)] + fn forward_trunk_profiled( + &self, + inputs: &Array, + cache: &mut Vec>, + times: &mut SectionTimes, + ) -> Result { + use std::time::Instant; + + let shape = inputs.shape(); + let B = *shape + .first() + .ok_or_else(|| Exception::custom("inputs must have >= 2 dims"))?; + let T = *shape + .get(1) + .ok_or_else(|| Exception::custom("inputs must have >= 2 dims"))?; + + if cache.is_empty() { + *cache = (0..self.layers.len()) + .map(|_| Some(SteppingKeyValueCache::new())) + .collect(); + } else if cache.len() != self.layers.len() { + return Err(Exception::custom(format!( + "cache len {} != num_layers {}", + cache.len(), + self.layers.len() + ))); + } + + // Sync point: make sure prior work isn't folded into embed_rows time. + inputs.eval()?; + + let t0 = Instant::now(); + let mut h = self.embed_rows(inputs)?; + h.eval()?; + times.add("embed_rows", t0.elapsed().as_nanos()); + + let mask = create_attention_mask(&h, cache, None)?; + + let heads = i32::try_from(self.config.heads) + .map_err(|_| Exception::custom("heads overflows i32"))?; + let kv_heads = i32::try_from(self.config.kv_heads) + .map_err(|_| Exception::custom("kv_heads overflows i32"))?; + let rms_eps = self.config.rms_norm_eps; + + for (layer, layer_cache) in self.layers.iter().zip(cache.iter_mut()) { + let t0 = Instant::now(); + let normed = fast::rms_norm(&h, &layer.input_norm, rms_eps)?; + normed.eval()?; + times.add("input_norm", t0.elapsed().as_nanos()); + + // qkv projections — 3× quantized_matmul on the same input. + let t0 = Instant::now(); + let q = layer.q_proj.forward(&normed)?; + let k = layer.k_proj.forward(&normed)?; + let v = layer.v_proj.forward(&normed)?; + q.eval()?; + k.eval()?; + v.eval()?; + times.add("qkv_proj", t0.elapsed().as_nanos()); + + // Reshape to [B, L, n_heads, head_dim] then transpose to + // [B, n_heads, L, head_dim]. Metadata-only; lumped with qk_norm. + let q = q + .reshape(&[B, T, heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let k = k + .reshape(&[B, T, kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let v = v + .reshape(&[B, T, kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + + let t0 = Instant::now(); + let q = fast::rms_norm(&q, &layer.q_norm, rms_eps)?; + let k = fast::rms_norm(&k, &layer.k_norm, rms_eps)?; + q.eval()?; + k.eval()?; + times.add("qk_norm", t0.elapsed().as_nanos()); + + let offset = layer_cache.as_ref().map_or(0, KeyValueCache::offset); + let t0 = Instant::now(); + let q = self.apply_rope(&q, offset)?; + let k = self.apply_rope(&k, offset)?; + q.eval()?; + k.eval()?; + times.add("rope", t0.elapsed().as_nanos()); + + let mask_arr = match &mask { + Some(crate::utils::AttentionMask::Array(a)) => Some(a), + _ => None, + }; + let mask_arr_opt: Option<&Array> = mask_arr; + + let t0 = Instant::now(); + let attn_out = match layer_cache.as_mut() { + Some(c) => cached_scaled_dot_product_attention( + q, + c, + k, + v, + self.attention_scale, + mask_arr_opt, + )?, + None => fast::scaled_dot_product_attention( + q, + k, + v, + self.attention_scale, + mask_arr_opt.map(mlx_rs::fast::ScaledDotProductAttentionMask::Array), + None::<&Array>, + )?, + }; + attn_out.eval()?; + times.add("sdpa_kv", t0.elapsed().as_nanos()); + + let attn_out = attn_out + .transpose_axes(&[0, 2, 1, 3])? + .reshape(&[B, T, -1])?; + + let t0 = Instant::now(); + let attn_out = layer.o_proj.forward(&attn_out)?; + attn_out.eval()?; + times.add("o_proj", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + let h_post_attn = h.add(&attn_out)?; + h_post_attn.eval()?; + times.add("residual", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + let normed_post = fast::rms_norm(&h_post_attn, &layer.post_attn_norm, rms_eps)?; + normed_post.eval()?; + times.add("post_attn_norm", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + let gate = layer.gate_proj.forward(&normed_post)?; + let up = layer.up_proj.forward(&normed_post)?; + gate.eval()?; + up.eval()?; + times.add("mlp_up_gate", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + let mlp_hidden = mlx_rs::nn::silu(&gate)?.multiply(&up)?; + mlp_hidden.eval()?; + times.add("silu_mul", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + let mlp_out = layer.down_proj.forward(&mlp_hidden)?; + mlp_out.eval()?; + times.add("mlp_down", t0.elapsed().as_nanos()); + + let t0 = Instant::now(); + h = h_post_attn.add(&mlp_out)?; + h.eval()?; + times.add("residual", t0.elapsed().as_nanos()); + } + + let t0 = Instant::now(); + let out = fast::rms_norm(&h, &self.final_norm, rms_eps)?; + out.eval()?; + times.add("final_norm", t0.elapsed().as_nanos()); + Ok(out) + } +} + +/// Free-function body of the decoder trunk. +/// +/// Lives at module scope (not as a method) so a **function pointer** to +/// [`decode_step_free`] satisfies `compile_with_state`'s +/// `F: Copy + 'static` bound — a closure capturing `&self` would not. +/// All `self.xxx` access is replaced with `gpu.xxx`; `embed_rows`, +/// `apply_rope`, and `project_logits` are called as methods on `gpu` +/// (they are already `&self`-only, so no further plumbing is needed). +#[allow(non_snake_case)] +pub fn forward_trunk_free( + gpu: &BonsaiQ1Gpu, + cache: &mut Vec>, + inputs: &Array, +) -> Result { + let shape = inputs.shape(); + let B = *shape + .first() + .ok_or_else(|| Exception::custom("inputs must have >= 2 dims"))?; + let T = *shape + .get(1) + .ok_or_else(|| Exception::custom("inputs must have >= 2 dims"))?; + + if cache.is_empty() { + *cache = (0..gpu.layers.len()) + .map(|_| Some(SteppingKeyValueCache::new())) + .collect(); + } else if cache.len() != gpu.layers.len() { + return Err(Exception::custom(format!( + "cache len {} != num_layers {}", + cache.len(), + gpu.layers.len() + ))); + } + + let mut h = gpu.embed_rows(inputs)?; // [B, L, hidden] + + let mask = create_attention_mask(&h, cache, None)?; + + let heads = + i32::try_from(gpu.config.heads).map_err(|_| Exception::custom("heads overflows i32"))?; + let kv_heads = i32::try_from(gpu.config.kv_heads) + .map_err(|_| Exception::custom("kv_heads overflows i32"))?; + let rms_eps = gpu.config.rms_norm_eps; + + for (layer, layer_cache) in gpu.layers.iter().zip(cache.iter_mut()) { + let normed = fast::rms_norm(&h, &layer.input_norm, rms_eps)?; + + let q = layer.q_proj.forward(&normed)?; + let k = layer.k_proj.forward(&normed)?; + let v = layer.v_proj.forward(&normed)?; + + let q = q + .reshape(&[B, T, heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let k = k + .reshape(&[B, T, kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + let v = v + .reshape(&[B, T, kv_heads, -1])? + .transpose_axes(&[0, 2, 1, 3])?; + + let q = fast::rms_norm(&q, &layer.q_norm, rms_eps)?; + let k = fast::rms_norm(&k, &layer.k_norm, rms_eps)?; + + let offset = layer_cache.as_ref().map_or(0, KeyValueCache::offset); + let q = gpu.apply_rope(&q, offset)?; + let k = gpu.apply_rope(&k, offset)?; + + let mask_arr = match &mask { + Some(crate::utils::AttentionMask::Array(a)) => Some(a), + _ => None, + }; + let mask_arr_opt: Option<&Array> = mask_arr; + + let attn_out = match layer_cache.as_mut() { + Some(c) => { + cached_scaled_dot_product_attention(q, c, k, v, gpu.attention_scale, mask_arr_opt)? + } + None => fast::scaled_dot_product_attention( + q, + k, + v, + gpu.attention_scale, + mask_arr_opt.map(mlx_rs::fast::ScaledDotProductAttentionMask::Array), + None::<&Array>, + )?, + }; + + let attn_out = attn_out + .transpose_axes(&[0, 2, 1, 3])? + .reshape(&[B, T, -1])?; + let attn_out = layer.o_proj.forward(&attn_out)?; + let h_post_attn = h.add(&attn_out)?; + + let normed_post = fast::rms_norm(&h_post_attn, &layer.post_attn_norm, rms_eps)?; + let gate = layer.gate_proj.forward(&normed_post)?; + let up = layer.up_proj.forward(&normed_post)?; + let mlp_hidden = mlx_rs::nn::silu(&gate)?.multiply(&up)?; + let mlp_out = layer.down_proj.forward(&mlp_hidden)?; + + h = h_post_attn.add(&mlp_out)?; + } + + fast::rms_norm(&h, &gpu.final_norm, rms_eps) +} + +/// Owned state wrapper for [`compile_with_state`]-driven decoding. +/// +/// `compile_with_state` takes the state by `&mut U` where `U: Updatable`. +/// Wrapping the model **and** the per-layer KV cache in one owned struct +/// lets us hand-roll a single `Updatable` impl whose positional iteration +/// order covers both — safer than fighting lifetimes on `(&mut gpu, cache)` +/// tuples. See session-25 recap for the design rationale. +/// +/// Expected construction: after prefill, move `gpu` and the filled cache +/// vector into this struct, run the decode loop with a compiled step, +/// then destructure back out when done. +pub struct BonsaiQ1DecodeState { + pub gpu: BonsaiQ1Gpu, + pub cache: Vec>, +} + +/// Number of updatable `Array`s per decoder layer: +/// - `input_norm` + 3×(w,s,b) qkv + `q_norm` + `k_norm` + 3×(w,s,b) o_proj +/// ... wait: 1 + 3×3 + 2 + 3 + 1 + 3×3 = 1+9+2+3+1+9 = **25**. +/// Corresponds to the array push order in [`BonsaiQ1DecodeState::updatable_states`]. +const PER_LAYER_UPDATABLE: usize = 25; + +impl mlx_rs::utils::Updatable for BonsaiQ1DecodeState { + fn updatable_states_len(&self) -> usize { + let mut n = 3 // embed (w, scales, biases) + + self.gpu.layers.len() * PER_LAYER_UPDATABLE + + 1; // final_norm + if self.gpu.lm_head.is_some() { + n += 3; + } + if self.gpu.yarn_freqs.is_some() { + n += 1; + } + for slot in &self.cache { + if let Some(c) = slot { + if c.keys().is_some() { + n += 1; + } + if c.values().is_some() { + n += 1; + } + } + } + n + } + + fn updatable_states(&self) -> impl IntoIterator { + let mut v: Vec<&Array> = Vec::with_capacity(self.updatable_states_len()); + v.push(&self.gpu.embed.w); + v.push(&self.gpu.embed.scales); + v.push(&self.gpu.embed.biases); + for layer in &self.gpu.layers { + v.push(&layer.input_norm); + v.push(&layer.q_proj.w); + v.push(&layer.q_proj.scales); + v.push(&layer.q_proj.biases); + v.push(&layer.k_proj.w); + v.push(&layer.k_proj.scales); + v.push(&layer.k_proj.biases); + v.push(&layer.v_proj.w); + v.push(&layer.v_proj.scales); + v.push(&layer.v_proj.biases); + v.push(&layer.q_norm); + v.push(&layer.k_norm); + v.push(&layer.o_proj.w); + v.push(&layer.o_proj.scales); + v.push(&layer.o_proj.biases); + v.push(&layer.post_attn_norm); + v.push(&layer.gate_proj.w); + v.push(&layer.gate_proj.scales); + v.push(&layer.gate_proj.biases); + v.push(&layer.up_proj.w); + v.push(&layer.up_proj.scales); + v.push(&layer.up_proj.biases); + v.push(&layer.down_proj.w); + v.push(&layer.down_proj.scales); + v.push(&layer.down_proj.biases); + } + v.push(&self.gpu.final_norm); + if let Some(lm) = self.gpu.lm_head.as_ref() { + v.push(&lm.w); + v.push(&lm.scales); + v.push(&lm.biases); + } + if let Some(y) = self.gpu.yarn_freqs.as_ref() { + v.push(y); + } + for slot in &self.cache { + if let Some(c) = slot { + if let Some(k) = c.keys() { + v.push(k); + } + if let Some(val) = c.values() { + v.push(val); + } + } + } + v + } + + fn updatable_states_mut(&mut self) -> impl IntoIterator { + let mut v: Vec<&mut Array> = Vec::with_capacity(self.updatable_states_len()); + v.push(&mut self.gpu.embed.w); + v.push(&mut self.gpu.embed.scales); + v.push(&mut self.gpu.embed.biases); + for layer in &mut self.gpu.layers { + v.push(&mut layer.input_norm); + v.push(&mut layer.q_proj.w); + v.push(&mut layer.q_proj.scales); + v.push(&mut layer.q_proj.biases); + v.push(&mut layer.k_proj.w); + v.push(&mut layer.k_proj.scales); + v.push(&mut layer.k_proj.biases); + v.push(&mut layer.v_proj.w); + v.push(&mut layer.v_proj.scales); + v.push(&mut layer.v_proj.biases); + v.push(&mut layer.q_norm); + v.push(&mut layer.k_norm); + v.push(&mut layer.o_proj.w); + v.push(&mut layer.o_proj.scales); + v.push(&mut layer.o_proj.biases); + v.push(&mut layer.post_attn_norm); + v.push(&mut layer.gate_proj.w); + v.push(&mut layer.gate_proj.scales); + v.push(&mut layer.gate_proj.biases); + v.push(&mut layer.up_proj.w); + v.push(&mut layer.up_proj.scales); + v.push(&mut layer.up_proj.biases); + v.push(&mut layer.down_proj.w); + v.push(&mut layer.down_proj.scales); + v.push(&mut layer.down_proj.biases); + } + v.push(&mut self.gpu.final_norm); + if let Some(lm) = self.gpu.lm_head.as_mut() { + v.push(&mut lm.w); + v.push(&mut lm.scales); + v.push(&mut lm.biases); + } + if let Some(y) = self.gpu.yarn_freqs.as_mut() { + v.push(y); + } + for slot in &mut self.cache { + if let Some(c) = slot { + let (k_opt, v_opt) = c.key_value_arrays_mut(); + if let Some(k) = k_opt { + v.push(k); + } + if let Some(val) = v_opt { + v.push(val); + } + } + } + v + } +} + +/// Free-fn decode step compatible with `compile_with_state`. +/// +/// `state.cache` **must** be populated by a prefill call before this runs: +/// compile-wrap is applied only in decode, and shape consistency across +/// steps (for the MLX per-shape trace cache) requires +/// [`SteppingKeyValueCache::reserve_max_tokens`] ahead of the first +/// `update_dense`. +pub fn decode_step_free( + state: &mut BonsaiQ1DecodeState, + inputs: &Array, +) -> Result { + let h = forward_trunk_free(&state.gpu, &mut state.cache, inputs)?; + let t = *h + .shape() + .get(1) + .ok_or_else(|| Exception::custom("trunk hidden missing T dim"))?; + let last = if t > 1 { h.index((.., -1.., ..)) } else { h }; + state.gpu.project_logits(&last) +} + +/// Per-section wall-time accumulator for the Bonsai-Q1 forward pass. +/// +/// Exists only to attribute the 45 ms/tok Bonsai-8B AR decode cost to +/// individual sections (embed / norms / qkv / rope / sdpa / o_proj / mlp / lm_head). +/// Each section's compute is force-`.eval()`'d to prevent MLX lazy batching +/// from pooling multiple sections into one materialization — ratios between +/// sections are meaningful even though absolutes will be slower than the +/// unprofiled path. +#[derive(Debug, Default, Clone)] +pub struct SectionTimes { + totals: std::collections::BTreeMap<&'static str, (u128, u64)>, +} + +impl SectionTimes { + pub fn new() -> Self { + Self::default() + } + + pub fn add(&mut self, name: &'static str, ns: u128) { + let e = self.totals.entry(name).or_insert((0, 0)); + e.0 += ns; + e.1 += 1; + } + + /// Total across all sections (ns). + pub fn total_ns(&self) -> u128 { + self.totals.values().map(|(t, _)| *t).sum() + } + + /// Section totals: `(name, total_ns, call_count)`, sorted by ns descending. + pub fn entries(&self) -> Vec<(&'static str, u128, u64)> { + let mut v: Vec<_> = self.totals.iter().map(|(k, (t, n))| (*k, *t, *n)).collect(); + v.sort_by_key(|b| std::cmp::Reverse(b.1)); + v + } +} + +fn load_packed( + tensors: &SafeTensors<'_>, + prefix: &str, + out_features: usize, + in_features: usize, + who: &str, +) -> Result { + if in_features % GROUP_SIZE != 0 { + return Err(format!( + "{who}: in_features {in_features} not divisible by group_size {GROUP_SIZE}" + )); + } + let packed_cols = in_features / 32; + let n_groups = in_features / GROUP_SIZE; + + let w_view = tensors + .tensor(&format!("{prefix}.weight")) + .map_err(|e| format!("{who}: {prefix}.weight: {e}"))?; + let s_view = tensors + .tensor(&format!("{prefix}.scales")) + .map_err(|e| format!("{who}: {prefix}.scales: {e}"))?; + let b_view = tensors + .tensor(&format!("{prefix}.biases")) + .map_err(|e| format!("{who}: {prefix}.biases: {e}"))?; + + let w_bytes = w_view.data(); + let s_bytes = s_view.data(); + let b_bytes = b_view.data(); + + let expected_w_bytes = out_features * packed_cols * 4; + if w_bytes.len() != expected_w_bytes { + return Err(format!( + "{who}: weight byte-size mismatch: got {} expected {}", + w_bytes.len(), + expected_w_bytes, + )); + } + let expected_sb_bytes = out_features * n_groups * 2; + if s_bytes.len() != expected_sb_bytes { + return Err(format!( + "{who}: scales byte-size mismatch: got {} expected {}", + s_bytes.len(), + expected_sb_bytes, + )); + } + if b_bytes.len() != expected_sb_bytes { + return Err(format!( + "{who}: biases byte-size mismatch: got {} expected {}", + b_bytes.len(), + expected_sb_bytes, + )); + } + + let w_packed: Vec = w_bytes + .chunks_exact(4) + .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + let scales = bytes_to_f16_vec(s_bytes); + let biases = bytes_to_f16_vec(b_bytes); + + Ok(PackedQ1Linear { + w_packed, + scales, + biases, + out_features, + in_features, + }) +} + +fn load_f16(tensors: &SafeTensors<'_>, name: &str) -> Result, String> { + let view = tensors.tensor(name).map_err(|e| format!("{name}: {e}"))?; + Ok(bytes_to_f16_vec(view.data())) +} + +fn bytes_to_f16_vec(b: &[u8]) -> Vec { + b.chunks_exact(2) + .map(|c| f16::from_bits(u16::from_le_bytes([c[0], c[1]]))) + .collect() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- diff --git a/crates/higgs-models/src/cache.rs b/crates/higgs-models/src/cache.rs index 79eabd2a..14f5bed8 100644 --- a/crates/higgs-models/src/cache.rs +++ b/crates/higgs-models/src/cache.rs @@ -453,6 +453,14 @@ impl SteppingKeyValueCache { self.values.as_ref() } + /// Simultaneous mutable access to the key and value arrays. + /// + /// Re-borrows both optional fields from a single `&mut` split to satisfy + /// the borrow checker when both must yield from one iterator. + pub const fn key_value_arrays_mut(&mut self) -> (Option<&mut Array>, Option<&mut Array>) { + (self.keys.as_mut(), self.values.as_mut()) + } + /// Create a pre-filled cache from existing K/V arrays. /// /// Sets `offset = keys.shape()[2]` so the next `update_dense` triggers a diff --git a/crates/higgs-models/src/deepseek_v2.rs b/crates/higgs-models/src/deepseek_v2.rs index e3b09cad..625264fc 100644 --- a/crates/higgs-models/src/deepseek_v2.rs +++ b/crates/higgs-models/src/deepseek_v2.rs @@ -120,9 +120,14 @@ impl DeepSeekV2ModelArgs { // YaRN RoPE helpers // --------------------------------------------------------------------------- +#[allow( + clippy::as_conversions, + clippy::cast_precision_loss, + clippy::cast_possible_truncation +)] fn yarn_find_correction_dim(num_rotations: f32, dim: i32, base: f32, max_pos: i32) -> f32 { - let dim_f = f32::from(i16::try_from(dim).unwrap_or(i16::MAX)); - let max_pos_f = f32::from(i16::try_from(max_pos).unwrap_or(i16::MAX)); + let dim_f = dim as f32; + let max_pos_f = max_pos as f32; (dim_f * (max_pos_f / (num_rotations * 2.0 * PI)).ln()) / (2.0 * base.ln()) } @@ -154,6 +159,7 @@ fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 { /// Precompute `YaRN`-interpolated `RoPE` frequencies. #[allow( clippy::as_conversions, + clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::indexing_slicing @@ -167,14 +173,14 @@ fn compute_yarn_freqs( beta_slow: f32, ) -> Array { let half_dim = dim / 2; - let dim_f = f32::from(i16::try_from(dim).unwrap_or(i16::MAX)); + let dim_f = dim as f32; // freq_extra = base^(arange(0, dim, 2) / dim) -- standard theta // freq_inter = scaling_factor * freq_extra -- extended theta let mut freq_extra = Vec::with_capacity(half_dim as usize); let mut freq_inter = Vec::with_capacity(half_dim as usize); for i in 0..half_dim { - let exp = f32::from(i16::try_from(2 * i).unwrap_or(0)) / dim_f; + let exp = (2 * i) as f32 / dim_f; let theta = base.powf(exp); freq_extra.push(theta); freq_inter.push(scaling_factor * theta); @@ -183,8 +189,8 @@ fn compute_yarn_freqs( let (low, high) = yarn_find_correction_range(beta_fast, beta_slow, dim, base, orig_max_pos); // Linear ramp mask: 0 at low, 1 at high - let low_f = f32::from(i16::try_from(low).unwrap_or(0)); - let high_f = f32::from(i16::try_from(high).unwrap_or(0)); + let low_f = low as f32; + let high_f = high as f32; let range = if (high_f - low_f).abs() < 0.001 { high_f - low_f + 0.001 } else { @@ -195,7 +201,7 @@ fn compute_yarn_freqs( // freq_mask = 1 - ramp (high mask = use freq_extra, low mask = use freq_inter) let mut freqs = Vec::with_capacity(half_dim as usize); for i in 0..half_dim as usize { - let idx_f = f32::from(i16::try_from(i).unwrap_or(0)); + let idx_f = i as f32; let ramp = ((idx_f - low_f) / range).clamp(0.0, 1.0); let mask = 1.0 - ramp; let inter = freq_inter[i]; @@ -300,6 +306,7 @@ impl DeepSeekV2Attention { // YaRN RoPE #[allow( clippy::as_conversions, + clippy::cast_precision_loss, clippy::cast_possible_truncation, clippy::option_if_let_else )] diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 6aec6383..fe52d245 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -1,3 +1,4 @@ +pub mod bonsai_q1; pub mod cache; pub mod deepseek_v2; pub mod error; @@ -13,9 +14,11 @@ pub mod starcoder2; pub mod transformer; pub mod turboquant; pub mod utils; +pub mod yarn; use std::collections::{HashMap, HashSet}; use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; use mlx_rs::module::ModuleParametersExt; use mlx_rs::ops::indexing::IndexOp; @@ -27,6 +30,8 @@ use serde_json::Value; use crate::error::ModelError; use crate::turboquant::KvCacheConfig; +static BONSAI_IGNORED_MASK_WARNED: AtomicBool = AtomicBool::new(false); + // --------------------------------------------------------------------------- // SamplingParams -- configurable sampling parameters // --------------------------------------------------------------------------- @@ -91,6 +96,29 @@ pub enum AnyCache { Hybrid(Vec>), } +impl AnyCache { + /// Trim every layer cache by `count` tokens, discarding the most recent + /// entries. Used after speculative-decode verify to roll back rejected + /// draft tokens. Hybrid SSM (recurrent) layers are intentionally left + /// untouched — their state cannot be trimmed by offset alone. + pub fn trim_by(&mut self, count: usize) { + match self { + Self::KV(layers) => { + for layer in layers.iter_mut().flatten() { + layer.trim_by(count); + } + } + Self::Hybrid(layers) => { + for layer in layers.iter_mut().flatten() { + if let LayerCache::KV(kv) = layer { + kv.trim_by(count); + } + } + } + } + } +} + /// Unified model wrapper dispatching to the correct architecture. pub enum AnyModel { /// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3. @@ -109,6 +137,8 @@ pub enum AnyModel { LlavaQwen2(llava_qwen2::LlavaQwen2Model), /// DeepSeek-V2 with Multi-head Latent Attention and sparse `MoE`. DeepSeekV2(deepseek_v2::DeepSeekV2CausalLM), + /// Bonsai-Q1: packed 1.25-bpw Qwen3-shaped target (1.7B / 8B). + BonsaiQ1(bonsai_q1::BonsaiQ1Gpu), } fn checked_head_dim(hidden_size: i32, num_attention_heads: i32) -> Result { @@ -190,6 +220,16 @@ impl AnyModel { (Self::LlavaQwen2(m), AnyCache::KV(c)) => m.forward_text(inputs, mask, c), (Self::DeepSeekV2(m), AnyCache::KV(c)) => m.forward(inputs, mask, c), (Self::Qwen3Next(m), AnyCache::Hybrid(c)) => m.forward(inputs, mask, c), + // BonsaiQ1 builds its causal mask internally; any externally-provided + // mask is ignored (causal-only semantics). + (Self::BonsaiQ1(m), AnyCache::KV(c)) => { + if mask.is_some() && !BONSAI_IGNORED_MASK_WARNED.swap(true, Ordering::Relaxed) { + tracing::warn!( + "BonsaiQ1 ignores externally provided masks and builds its own causal mask" + ); + } + m.forward(inputs, c) + } _ => Err(Exception::custom("Model/cache type mismatch")), } } @@ -210,6 +250,7 @@ impl AnyModel { (Self::LlavaQwen2(m), AnyCache::KV(c)) => m.forward_text_hidden(inputs, mask, c), (Self::DeepSeekV2(m), AnyCache::KV(c)) => m.forward_hidden(inputs, mask, c), (Self::Qwen3Next(m), AnyCache::Hybrid(c)) => m.forward_hidden(inputs, mask, c), + (Self::BonsaiQ1(m), AnyCache::KV(c)) => bonsai_q1::forward_trunk_free(m, c, inputs), _ => Err(Exception::custom("Model/cache type mismatch")), } } @@ -324,7 +365,8 @@ impl AnyModel { | Self::Phi3(_) | Self::Starcoder2(_) | Self::LlavaQwen2(_) - | Self::DeepSeekV2(_) => Err(Exception::custom( + | Self::DeepSeekV2(_) + | Self::BonsaiQ1(_) => Err(Exception::custom( "Batched forward only supported for Transformer models", )), } @@ -350,7 +392,8 @@ impl AnyModel { | Self::Phi3(_) | Self::Starcoder2(_) | Self::LlavaQwen2(_) - | Self::DeepSeekV2(_) => None, + | Self::DeepSeekV2(_) + | Self::BonsaiQ1(_) => None, } } @@ -371,7 +414,8 @@ impl AnyModel { | Self::Phi3(_) | Self::Starcoder2(_) | Self::LlavaQwen2(_) - | Self::DeepSeekV2(_) => Err(Exception::custom("MTP not supported for this model")), + | Self::DeepSeekV2(_) + | Self::BonsaiQ1(_) => Err(Exception::custom("MTP not supported for this model")), } } @@ -390,7 +434,8 @@ impl AnyModel { | Self::Phi3(_) | Self::Starcoder2(_) | Self::LlavaQwen2(_) - | Self::DeepSeekV2(_) => Err(Exception::custom("MTP not supported for this model")), + | Self::DeepSeekV2(_) + | Self::BonsaiQ1(_) => Err(Exception::custom("MTP not supported for this model")), } } @@ -412,7 +457,7 @@ impl AnyModel { } /// The model's hidden dimension. - pub const fn hidden_size(&self) -> i32 { + pub fn hidden_size(&self) -> i32 { match self { Self::Transformer(m) => m.args.hidden_size, Self::Qwen3Moe(m) => m.args.hidden_size, @@ -422,6 +467,7 @@ impl AnyModel { Self::Starcoder2(m) => m.args.hidden_size, Self::LlavaQwen2(m) => m.hidden_size(), Self::DeepSeekV2(m) => m.args.hidden_size, + Self::BonsaiQ1(m) => i32::try_from(m.config.hidden).unwrap_or(i32::MAX), } } @@ -458,6 +504,10 @@ impl AnyModel { m.args.num_key_value_heads, m.args.qk_nope_head_dim + m.args.qk_rope_head_dim, )), + Self::BonsaiQ1(m) => Ok(( + i32::try_from(m.config.kv_heads).map_err(|e| Exception::custom(e.to_string()))?, + i32::try_from(m.config.head_dim).map_err(|e| Exception::custom(e.to_string()))?, + )), } } @@ -466,6 +516,7 @@ impl AnyModel { self.make_cache_with_config(KvCacheConfig::default()) } + #[allow(clippy::too_many_lines)] pub fn make_cache_with_config( &self, kv_cache_config: KvCacheConfig, @@ -560,6 +611,16 @@ impl AnyModel { Ok(AnyCache::Hybrid(m.make_cache())) } } + Self::BonsaiQ1(m) => { + if kv_cache_config.is_turboquant() { + return Err(Exception::custom( + "TurboQuant is not supported for BonsaiQ1 (1-bit packed engine)", + )); + } + let layers = + i32::try_from(m.config.layers).map_err(|e| Exception::custom(e.to_string()))?; + Ok(make_kv_cache(layers)) + } } } @@ -578,7 +639,8 @@ impl AnyModel { | Self::Gemma2(_) | Self::Phi3(_) | Self::Starcoder2(_) - | Self::DeepSeekV2(_) => None, + | Self::DeepSeekV2(_) + | Self::BonsaiQ1(_) => None, } } @@ -1189,6 +1251,7 @@ fn remap_quantized_key(key: &str) -> Option { #[allow(clippy::panic, clippy::unwrap_used, clippy::indexing_slicing)] mod tests { use super::*; + use crate::cache::KeyValueCache; fn params(temp: f32, top_p: f32) -> SamplingParams { SamplingParams { @@ -1695,4 +1758,58 @@ mod tests { assert!((vals[1] - 2.0).abs() < 1e-5); assert!((vals[2] - 4.5).abs() < 1e-5); } + + // --- AnyCache::trim_by tests --- + + #[test] + fn any_cache_trim_by_kv_dispatches_to_each_layer() { + // Two KV layers, both at offset 0; trim_by saturates to 0. + // Verifies the dispatcher iterates None and Some(_) layers without panic. + let mut cache = AnyCache::KV(vec![ + Some(cache::SteppingKeyValueCache::new()), + None, + Some(cache::SteppingKeyValueCache::new()), + ]); + cache.trim_by(5); + if let AnyCache::KV(layers) = &cache { + assert_eq!(layers.len(), 3); + for layer in layers.iter().flatten() { + assert_eq!(layer.offset(), 0); + } + } else { + panic!("expected KV variant"); + } + } + + #[test] + fn any_cache_trim_by_hybrid_skips_arrays_layers() { + // Hybrid mixes LayerCache::KV (trimmable) and LayerCache::Arrays (recurrent, + // intentionally untouched). Verifies the dispatcher reaches into KV layers + // and leaves Arrays alone. + let mut arrays = qwen3_next::ArraysCache::new(); + arrays.offset = 7; + let mut cache = AnyCache::Hybrid(vec![ + Some(LayerCache::KV(cache::SteppingKeyValueCache::new())), + Some(LayerCache::Arrays(arrays)), + None, + ]); + cache.trim_by(3); + if let AnyCache::Hybrid(layers) = &cache { + assert_eq!(layers.len(), 3); + // KV layer trimmed (saturated at 0 since starting offset was 0) + if let Some(LayerCache::KV(kv)) = layers.first().and_then(|l| l.as_ref()) { + assert_eq!(kv.offset(), 0); + } else { + panic!("expected first layer to be KV variant"); + } + // Arrays layer offset unchanged (recurrent state, can't trim by offset) + if let Some(LayerCache::Arrays(a)) = layers.get(1).and_then(|l| l.as_ref()) { + assert_eq!(a.offset, 7, "Arrays layer offset must NOT be trimmed"); + } else { + panic!("expected second layer to be Arrays variant"); + } + } else { + panic!("expected Hybrid variant"); + } + } } diff --git a/crates/higgs-models/src/qwen3_next.rs b/crates/higgs-models/src/qwen3_next.rs index a23d4543..3b0b1723 100644 --- a/crates/higgs-models/src/qwen3_next.rs +++ b/crates/higgs-models/src/qwen3_next.rs @@ -1015,6 +1015,7 @@ static DECODE_GEMV_ENABLED: OnceLock = OnceLock::new(); static QGEMV_NSG_OVERRIDE: OnceLock> = OnceLock::new(); static DENSE_FFN_GEMV_MODE: OnceLock = OnceLock::new(); static DENSE_FFN_FUSE_GATE_UP: OnceLock = OnceLock::new(); +static MOE_FFN_FUSE_GATE_UP: OnceLock = OnceLock::new(); #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum DenseFfnGemvMode { @@ -1087,6 +1088,10 @@ fn dense_ffn_fuse_gate_up() -> bool { }) } +fn moe_ffn_fuse_gate_up() -> bool { + *MOE_FFN_FUSE_GATE_UP.get_or_init(|| truthy_env_var("HIGGS_MOE_FFN_GATE_UP")) +} + fn qgemv_config_cache_enabled() -> bool { *QGEMV_CONFIG_CACHE_ENABLED.get_or_init(|| truthy_env_var("HIGGS_CACHE_QGEMV_CONFIGS")) } @@ -1696,6 +1701,8 @@ pub(crate) struct SwitchMlpWeights { up_proj: QLinear, #[param] down_proj: QLinear, + /// Lazily fused gate+up weights for `MoE` `gather_qmm` (3→2 calls per layer). + fused_gate_up: Option<(Array, Array, Array, i32)>, } impl SwitchMlpWeights { @@ -1705,6 +1712,7 @@ impl SwitchMlpWeights { gate_proj, up_proj, down_proj, + fused_gate_up: None, }) } @@ -1867,6 +1875,106 @@ impl SwitchMlpWeights { // Reshape back to [B, L, top_k, D] out_unsorted.reshape(&[b, l, top_k, d]) } + + /// Like `forward_gather_global_sort` but fuses gate+up into a single + /// `gather_qmm` call (3→2 per layer). Lazy-inits fused weights on first call. + /// Production routing gates this behind `HIGGS_MOE_FFN_GATE_UP` because the + /// fused cache duplicates the resident gate/up tensors. + pub(crate) fn forward_gather_fused( + &mut self, + x: &Array, + indices: &Array, + ) -> Result { + // Lazy-init: concatenate gate+up weights along axis 1 (intermediate dim). + // MoE weights are [num_experts, intermediate_packed, hidden]. + if self.fused_gate_up.is_none() { + let intermediate = *self + .gate_proj + .weight + .shape() + .get(1) + .ok_or_else(|| Exception::custom("gate_proj weight missing dim 1"))?; + let fw = ops::concatenate_axis(&[&*self.gate_proj.weight, &*self.up_proj.weight], 1)?; + let fs = ops::concatenate_axis(&[&*self.gate_proj.scales, &*self.up_proj.scales], 1)?; + let fb = ops::concatenate_axis(&[&*self.gate_proj.biases, &*self.up_proj.biases], 1)?; + fw.eval()?; + fs.eval()?; + fb.eval()?; + self.fused_gate_up = Some((fw, fs, fb, intermediate)); + } + let (fw, fs, fb, intermediate) = self + .fused_gate_up + .as_ref() + .ok_or_else(|| Exception::custom("fused_gate_up missing after init"))?; + + // --- Global sort (same as forward_gather_global_sort) --- + let x_shape = x.shape(); + let err = || Exception::custom("forward_gather_fused input must be [B, L, D]"); + let b = *x_shape.first().ok_or_else(err)?; + let l = *x_shape.get(1).ok_or_else(err)?; + let d = *x_shape.get(2).ok_or_else(err)?; + let top_k = *indices + .shape() + .last() + .ok_or_else(|| Exception::custom("indices must have last dim"))?; + + let idx_flat = indices.flatten(None, None)?; + let order = ops::argsort_axis(&idx_flat, 0)?; + let inv_order = ops::argsort_axis(&order, 0)?; + + let top_k_u32 = + u32::try_from(top_k).map_err(|_| Exception::custom("top_k must fit in u32"))?; + let top_k_arr = Array::from_slice(&[top_k_u32], &[1]); + let token_idx = order.floor_divide(&top_k_arr)?; + + let x_flat = x.reshape(&[b * l, 1, d])?; + let x_sorted = x_flat.take_axis(&token_idx, 0)?; + let idx_sorted = idx_flat.take_axis(&order, 0)?; + + // --- Fused gate+up: ONE gather_qmm instead of TWO --- + let fused_out = gather_qmm( + &x_sorted, + fw, + fs, + fb, + &idx_sorted, + true, + self.gate_proj.group_size, + self.gate_proj.bits, + true, + )?; + // Split at intermediate boundary → gate_out, up_out + let parts = fused_out.split_axis(&[*intermediate], Some(-1))?; + let gate_out = parts + .first() + .ok_or_else(|| Exception::custom("fused split failed"))?; + let up_out = parts + .get(1) + .ok_or_else(|| Exception::custom("fused split failed"))?; + let activated = swiglu(gate_out, up_out)?; + + // --- down_proj: unchanged --- + let down_out = gather_qmm( + &activated, + &self.down_proj.weight, + &self.down_proj.scales, + &self.down_proj.biases, + &idx_sorted, + true, + self.down_proj.group_size, + self.down_proj.bits, + true, + )?; + + // down_out: [N, 1, D] -> squeeze M -> [N, D] + let out_flat = down_out.squeeze_axes(&[-2])?; + + // --- Unsort: restore original token order --- + let out_unsorted = out_flat.take_axis(&inv_order, 0)?; + + // Reshape back to [B, L, top_k, D] + out_unsorted.reshape(&[b, l, top_k, d]) + } } // --------------------------------------------------------------------------- @@ -2667,14 +2775,6 @@ impl FfnBlock { .gate .as_ref() .ok_or_else(|| Exception::custom("MoE gate missing"))?; - let switch_ref = self - .switch_mlp - .as_ref() - .ok_or_else(|| Exception::custom("MoE switch_mlp missing"))?; - let se_ref = self - .shared_expert - .as_ref() - .ok_or_else(|| Exception::custom("MoE shared_expert missing"))?; let seg_ref = self .shared_expert_gate .as_ref() @@ -2698,13 +2798,26 @@ impl FfnBlock { raw_scores }; - let y = switch_ref.forward_gather_global_sort(x, &inds)?; + let switch_ref = self + .switch_mlp + .as_mut() + .ok_or_else(|| Exception::custom("MoE switch_mlp missing"))?; + let y = if moe_ffn_fuse_gate_up() { + switch_ref.forward_gather_fused(x, &inds)? + } else { + switch_ref.forward_gather_global_sort(x, &inds)? + }; let expert_sum = y .multiply(&scores.expand_dims(-1)?)? .sum_axes(&[-2], false)?; + let se_ref = self + .shared_expert + .as_ref() + .ok_or_else(|| Exception::custom("MoE shared_expert missing"))?; let shared_y = se_ref.forward(x)?; + let shared_gate_val = nn::sigmoid(&seg_ref.forward(x)?)?; let shared_out = shared_y.multiply(&shared_gate_val)?; @@ -3722,14 +3835,29 @@ fn load_qwen3_5_moe_text_config_args>( .or_insert(serde_json::Value::from(0)); } - // When HIGGS_SEPARATE_GDN_PROJ is set, construct the model with separate - // GDN projection fields so the direct weight loader can match them. - // Otherwise, construct with fused fields (weights are rearranged at load time). - let use_separate = std::env::var("HIGGS_SEPARATE_GDN_PROJ").is_ok(); + // When HIGGS_SEPARATE_GDN_PROJ is set, or when per-layer GDN BA quantization + // disagrees on bit-width / group_size between in_proj_a and in_proj_b (common + // in Unsloth dynamic quants), construct the model with separate GDN + // projection fields so the direct weight loader can match them. Otherwise, + // construct with fused fields (weights are rearranged at load time). + let mixed_ba_layers = qwen3_5_mixed_ba_quantization_layers(&config, text_config); + let config_requests_separate = map + .get("use_separate_gdn_projections") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + let use_separate = config_requests_separate + || std::env::var("HIGGS_SEPARATE_GDN_PROJ").is_ok() + || !mixed_ba_layers.is_empty(); map.insert( "use_separate_gdn_projections".to_owned(), serde_json::Value::from(use_separate), ); + if !mixed_ba_layers.is_empty() { + tracing::info!( + layers = ?mixed_ba_layers, + "Detected mixed-bit GDN BA projections; using separate GDN projections" + ); + } // Detect per-layer gate quantization override from top-level quantization config if let Some(gate_q) = gate_quantization_override(&config) { @@ -3739,6 +3867,59 @@ fn load_qwen3_5_moe_text_config_args>( Ok(serde_json::from_value(obj)?) } +/// Parse a `{group_size, bits}` quantization spec from a JSON node. +fn qwen3_5_quantization_config(value: &serde_json::Value) -> Option { + Some(QuantizationConfig { + group_size: i32::try_from(value.get("group_size")?.as_i64()?).ok()?, + bits: i32::try_from(value.get("bits")?.as_i64()?).ok()?, + }) +} + +/// Scan the per-layer `quantization` map and return layer indices where the GDN +/// `in_proj_a` and `in_proj_b` projections disagree on bit-width or group size. +/// Such layers cannot be fused into a single `in_proj_ba` matrix without +/// dequantizing, so the loader must fall back to separate GDN projections. +fn qwen3_5_mixed_ba_quantization_layers( + config: &serde_json::Value, + text_config: &serde_json::Value, +) -> Vec { + let Some(quant) = config.get("quantization") else { + return Vec::new(); + }; + let Some(default_quant) = qwen3_5_quantization_config(quant) else { + return Vec::new(); + }; + let Some(num_hidden_layers) = text_config + .get("num_hidden_layers") + .and_then(serde_json::Value::as_i64) + .and_then(|n| i32::try_from(n).ok()) + else { + return Vec::new(); + }; + + (0..num_hidden_layers) + .filter(|layer_idx| { + let prefixes = [ + format!("language_model.model.layers.{layer_idx}.linear_attn"), + format!("model.layers.{layer_idx}.linear_attn"), + ]; + let projection_quantization = |projection: &str| { + prefixes + .iter() + .find_map(|prefix| { + quant + .get(format!("{prefix}.{projection}")) + .and_then(qwen3_5_quantization_config) + }) + .unwrap_or_else(|| default_quant.clone()) + }; + let a_quant = projection_quantization("in_proj_a"); + let b_quant = projection_quantization("in_proj_b"); + a_quant.bits != b_quant.bits || a_quant.group_size != b_quant.group_size + }) + .collect() +} + /// Load a Qwen3.5 dense model (VLM wrapper around `Qwen3Next` architecture). /// /// Reads `text_config` for model args, strips `language_model.` prefix from @@ -3766,15 +3947,7 @@ pub fn load_qwen3_5_model>(model_dir: P) -> Result>( head_v_dim: args.linear_value_head_dim, }; gdn_dims.validate()?; - let mut model = Qwen3NextCausalLM::new(args.clone())?; - // Load weights with GDN projection rearrangement: flat (qkv,z,b,a) // → per-head-grouped (qkvz,ba) for fused 2-dispatch forward path. - // Respect use_separate_gdn_projections config flag or HIGGS_SEPARATE_GDN_PROJ env var. - let use_separate = + // Respects use_separate_gdn_projections (set by HIGGS_SEPARATE_GDN_PROJ env + // var or mixed-bit BA detection in load_qwen3_5_moe_text_config_args), and + // falls back to separate projections at runtime if fusion finds a + // shape-incompatible BA pair. + let model = load_qwen3_5_model_with_gdn_fallback(model_path, args, &gdn_dims)?; + + tracing::info!("Qwen3.5-MoE model loaded successfully"); + Ok(model) +} + +/// Build a `Qwen3NextCausalLM` and load weights, choosing fused or separate GDN +/// projections. When the config (or env var) requests separate projections, use +/// the direct loader. Otherwise try the fused loader; if it reports a mixed-bit +/// `in_proj_ba` shape mismatch, rebuild the model with separate projections and +/// retry via the direct loader. +fn load_qwen3_5_model_with_gdn_fallback( + model_path: &Path, + mut args: Qwen3NextModelArgs, + gdn_dims: &GdnDims, +) -> Result { + let force_separate = args.use_separate_gdn_projections || std::env::var("HIGGS_SEPARATE_GDN_PROJ").is_ok(); - if use_separate { + if force_separate { + args.use_separate_gdn_projections = true; + let mut model = Qwen3NextCausalLM::new(args)?; load_qwen3_5_moe_weights_direct(&mut model, model_path)?; tracing::info!("Using SEPARATE GDN projections (4 dispatches per layer)"); - } else { - load_qwen3_5_moe_weights_fused(&mut model, model_path, &gdn_dims)?; - tracing::info!("Using FUSED GDN projections (2 dispatches per layer)"); + return Ok(model); } - tracing::info!("Qwen3.5-MoE model loaded successfully"); - Ok(model) + let mut fused_model = Qwen3NextCausalLM::new(args.clone())?; + match load_qwen3_5_moe_weights_fused(&mut fused_model, model_path, gdn_dims) { + Ok(()) => { + tracing::info!("Using FUSED GDN projections (2 dispatches per layer)"); + Ok(fused_model) + } + Err(err) if is_mixed_bit_gdn_ba_fusion_error(&err) => { + tracing::warn!( + error = %err, + "Detected mixed-bit GDN BA projection shapes; retrying with separate GDN projections" + ); + args.use_separate_gdn_projections = true; + let mut separate_model = Qwen3NextCausalLM::new(args)?; + load_qwen3_5_moe_weights_direct(&mut separate_model, model_path)?; + tracing::info!( + "Using SEPARATE GDN projections (4 dispatches per layer, mixed-bit fallback)" + ); + Ok(separate_model) + } + Err(err) => Err(err), + } +} + +/// Returns true when the supplied error is the mixed-bit BA fusion error raised +/// by [`load_qwen3_5_moe_weights_fused`] when `in_proj_a` and `in_proj_b` have +/// incompatible packed inner shapes. +fn is_mixed_bit_gdn_ba_fusion_error(err: &ModelError) -> bool { + matches!( + err, + ModelError::ShapeMismatch(message) + if message.contains("in_proj_ba") + && message.contains("requires separate GDN projections") + ) } /// GDN dimension info extracted from model args before move. @@ -3918,6 +4139,25 @@ fn concat_and_permute(a: &Array, b: &Array, perm: &[i32]) -> Result bool { + a_shape.len() == b_shape.len() + && a_shape + .iter() + .zip(b_shape.iter()) + .enumerate() + .all(|(axis, (lhs, rhs))| axis == 0 || lhs == rhs) +} + +fn can_concatenate_axis0(a: &Array, b: &Array) -> bool { + let a_shape = a.shape(); + let b_shape = b.shape(); + can_concatenate_axis0_shapes(a_shape, b_shape) +} + /// Load Qwen3.5-MoE weights with GDN projection fusion. /// /// Direct weight loader: strip `language_model.` prefix, no rearrangement. @@ -4052,6 +4292,13 @@ fn load_qwen3_5_moe_weights_fused( format!("Incomplete GDN projection pair for key: {combined_key}"), ))); }; + if combined_key.contains("in_proj_ba") && !can_concatenate_axis0(a, b) { + return Err(crate::error::ModelError::ShapeMismatch(format!( + "Mixed-bit BA fusion requires separate GDN projections for key {combined_key}: {:?} vs {:?}", + a.shape(), + b.shape() + ))); + } let Some(param) = params.get_mut(combined_key.as_str()) else { return Err(crate::error::ModelError::Io(std::io::Error::other( format!("Fused target key not found in model params: {combined_key}"), @@ -4947,6 +5194,81 @@ mod tests { ); } + #[test] + fn test_moe_gate_up_fusion_parity() { + // Fused gate+up (2 gather_qmm) must match unfused (3 gather_qmm). + // Uses random weights + distinct per-token inputs to stress sort/unsort. + let num_experts = 8; + let hidden = 128; + let intermediate = 64; + let top_k = 3; + let b = 1; + let l = 16; + + let mut block = SwitchMlpWeights::new(64, 4).unwrap(); + + let gate_w = mlx_rs::random::uniform::( + -1.0, + 1.0, + &[num_experts, intermediate, hidden], + None, + ) + .unwrap(); + let (gw, gs, gb) = quantize_weights(&gate_w, 64, 4); + *block.gate_proj.weight = gw; + *block.gate_proj.scales = gs; + *block.gate_proj.biases = gb; + + let up_w = mlx_rs::random::uniform::( + -1.0, + 1.0, + &[num_experts, intermediate, hidden], + None, + ) + .unwrap(); + let (uw, us, ub) = quantize_weights(&up_w, 64, 4); + *block.up_proj.weight = uw; + *block.up_proj.scales = us; + *block.up_proj.biases = ub; + + let down_w = mlx_rs::random::uniform::( + -1.0, + 1.0, + &[num_experts, hidden, intermediate], + None, + ) + .unwrap(); + let (dw, ds, db) = quantize_weights(&down_w, 64, 4); + *block.down_proj.weight = dw; + *block.down_proj.scales = ds; + *block.down_proj.biases = db; + + let x = mlx_rs::random::uniform::(-1.0, 1.0, &[b, l, hidden], None).unwrap(); + let idx_data: Vec = (0..(b * l * top_k) as u32) + .map(|i| i % num_experts as u32) + .collect(); + let indices = Array::from_slice(&idx_data, &[b, l, top_k]); + x.eval().unwrap(); + indices.eval().unwrap(); + + // Reference: unfused 3-call path + let reference = block.forward_gather_global_sort(&x, &indices).unwrap(); + // Fused: 2-call path + let fused = block.forward_gather_fused(&x, &indices).unwrap(); + reference.eval().unwrap(); + fused.eval().unwrap(); + + assert_eq!(reference.shape(), fused.shape()); + assert_eq!(fused.shape(), &[b, l, top_k, hidden]); + + let diff = reference.subtract(&fused).unwrap().abs().unwrap(); + let max_diff: f32 = diff.max(None).unwrap().item(); + assert!( + max_diff < 1e-5, + "fused gate+up differs from unfused by {max_diff}" + ); + } + #[test] fn test_switch_mlp_forward_gather_shapes() { // Verify forward_gather produces the correct output shape with the @@ -6345,6 +6667,7 @@ mod tests { gate_proj: make_switch_ql(d, d_inter), up_proj: make_switch_ql(d, d_inter), down_proj: make_switch_ql(d_inter, d), + fused_gate_up: None, }, shared_expert: Qwen3NextMLP { gate_proj: make_ql(d, shared_inter * 2, gs, bits), @@ -12825,6 +13148,132 @@ mod tests { assert_eq!(args.mtp_num_hidden_layers, 1); } + #[test] + fn test_load_qwen35_mixed_ba_quantization_forces_separate_gdn() { + let dir = tempfile::tempdir().unwrap(); + let config = format!( + r#"{{ + "text_config": {}, + "tie_word_embeddings": false, + "quantization": {{ + "group_size": 64, + "bits": 2, + "mode": "affine", + "language_model.model.layers.1.linear_attn.in_proj_a": {{ + "group_size": 64, + "bits": 5, + "mode": "affine" + }} + }} + }}"#, + qwen35_dense_text_config() + ); + std::fs::write(dir.path().join("config.json"), config).unwrap(); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + args.use_separate_gdn_projections, + "mixed-bit in_proj_a/in_proj_b must force separate GDN projections" + ); + } + + #[test] + fn test_load_qwen35_mixed_ba_quantization_supports_unprefixed_layer_keys() { + let dir = tempfile::tempdir().unwrap(); + let config = format!( + r#"{{ + "text_config": {}, + "tie_word_embeddings": false, + "quantization": {{ + "group_size": 64, + "bits": 2, + "mode": "affine", + "model.layers.1.linear_attn.in_proj_a": {{ + "group_size": 64, + "bits": 5, + "mode": "affine" + }} + }} + }}"#, + qwen35_dense_text_config() + ); + std::fs::write(dir.path().join("config.json"), config).unwrap(); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + args.use_separate_gdn_projections, + "unprefixed mixed-bit in_proj_a/in_proj_b must force separate GDN projections" + ); + } + + #[test] + fn test_load_qwen35_matching_ba_quantization_keeps_fused_gdn() { + let dir = tempfile::tempdir().unwrap(); + let config = format!( + r#"{{ + "text_config": {}, + "tie_word_embeddings": false, + "quantization": {{ + "group_size": 64, + "bits": 2, + "mode": "affine", + "language_model.model.layers.1.linear_attn.in_proj_a": {{ + "group_size": 64, + "bits": 5, + "mode": "affine" + }}, + "language_model.model.layers.1.linear_attn.in_proj_b": {{ + "group_size": 64, + "bits": 5, + "mode": "affine" + }} + }} + }}"#, + qwen35_dense_text_config() + ); + std::fs::write(dir.path().join("config.json"), config).unwrap(); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + !args.use_separate_gdn_projections, + "matching BA overrides should keep the fused GDN loader path" + ); + } + + #[test] + fn test_load_qwen35_explicit_separate_gdn_config_is_preserved() { + let dir = tempfile::tempdir().unwrap(); + let mut text_config = qwen35_dense_text_config().trim_end_matches('}').to_owned(); + text_config.push_str( + r#", + "use_separate_gdn_projections": true + }"#, + ); + write_qwen35_config(dir.path(), &text_config); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + args.use_separate_gdn_projections, + "explicit use_separate_gdn_projections=true must not be overwritten" + ); + } + + #[test] + fn test_can_concatenate_axis0_detects_quantized_inner_shape_mismatch() { + assert!( + !can_concatenate_axis0_shapes(&[48, 320], &[48, 800]), + "different packed inner dims must block BA fusion" + ); + assert!( + can_concatenate_axis0_shapes(&[48, 320], &[96, 320]), + "axis-0 size may differ because fusion concatenates rows" + ); + } + /// GQA ratio: `num_v_heads` must be divisible by `num_k_heads`. /// This validates the assumption used in test/bench GDN recurrence loops. #[test] diff --git a/crates/higgs-models/src/yarn.rs b/crates/higgs-models/src/yarn.rs new file mode 100644 index 00000000..5081b8a0 --- /dev/null +++ b/crates/higgs-models/src/yarn.rs @@ -0,0 +1,223 @@ +//! YaRN RoPE helpers shared across models. +//! +//! Extracted verbatim from the original site in `deepseek_v2.rs`. The +//! `apply_yarn_rope` wrapper adds a `traditional` flag so Qwen3-family models +//! (Bonsai) can reuse the same freq precomputation with `traditional=false`, +//! while DeepSeek stays on `traditional=true`. + +#![allow(clippy::doc_markdown)] // YaRN, RoPE, etc. are domain terms, not items. + +use std::f32::consts::PI; + +use mlx_rs::{Array, error::Exception, fast}; + +#[allow( + clippy::as_conversions, + clippy::cast_precision_loss, + clippy::cast_possible_truncation +)] +fn yarn_find_correction_dim(num_rotations: f32, dim: i32, base: f32, max_pos: i32) -> f32 { + let dim_f = dim as f32; + let max_pos_f = max_pos as f32; + (dim_f * (max_pos_f / (num_rotations * 2.0 * PI)).ln()) / (2.0 * base.ln()) +} + +#[allow( + clippy::as_conversions, + clippy::cast_possible_truncation, + clippy::cast_sign_loss +)] +fn yarn_find_correction_range( + low_rot: f32, + high_rot: f32, + dim: i32, + base: f32, + max_pos: i32, +) -> (i32, i32) { + let low = yarn_find_correction_dim(low_rot, dim, base, max_pos).floor() as i32; + let high = yarn_find_correction_dim(high_rot, dim, base, max_pos).ceil() as i32; + (low.max(0), high.min(dim - 1)) +} + +pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 { + if scale <= 1.0 { + 1.0 + } else { + (0.1 * mscale).mul_add(scale.ln(), 1.0) + } +} + +#[allow( + clippy::as_conversions, + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::indexing_slicing +)] +pub(crate) fn compute_yarn_freqs( + dim: i32, + base: f32, + scaling_factor: f32, + orig_max_pos: i32, + beta_fast: f32, + beta_slow: f32, +) -> Array { + let half_dim = dim / 2; + let dim_f = dim as f32; + + let mut freq_extra = Vec::with_capacity(half_dim as usize); + let mut freq_inter = Vec::with_capacity(half_dim as usize); + for i in 0..half_dim { + let exp = (2 * i) as f32 / dim_f; + let theta = base.powf(exp); + freq_extra.push(theta); + freq_inter.push(scaling_factor * theta); + } + + let (low, high) = yarn_find_correction_range(beta_fast, beta_slow, dim, base, orig_max_pos); + + let low_f = low as f32; + let high_f = high as f32; + let range = if (high_f - low_f).abs() < 0.001 { + high_f - low_f + 0.001 + } else { + high_f - low_f + }; + + let mut freqs = Vec::with_capacity(half_dim as usize); + for i in 0..half_dim as usize { + let idx_f = i as f32; + let ramp = ((idx_f - low_f) / range).clamp(0.0, 1.0); + let mask = 1.0 - ramp; + let inter = freq_inter[i]; + let extra = freq_extra[i]; + let denom = inter * mask + extra * (1.0 - mask); + freqs.push((inter * extra) / denom); + } + + Array::from_slice(&freqs, &[half_dim]) +} + +/// Apply YaRN-scaled RoPE. +/// +/// When `mscale != 1.0`, inputs are pre-scaled before rotation (matches the +/// DeepSeek reference). `traditional=false` matches the Qwen3 / LLaMA rope +/// layout; `traditional=true` matches DeepSeek's packed complex layout. +/// +/// `offset` is a scalar `Array` (not an `i32`) so the value is not baked into +/// compiled traces — required for `compile_with_state` wrapping of decode. +pub(crate) fn apply_yarn_rope( + x: &Array, + dim: i32, + base: f32, + yarn_freqs: Option<&Array>, + mscale: f32, + offset: &Array, + traditional: bool, +) -> Result { + let x_scaled = if (mscale - 1.0).abs() > f32::EPSILON { + // Match x's dtype to avoid silent upcast (fp16 → f32) that bleeds into + // the entire attention path (rope, sdpa, o_proj inputs). For Bonsai + // with rope_yarn_factor>1, mscale ≈ 1.14, so this branch fires every + // rope call; without the cast the whole decode runs in f32 and pays + // ~28 ms/step on 8B. See bisect_decode v6 vs v7. + let scalar = Array::from_f32(mscale).as_dtype(x.dtype())?; + x.multiply(&scalar)? + } else { + x.clone() + }; + yarn_freqs.map_or_else( + || { + fast::rope_dynamic( + &x_scaled, + dim, + traditional, + base, + 1.0, + offset, + None::<&Array>, + ) + }, + |freqs| { + fast::rope_dynamic( + &x_scaled, + dim, + traditional, + None::, + 1.0, + offset, + Some(freqs), + ) + }, + ) +} + +#[cfg(test)] +#[allow( + clippy::panic, + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + clippy::as_conversions, + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_lossless +)] +mod tests { + use super::*; + use mlx_rs::random; + + /// Parity: dynamic-offset rope (via `apply_yarn_rope`) must match + /// static-offset `fast::rope` for every offset in 0..64, both with and + /// without precomputed YaRN freqs. Guards the B1 step 1 migration to + /// `fast::rope_dynamic` (prerequisite for `compile_with_state`-wrapped + /// decode — the static `offset: i32` was being baked into the compile + /// trace, forcing a recompile every step). + #[test] + #[ignore = "passes targeted (`cargo test yarn::`) but fails when run after other MLX tests in the same process — global Metal/RNG state contamination, pre-existing harness limitation"] + fn rope_dynamic_matches_static_offset_0_to_64() { + random::seed(71).unwrap(); + // [B=2, H=4, T=1, head_dim=16] — matches decode shape (T=1). + let head_dim: i32 = 16; + let base: f32 = 10_000.0; + let x = random::uniform::<_, f32>(0.0, 1.0, &[2, 4, 1, head_dim], None).unwrap(); + + // Case A: no yarn_freqs (base path). + for offset in 0_i32..64 { + let off_arr = Array::from_int(offset); + let got = apply_yarn_rope(&x, head_dim, base, None, 1.0, &off_arr, false).unwrap(); + let want = fast::rope(&x, head_dim, false, base, 1.0, offset, None::<&Array>).unwrap(); + let diff = (&got - &want) + .abs() + .unwrap() + .max(None) + .unwrap() + .item::(); + assert!( + diff < 1e-5, + "offset={offset} no-freqs: max_diff={diff} >= 1e-5" + ); + } + + // Case B: with precomputed yarn_freqs (Bonsai path). + let freqs = compute_yarn_freqs(head_dim, base, 1.0, 2048, 32.0, 1.0); + for offset in 0_i32..64 { + let off_arr = Array::from_int(offset); + let got = + apply_yarn_rope(&x, head_dim, base, Some(&freqs), 1.0, &off_arr, false).unwrap(); + let want = + fast::rope(&x, head_dim, false, None::, 1.0, offset, Some(&freqs)).unwrap(); + let diff = (&got - &want) + .abs() + .unwrap() + .max(None) + .unwrap() + .item::(); + assert!( + diff < 1e-5, + "offset={offset} with-freqs: max_diff={diff} >= 1e-5" + ); + } + } +} diff --git a/docs/BONSAI_Q1.md b/docs/BONSAI_Q1.md new file mode 100644 index 00000000..7211d7d8 --- /dev/null +++ b/docs/BONSAI_Q1.md @@ -0,0 +1,18 @@ +# Bonsai-Q1 + +Bonsai-Q1 checkpoints are Qwen3-shaped models with MLX 1-bit affine +quantization metadata: + +- `model_type = "qwen3"` +- `quantization.bits = 1` +- `quantization.group_size = 128` + +The Higgs workspace stays on the pinned upstream `oxideai/mlx-rs` dependency. +That upstream revision does not yet include the MLX bits=1 affine Metal kernels, +so `higgs-engine` detects Bonsai-Q1 configs and returns an explicit unsupported +model error instead of routing them into the regular Qwen3 transformer loader. + +The packed loader and engine live in `crates/higgs-models/src/bonsai_q1.rs` so +the Rust-side code can be reviewed independently. Runtime enablement should wait +until bits=1 affine quantization support lands upstream in the MLX dependency +chain.