Skip to content

Conversation

@m96-chan
Copy link
Owner

Summary

  • Fused NN Kernels: Add high-performance fused kernels (RMSNorm+Residual, SwiGLU, GeGLU) with 2-14x speedup
  • Flash Attention 3 SM120: TMA-enabled Flash Attention 3 for Blackwell GPUs with producer/consumer warp architecture
  • FP8 Block-Scale MMA: Native PTX inline assembly for FP8 block-scale matrix multiply-accumulate

Highlights

Fused NN Kernels

Kernel Batch Speedup
SwiGLU 1 2.38x
SwiGLU 32 14.25x
RMSNorm+Residual 128 12.37x
GeGLU 32 13.10x

Flash Attention 3 SM120

  • TMA (Tensor Memory Accelerator) for efficient global memory access
  • Producer/consumer warp architecture for overlapped compute and memory ops
  • Tunable configurations for different sequence lengths

FP8 Block-Scale MMA

  • Native PTX mma.sync for FP8 with per-block scaling
  • Enables future W8A8 quantized inference paths

Test plan

  • Fused kernel correctness tests pass
  • Fused kernel benchmarks show expected speedups
  • Build succeeds for SM120a
  • FA3 correctness verification (known FP8 precision issues being investigated)

🤖 Generated with Claude Code

m96-chan and others added 23 commits January 6, 2026 20:48
Add GPU-accelerated 1D convolution to replace CPU fallback in Whisper ASR encoder.

Changes:
- Add native/ops/conv/conv1d_kernels.cuh: F32/BF16/F16 kernels
- Add native/ops/conv/conv1d.cu: Dispatcher with dtype validation
- Add native/bindings/nn/conv.cpp: pybind11 bindings
- Add src/pygpukit/ops/conv.py: Python API with CPU fallback
- Update Whisper encoder to use native conv1d

Performance: Eliminates GPU->CPU->GPU roundtrip per audio frame.
Correctness: Max diff vs NumPy reference < 5e-7.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add security example with Meta's Llama Guard 3 model for content moderation:
- MLCommons hazard taxonomy (S1-S14 categories)
- User input and agent response classification
- Interactive and batch classification modes
- Greedy decoding for deterministic safety classification

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add LLaMA 4 model implementation with native CUDA kernels
- Update CMakeLists.txt and bindings for LLaMA 4 ops

Note: LLaMA 4 kernels are monolithic and need refactoring
to follow modular nn/ structure (see issue)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implement FA3 with WMMA tensor core acceleration:
- WMMA-based score computation (Q @ K^T)
- WMMA-based output computation (P @ V)
- Vectorized memory loads (float4)
- Warp-level softmax with shuffle reductions

Benchmark results (RTX 5090, 32 heads, head_dim=128):
- seq_len=128:  FA3 1.02x vs SDPA
- seq_len=512:  FA3 1.03x vs SDPA
- seq_len=1024: FA3 0.99x vs SDPA
- seq_len=2048: FA3 1.01x vs SDPA

All correctness tests pass (mean relative error < 2%).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add reusable TMA (Tensor Memory Accelerator) utilities:

- tma_utils.cuh: CUtensorMap descriptor creation, async copy ops
  - barrier_init/arrive/wait for mbarrier synchronization
  - tma_load_2d/3d for async global->shared transfers
  - Support for BF16, FP16, FP32 data types
  - 128B swizzle for bank-conflict-free access

- warp_scheduler.cuh: Producer/consumer warp specialization
  - WarpRole enum and detection helpers
  - Warpgroup utilities for WGMMA
  - Named barriers for SM90+
  - FA3Config/GemmConfig presets

- pipeline.cuh: Multi-stage async pipeline management
  - Pipeline<N> template for N-stage buffering
  - DualBufferPipeline optimized 2-stage
  - PipelineBuffer shared memory manager

These utilities enable TMA-based optimization for:
- Flash Attention 3
- Persistent GEMM
- Any kernel needing async global->shared transfers

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add flash_attention_3_tma.cuh with:

- TmaSharedMemory: Multi-stage K/V buffers with mbarrier
- TmaFA3Config: Warp-specialized configuration (4 producer, 8 consumer)
- Producer functions: TMA async bulk tensor loads
- Consumer functions: WMMA-based score and output computation
- 4-stage pipeline for K/V prefetching

Architecture:
- Producer warps (0-3): Issue TMA loads for K/V tiles
- Consumer warps (4-11): Compute attention scores and output
- mbarrier synchronization between stages

NOTE: Requires Python bindings to create CUtensorMap descriptors.
This is a WIP - kernel compiles but not yet callable from Python.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add TMA FA3 environment control (PYGPUKIT_FA3_TMA)
- Create TMA descriptor launcher function for Q/K/V tensors
- Integrate TMA path into sdpa_causal_dispatch before regular FA3
- Fix TMA kernel to use 3D loads for 3D tensor descriptors
- Add benchmark script for TMA vs baseline comparison

Benchmark results (RTX 5090, SM 120a):
- [32, 512, 128]:  Baseline 2090us, TMA 2170us (0.96x)
- [32, 1024, 128]: Baseline 7175us, TMA 7187us (1.00x)
- [32, 2048, 128]: Baseline 27165us, TMA 27125us (1.00x)
- [32, 4096, 128]: Baseline 93848us, TMA 93444us (1.00x)

Correctness: PASS (results match baseline)

Note: TMA kernel is functional but not yet optimized for speedup.
Future work: warp specialization tuning, swizzle patterns.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Bug: TMA FA3 kernel hung at 256+ blocks due to __syncthreads()
inside consumer-only code path. Producer warps never reached sync.

Fix:
- Split consumer_compute_output() into two functions:
  - convert_scores_to_probs(): ALL threads participate (has syncs)
  - consumer_compute_output_matmul(): consumers only (no syncs)
- Reduce TILE_Q 64->32 and NUM_STAGES 4->2 for 99KB smem limit
- Use union for smem_scores/smem_probs to save 8KB

Benchmark (RTX 5090, 32 heads):
- seq_len=512:  6.6ms, 0.65 TFLOPS
- seq_len=1024: 25.8ms, 0.66 TFLOPS
- seq_len=2048: 99.2ms, 0.69 TFLOPS
- seq_len=4096: 387.5ms, 0.71 TFLOPS

Correctness: PASS (matches FA3 baseline)

Next: Parallelize softmax across query positions for 8-32x speedup

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes:
1. Warp-parallel softmax: Each consumer warp handles different q rows
   - 8 warps process 8 rows simultaneously (was: all warps on same row)
   - Purely warp-synchronous with no __syncthreads() inside

2. Fix consumer warp indexing bug in matmul functions:
   - consumer_compute_scores: use consumer_warp_idx (0-7) not global warp_id (4-11)
   - consumer_compute_output_matmul: same fix
   - Ensures all tiles are computed (was missing tiles 0-3)

3. Direct BF16 softmax output:
   - Softmax writes BF16 directly to smem_probs
   - Eliminates convert_scores_to_probs function call
   - Saves 2 __syncthreads() per iteration

Sync point analysis (after optimization):
- 5 syncs per iteration (was 7):
  1. After barrier_wait (TMA data visible)
  2. After Q@K (scores ready for causal mask)
  3. After causal mask (scores ready for softmax)
  4. After softmax (probs ready for P@V)
  5. End of iteration (next TMA)

Benchmark (RTX 5090, 32 heads):
- Performance: ~0.65-0.71 TFLOPS (similar to baseline)
- Correctness: PASS

Note: Performance unchanged suggests bottleneck is elsewhere
(WMMA efficiency, memory bandwidth, or 1 block/SM occupancy).
Next optimization: wgmma instructions for SM120a.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Root cause: Union between smem_scores (float) and smem_probs (bf16)
caused a race condition when multiple warps processed different Q rows
in parallel. Warp B writing to smem_probs[row_B] could corrupt
smem_scores[row_A] that Warp A was still reading.

Fix: Two-phase softmax approach
- Phase 1: ALL warps read scores, compute probs, store to REGISTERS
- Phase 2: After __syncthreads(), ALL warps write probs to smem_probs

Also includes:
- TMA descriptor cache for reduced host-side overhead (99.4% hit rate)
- cudaEvent-based kernel timing for accurate benchmarks
- Proper handling of fully-masked rows (causal attention edge case)

Benchmark results (RTX 5090, SM120a):
- seq_len=1024: 51.21 TFLOPS (kernel-only)
- seq_len=2048: 59.86 TFLOPS (kernel-only)
- Correctness: PASS (max_diff=0.0)
- Determinism: PASS (all runs identical)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 1 implementation identical to FA3 TMA structure.
This establishes the baseline for NVFP4 integration in Phase 2/3.

Benchmark results (RTX 5090, seq_len=1024, 32 heads, 128 head_dim):
- Kernel-only: 335.6 us (51.19 TFLOPS)
- E2E cached: 368.1 us (46.67 TFLOPS)
- Correctness: PASS (max diff = 0 vs FA3)

Files added:
- flash_attention_4_sm120.cuh: FA4 kernel with config structs for all phases
- benchmark_fa4_sm120.py: Benchmark script with correctness verification
- fa4_sm120_research.md: SM100 vs SM120 architecture research

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 2 validates the NVFP4 GEMM path for attention scores.

Benchmark results (RTX 5090, seq_len=1024, single head):
- NVFP4 Q@K^T: 394.0 us (0.68 TFLOPS)
- Correctness: 21% rel_diff vs NumPy (acceptable for 4-bit)

Key finding: NVFP4 GEMM optimized for large K (LLM weights),
not attention's small K=128 (head_dim). CUTLASS uses K=256 tiles.

For comparison:
- Full FA3 TMA (32 heads): 330.9 us (51.92 TFLOPS)

NVFP4 benefit in attention comes from memory bandwidth (4x smaller
loads), not compute throughput. Full integration requires PTX
inline assembly for mma.sync.aligned.block_scale.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Phase 3 Results (RTX 5090, seq_len=1024):
- P@V (K=seq_len=1024): 94.7 us (2.84 TFLOPS)
- Q@K^T (K=head_dim=128): 353.3 us (0.76 TFLOPS)
- Larger K speedup: 3.73x (better tile utilization)

Key Findings:
1. NVFP4 CUTLASS GEMM uses K=256 tile size, suboptimal for head_dim=128
2. P (softmax output) CANNOT use NVFP4 directly:
   - Softmax values ~1/seq_len = 0.001
   - NVFP4 smallest positive = 0.25
   - All P values quantize to 0 (100% error)

Recommended FA4 Architecture:
- Q, K, V: pre-quantize to NVFP4 (static weights OK)
- P: keep in BF16 (dynamic, small values)
- Q@K^T: use mma.sync.block_scale (NVFP4)
- P@V: use mma.sync (BF16) or mixed precision

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Complete analysis of FA4 (Flash Attention 4) feasibility for RTX 5090.

Key Findings:
1. SM120 uses mma.sync.block_scale, NOT tcgen05.mma (datacenter)
2. NVFP4 GEMM optimized for K=256 tiles, suboptimal for head_dim=128
3. P (softmax output) CANNOT use NVFP4:
   - Softmax values ~0.001 << NVFP4 minimum 0.25
   - All P values quantize to 0 (100% error)

Recommendation: Do NOT proceed with FA4 NVFP4 for SM120.
FA3 TMA (51.97 TFLOPS) is already optimal for GeForce Blackwell.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add 5 SM120 config versions for FA3 TMA attention tuning:
- V0: Baseline (TILE_Q=32, TILE_KV=64, 4+8 warps) - 63.61 TFLOPS
- V1: Smaller tiles (TILE_KV=32) - 53.11 TFLOPS
- V2: 3-stage pipeline (TILE_KV=32) - 52.86 TFLOPS
- V3: More compute warps (2+10) - 64.01 TFLOPS
- V4: Most compute warps (4+12) - 66.62 TFLOPS (+4.7%)

Environment variable PYGPUKIT_FA3_SM120_VERSION (0-4) selects config.
Version 4 achieves best performance with 16 total warps.

Benchmark results (RTX 5090, seq_len=4096, heads=32, head_dim=128):
- V0 (baseline): 63.61 TFLOPS
- V4 (4+12 warps): 66.62 TFLOPS

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add detailed documentation explaining why all 6 __syncthreads() per KV
tile are required and cannot be reduced:

1. After barrier_wait - mbarrier is per-thread, need block sync
2. After compute_scores - scores must complete before mask
3. After mask - mask must complete before softmax reads
4. After softmax phase1 - union race condition prevention
5. After softmax phase2 - probs must complete before P@V
6. End of loop - prevents cross-iteration TMA/read race

Attempted sync reduction failed due to:
- Removing sync after barrier_wait causes thread divergence races
- Removing end-of-loop sync causes prefetch/read stage conflicts

Current performance: 64.6 TFLOPS (SM120, seq_len=4096)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implement FP8 E4M3 block-scale MMA using native PTX inline assembly for SM120.
Fragment layouts derived from CUTLASS CuTe mma_traits_sm80.hpp analysis.

Key implementation details:
- PTX instruction: mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0
- A fragment: 4 registers (16 FP8 E4M3 elements each)
- B fragment: 2 registers (8 FP8 E4M3 elements each)
- C/D fragment: 4 FP32 registers (16x8 output tile)
- Scale factors: UE8M0 format (8-bit unsigned exponent)

CuTe Layout Analysis:
- ALayout: (T32,V16) -> (M16,K32), t0=lane/8, t1=lane%8
- BLayout: (T32,V8) -> (K32,N8), non-contiguous byte access
- CLayout: (T32,V4) -> (M16,N8), d[v] = C[4*t0+v, t1]

Test result: PASS on RTX 5090 (SM 120)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implements FA3 with FP8 E4M3 Q@K^T using SM120's block-scale MMA
instruction for ~50% memory bandwidth reduction vs BF16.

Key implementation details:
- FP8 E4M3 quantization with per-head global UE8M0 scaling
- mma.sync.aligned.kind::mxf8f6f4.block_scale.m16n8k32.f32.e4m3.e4m3
- B fragment loading: n_idx=lane_id/4, k_base=(lane_id%4)*8
- SM80_16x8_Row C fragment layout for correct output mapping
- BF16 P@V with WMMA for precision (FP8 V gave ~18% error)

Validation results (vs BF16 FA3 reference):
- Prefill (128 tokens): 1.97% error, 0.9999 correlation - PASS
- Prefill (512 tokens): 1.58% error, 0.9999 correlation - PASS
- Decode (single token): 0% error, perfect correlation - PASS

New files:
- native/ops/nn/attention/flash_attention_3_fp8_sm120.cuh
- native/ops/matmul/gemm/fp8_block_scale/test_mma_direct.cuh

Python API: sdpa_causal_fp8(), fa3_fp8_available(), test_fp8_mma_direct()

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add high-performance fused kernels to reduce memory bandwidth and
kernel launch overhead in LLM inference pipelines.

New kernels:
- rmsnorm_residual: y = rmsnorm(x + residual) * gamma
- swiglu: y = silu(gate) * up (used in Qwen, LLaMA3, Mistral FFN)
- geglu: y = gelu(gate) * up

Benchmark results (RTX 5090):
- SwiGLU: 2.38-14.25x speedup vs separate ops
- RMSNorm+Residual: 2.03-12.37x speedup
- GeGLU: 2.40-13.10x speedup

Larger batch sizes show greater speedups due to memory bandwidth
savings from eliminating intermediate buffers.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix import sorting (I001)
- Fix unused loop variable (B007) by renaming to _features
- Fix loop variable binding (B023) by using default args
- Remove unused mode argument in open() (UP015)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@m96-chan m96-chan merged commit 224e6bb into main Jan 26, 2026
25 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants