dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17
Open
dusterbloom wants to merge 4 commits intomainfrom
Open
dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17dusterbloom wants to merge 4 commits intomainfrom
dusterbloom wants to merge 4 commits intomainfrom
Conversation
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.
Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.
Adaptations from feat/magic-canvas → origin/main:
* `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
on origin/main (PR panbanda#143). Two call sites converted with
`unsigned_abs().try_into().unwrap_or(usize::MAX)`.
* Workspace clippy (nursery: `as_conversions`,
`cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
`explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
`redundant_pattern_matching`, `missing_const_for_fn`) — all 30
errors fixed in-place: `i32::try_from` for tensor-shape casts,
`clone_from` for in-place clones, `filter_map(Option::as_mut)` for
`iter().filter_map(if-let)` patterns, backticks on doc items.
No file-level allows.
The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.
Verification on origin/main:
* `cargo check -p higgs-models` — clean
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 330/330 pass
* `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.
Methods added on `Qwen3NextCausalLM`:
* `forward_with_taps` — forward returning logits AND vec of hidden
states at specified target layers; the drafter conditions on these.
* `forward_with_taps_stateless` — same, but does NOT mutate the
recurrent (GDN) state. Used during verify when state advancement is
handled separately.
* `forward_with_taps_tape` — forward that records each GDN layer's
innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
rerun for partial-accept rollback.
* `replay_tape_rollback` — restore GDN state to a tape position
without re-running the full model.
* `embed_token_ids` — apply the embedding layer alone (drafter input).
* `forward_all_logits_from_hidden` — apply lm_head alone (target's
verification of drafter outputs).
* `project_logits` (private helper) — lm_head with origin/main's
available projection paths only (ANE + dense_lm_head fields don't
exist here yet; ported in PR-8).
Methods added on `GatedDeltaNet`:
* `forward_stateless` — GDN forward without state mutation.
* `forward_with_tape` — GDN forward that captures the per-step
innovation into the tape.
* `replay_from_tape` — apply a tape to recompute SSM state to a target
position. Annotated `#[allow(dead_code)]` until the engine glue
drives it (next commit).
New public type `GdnLayerTape` exposes the per-layer innovation record.
Metal kernel infrastructure ported alongside:
* `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
* `gated_delta_kernel_ffi_with_tape` + matching kernel
* `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
discards the new state, matches caller semantics in `forward_stateless`)
Adaptations from feat/magic-canvas → origin/main:
* `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.
* `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
don't exist on this branch — ANE-feature paths stripped to the
plain Metal/MLX path. Fields ported in PR-8.
* Error handler uses `thread_local! RefCell<Option<String>>` instead
of feat/magic-canvas's `Mutex<Option<String>>` — matches the
branch's existing FFI error pattern.
Senior-Rust hygiene:
* No file-level blanket allows added.
* Function-scoped `#[allow(...)]` on the four genuine numerical
kernel functions (`forward_stateless`, `forward_with_tape`,
`forward_with_taps_tape`, `replay_tape_rollback`), each with a
one-line justification comment.
* `unwrap_used` never allowed — refactored to `?` propagation or
`expect("reason")` at the two call sites.
* Mechanical clippy refactors throughout: `find_map` for
`filter_map(..).next()`, `clone_from` for `assigning_clones`,
`if let` for single-pattern `match`, backticks for `doc_markdown`.
Verification on origin/main:
* `cargo check -p higgs-models` — clean
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 330/330 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader
Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.
`AnyModel` (in `higgs-models/src/lib.rs`):
* `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
* `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
tap hiddens + per-layer GDN tape) via a public type alias to
placate `clippy::type_complexity`
* `embed_token_ids` — Qwen3Next-only
* `forward_all_logits_from_hidden` — Qwen3Next-only
All non-Qwen3Next arms enumerate every variant explicitly to satisfy
`clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).
`AnyCache`:
* `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
slice/vec for engine glue that needs to inspect GDN state. Returns
`Result<_, Exception>` rather than panicking when called on a `KV`
cache, so the verify path in `SimpleEngine::generate_dflash_inner`
can propagate via `?`.
`dflash::accept_prefix`:
* Greedy speculative-decode acceptance: longest-matching prefix of
`draft` against `verify_argmax`, plus one bonus token at the
diverge point (or after the last accept).
* 5 unit tests covering full match, first-token reject, partial
match, empty draft, and the debug-only length assertion.
* Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
to avoid pulling in the 9970-line diffusion module for a 16-line
helper.
`engine::model_loader::load_dflash_drafter`:
* Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
call site lands in the next commit.
Verification on origin/main:
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
* `cargo test -p higgs-engine --lib` — 228/228 pass
* `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match
CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:
* `dflash.rs` — `GdnStateBackup::restore_and_rollback`
* `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`
Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
8 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lands the DFlash block-diffusion speculative-decoding foundation ported from
feat/magic-canvas. This is PR-6a of the magic-canvas split — the module + model surgery + dispatch surface; the engine-level draft-verify loop (SimpleEngine::generate_dflash_inner) is deferred to PR-6b because it needs end-to-end runtime verification against a real DFlash drafter checkpoint that this PR can't exercise from CI.What's in this PR
Three commits, all on a clean branch off
origin/main:ad7edea1crates/higgs-models/src/dflash.rs— the 0.5B drafter (config, dual-stream attention, GDN-state save/restore, KV-only rollback, drafter loader)62a73522forward_with_taps,forward_with_taps_stateless,forward_with_taps_tape,replay_tape_rollback,embed_token_ids,forward_all_logits_from_hidden,project_logitsonQwen3NextCausalLM;forward_stateless,forward_with_tape,replay_from_tapeonGatedDeltaNet;pub struct GdnLayerTape; Metal kernel FFI (tape_replay_kernel_ffi,gated_delta_kernel_ffi_with_tape,gated_delta_kernel_ffi_stateless)1c3c0d37AnyModel::{forward_with_taps, forward_with_taps_tape, embed_token_ids, forward_all_logits_from_hidden}dispatchers;AnyCache::{as_hybrid, as_hybrid_mut};dflash::accept_prefix(16-line greedy spec-decode acceptance + 5 unit tests);engine::model_loader::load_dflash_drafterTotal: ~2.4K lines net new, no public-API regressions on origin/main.
What's NOT in this PR (deferred)
SimpleEngine::generate_dflash_inner— the draft-verify loop. Thefeat/magic-canvasglue assumed a cpu_engine + ANE executor that we're stripping (defer to PR-8 ANE work) and a struct shape that has since evolved onorigin/main. Best done with a Carnice-9B + 0.5B drafter checkpoint loaded so we can verify the verify-path correctness, not just the compile.dflash_cpu.rs(CPU BLAS drafter) — depends on 7 BLAS helpers fromdiffusion.rs(~9970 lines, mostly unrelated). Defer to PR-6b together with the engine glue.dflash_ane.rs(ANE-accelerated drafter) — feature-gated; PR-8 territory.feat/magic-canvas) — depends on the engine glue being live. Will follow.Adaptations from
feat/magic-canvas→origin/mainSteppingKeyValueCache::rollback(i32)was renamedtrim_by(usize)on origin/main (PR feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback panbanda/higgs#143). Call sites indflash.rsandreplay_tape_rollbackconverted withunsigned_abs().try_into().unwrap_or(usize::MAX).Qwen3NextCausalLM'slm_head_ane,dense_lm_head,ane_handle,ane_kernelsfields don't exist on this branch — ANE-feature paths inproject_logitsandforward_with_tapestripped to plain Metal/MLX. Fields ported in PR-8.thread_local! RefCell<Option<String>>(matching this branch's existing FFI pattern) rather thanfeat/magic-canvas'sMutex<Option<String>>.Senior-Rust hygiene
#![allow(clippy::items_after_test_module)]preserved.#[allow(...)]only on four genuinely-numerical kernel functions (forward_stateless,forward_with_tape,forward_with_taps_tape,replay_tape_rollback), each with a one-line documented justification (Metal-kernel dispatch, tensor-shape indices, hot-path casts).unwrap_usednever allowed — refactored to?propagation throughout.AnyModelenumerate variants explicitly (no_ =>catch-alls).clippy::type_complexityresolved withpub type TapsTapeOutput.Test plan
cargo check -p higgs-models— cleancargo clippy --all-targets --all-features -- -D warnings— cleancargo fmt --check— cleancargo test -p higgs-models --lib— 335/335 pass (5 newaccept_prefixtests)cargo test -p higgs-engine --lib— 228/228 passcargo test -p higgs --lib -- --test-threads=1— 449/449 passContext
This is part of the
feat/magic-canvasPR split. PRs already shipped against panbanda upstream:🤖 Generated with Claude Code