dusterbloom: feat(engine): DFlash draft-verify loop in SimpleEngine [DRAFT]#18
Open
dusterbloom wants to merge 2 commits intodusterbloom/dflash-baselinefrom
Open
dusterbloom: feat(engine): DFlash draft-verify loop in SimpleEngine [DRAFT]#18dusterbloom wants to merge 2 commits intodusterbloom/dflash-baselinefrom
dusterbloom wants to merge 2 commits intodusterbloom/dflash-baselinefrom
Conversation
…loop
Wires DFlash speculative decoding into `SimpleEngine`. When a drafter is
loaded and no constraint or multimodal input is active, `generate_inner`
dispatches to `generate_dflash_inner` instead of the AR path:
1. Prefill the target with the prompt, capturing tap hidden states
at the drafter's `target_layer_ids`.
2. Sample first token from prefill logits.
3. Per round:
a. Embed `[anchor, mask, mask, ...]` into the target's embed space
via `AnyModel::embed_token_ids`.
b. Drafter forward (MLX path) on `(noise_embedding, current_taps)`
producing 16 candidate hidden states.
c. Project candidate hiddens through target's lm_head via
`AnyModel::forward_all_logits_from_hidden`, argmax to drafts.
d. Verify input `[anchor, draft_0..draft_14]` through target's
`forward_with_taps`, which advances both KV and GDN state.
e. Save GDN state (`GdnStateBackup::save`) before verify so we
can roll back on partial accept.
f. Accept prefix via greedy `dflash::accept_prefix`.
g. On partial accept: restore GDN, rerun accepted tokens to
re-advance recurrent state and refresh tap hiddens.
4. Continue until EOS, stop sequence, or max_tokens.
`DFlashState` (engine-private) holds the drafter, tap layer indices,
block size, and mask token id. Loaded by `SimpleEngine::load_with_dflash`
via either an explicit `--draft-model` path or the `HIGGS_DFLASH_PATH`
env var; thin-wrapped from `load`.
Adaptations from feat/magic-canvas a7e2737 → origin/main:
* `cpu_engine` (CPU BLAS drafter) and `ane_executor` (ANE+CPU hybrid)
fields STRIPPED. The ANE/CPU dispatch tri-branch in the original
is collapsed to the unconditional MLX path. Will return in PR-6c
once `dflash_cpu.rs` and `dflash_ane.rs` are ported.
* `accept_prefix` imported from `higgs_models::dflash` (shipped in
PR-6a) rather than `crate::diffusion::accept_prefix`.
* `cache.as_hybrid()? / as_hybrid_mut()?` for `GdnStateBackup`
instead of direct enum match (uses the `AnyCache` accessors from
PR-6a; propagate via `?` if cache type doesn't match).
* Bounds-checked slicing throughout — `verify_tokens.get(..n)` /
`accepted.last().ok_or_else(...)` rather than `[..n]` /
`.last().unwrap()`. Senior-Rust hygiene: no `unwrap()` on
`Result`/`Option`, no `as` for sign-changing casts.
* `i32::try_from` / `usize::try_from` for tensor-shape arithmetic
where the source used `as` casts (cleaner for the
`clippy::cast_sign_loss` direction).
Verification:
* `cargo check -p higgs-engine` — clean
* `cargo clippy --all-targets --all-features -- -D warnings` — clean
* `cargo fmt --check` — clean
* `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
* `cargo test -p higgs-engine --lib` — 228/228 pass
* `cargo test -p higgs-models --lib` — 335/335 pass
Marked draft because the verify-path semantics need end-to-end
validation against a real Carnice-9B + 0.5B drafter pair before this
should land. Compile-time + clippy-time correctness are confirmed; the
cross-token verify-loop logic — particularly the GDN replay path on
partial accept — is the kind of code where silent bugs would only
show as quality regressions on long generations.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous engine glue used `forward_with_taps` for the verify pass and a full `forward_with_taps` rerun on partial accept. Both paths route GDN layers through `gated_delta_kernel_ffi`, whose S>1 recurrence is not bit-exact with K sequential S=1 calls — argmax matches when logit gaps are wide, flips on close calls. With the 4B Qwen3.5 target+drafter pair on origin/main, DFlash diverged from AR baseline within ~5 generated tokens (e.g. iter 3 verify_argmax[1] = '正确' instead of AR's ' True'). Switch the verify pass to `forward_with_taps_tape` (records per-position GDN innovation tapes) and the partial-accept rollback to `replay_tape_rollback`. Both route through `gated_delta_kernel_ffi_with_tape`, which is bit-exact at S>1 because innovations are recorded per position and the kernel does strict sequential recurrence with no parallel-scan FP non-associativity. The tape-recording infrastructure was already on this branch from 62a7352 (port of feat/magic-canvas d6daf3e); this commit wires the engine to use it. Validated against mlx-community/Qwen3.5-4B-MLX-4bit target + z-lab/Qwen3.5-4B-DFlash drafter at T=0, seed=0: prompt: "The capital of France is" AR baseline (200 tok): " Paris.\nA. True\nB. False\nAnswer:..." DFlash tape replay (200): " Paris.\nA. True\nB. False\nAnswer:..." diff: BYTE-IDENTICAL Changes: - crates/higgs-models/src/lib.rs: add AnyModel::replay_tape_rollback dispatcher (Qwen3Next + Hybrid cache). - crates/higgs-engine/src/simple.rs: replace verify forward and the partial-accept rollback branch in generate_dflash_inner; remove now-unused GdnStateBackup import. Add HIGGS_DFLASH_TRACE-gated per-iter trace useful for future debugging. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on top of #17 (PR-6a). Wires DFlash speculative decoding into
SimpleEnginevia a newgenerate_dflash_innermethod that drives the draft-verify loop end-to-end against the model+cache surface PR-6a established. This is PR-6b of the magic-canvas split.Marked DRAFT because the verify-path semantics — particularly the GDN replay on partial accept — need end-to-end validation against a real Carnice-9B + 0.5B drafter pair before merge. Compile-time and clippy-time correctness are confirmed; cross-token verify-loop logic is exactly the surface where silent quality regressions hide.
What's in this PR
Single commit, single file:
f3577aa6crates/higgs-engine/src/simple.rsstruct DFlashState+SimpleEngine::load_with_dflash+ dispatch intogenerate_inner+generate_dflash_inner(the draft-verify loop)What's NOT in this PR
dflash_cpu.rs(CPU BLAS drafter) — pulls in BLAS helpers fromdiffusion.rs. The MLX drafter forward path lands here; CPU/ANE backends in PR-6c.dflash_ane.rs(ANE-accelerated drafter) — feature-gated; PR-8 territory.feat/magic-canvas) — depends on this glue being live, will follow.Adaptations from
feat/magic-canvas:a7e2737cane_executor.forward(...)/cpu_engine.forward(...)/drafter.forward(...)based on optional fields. Here it's the unconditional MLX path. PR-6c will reintroduce the CPU/ANE dispatch when those modules are ported.accept_prefixfromhiggs_models::dflash(shipped in PR-6a) rather thancrate::diffusion::accept_prefix.cache.as_hybrid()? / as_hybrid_mut()?forGdnStateBackup::save/restore_and_rollback— uses theAnyCacheaccessors from PR-6a, propagates a cleanExceptionif cache type mismatches.load_with_dflashis 5-arg (origin/main'sloadalready grewtuning: MlxRuntimeTuningandraise_wired_limit: boolpost-feat/magic-canvas).Senior-Rust hygiene
unwrap()onResult/Optionrefactored throughout:accepted.last().ok_or_else(...),verify_tokens.get(..n).ok_or_else(...),i32::try_from(prompt_len)?,f64::from(u32::try_from(tokens.len()).unwrap_or(u32::MAX))etc.[..n]panics in the verify loop.#[allow(clippy::too_many_lines, clippy::cast_possible_truncation, clippy::cast_sign_loss, clippy::significant_drop_tightening)]ongenerate_dflash_inneronly, with one-line documented justification (long verify loop with unavoidable shape arithmetic and lock-held drafter forward).Test plan
cargo check -p higgs-engine— cleancargo clippy --all-targets --all-features -- -D warnings— clean (rustc 1.95.0, matches CI)cargo fmt --check— cleancargo test -p higgs --lib -- --test-threads=1— 449/449 passcargo test -p higgs-engine --lib— 228/228 passcargo test -p higgs-models --lib— 335/335 passQwen3.5-9Btarget +z-lab/Qwen3.5-9B-DFlashdrafter viahiggs serve --draft-model <path>, generate 200 tokens at T=0, confirm coherent output and ≥21 tps decode (target ≈22 per the magic-canvas baseline).a7e2737coriginally fixed — same canonical regression bench).🤖 Generated with Claude Code