feat: CUDA/NVIDIA port — Qwen3.5-397B on single GPU at 5.35 tok/s (5.86 peak)#7
feat: CUDA/NVIDIA port — Qwen3.5-397B on single GPU at 5.35 tok/s (5.86 peak)#7ssubbotin wants to merge 24 commits intodanveloper:mainfrom
Conversation
Complete CUDA inference engine that runs the full 397B parameter MoE model on a single RTX 4090 (24GB VRAM) + 64GB RAM + NVMe SSD. Key components: - cuda_infer/infer.cu: Full inference engine (~1400 lines) Model loading (mmap + GPU upload), 60-layer forward pass, GatedDeltaNet linear attention, full attention with KV cache, MoE routing + expert SSD streaming, tokenizer integration. - cuda_infer/kernels.cuh: 15 CUDA kernels ported from Metal FMA-optimized 4-bit dequant matvec, SwiGLU, RMS norm, attention (Q@K^T, softmax, scores@V), GatedDeltaNet recurrence, conv1d, MoE combine+residual. - bench_transfer.cu: Transfer path benchmarks Measured GDS (5.3ms), pread+cudaMemcpy (8.3ms), warm cache (2.7ms) per layer for K=4 experts. Performance: 2.45 tok/s (RTX 4090, Samsung 990 EVO Plus, PCIe 4.0 x4) Comparison: requires only 64GB RAM vs 256-384GB for llama.cpp/KTransformers NVIDIA GPUDirect Storage (GDS) enables direct NVMe-to-GPU DMA, providing 37% speedup over traditional pread+cudaMemcpy path.
|
Correction: works on 16Gb RAM + 16Gb VRAM Actual running process utilizes 6 Gb RAM and 6 Gb VRAM for running 397B MoE |
Add --serve PORT mode to the CUDA inference engine. Implements: - POST /v1/chat/completions with SSE streaming (token-by-token) - GET /v1/models (OpenAI model list) - GET /health (status check) - CORS headers for browser clients ChatML tokenization for user messages, state reset between requests. Tested at 2.68 tok/s streaming via curl.
Add tool/function calling to the HTTP server:
- Accept "tools" array in /v1/chat/completions requests
- Inject tool definitions into prompt using Qwen Hermes format
- Parse <tool_call> tags from model output
- Return OpenAI-compatible tool_calls SSE chunks
- Handle tool results via role="tool" messages
- Build full ChatML conversation from messages array
Tested: model correctly calls get_weather({"location": "Tokyo"})
when given the tool definition and asked about weather.
Known issues: model doesn't stop after tool call, special tokens
leak into content stream. Will fix in follow-up.
- Stop generation immediately after </tool_call> is detected (was continuing to generate 200 tokens after the tool call) - Filter special tokens by ID (151643-151654) and by decoded text (<|im_end|>, <|im_start|>, <|endoftext|>, <think>/</ think>) - Stop on <|im_end|> in decoded text (model generates these as regular tokens, not just special token IDs) - Clean output: "Hello there, friend!" with finish_reason="stop" - Tool calls: immediate stop with finish_reason="tool_calls"
…ments Update cuda_infer/README.md with: - HTTP server usage (--serve PORT) - Tool calling examples with curl - Sending tool results back (multi-turn tool use) - Claude Code integration via litellm proxy - OpenAI Python SDK, aider, continue.dev examples - Custom system prompt (~/.flash-moe/system.md) - Corrected RAM requirements: 16GB min, 32GB recommended (process uses only 5.5GB; GDS bypasses RAM for expert data)
Add POST /v1/messages endpoint implementing the Anthropic Messages API with SSE streaming, eliminating the need for a litellm proxy. Supports: - message_start/content_block_start/content_block_delta/content_block_stop/ message_delta/message_stop event sequence - Text content blocks with text_delta streaming - Tool use: tool_use content blocks with input_json_delta - stop_reason: "end_turn" for normal completion, "tool_use" for tool calls - System prompt as top-level field - Array content blocks (text + tool_result) - Anthropic tool format (input_schema) Both APIs now available simultaneously: POST /v1/chat/completions (OpenAI format) POST /v1/messages (Anthropic format) Tested: basic chat and tool calling both produce correct Anthropic SSE event streams at 2.6-2.8 tok/s.
System prompt pre-caching: - Tokenize and prefill system prompt at server startup (~4s) - Snapshot all 60 layers of KV cache + delta-net + conv state - Restore from snapshot on each request instead of resetting to zero - Saves ~4s per request (no more re-prefilling system prompt) Fixed special token IDs for this model (MLX 4-bit quantization): - <|endoftext|> = 248044 (was 151643) - <|im_start|> = 248045 (was 151644) - <|im_end|> = 248046 (was 151645) - <think>/</ think> = 248068/248069 Prompt builders now only generate user turn content since system prompt is already in the KV cache from the snapshot. Custom system prompt: ~/.flash-moe/system.md (loaded at startup)
Keep KV cache and attention state across requests in the same session: - Pass "session_id" in request body to maintain conversation state - Same session_id: continue from where the last response ended (no re-prefill) - Different/no session_id: restore from system prompt snapshot (new conversation) - Single active session at a time (one GPU = one conversation) - Also supports x-session-id header for Anthropic endpoint Tested: Turn 1 "My name is Alice" → Turn 2 (same session) "What is my name?" → "Your name is Alice." New session → "I don't know your name yet!" Also fixed special token IDs for MLX 4-bit model: <|endoftext|>=248044, <|im_start|>=248045, <|im_end|>=248046
Add detailed per-layer timing when --timing flag is used: norm, attn, oproj, route, shared, io, expert, combine Measured on RTX 4090 + Samsung 990 EVO Plus (PCIe 4.0 x4): norm=0.02 attn=0.28 oproj=0.02 route=0.04 shared=0.04 io=5.79 expert=0.13 combine=0.01 ms/layer Key finding: 87% of per-layer time is SSD I/O (5.8ms). GPU compute is only 0.5ms — pipelining across layers would save at most 8%, not worth the complexity.
GDS bypasses the OS page cache, leaving 58GB of RAM unused. pread populates the page cache, so hot experts stay in RAM (~3ms) instead of always hitting SSD (~5.3ms via GDS). Measured improvement with warm cache: pread + page cache: 2.52 tok/s (best burst: 4.56 tok/s) GDS direct: 2.41 tok/s (constant, no cache benefit) GDS is still available via ENABLE_GDS=1 env var for systems with less than 32GB RAM where page cache isn't beneficial. Page cache grows to ~50GB during sustained generation, caching roughly half the 203GB expert data and accelerating repeat accesses.
LRU cache of recently-used experts in GPU VRAM. Uses ~17GB of the 24GB RTX 4090 VRAM (remaining after model weights + scratch buffers). Holds ~2,500 experts; after a few requests, ~95% of expert accesses hit the cache and skip SSD/page-cache entirely. Three-tier caching hierarchy: 1. VRAM cache (~17GB): instant access, LRU eviction 2. OS page cache (~50GB): pread populates it, ~10 GB/s 3. NVMe SSD: cold misses only, ~5-7 GB/s Performance progression in server mode: Request 1 (cold): 2.49 tok/s Request 2 (warm): 3.22 tok/s (+29%) Request 3: 3.24 tok/s (+30%) Request 4 (hot): 3.55 tok/s (+43%) Cache misses use async D2D copy to fill the VRAM slot in the background while expert forward runs from the temp buffer. Set DISABLE_VRAM_CACHE=1 to disable (saves 17GB VRAM for other uses).
Three optimizations combined: 1. Frequency-weighted VRAM cache eviction: - Eviction score = access_count * FREQ_WEIGHT + last_used - Hot experts (high access_count) survive topic changes - Pure LRU peak: 4.74 tok/s → freq-weighted peak: 5.86 tok/s 2. uint4 vectorized loads in dequant kernel: - Load 128 bits (4 × uint32 = 32 nibbles) per instruction - #pragma unroll over 4 words for better instruction scheduling - __ldg() intrinsic for read-through L1 cache on weights/scales 3. Eliminated all runtime divisions and branches: - All /8 /64 /4 *8 → bit shifts (>>3 >>6 >>2 <<3) - Removed if-branch in launch helper (vec4 always used) - More consistent execution: 5.12-5.86 range vs 5.01-6.30 Performance progression: Original (GDS): 2.45 tok/s + page cache: 2.52 tok/s (+3%) + VRAM cache (pure LRU): 3.55 tok/s (+45%) + freq-weighted LRU: 4.74 tok/s peak + vec4 + shifts + __ldg: 5.35 tok/s avg, 5.86 peak (+118%) Now 23% faster than Apple Silicon version (4.36 tok/s).
Paper (paper/flash_moe_cuda.tex): - Expanded Related Work to 17 references (PowerInfer, Pre-gated MoE, DeepSpeed-MoE, S-LoRA, LRFU, ARC, Mixtral, DeepSeek-V3, etc.) - Positioned against PowerInfer hot/cold partitioning - Clarified "sustained" → "steady-state" with cold-start numbers - Labeled RTX 2080 Ti virtualized storage as non-comparable - Paper now 7 pages, IEEE two-column format Review (paper/flash_moe_cuda_review.md): - Full 5-reviewer peer review with editorial decision - Revision roadmap with 7 required + 7 suggested items Code (cuda_infer/infer.cu): - Added expert logging for profiling (EXPERT_LOG env var)
…urve, S5 kernel metrics
R4: Expert profiling expanded to 1,290 tokens across 3 diverse prompts
(science, code, creative). 309,600 routing decisions confirm:
26.6% temporal locality, 0.8% cross-layer correlation (stable).
S1: W parameter sensitivity — tested W=0,1,5,10,20,50.
All W>=1 within 2% of each other (4.80-4.94 tok/s).
Not sensitive — any W>=1 works.
S3: Working set curve (cache hit rate vs size) from 1290-token data.
Static top-N: 500 experts=20%, 2500=48.6%.
Runtime LRU achieves 95% because active working set is smaller.
S5: CUDA kernel metrics from ncu profiling:
28% DRAM throughput, 16-56% occupancy, 37 regs/thread.
Also added context-length degradation data (2.55→1.86 tok/s
over 10 sequential requests with growing context).
Paper now 8 pages with all reviewer-requested data.
All model constants now guarded with #ifndef, allowing override via -D flags at compile time. Expert offsets computed from dimensions instead of hardcoded. Added configure.py: reads model_weights.json config section and generates the correct nvcc -D flags or a per-model Makefile. Workflow for any MoE model: python3 configure.py --manifest model_weights.json --print-cmd # outputs: nvcc -DHIDDEN_DIM=3072 -DNUM_LAYERS=48 ... Default build (no -D flags) targets Qwen3.5-397B-A17B. Each model gets its own binary with exact-sized arrays — no wasted memory from MAX_LAYERS or runtime indirection.
Add dequant_matvec_q4k kernel for GGML Q4_K quantization format, enabling direct use of GGUF model files without format conversion. Q4_K format: 256-element super-blocks with packed 6-bit scales, fp16 super-block scale/min, 4-bit quantized values. Optimizations applied: - Precompute all 8 scale/min pairs (no branch in inner loop) - uint32 loads for qs array (4 bytes = 8 nibbles per load) - FMA optimization: fma(nibble, ds*x, -ms*x) - __ldg() for read-through L1 cache - All divisions replaced with bit shifts - Full #pragma unroll Benchmark vs MLX affine 4-bit (RTX 4090): gate/up [1024, 4096]: 1.06x (near parity) routing [512, 4096]: 1.08x (near parity) lm_head [248320, 4096]: 1.34x down [4096, 1024]: 1.70x (narrow input, few blocks/row) Net impact: ~5% throughput reduction vs MLX format. GGUF users skip the 209GB safetensors download.
|
How many experts are you using compare with llama.cpp? |
repack_experts.py no longer has hardcoded sizes for Qwen3.5-397B. Component sizes, expert count, and layer count are auto-detected from expert_index.json at runtime. Works for any MoE model: python3 build_expert_index.py --model /path/to/safetensors --output index.json python3 repack_experts.py --index index.json Tested formats: Qwen3.5-397B-A17B: 512 experts, 7,077,888 bytes/expert Qwen3.5-122B-A10B: 256 experts, different dimensions (auto-detected)
Added Section 4.2: llama.cpp vs Flash-MoE on identical RTX 4090 + 64GB RAM. Same model (Qwen3.5-397B at 4-bit), same prompt: Flash-MoE CUDA (warm): 5.35 tok/s, 5.5 GB RAM llama.cpp -ngl 99: OOM (228GB > 24GB VRAM) llama.cpp -ngl 0: <0.05 tok/s (2h+ for 20 tokens, 54GB RAM) The comparison demonstrates Flash-MoE's fundamental advantage: expert-level streaming with VRAM caching vs whole-model mmap. When the model doesn't fit in RAM, llama.cpp falls back to OS paging which thrashes catastrophically. Flash-MoE streams only the active experts (~27MB/layer) and caches hot ones in VRAM.
- Add 5-run measurements with std dev (5.57 ± 0.12 tok/s, n=15) - Update warm-up table with honest diverse-prompt data - Add Limitations subsection (batched serving, warm-up, W, multi-GPU) - Cite all 17 references in text (Mixtral, DeepSeek-V3, MoE-Gen, FloE) - Add AI disclosure statement - Add measurement methodology note - Reduce em dash density, add Table 2 footnote Co-Authored-By: Sergey Subbotin <ssubbotin@gmail.com>
|
Both Flash-MoE and llama.cpp activate the same number of experts — K=4 out of 512 per layer (plus 1 shared expert). This is determined by the model's router, not the inference engine. The difference is how those experts are loaded:
Flash-MoE only reads the 4 active experts each layer needs. llama.cpp memory-maps the entire 228 GB GGUF file, so the OS has to page in/out continuously with only 64 GB of physical RAM — that's why it thrashes and gets <0.05 tok/s. The VRAM expert cache (~17 GB, ~2565 expert slots) means ~95% of expert accesses hit GPU memory at 1008 GB/s after warm-up, which is where the 5.57 tok/s comes from. |
- Switch from twocolumn to single-column (tables no longer overflow) - Replace all 11 prose em dashes with commas, parentheses, semicolons, or colons for cleaner academic style Co-Authored-By: Sergey Subbotin <ssubbotin@gmail.com>
|
Looks great, I definitely know nothing about that, so if active 10 experts, I think the speed would be around 1 tg/s ? Also, seems it's highly dependent on ssd speed, but in that case it actually hit the gpu vram speed first, theoretically, i think it is possible to store some experts, or even with some diff, assuming it actually can predict and prepare for the next layer? Again, I have no clue and no knowledge around this, and you did great. |
Summary
Complete CUDA inference engine that runs Qwen3.5-397B-A17B on a single NVIDIA GPU, streaming 209GB of expert weights from NVMe SSD at 5.35 tokens/second (RTX 4090, 5.86 peak).
Port of the Metal/Apple Silicon engine to x86/NVIDIA hardware with significant enhancements:
/v1/chat/completions) and Anthropic (/v1/messages) APIs<tool_call>parsing and OpenAI/Anthropic response formatsMulti-Hardware Benchmarks
VRAM Cache Warm-Up (RTX 4090)
Comparison with Other Solutions
*KTransformers numbers for Qwen3-235B; 397B single-GPU not published.
Key Architecture Decisions
__ldg(). All divisions eliminated. Consistent 5+ tok/s.Features
--serve PORT): OpenAI + Anthropic SSE streaming<tool_call>parsing, OpenAItool_calls/ Anthropictool_usesession_idmaintains conversation state/v1/messages, just setANTHROPIC_BASE_URL--timingshows phase breakdownFiles
cuda_infer/infer.cu-- Complete engine + HTTP server (~2000 lines)cuda_infer/kernels.cuh-- 15 CUDA kernels (~570 lines)cuda_infer/README.md-- Full documentationbench_transfer.cu-- Transfer path benchmarksTest Plan