[fast_math] Add bfloat16_t PTX specializations for fast_exp and fast_tanh#3242
Open
VittoriaLanzo wants to merge 2 commits into
Open
[fast_math] Add bfloat16_t PTX specializations for fast_exp and fast_tanh#3242VittoriaLanzo wants to merge 2 commits into
VittoriaLanzo wants to merge 2 commits into
Conversation
Two new GTest cases in Fast_math_bfloat16: - device_fast_exp_op_array8_sm80: runs fast_exp_op<Array<bfloat16_t,8>> on device via test_Epilogue_thread_activation; skips below SM80 - device_fast_tanh_op_array8_sm90: runs fast_tanh_op<Array<bfloat16_t,8>> on device; skips below SM90 (tanh.approx.bf16x2 path)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The problem
include/cutlass/fast_math.hhas PTX-acceleratedfast_tanhandfast_expspecializations forhalf_t(fp16), but none forbfloat16_t. Every BF16 activation that routes throughfast_exporfast_tanhon SM90+ hardware falls through to a float round-trip:tanh.approx.bf16x2costs 0.5 instructions per element (two elements per instruction, SM90+/CUDA 12+). On Hopper and Blackwell any path throughfast_tanh_oporfast_exp_oppaid a theoretical 6× instruction overhead compared to thehalf_tpath.Changes
fast_exp(bfloat16_t x)— scalar,CUTLASS_HOST_DEVICE, SM80+/CUDA 11+, uses::hexp(__nv_bfloat16)fast_tanh(bfloat16_t x)— scalar,CUTLASS_HOST_DEVICE, SM90+/CUDA 12+, usestanh.approx.bf16PTXfast_exp_op<Array<bfloat16_t, N>>—CUTLASS_DEVICE, SM80+/CUDA 11+,::h2exp(__nv_bfloat162)for N/2 pairs, scalar residual for odd Nfast_tanh_op<Array<bfloat16_t, N>>—CUTLASS_DEVICE, SM90+/CUDA 12+,tanh.approx.bf16x2PTX for N/2 pairs,tanh.approx.bf16for odd-N residual#include "cutlass/bfloat16.h"added tofast_math.h— required by the new specializationsEach specialization is inserted immediately after the corresponding
half_tblock. Thehalf_tcode is unchanged.Instruction count
Theoretical (ISA-derived), per element:
fast_tanhscalarfast_tanh_oparrayfast_expscalarfast_exp_oparray†
ex2.approx.bf16/ex2.approx.bf16x2is native on SM90+/CUDA 12.1+. On earlier toolchainshexp/h2expexpands to float round-trips and there is no speedup. The tanh figures are CUDA-version-independent —tanh.approx.bf16x2is the only SM90 mechanism and is verified in the PTX output below.Benchmark script
Standalone script — not committed to the repository. Compile and run on SM90+ hardware to measure actual throughput. The PTX path uses the patched specializations; the fallback path simulates pre-patch behaviour with explicit float casts.
Design choices
reinterpret_cast<__nv_bfloat16 const &>(x.storage)in the scalarfast_exp, notx.to_nv_bfloat16().bfloat16_t::to_nv_bfloat16()isCUTLASS_DEVICE-only. Scalarfast_expisCUTLASS_HOST_DEVICE— the reinterpret_cast produces the identical bit pattern and compiles on both host and device.bfloat16_t::bitcast(bits)in the scalarfast_tanh.half_tuses"=h"(x.raw())as a PTX output operand directly;bfloat16_t::raw()returnsuint16_tby value, so that form doesn't compile. Localuint16_t bits+bitcastis the idiomatic equivalent andbitcastisCUTLASS_HOST_DEVICE.uint16_t *out_rawin the odd-N residuals, not proxy assignment.reference::operator=(bfloat16_t x)in the sub-word Array specialization doesreinterpret_cast<uint32_t const &>(x)— a 4-byte read from a 2-byte object. The existingfast_tanh_op<Array<half_t,N>>already avoids this by going through rawuint16_t *pointers; I mirror that pattern rather than going through the proxy.SM80+/CUDA 11+ for exp, SM90+/CUDA 12+ for tanh.
h2exp(__nv_bfloat162)is available from Ampere (SM80, same as__nv_bfloat162support).tanh.approx.bf16uses hardware acceleration from Hopper (SM90, PTX ISA 8.0 — confirmed in CUDA Math API docs). Separate#ifguards per operation (not one combined guard) preserve the SM80 exp path on A100/A800 where the SM90 tanh guard would not fire.::h2expnoth2expbf16. The CUDA Math API overloadsh2expfor both__half2and__nv_bfloat162. There is no separateh2expbf16in the API.Correctness
__nv_bfloat162reinterpret onArray<bfloat16_t,N>:Array<T,N,false>(defined inarray_subbyte.h) usesuint32_tstorage whensizeof_bits<T> * Nis divisible by 32 — for 16-bit types, any even N.alignof(uint32_t) = 4satisfies__nv_bfloat162's 4-byte alignment requirement. The pair loop only runs for even N ≥ 2; the odd-N residual uses a scalar path with no__nv_bfloat162reinterpret.N/2 == 0andN%2 == 0; neither loop nor residual executes.::h2expon__nv_bfloat162is overloaded from the same function name as::h2exp(__half2)— no separate bfloat16 variant exists or is needed.Tests
New suite
Fast_math_bfloat16intest/unit/epilogue/thread/activation.cu— 11 test cases: 9 host tests (15 assertions) + 2 device tests. Thebfloat16_toverload did not exist pre-patch; host builds silently convert to float viabfloat16_t::operator float()rather than dispatching the new specialization.fast_exp_zerofast_exp(0.0f)→ 1.0f ± 1 ULPfast_exp_onefast_exp(1.0f)→ expf(1.0f) ± 2 ULPfast_exp_neg_onefast_exp(-1.0f)→ expf(-1.0f) ± 2 ULPfast_tanh_zerofast_tanh(0.0f)→ 0.0f ± 1 ULPfast_tanh_onefast_tanh(1.0f)→ tanhf(1.0f) ± 2 ULPfast_exp_op_array4_hostfast_exp_op<BF16x4>on {0,1,-1,2}, element-wise expf reference ± 2 ULPfast_tanh_op_array4_hostfast_tanh_op<BF16x4>on {0,1,-1,2}, element-wise tanhf reference ± 2 ULPfast_exp_op_array1_odd_residualfast_tanh_op_array1_odd_residualdevice_fast_exp_op_array8_sm80fast_exp_op<BF16x8>on 128 elements, expf reference ± 2%; skips below SM80device_fast_tanh_op_array8_sm90fast_tanh_op<BF16x8>on 128 elements, tanhf reference ± 2% + 1e-4 abs floor; skips below SM90Host tests exercise the float fallback path (
__CUDA_ARCH__not defined at host compile time). Device tests compile to SM90 and useGTEST_SKIP()at runtime when the detected SM is below the required threshold — confirmed on P100 (SM 6.0), see evidence below.Evidence
Verified on Kaggle (P100, SM 6.0, CUDA 12 toolchain,
nvcc -arch=sm_90), kernelvittorialanzo/cutlass-bf16-fast-math-testv9. The Kaggle kernel runs standalone C++ verification programs that exercise the same algorithms as the GTest suite; they are not the CUTLASS GTest suite itself (which will run in CUTLASS CI against SM90+ hardware). Host unit tests run on CPU (float fallback path); PTX emission via--ptxis compile-time only.Host unit tests (float fallback path)
SM90 PTX check (
nvcc -arch=sm_90 --ptx)No
cvt.f32.bf16in the PTX output.cvt.rn.bf16.f32appears in the PTX file but only in theexp_kkernel (h2expinternals) — not in the tanh path.The PTX file contains both
tanh_kandexp_kkernels. Tanh-kernel lines only (tanh_k, N=8):Device test graceful skip (SM 6.0 — P100)
Checklist
GTEST_SKIP()guards;bfloat16_toverload absent pre-patch so calls silently converted to float via implicit conversion rather than dispatching the new specialization)half_tspecializations and existing tests unchanged#include "cutlass/bfloat16.h"added tofast_math.h(required by new specializations)tanh.approx.bf16x2present,cvt.f32.bf16absent) — confirmed on Kaggle (P100, SM 6.0),nvcc -arch=sm_90Agentic workflow
This contribution was produced with a multi-agent pipeline under my direction and supervision. I approved each stage at human checkpoints; the pipeline handled planning, implementation, testing, and adversarial review.
Stage 0 — Performance assessment (identifies the missing bf16 fast-math specializations):
flowchart TD S([STARTUP\nroster check · pre-flight]):::green --> O[ORCHESTRATOR\nmain thread · routing only]:::blue O --> R[RECON AGENT\nstatic analysis · 5 passes · hotspot index]:::blue R -->|HOTSPOT_INDEX| D[DISPATCHER\nmechanical signal routing]:::yellow D -->|ROUTING_MANIFEST| C[COMPLEXITY AGENT\nnested loops · O n2]:::red D -->|ROUTING_MANIFEST| M[MEMORY AGENT\nalloc in loops · GC pressure]:::red D -->|ROUTING_MANIFEST| IO[IO AGENT\nN+1 · blocking async · locks]:::red D -->|ROUTING_MANIFEST| DS[DATASTRUCTURES AGENT\nwrong DS · linear search]:::red D -->|ROUTING_MANIFEST| CA[CACHE AGENT\nredundant compute · regex in loop]:::red C -->|FINDINGS| SC[SCORING COLLECTOR\nmerge · deduplicate · priority_score · top 15]:::yellow M -->|FINDINGS| SC IO -->|FINDINGS| SC DS -->|FINDINGS| SC CA -->|FINDINGS| SC SC -->|SCORED_FINDINGS| PQ[PRE-QUALIFICATION GATE\ngit log · gh pr · grep\neliminate owned or intentional findings]:::purple PQ -->|cleared_findings| TR[TOT ROOT\nBranch A: correctness bugs\nBranch B: high-value speed\nBranch C: risk / low-priority]:::purple TR -->|branches A · B · C| AS[ASSESSMENT SYNTHESIZER\nread code · confirm or discard each finding]:::purple AS -->|ASSESSMENT_REPORT| HR([HUMAN REVIEW GATE\npipeline halts · human decides]):::green classDef green fill:#1a4a1a,stroke:#4CAF50,color:#4CAF50 classDef blue fill:#1a2a4a,stroke:#5b9bd5,color:#9cc4e8 classDef yellow fill:#3a3000,stroke:#d4a017,color:#d4c87a classDef red fill:#3a0000,stroke:#cc3333,color:#ff8888 classDef purple fill:#2a1a4a,stroke:#9966cc,color:#cc99ffStage 1-3 — Contribution pipeline (plans, implements, reviews, delivers):
flowchart TD subgraph L1["Layer 1 — Planning"] PE[pattern-extractor] --> CE[coverage-engineer\nassessment mode] CE --> CP[contribution-planner] CP --> AR[architect-reviewer] AR -->|OBJECTIONS_RAISED| CP AR -->|NO_OBJECTIONS| impl end subgraph L2["Layer 2 — Implementation + Adversarial Loop"] impl([enter]) --> CA1[code-author\nimplementation pass] CA1 --> CA2[code-author\nregression tests] CA2 --> RE[readability-editor] RE --> VT[verbosity-trimmer] VT --> SL[scope-linter] SL --> SR[senior-reviewer\nATTACK] SL --> RT[red-team\nATTACK] SR --> SE[senior-engineer] RT --> SE SE --> TA[test-author] TA --> HM[hostile-maintainer\nRE_REVIEW] TA --> RT2[red-team\nRE_REVIEW] HM --> SE2[senior-engineer\nrevision] RT2 --> SE2 SE2 --> MG[merge-gate] MG -->|APPROVE| pr MG -->|REVISE| SE end subgraph L3["Layer 3 — PR Delivery"] pr([enter]) --> PW[pr-writer] PW --> PRA[pr-adversary] PRA -->|REVISIONS_REQUESTED| PW PRA -->|PR_READY| LA[license-auditor] LA --> done([human approval]) end