Skip to content

dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17

Open
dusterbloom wants to merge 4 commits intomainfrom
dusterbloom/dflash-baseline
Open

dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch#17
dusterbloom wants to merge 4 commits intomainfrom
dusterbloom/dflash-baseline

Conversation

@dusterbloom
Copy link
Copy Markdown
Owner

Summary

Lands the DFlash block-diffusion speculative-decoding foundation ported from feat/magic-canvas. This is PR-6a of the magic-canvas split — the module + model surgery + dispatch surface; the engine-level draft-verify loop (SimpleEngine::generate_dflash_inner) is deferred to PR-6b because it needs end-to-end runtime verification against a real DFlash drafter checkpoint that this PR can't exercise from CI.

What's in this PR

Three commits, all on a clean branch off origin/main:

Commit Adds Net lines
ad7edea1 crates/higgs-models/src/dflash.rs — the 0.5B drafter (config, dual-stream attention, GDN-state save/restore, KV-only rollback, drafter loader) +561
62a73522 forward_with_taps, forward_with_taps_stateless, forward_with_taps_tape, replay_tape_rollback, embed_token_ids, forward_all_logits_from_hidden, project_logits on Qwen3NextCausalLM; forward_stateless, forward_with_tape, replay_from_tape on GatedDeltaNet; pub struct GdnLayerTape; Metal kernel FFI (tape_replay_kernel_ffi, gated_delta_kernel_ffi_with_tape, gated_delta_kernel_ffi_stateless) +1657
1c3c0d37 AnyModel::{forward_with_taps, forward_with_taps_tape, embed_token_ids, forward_all_logits_from_hidden} dispatchers; AnyCache::{as_hybrid, as_hybrid_mut}; dflash::accept_prefix (16-line greedy spec-decode acceptance + 5 unit tests); engine::model_loader::load_dflash_drafter +193

Total: ~2.4K lines net new, no public-API regressions on origin/main.

What's NOT in this PR (deferred)

  • SimpleEngine::generate_dflash_inner — the draft-verify loop. The feat/magic-canvas glue assumed a cpu_engine + ANE executor that we're stripping (defer to PR-8 ANE work) and a struct shape that has since evolved on origin/main. Best done with a Carnice-9B + 0.5B drafter checkpoint loaded so we can verify the verify-path correctness, not just the compile.
  • dflash_cpu.rs (CPU BLAS drafter) — depends on 7 BLAS helpers from diffusion.rs (~9970 lines, mostly unrelated). Defer to PR-6b together with the engine glue.
  • dflash_ane.rs (ANE-accelerated drafter) — feature-gated; PR-8 territory.
  • DFlash test suite (~3.8K lines on feat/magic-canvas) — depends on the engine glue being live. Will follow.

Adaptations from feat/magic-canvasorigin/main

  • SteppingKeyValueCache::rollback(i32) was renamed trim_by(usize) on origin/main (PR feat(cache): AnyCache::trim_by dispatcher for spec-decode rollback panbanda/higgs#143). Call sites in dflash.rs and replay_tape_rollback converted with unsigned_abs().try_into().unwrap_or(usize::MAX).
  • Qwen3NextCausalLM's lm_head_ane, dense_lm_head, ane_handle, ane_kernels fields don't exist on this branch — ANE-feature paths in project_logits and forward_with_tape stripped to plain Metal/MLX. Fields ported in PR-8.
  • FFI error handler uses thread_local! RefCell<Option<String>> (matching this branch's existing FFI pattern) rather than feat/magic-canvas's Mutex<Option<String>>.

Senior-Rust hygiene

  • No file-level blanket allows added. Origin/main's pre-existing #![allow(clippy::items_after_test_module)] preserved.
  • Function-scoped #[allow(...)] only on four genuinely-numerical kernel functions (forward_stateless, forward_with_tape, forward_with_taps_tape, replay_tape_rollback), each with a one-line documented justification (Metal-kernel dispatch, tensor-shape indices, hot-path casts).
  • unwrap_used never allowed — refactored to ? propagation throughout.
  • All non-Qwen3Next match arms in AnyModel enumerate variants explicitly (no _ => catch-alls).
  • clippy::type_complexity resolved with pub type TapsTapeOutput.

Test plan

  • cargo check -p higgs-models — clean
  • cargo clippy --all-targets --all-features -- -D warnings — clean
  • cargo fmt --check — clean
  • cargo test -p higgs-models --lib — 335/335 pass (5 new accept_prefix tests)
  • cargo test -p higgs-engine --lib — 228/228 pass
  • cargo test -p higgs --lib -- --test-threads=1 — 449/449 pass
  • Engine-level: defer to PR-6b with real drafter checkpoint

Context

This is part of the feat/magic-canvas PR split. PRs already shipped against panbanda upstream:

🤖 Generated with Claude Code

dusterbloom and others added 4 commits May 5, 2026 12:33
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.

Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Two call sites converted with
    `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * Workspace clippy (nursery: `as_conversions`,
    `cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
    `explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
    `redundant_pattern_matching`, `missing_const_for_fn`) — all 30
    errors fixed in-place: `i32::try_from` for tensor-shape casts,
    `clone_from` for in-place clones, `filter_map(Option::as_mut)` for
    `iter().filter_map(if-let)` patterns, backticks on doc items.
    No file-level allows.

The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.

Methods added on `Qwen3NextCausalLM`:

  * `forward_with_taps` — forward returning logits AND vec of hidden
    states at specified target layers; the drafter conditions on these.

  * `forward_with_taps_stateless` — same, but does NOT mutate the
    recurrent (GDN) state. Used during verify when state advancement is
    handled separately.

  * `forward_with_taps_tape` — forward that records each GDN layer's
    innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
    rerun for partial-accept rollback.

  * `replay_tape_rollback` — restore GDN state to a tape position
    without re-running the full model.

  * `embed_token_ids` — apply the embedding layer alone (drafter input).

  * `forward_all_logits_from_hidden` — apply lm_head alone (target's
    verification of drafter outputs).

  * `project_logits` (private helper) — lm_head with origin/main's
    available projection paths only (ANE + dense_lm_head fields don't
    exist here yet; ported in PR-8).

Methods added on `GatedDeltaNet`:

  * `forward_stateless` — GDN forward without state mutation.

  * `forward_with_tape` — GDN forward that captures the per-step
    innovation into the tape.

  * `replay_from_tape` — apply a tape to recompute SSM state to a target
    position. Annotated `#[allow(dead_code)]` until the engine glue
    drives it (next commit).

New public type `GdnLayerTape` exposes the per-layer innovation record.

Metal kernel infrastructure ported alongside:

  * `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
  * `gated_delta_kernel_ffi_with_tape` + matching kernel
  * `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
    discards the new state, matches caller semantics in `forward_stateless`)

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
    converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
    don't exist on this branch — ANE-feature paths stripped to the
    plain Metal/MLX path. Fields ported in PR-8.

  * Error handler uses `thread_local! RefCell<Option<String>>` instead
    of feat/magic-canvas's `Mutex<Option<String>>` — matches the
    branch's existing FFI error pattern.

Senior-Rust hygiene:

  * No file-level blanket allows added.
  * Function-scoped `#[allow(...)]` on the four genuine numerical
    kernel functions (`forward_stateless`, `forward_with_tape`,
    `forward_with_taps_tape`, `replay_tape_rollback`), each with a
    one-line justification comment.
  * `unwrap_used` never allowed — refactored to `?` propagation or
    `expect("reason")` at the two call sites.
  * Mechanical clippy refactors throughout: `find_map` for
    `filter_map(..).next()`, `clone_from` for `assigning_clones`,
    `if let` for single-pattern `match`, backticks for `doc_markdown`.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader

Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.

`AnyModel` (in `higgs-models/src/lib.rs`):

  * `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
  * `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
    tap hiddens + per-layer GDN tape) via a public type alias to
    placate `clippy::type_complexity`
  * `embed_token_ids` — Qwen3Next-only
  * `forward_all_logits_from_hidden` — Qwen3Next-only

  All non-Qwen3Next arms enumerate every variant explicitly to satisfy
  `clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).

`AnyCache`:

  * `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
    slice/vec for engine glue that needs to inspect GDN state. Returns
    `Result<_, Exception>` rather than panicking when called on a `KV`
    cache, so the verify path in `SimpleEngine::generate_dflash_inner`
    can propagate via `?`.

`dflash::accept_prefix`:

  * Greedy speculative-decode acceptance: longest-matching prefix of
    `draft` against `verify_argmax`, plus one bonus token at the
    diverge point (or after the last accept).
  * 5 unit tests covering full match, first-token reject, partial
    match, empty draft, and the debug-only length assertion.
  * Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
    to avoid pulling in the 9970-line diffusion module for a 16-line
    helper.

`engine::model_loader::load_dflash_drafter`:

  * Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
    converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
    call site lands in the next commit.

Verification on origin/main:
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
  * `cargo test -p higgs-engine --lib` — 228/228 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match

CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:

  * `dflash.rs` — `GdnStateBackup::restore_and_rollback`
  * `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`

Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant