Claude/cutlass kernel implementation 6z7ep#4
Open
tomasruizt wants to merge 12 commits intomainfrom
Open
Conversation
Uses CUTLASS GEMM (tensor cores, SM80+) for the matmul stage, followed by a custom Gumbel-max sampling kernel. The two-stage reduction pattern matches the existing Triton and CUDA variants. Registered as the "fused-cutlass" provider in get_sampler(). https://claude.ai/code/session_01N4JDFDxqBmCbnVW3RMpSyW
Rewrite the GEMM from CUTLASS 2.x (device::Gemm, SM80) to 3.x (GemmUniversal with CollectiveBuilder, SM90). This enables Sm90EVT epilogues in subsequent steps. Key changes: - Use CollectiveBuilder for automatic mainloop/epilogue configuration with TMA and warp specialization on SM90. - AlignmentC=AlignmentD=1 to support arbitrary H (including H=1). - Construct CuTe strides manually (make_cute_packed_stride unavailable in pip-installed CUTLASS). - Require compute_90a gencode for TMA/WGMMA instructions. - Fix cudaMemcpy GPU-to-CPU sync in sampling kernel: pass temperature as device pointer, read via __ldg() on GPU. - Add nvidia-cutlass, cuda-bench, cupti-python to Modal image. - Increase Modal speed test timeout for JIT compilation. Verified on Modal H100: fused-cutlass 0.482ms vs fused-triton 0.426ms (V=151936, D=4096, H=4). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wires up Sm90EVT with Sm90Compute<plus>, Sm90ScalarBroadcast, and Sm90AccFetch. Adds a test_evt_add1 function and Modal test script. Key findings: EVTs require TmaWarpSpecialized schedule, 16B-aligned stores (AlignmentD=4 for float32), and a smaller CTA tile (64x64x64) to satisfy the EPI_TILE_M >= MMA_TILE_M constraint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Step 4 of the CUTLASS FMMS kernel plan: validate Sm90RowReduction with cutlass::maximum to reduce across M (vocab) producing one max per H column. EVT tree uses Sm90Compute<Identity> as root (stores full logits to D) with an inner Sm90RowReduction that writes the reduced max to a separate buffer. Verified on Modal H100 with max_err=0.0014 at V=151936, D=4096, H=4. Also adds enable_cuda_jit_cache() to persist JIT-compiled CUDA extensions on the Modal volume, skipping expensive nvcc recompilation across runs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Check for an existing .so before calling torch's load() to report "likely cached" vs "likely compiled" with wall-clock time. Also enable the JIT cache in modal_speed_test.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implement Sm90RowArgmax custom EVT visitor that tracks (value, index) pairs during the GEMM epilogue, producing per-CTA-tile argmax without an intermediate [V, H] logits buffer. Key fix: use sm90_partition_for_epilogue<true> for M coordinates since visit() receives accumulator-layout fragments, not D-layout. Remove atomic counters by having each CTA write its own tile result directly. Also add chi-squared sampling distribution test for fused-cutlass, pad D for TMA alignment, and add Make targets for CUTLASS testing and benchmarking. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add inv_temperature parameter to Sm90RowArgmax, scaling accumulator values in visit() before argmax comparison. Unified into the existing test_row_argmax function with a default of 1.0 for backward compat. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fuses temperature scaling, Philox RNG Gumbel noise, and per-tile argmax into the GEMM epilogue. Uses curand_Philox4x32_10 directly (not curand_init) for lightweight per-element RNG keyed on global (m, n) coordinates. Key changes: - inv_temperature passed as GPU scalar tensor (no CPU-GPU sync) - TmaWarpSpecializedCooperative epilogue for 128x128x64 tiles - Header content hash in JIT extension name for cache invalidation - New fused-cutlass-evt provider and modal speed test --clear-cache flag Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
f2b55d2 to
3e59b77
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.