-
Notifications
You must be signed in to change notification settings - Fork 0
v0.2.20: Fused NN Kernels + Flash Attention 3 SM120 + FP8 Block-Scale MMA #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add GPU-accelerated 1D convolution to replace CPU fallback in Whisper ASR encoder. Changes: - Add native/ops/conv/conv1d_kernels.cuh: F32/BF16/F16 kernels - Add native/ops/conv/conv1d.cu: Dispatcher with dtype validation - Add native/bindings/nn/conv.cpp: pybind11 bindings - Add src/pygpukit/ops/conv.py: Python API with CPU fallback - Update Whisper encoder to use native conv1d Performance: Eliminates GPU->CPU->GPU roundtrip per audio frame. Correctness: Max diff vs NumPy reference < 5e-7. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add security example with Meta's Llama Guard 3 model for content moderation: - MLCommons hazard taxonomy (S1-S14 categories) - User input and agent response classification - Interactive and batch classification modes - Greedy decoding for deterministic safety classification 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add LLaMA 4 model implementation with native CUDA kernels - Update CMakeLists.txt and bindings for LLaMA 4 ops Note: LLaMA 4 kernels are monolithic and need refactoring to follow modular nn/ structure (see issue) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implement FA3 with WMMA tensor core acceleration: - WMMA-based score computation (Q @ K^T) - WMMA-based output computation (P @ V) - Vectorized memory loads (float4) - Warp-level softmax with shuffle reductions Benchmark results (RTX 5090, 32 heads, head_dim=128): - seq_len=128: FA3 1.02x vs SDPA - seq_len=512: FA3 1.03x vs SDPA - seq_len=1024: FA3 0.99x vs SDPA - seq_len=2048: FA3 1.01x vs SDPA All correctness tests pass (mean relative error < 2%). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add reusable TMA (Tensor Memory Accelerator) utilities: - tma_utils.cuh: CUtensorMap descriptor creation, async copy ops - barrier_init/arrive/wait for mbarrier synchronization - tma_load_2d/3d for async global->shared transfers - Support for BF16, FP16, FP32 data types - 128B swizzle for bank-conflict-free access - warp_scheduler.cuh: Producer/consumer warp specialization - WarpRole enum and detection helpers - Warpgroup utilities for WGMMA - Named barriers for SM90+ - FA3Config/GemmConfig presets - pipeline.cuh: Multi-stage async pipeline management - Pipeline<N> template for N-stage buffering - DualBufferPipeline optimized 2-stage - PipelineBuffer shared memory manager These utilities enable TMA-based optimization for: - Flash Attention 3 - Persistent GEMM - Any kernel needing async global->shared transfers Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add flash_attention_3_tma.cuh with: - TmaSharedMemory: Multi-stage K/V buffers with mbarrier - TmaFA3Config: Warp-specialized configuration (4 producer, 8 consumer) - Producer functions: TMA async bulk tensor loads - Consumer functions: WMMA-based score and output computation - 4-stage pipeline for K/V prefetching Architecture: - Producer warps (0-3): Issue TMA loads for K/V tiles - Consumer warps (4-11): Compute attention scores and output - mbarrier synchronization between stages NOTE: Requires Python bindings to create CUtensorMap descriptors. This is a WIP - kernel compiles but not yet callable from Python. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add TMA FA3 environment control (PYGPUKIT_FA3_TMA) - Create TMA descriptor launcher function for Q/K/V tensors - Integrate TMA path into sdpa_causal_dispatch before regular FA3 - Fix TMA kernel to use 3D loads for 3D tensor descriptors - Add benchmark script for TMA vs baseline comparison Benchmark results (RTX 5090, SM 120a): - [32, 512, 128]: Baseline 2090us, TMA 2170us (0.96x) - [32, 1024, 128]: Baseline 7175us, TMA 7187us (1.00x) - [32, 2048, 128]: Baseline 27165us, TMA 27125us (1.00x) - [32, 4096, 128]: Baseline 93848us, TMA 93444us (1.00x) Correctness: PASS (results match baseline) Note: TMA kernel is functional but not yet optimized for speedup. Future work: warp specialization tuning, swizzle patterns. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Bug: TMA FA3 kernel hung at 256+ blocks due to __syncthreads() inside consumer-only code path. Producer warps never reached sync. Fix: - Split consumer_compute_output() into two functions: - convert_scores_to_probs(): ALL threads participate (has syncs) - consumer_compute_output_matmul(): consumers only (no syncs) - Reduce TILE_Q 64->32 and NUM_STAGES 4->2 for 99KB smem limit - Use union for smem_scores/smem_probs to save 8KB Benchmark (RTX 5090, 32 heads): - seq_len=512: 6.6ms, 0.65 TFLOPS - seq_len=1024: 25.8ms, 0.66 TFLOPS - seq_len=2048: 99.2ms, 0.69 TFLOPS - seq_len=4096: 387.5ms, 0.71 TFLOPS Correctness: PASS (matches FA3 baseline) Next: Parallelize softmax across query positions for 8-32x speedup Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes: 1. Warp-parallel softmax: Each consumer warp handles different q rows - 8 warps process 8 rows simultaneously (was: all warps on same row) - Purely warp-synchronous with no __syncthreads() inside 2. Fix consumer warp indexing bug in matmul functions: - consumer_compute_scores: use consumer_warp_idx (0-7) not global warp_id (4-11) - consumer_compute_output_matmul: same fix - Ensures all tiles are computed (was missing tiles 0-3) 3. Direct BF16 softmax output: - Softmax writes BF16 directly to smem_probs - Eliminates convert_scores_to_probs function call - Saves 2 __syncthreads() per iteration Sync point analysis (after optimization): - 5 syncs per iteration (was 7): 1. After barrier_wait (TMA data visible) 2. After Q@K (scores ready for causal mask) 3. After causal mask (scores ready for softmax) 4. After softmax (probs ready for P@V) 5. End of iteration (next TMA) Benchmark (RTX 5090, 32 heads): - Performance: ~0.65-0.71 TFLOPS (similar to baseline) - Correctness: PASS Note: Performance unchanged suggests bottleneck is elsewhere (WMMA efficiency, memory bandwidth, or 1 block/SM occupancy). Next optimization: wgmma instructions for SM120a. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Root cause: Union between smem_scores (float) and smem_probs (bf16) caused a race condition when multiple warps processed different Q rows in parallel. Warp B writing to smem_probs[row_B] could corrupt smem_scores[row_A] that Warp A was still reading. Fix: Two-phase softmax approach - Phase 1: ALL warps read scores, compute probs, store to REGISTERS - Phase 2: After __syncthreads(), ALL warps write probs to smem_probs Also includes: - TMA descriptor cache for reduced host-side overhead (99.4% hit rate) - cudaEvent-based kernel timing for accurate benchmarks - Proper handling of fully-masked rows (causal attention edge case) Benchmark results (RTX 5090, SM120a): - seq_len=1024: 51.21 TFLOPS (kernel-only) - seq_len=2048: 59.86 TFLOPS (kernel-only) - Correctness: PASS (max_diff=0.0) - Determinism: PASS (all runs identical) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 1 implementation identical to FA3 TMA structure. This establishes the baseline for NVFP4 integration in Phase 2/3. Benchmark results (RTX 5090, seq_len=1024, 32 heads, 128 head_dim): - Kernel-only: 335.6 us (51.19 TFLOPS) - E2E cached: 368.1 us (46.67 TFLOPS) - Correctness: PASS (max diff = 0 vs FA3) Files added: - flash_attention_4_sm120.cuh: FA4 kernel with config structs for all phases - benchmark_fa4_sm120.py: Benchmark script with correctness verification - fa4_sm120_research.md: SM100 vs SM120 architecture research Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 2 validates the NVFP4 GEMM path for attention scores. Benchmark results (RTX 5090, seq_len=1024, single head): - NVFP4 Q@K^T: 394.0 us (0.68 TFLOPS) - Correctness: 21% rel_diff vs NumPy (acceptable for 4-bit) Key finding: NVFP4 GEMM optimized for large K (LLM weights), not attention's small K=128 (head_dim). CUTLASS uses K=256 tiles. For comparison: - Full FA3 TMA (32 heads): 330.9 us (51.92 TFLOPS) NVFP4 benefit in attention comes from memory bandwidth (4x smaller loads), not compute throughput. Full integration requires PTX inline assembly for mma.sync.aligned.block_scale. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 3 Results (RTX 5090, seq_len=1024): - P@V (K=seq_len=1024): 94.7 us (2.84 TFLOPS) - Q@K^T (K=head_dim=128): 353.3 us (0.76 TFLOPS) - Larger K speedup: 3.73x (better tile utilization) Key Findings: 1. NVFP4 CUTLASS GEMM uses K=256 tile size, suboptimal for head_dim=128 2. P (softmax output) CANNOT use NVFP4 directly: - Softmax values ~1/seq_len = 0.001 - NVFP4 smallest positive = 0.25 - All P values quantize to 0 (100% error) Recommended FA4 Architecture: - Q, K, V: pre-quantize to NVFP4 (static weights OK) - P: keep in BF16 (dynamic, small values) - Q@K^T: use mma.sync.block_scale (NVFP4) - P@V: use mma.sync (BF16) or mixed precision Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Complete analysis of FA4 (Flash Attention 4) feasibility for RTX 5090. Key Findings: 1. SM120 uses mma.sync.block_scale, NOT tcgen05.mma (datacenter) 2. NVFP4 GEMM optimized for K=256 tiles, suboptimal for head_dim=128 3. P (softmax output) CANNOT use NVFP4: - Softmax values ~0.001 << NVFP4 minimum 0.25 - All P values quantize to 0 (100% error) Recommendation: Do NOT proceed with FA4 NVFP4 for SM120. FA3 TMA (51.97 TFLOPS) is already optimal for GeForce Blackwell. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add 5 SM120 config versions for FA3 TMA attention tuning: - V0: Baseline (TILE_Q=32, TILE_KV=64, 4+8 warps) - 63.61 TFLOPS - V1: Smaller tiles (TILE_KV=32) - 53.11 TFLOPS - V2: 3-stage pipeline (TILE_KV=32) - 52.86 TFLOPS - V3: More compute warps (2+10) - 64.01 TFLOPS - V4: Most compute warps (4+12) - 66.62 TFLOPS (+4.7%) Environment variable PYGPUKIT_FA3_SM120_VERSION (0-4) selects config. Version 4 achieves best performance with 16 total warps. Benchmark results (RTX 5090, seq_len=4096, heads=32, head_dim=128): - V0 (baseline): 63.61 TFLOPS - V4 (4+12 warps): 66.62 TFLOPS Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add detailed documentation explaining why all 6 __syncthreads() per KV tile are required and cannot be reduced: 1. After barrier_wait - mbarrier is per-thread, need block sync 2. After compute_scores - scores must complete before mask 3. After mask - mask must complete before softmax reads 4. After softmax phase1 - union race condition prevention 5. After softmax phase2 - probs must complete before P@V 6. End of loop - prevents cross-iteration TMA/read race Attempted sync reduction failed due to: - Removing sync after barrier_wait causes thread divergence races - Removing end-of-loop sync causes prefetch/read stage conflicts Current performance: 64.6 TFLOPS (SM120, seq_len=4096) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implement FP8 E4M3 block-scale MMA using native PTX inline assembly for SM120. Fragment layouts derived from CUTLASS CuTe mma_traits_sm80.hpp analysis. Key implementation details: - PTX instruction: mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0 - A fragment: 4 registers (16 FP8 E4M3 elements each) - B fragment: 2 registers (8 FP8 E4M3 elements each) - C/D fragment: 4 FP32 registers (16x8 output tile) - Scale factors: UE8M0 format (8-bit unsigned exponent) CuTe Layout Analysis: - ALayout: (T32,V16) -> (M16,K32), t0=lane/8, t1=lane%8 - BLayout: (T32,V8) -> (K32,N8), non-contiguous byte access - CLayout: (T32,V4) -> (M16,N8), d[v] = C[4*t0+v, t1] Test result: PASS on RTX 5090 (SM 120) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implements FA3 with FP8 E4M3 Q@K^T using SM120's block-scale MMA instruction for ~50% memory bandwidth reduction vs BF16. Key implementation details: - FP8 E4M3 quantization with per-head global UE8M0 scaling - mma.sync.aligned.kind::mxf8f6f4.block_scale.m16n8k32.f32.e4m3.e4m3 - B fragment loading: n_idx=lane_id/4, k_base=(lane_id%4)*8 - SM80_16x8_Row C fragment layout for correct output mapping - BF16 P@V with WMMA for precision (FP8 V gave ~18% error) Validation results (vs BF16 FA3 reference): - Prefill (128 tokens): 1.97% error, 0.9999 correlation - PASS - Prefill (512 tokens): 1.58% error, 0.9999 correlation - PASS - Decode (single token): 0% error, perfect correlation - PASS New files: - native/ops/nn/attention/flash_attention_3_fp8_sm120.cuh - native/ops/matmul/gemm/fp8_block_scale/test_mma_direct.cuh Python API: sdpa_causal_fp8(), fa3_fp8_available(), test_fp8_mma_direct() Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add high-performance fused kernels to reduce memory bandwidth and kernel launch overhead in LLM inference pipelines. New kernels: - rmsnorm_residual: y = rmsnorm(x + residual) * gamma - swiglu: y = silu(gate) * up (used in Qwen, LLaMA3, Mistral FFN) - geglu: y = gelu(gate) * up Benchmark results (RTX 5090): - SwiGLU: 2.38-14.25x speedup vs separate ops - RMSNorm+Residual: 2.03-12.37x speedup - GeGLU: 2.40-13.10x speedup Larger batch sizes show greater speedups due to memory bandwidth savings from eliminating intermediate buffers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix import sorting (I001) - Fix unused loop variable (B007) by renaming to _features - Fix loop variable binding (B023) by using default args - Remove unused mode argument in open() (UP015) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Highlights
Fused NN Kernels
Flash Attention 3 SM120
FP8 Block-Scale MMA
mma.syncfor FP8 with per-block scalingTest plan
🤖 Generated with Claude Code