Skip to content

Claude/cutlass kernel implementation 6z7ep#4

Open
tomasruizt wants to merge 12 commits intomainfrom
claude/cutlass-kernel-implementation-6z7ep
Open

Claude/cutlass kernel implementation 6z7ep#4
tomasruizt wants to merge 12 commits intomainfrom
claude/cutlass-kernel-implementation-6z7ep

Conversation

@tomasruizt
Copy link
Copy Markdown
Owner

No description provided.

tomasruizt and others added 12 commits March 9, 2026 19:28
Uses CUTLASS GEMM (tensor cores, SM80+) for the matmul stage,
followed by a custom Gumbel-max sampling kernel. The two-stage
reduction pattern matches the existing Triton and CUDA variants.

Registered as the "fused-cutlass" provider in get_sampler().

https://claude.ai/code/session_01N4JDFDxqBmCbnVW3RMpSyW
Rewrite the GEMM from CUTLASS 2.x (device::Gemm, SM80) to 3.x
(GemmUniversal with CollectiveBuilder, SM90). This enables Sm90EVT
epilogues in subsequent steps. Key changes:

- Use CollectiveBuilder for automatic mainloop/epilogue configuration
  with TMA and warp specialization on SM90.
- AlignmentC=AlignmentD=1 to support arbitrary H (including H=1).
- Construct CuTe strides manually (make_cute_packed_stride unavailable
  in pip-installed CUTLASS).
- Require compute_90a gencode for TMA/WGMMA instructions.
- Fix cudaMemcpy GPU-to-CPU sync in sampling kernel: pass temperature
  as device pointer, read via __ldg() on GPU.
- Add nvidia-cutlass, cuda-bench, cupti-python to Modal image.
- Increase Modal speed test timeout for JIT compilation.

Verified on Modal H100: fused-cutlass 0.482ms vs fused-triton 0.426ms
(V=151936, D=4096, H=4).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wires up Sm90EVT with Sm90Compute<plus>, Sm90ScalarBroadcast, and
Sm90AccFetch. Adds a test_evt_add1 function and Modal test script.
Key findings: EVTs require TmaWarpSpecialized schedule, 16B-aligned
stores (AlignmentD=4 for float32), and a smaller CTA tile (64x64x64)
to satisfy the EPI_TILE_M >= MMA_TILE_M constraint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Step 4 of the CUTLASS FMMS kernel plan: validate Sm90RowReduction with
cutlass::maximum to reduce across M (vocab) producing one max per H column.
EVT tree uses Sm90Compute<Identity> as root (stores full logits to D) with
an inner Sm90RowReduction that writes the reduced max to a separate buffer.
Verified on Modal H100 with max_err=0.0014 at V=151936, D=4096, H=4.

Also adds enable_cuda_jit_cache() to persist JIT-compiled CUDA extensions
on the Modal volume, skipping expensive nvcc recompilation across runs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Check for an existing .so before calling torch's load() to report
"likely cached" vs "likely compiled" with wall-clock time. Also
enable the JIT cache in modal_speed_test.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implement Sm90RowArgmax custom EVT visitor that tracks (value, index) pairs
during the GEMM epilogue, producing per-CTA-tile argmax without an intermediate
[V, H] logits buffer. Key fix: use sm90_partition_for_epilogue<true> for M
coordinates since visit() receives accumulator-layout fragments, not D-layout.
Remove atomic counters by having each CTA write its own tile result directly.

Also add chi-squared sampling distribution test for fused-cutlass, pad D for
TMA alignment, and add Make targets for CUTLASS testing and benchmarking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add inv_temperature parameter to Sm90RowArgmax, scaling accumulator
values in visit() before argmax comparison. Unified into the existing
test_row_argmax function with a default of 1.0 for backward compat.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fuses temperature scaling, Philox RNG Gumbel noise, and per-tile argmax
into the GEMM epilogue. Uses curand_Philox4x32_10 directly (not curand_init)
for lightweight per-element RNG keyed on global (m, n) coordinates.

Key changes:
- inv_temperature passed as GPU scalar tensor (no CPU-GPU sync)
- TmaWarpSpecializedCooperative epilogue for 128x128x64 tiles
- Header content hash in JIT extension name for cache invalidation
- New fused-cutlass-evt provider and modal speed test --clear-cache flag

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@tomasruizt tomasruizt force-pushed the claude/cutlass-kernel-implementation-6z7ep branch from f2b55d2 to 3e59b77 Compare March 9, 2026 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants