feat(dflash): MoE 35B-A3B support + DDTree CUDA graph reuse#39
Open
dusterbloom wants to merge 11 commits intoLuce-Org:mainfrom
Open
feat(dflash): MoE 35B-A3B support + DDTree CUDA graph reuse#39dusterbloom wants to merge 11 commits intoLuce-Org:mainfrom
dusterbloom wants to merge 11 commits intoLuce-Org:mainfrom
Conversation
Add MoE tensor fields to TargetLayer (ffn_gate_inp, ffn_up_exps, ffn_gate_exps, ffn_down_exps, shared expert tensors) and MoE hparams to TargetWeights (n_expert, n_expert_used, expert_ff_dim, shared_ff_dim). Update load_target_gguf() to accept both qwen35 (dense) and qwen35moe architectures with separate validation paths. Add smoke_load_moe_target test that loads Qwen3.6-35B-A3B and validates all 40 layers, 256 experts, 10 full-attn + 30 delta-net layers. No regression: 27B loader still passes smoke_load_target.
…cycle 2) Add build_moe_ffn() implementing full qwen35moe FFN path: - Softmax gating over 256 experts, top-8 selection - Per-expert SwiGLU via ggml_mul_mat_id - Weight normalization and aggregation - Shared expert path with sigmoid gating (ffn_gate_inp_shexp) Tested with smoke_moe_ffn on Qwen3.6-35B-A3B: valid output, no NaN/Inf, correct shape [2048, 1].
…rd test (cycle 3) Replace all q35:: namespace constants with runtime reads from TargetWeights so the same graph builder handles both 64-layer 27B and 40-layer 35B-A3B MoE. Dynamic CAPTURE_LAYERS computation, MoE FFN branch, and dynamic cache sizing. Full forward smoke test passes for both models with no regressions.
…ycle 4) Add DraftHparams struct with config.json parsing for layer count, hidden size, attention dims, and YaRN RoPE scaling params. Parameterize draft loader and graph builder to handle both 5-layer 27B and 8-layer 35B-A3B drafts. YaRN RoPE with factor=64, beta_fast=32, beta_slow=1 supported. Both draft models pass forward smoke tests with no regressions.
…le 5) Replace all DFLASH27B_TARGET_HIDDEN/VOCAB/DRAFT_BLOCK_SIZE/N_TARGET_LAYERS macro usages in test_dflash.cpp and smoke_draft_graph.cpp with runtime reads from loaded model weights. Enables the speculative decoding loop to run with both 64-layer 27B and 40-layer 35B-A3B MoE models.
Reshape sh_gate/sh_up to 2D and sh_down to 2D before shared expert gating broadcast, fixing ggml_can_repeat assertion when n_tokens > 1. Chain speculative decoding: 78 tok/s, DDTree: 14 tok/s on RTX 3090.
Remove unnecessary ggml_repeat in shared expert gating (use ggml_mul broadcast instead). Add ggml_gallocr_reserve for graph buffer reuse and parameterize test_generate for both model sizes. Benchmarks on RTX 3090 (target-only decode): 27B Q4_K: 35.3 tok/s (llama.cpp: 36.5, gap: -3.3%) 35B-A3B: 64.5 tok/s (llama.cpp: 85.0, gap: -24.1%)
The MoE draft model (factor=64 YaRN) was using attn_factor=1/(64^2)=1/4096 as the flash attention scale, making attention 4096x too weak. The Python reference uses standard 1/sqrt(head_dim) — YaRN correction belongs only in the RoPE cos/sin multipliers (ggml_rope_ext mscale param), not the attention scale. Also fixed the RoPE mscale from 1/factor^2 to the correct YaRN formula: 1/(0.1*ln(factor)+1) = 0.706 for factor=64. HumanEval DDTree benchmark (RTX 3090, budget=22): MoE 35B-A3B: 19.1 -> 53.0 tok/s (2.8x improvement) 27B: 81.2 tok/s (no change, factor=1 unaffected)
Enable CUDA graphs and rewrite test_generate with fixed-graph architecture: - K/V written to fixed scratch slot (max_ctx-1), copied to correct position after compute so graph structure never changes between decode steps - F16 attention mask input for variable-length causal attention - ggml_argmax in graph eliminates GPU→CPU logits transfer per step - CUDA graph replay eliminates ~1000 kernel launches per decode step Results on RTX 3090 (Qwen3.6-35B-A3B Q2_K): MoE AR: 64.5 → 143.8 tok/s (+123%, now 1.7x faster than llama.cpp) 27B AR: 35.3 → 41.9 tok/s (+19%) 27B DDTree: 85.4 → 83.3 tok/s (no regression)
Add build_target_step_tree_reusable() with fixed kv_start and n_tokens so CUDA graphs can replay across DDTree decode steps. K/V and target features are written to scratch slots (max_ctx - budget - 1 .. max_ctx - 2) and copied to committed positions after verify. Results on RTX 3090: MoE DDTree: 53.8 → 55.8 tok/s (+3.7%, limited by 22% acceptance) 27B DDTree: 83.3 → 80.8 tok/s (no regression, within noise)
# Conflicts: # dflash/deps/llama.cpp # dflash/src/internal.h # dflash/src/qwen35_target_graph.cpp # dflash/test/test_dflash.cpp
9fc1226 to
c86ec86
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Qwen3.5/3.6 35B-A3B MoE target support to the dflash spec-decode path, plus performance work on the DDTree verify graph and MoE AR.
MoE 35B-A3B (5 cycles)
Perf / fixes
Merge with main
This branch was synced with `origin/main` (32 upstream commits, including layer-segmented prefill + sliding-window FA, Blackwell/NVFP4 megakernel, and `--fa-window` CLI). Conflicts in `internal.h`, `qwen35_target_graph.cpp`, and `test_dflash.cpp` were resolved by:
Verified end-to-end on Qwen3.6-27B-Q4_K_XL + dflash-3.6 drafter:
Side note: comparison with the buun-llama-cpp fork
While testing this branch we ran the same prompt on https://github.com/spiritbuun/buun-llama-cpp (`Qwen3.6-27B-DFlash-GGUF` linear-chain spec-decode in upstream llama.cpp) for cross-implementation calibration:
Two takeaways relevant to lucebox:
Drafter weights are a real lever. Holding the runtime constant (buun chain), spiritbuun's drafter delivers +7.6pt accept and ~+50% tps over z-lab's at F16. spiritbuun appears to have re-trained / fine-tuned on top of the z-lab release rather than just quantising it.
DDTree budget=16 is faster than the default budget=22 on this prompt (109 vs 97 tps) — fewer redundant tree branches, slightly higher per-step accept (48.1% vs 45.5%). Worth considering as the default for short-context code-shaped prompts. Budgets ≤ block_size (16) crash with a ggml shape assertion in test_dflash.
We attempted to add a linear-chain mode in lucebox via `--fast-rollback` (no `--ddtree`) but it consistently produced ~42% accept on the same drafter — substantially worse than buun's chain at 78.5% with the same weights. We've left that investigation on a separate branch (`session-debug-2026-04-26` on the fork) along with a new `test_chunked_vs_seq.cpp` regression that exercises `build_delta_net_chunked` against `ggml_gated_delta_net` and a scalar-C++ reference. The test currently fails at n_tokens=16 for all three paths against each other — so the disagreement is not uniquely a lucebox bug, but the test is a useful starting point for future GDA correctness work.
Test plan