Skip to content

TEMP: CI validation for dflash-engine-glue (close after CI completes)#20

Closed
dusterbloom wants to merge 6 commits intomainfrom
dusterbloom/dflash-engine-glue
Closed

TEMP: CI validation for dflash-engine-glue (close after CI completes)#20
dusterbloom wants to merge 6 commits intomainfrom
dusterbloom/dflash-engine-glue

Conversation

@dusterbloom
Copy link
Copy Markdown
Owner

Purpose

Temporary PR for CI signal only. PR #18 (the real one) targets the stacked
base dusterbloom/dflash-baseline; the CI workflow only triggers on PRs to
main, so #18 has no CI runs. This PR exists so we can verify CI is green on
the bit-exact-AR-parity fix (e23415da).

Will be closed without merging as soon as CI completes.

What's in this diff vs main

Combined contents of the stacked branches:

  • PR-6a (dusterbloom/dflash-baseline, on dusterbloom: feat(dflash): DFlash drafter foundation — module + tap APIs + dispatch #17 — already CI-green): DFlash
    drafter foundation, tap APIs, GDN tape replay infrastructure
  • PR-6b (dusterbloom/dflash-engine-glue, on dusterbloom: feat(engine): DFlash draft-verify loop in SimpleEngine [DRAFT] #18 — draft):
    SimpleEngine::generate_dflash_inner MLX draft-verify loop
  • NEW: e23415da — switch verify to tape-replay path. The previous
    glue used forward_with_taps + full rerun on partial accept, which routed
    through the regular gated_delta_kernel_ffi whose S>1 recurrence is not
    bit-exact with K sequential S=1 calls. Switching to forward_with_taps_tape
    • replay_tape_rollback (using gated_delta_kernel_ffi_with_tape, which
      IS bit-exact) restores the canonical regression bench: first 200 generated
      tokens are byte-identical to AR baseline at T=0.

Local CI-equivalent gates (already passing)

Check Result
cargo fmt --all -- --check clean
cargo clippy --all-targets --all-features clean
cargo test --all-features -- --test-threads=1 335 passed, 0 failed

E2E validation

prompt: "The capital of France is" (T=0, seed=0, 200 tokens)
AR baseline:        " Paris.\nA. True\nB. False\nAnswer:..."
DFlash tape replay: " Paris.\nA. True\nB. False\nAnswer:..."
diff: BYTE-IDENTICAL

Test plan

🤖 Generated with Claude Code

dusterbloom and others added 6 commits May 5, 2026 12:33
Adds `crates/higgs-models/src/dflash.rs` from feat/magic-canvas — the
0.5B drafter that produces 16 draft tokens per round via a single
non-causal forward pass on hidden states tapped from 5 target-model
layers.

Architecture (8 decoder layers, dual-stream attention) is verbatim from
the magic-canvas baseline `c1f85ade` (final stable state, before WIP
ANE work). Wire-up into `SimpleEngine` lands in the follow-up commit.

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Two call sites converted with
    `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * Workspace clippy (nursery: `as_conversions`,
    `cast_possible_truncation`, `doc_markdown`, `assigning_clones`,
    `explicit_iter_loop`, `unnecessary_cast`, `shadow_unrelated`,
    `redundant_pattern_matching`, `missing_const_for_fn`) — all 30
    errors fixed in-place: `i32::try_from` for tensor-shape casts,
    `clone_from` for in-place clones, `filter_map(Option::as_mut)` for
    `iter().filter_map(if-let)` patterns, backticks on doc items.
    No file-level allows.

The original DFlash test suite (~3.8K lines, 30+ end-to-end tests)
depends on tap APIs (`forward_with_taps_tape`, `replay_tape_rollback`,
`forward_all_logits_from_hidden`) and `crate::diffusion::accept_prefix`
that aren't on `origin/main` yet. Tests are deferred to a follow-up PR
alongside the qwen3_next tap-API surface — there's a comment block at
the bottom of `dflash.rs` flagging this.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the model-side surface that the DFlash drafter speculates against —
hidden-state taps during forward, GDN innovation tape for cheap rollback,
and helpers for embedding lookup + lm_head application in isolation.

Methods added on `Qwen3NextCausalLM`:

  * `forward_with_taps` — forward returning logits AND vec of hidden
    states at specified target layers; the drafter conditions on these.

  * `forward_with_taps_stateless` — same, but does NOT mutate the
    recurrent (GDN) state. Used during verify when state advancement is
    handled separately.

  * `forward_with_taps_tape` — forward that records each GDN layer's
    innovation into a `GdnLayerTape`. Enables ~5ms replay vs ~30ms
    rerun for partial-accept rollback.

  * `replay_tape_rollback` — restore GDN state to a tape position
    without re-running the full model.

  * `embed_token_ids` — apply the embedding layer alone (drafter input).

  * `forward_all_logits_from_hidden` — apply lm_head alone (target's
    verification of drafter outputs).

  * `project_logits` (private helper) — lm_head with origin/main's
    available projection paths only (ANE + dense_lm_head fields don't
    exist here yet; ported in PR-8).

Methods added on `GatedDeltaNet`:

  * `forward_stateless` — GDN forward without state mutation.

  * `forward_with_tape` — GDN forward that captures the per-step
    innovation into the tape.

  * `replay_from_tape` — apply a tape to recompute SSM state to a target
    position. Annotated `#[allow(dead_code)]` until the engine glue
    drives it (next commit).

New public type `GdnLayerTape` exposes the per-layer innovation record.

Metal kernel infrastructure ported alongside:

  * `tape_replay_kernel_ffi` + `TAPE_REPLAY_KERNEL` static + Metal source
  * `gated_delta_kernel_ffi_with_tape` + matching kernel
  * `gated_delta_kernel_ffi_stateless` (thin wrapper over existing FFI;
    discards the new state, matches caller semantics in `forward_stateless`)

Adaptations from feat/magic-canvas → origin/main:

  * `SteppingKeyValueCache::rollback(i32)` was renamed `trim_by(usize)`
    on origin/main (PR panbanda#143). Call site in `replay_tape_rollback`
    converted with `unsigned_abs().try_into().unwrap_or(usize::MAX)`.

  * `lm_head_ane`, `dense_lm_head`, `ane_handle`, `ane_kernels` fields
    don't exist on this branch — ANE-feature paths stripped to the
    plain Metal/MLX path. Fields ported in PR-8.

  * Error handler uses `thread_local! RefCell<Option<String>>` instead
    of feat/magic-canvas's `Mutex<Option<String>>` — matches the
    branch's existing FFI error pattern.

Senior-Rust hygiene:

  * No file-level blanket allows added.
  * Function-scoped `#[allow(...)]` on the four genuine numerical
    kernel functions (`forward_stateless`, `forward_with_tape`,
    `forward_with_taps_tape`, `replay_tape_rollback`), each with a
    one-line justification comment.
  * `unwrap_used` never allowed — refactored to `?` propagation or
    `expect("reason")` at the two call sites.
  * Mechanical clippy refactors throughout: `find_map` for
    `filter_map(..).next()`, `clone_from` for `assigning_clones`,
    `if let` for single-pattern `match`, backticks for `doc_markdown`.

Verification on origin/main:
  * `cargo check -p higgs-models` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 330/330 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ader

Surfaces the qwen3_next tap APIs through the polymorphic `AnyModel`
enum so engine code can call them without matching variants directly,
adds the greedy speculative-decode acceptance helper, and exposes a
`load_dflash_drafter` entry point on the engine's model_loader.

`AnyModel` (in `higgs-models/src/lib.rs`):

  * `forward_with_taps` — dispatches Qwen3Next + Hybrid; errors otherwise
  * `forward_with_taps_tape` — same, returns `TapsTapeOutput` (logits +
    tap hiddens + per-layer GDN tape) via a public type alias to
    placate `clippy::type_complexity`
  * `embed_token_ids` — Qwen3Next-only
  * `forward_all_logits_from_hidden` — Qwen3Next-only

  All non-Qwen3Next arms enumerate every variant explicitly to satisfy
  `clippy::wildcard_enum_match_arm` (no `_ =>` catch-alls).

`AnyCache`:

  * `as_hybrid` / `as_hybrid_mut` — borrow the inner hybrid layer-cache
    slice/vec for engine glue that needs to inspect GDN state. Returns
    `Result<_, Exception>` rather than panicking when called on a `KV`
    cache, so the verify path in `SimpleEngine::generate_dflash_inner`
    can propagate via `?`.

`dflash::accept_prefix`:

  * Greedy speculative-decode acceptance: longest-matching prefix of
    `draft` against `verify_argmax`, plus one bonus token at the
    diverge point (or after the last accept).
  * 5 unit tests covering full match, first-token reject, partial
    match, empty draft, and the debug-only length assertion.
  * Inlined here rather than ported from `feat/magic-canvas:diffusion.rs`
    to avoid pulling in the 9970-line diffusion module for a 16-line
    helper.

`engine::model_loader::load_dflash_drafter`:

  * Thin `Result` adapter over `higgs_models::dflash::load_dflash_drafter`,
    converting `ModelError` → `EngineError`. The `SimpleEngine::load_with_dflash`
    call site lands in the next commit.

Verification on origin/main:
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs-models --lib` — 335/335 pass (5 new accept_prefix tests)
  * `cargo test -p higgs-engine --lib` — 228/228 pass
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass

The remaining piece — `SimpleEngine::generate_dflash_inner` (the
draft-verify loop wired into `generate_inner`) — lands as a follow-up
commit. It needs end-to-end verification against a real DFlash drafter
checkpoint (Carnice-9B + 0.5B drafter); shipping it without that
runtime test would risk silent correctness regressions in the verify
path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_match

CI's clippy is one minor version ahead of my local toolchain and flags
the `if rollback > 0 { ... }` body inside the `Some(LayerCache::KV(kv))`
match arm. Two call sites:

  * `dflash.rs` — `GdnStateBackup::restore_and_rollback`
  * `qwen3_next.rs` — `Qwen3NextCausalLM::replay_tape_rollback`

Convert to a match guard and add an explicit no-op arm for the
guard-fails-and-`None` case so the match is exhaustive without a
wildcard. No behaviour change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…loop

Wires DFlash speculative decoding into `SimpleEngine`. When a drafter is
loaded and no constraint or multimodal input is active, `generate_inner`
dispatches to `generate_dflash_inner` instead of the AR path:

  1. Prefill the target with the prompt, capturing tap hidden states
     at the drafter's `target_layer_ids`.
  2. Sample first token from prefill logits.
  3. Per round:
     a. Embed `[anchor, mask, mask, ...]` into the target's embed space
        via `AnyModel::embed_token_ids`.
     b. Drafter forward (MLX path) on `(noise_embedding, current_taps)`
        producing 16 candidate hidden states.
     c. Project candidate hiddens through target's lm_head via
        `AnyModel::forward_all_logits_from_hidden`, argmax to drafts.
     d. Verify input `[anchor, draft_0..draft_14]` through target's
        `forward_with_taps`, which advances both KV and GDN state.
     e. Save GDN state (`GdnStateBackup::save`) before verify so we
        can roll back on partial accept.
     f. Accept prefix via greedy `dflash::accept_prefix`.
     g. On partial accept: restore GDN, rerun accepted tokens to
        re-advance recurrent state and refresh tap hiddens.
  4. Continue until EOS, stop sequence, or max_tokens.

`DFlashState` (engine-private) holds the drafter, tap layer indices,
block size, and mask token id. Loaded by `SimpleEngine::load_with_dflash`
via either an explicit `--draft-model` path or the `HIGGS_DFLASH_PATH`
env var; thin-wrapped from `load`.

Adaptations from feat/magic-canvas a7e2737 → origin/main:

  * `cpu_engine` (CPU BLAS drafter) and `ane_executor` (ANE+CPU hybrid)
    fields STRIPPED. The ANE/CPU dispatch tri-branch in the original
    is collapsed to the unconditional MLX path. Will return in PR-6c
    once `dflash_cpu.rs` and `dflash_ane.rs` are ported.

  * `accept_prefix` imported from `higgs_models::dflash` (shipped in
    PR-6a) rather than `crate::diffusion::accept_prefix`.

  * `cache.as_hybrid()? / as_hybrid_mut()?` for `GdnStateBackup`
    instead of direct enum match (uses the `AnyCache` accessors from
    PR-6a; propagate via `?` if cache type doesn't match).

  * Bounds-checked slicing throughout — `verify_tokens.get(..n)` /
    `accepted.last().ok_or_else(...)` rather than `[..n]` /
    `.last().unwrap()`. Senior-Rust hygiene: no `unwrap()` on
    `Result`/`Option`, no `as` for sign-changing casts.

  * `i32::try_from` / `usize::try_from` for tensor-shape arithmetic
    where the source used `as` casts (cleaner for the
    `clippy::cast_sign_loss` direction).

Verification:

  * `cargo check -p higgs-engine` — clean
  * `cargo clippy --all-targets --all-features -- -D warnings` — clean
  * `cargo fmt --check` — clean
  * `cargo test -p higgs --lib -- --test-threads=1` — 449/449 pass
  * `cargo test -p higgs-engine --lib` — 228/228 pass
  * `cargo test -p higgs-models --lib` — 335/335 pass

Marked draft because the verify-path semantics need end-to-end
validation against a real Carnice-9B + 0.5B drafter pair before this
should land. Compile-time + clippy-time correctness are confirmed; the
cross-token verify-loop logic — particularly the GDN replay path on
partial accept — is the kind of code where silent bugs would only
show as quality regressions on long generations.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous engine glue used `forward_with_taps` for the verify pass and
a full `forward_with_taps` rerun on partial accept. Both paths route GDN
layers through `gated_delta_kernel_ffi`, whose S>1 recurrence is not
bit-exact with K sequential S=1 calls — argmax matches when logit gaps
are wide, flips on close calls. With the 4B Qwen3.5 target+drafter pair
on origin/main, DFlash diverged from AR baseline within ~5 generated
tokens (e.g. iter 3 verify_argmax[1] = '正确' instead of AR's ' True').

Switch the verify pass to `forward_with_taps_tape` (records per-position
GDN innovation tapes) and the partial-accept rollback to
`replay_tape_rollback`. Both route through
`gated_delta_kernel_ffi_with_tape`, which is bit-exact at S>1 because
innovations are recorded per position and the kernel does strict
sequential recurrence with no parallel-scan FP non-associativity. The
tape-recording infrastructure was already on this branch from 62a7352
(port of feat/magic-canvas d6daf3e); this commit wires the engine to
use it.

Validated against mlx-community/Qwen3.5-4B-MLX-4bit target +
z-lab/Qwen3.5-4B-DFlash drafter at T=0, seed=0:

  prompt: "The capital of France is"
  AR baseline (200 tok):     " Paris.\nA. True\nB. False\nAnswer:..."
  DFlash tape replay (200):  " Paris.\nA. True\nB. False\nAnswer:..."
  diff: BYTE-IDENTICAL

Changes:
- crates/higgs-models/src/lib.rs: add AnyModel::replay_tape_rollback
  dispatcher (Qwen3Next + Hybrid cache).
- crates/higgs-engine/src/simple.rs: replace verify forward and the
  partial-accept rollback branch in generate_dflash_inner; remove
  now-unused GdnStateBackup import. Add HIGGS_DFLASH_TRACE-gated
  per-iter trace useful for future debugging.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@dusterbloom
Copy link
Copy Markdown
Owner Author

CI passed (6/6 green). Closing — purpose served. Real PR is #18.

@dusterbloom dusterbloom closed this May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant