Skip to content

[fast_math] Add bfloat16_t PTX specializations for fast_exp and fast_tanh#3242

Open
VittoriaLanzo wants to merge 2 commits into
NVIDIA:mainfrom
VittoriaLanzo:fast-math/bf16-ptx-specializations
Open

[fast_math] Add bfloat16_t PTX specializations for fast_exp and fast_tanh#3242
VittoriaLanzo wants to merge 2 commits into
NVIDIA:mainfrom
VittoriaLanzo:fast-math/bf16-ptx-specializations

Conversation

@VittoriaLanzo
Copy link
Copy Markdown

@VittoriaLanzo VittoriaLanzo commented May 16, 2026

The problem

include/cutlass/fast_math.h has PTX-accelerated fast_tanh and fast_exp specializations for half_t (fp16), but none for bfloat16_t. Every BF16 activation that routes through fast_exp or fast_tanh on SM90+ hardware falls through to a float round-trip:

cvt.f32.bf16 + tanh.approx.f32 + cvt.rn.bf16.f32   # 3 instructions per element

tanh.approx.bf16x2 costs 0.5 instructions per element (two elements per instruction, SM90+/CUDA 12+). On Hopper and Blackwell any path through fast_tanh_op or fast_exp_op paid a theoretical 6× instruction overhead compared to the half_t path.

Changes

  1. fast_exp(bfloat16_t x) — scalar, CUTLASS_HOST_DEVICE, SM80+/CUDA 11+, uses ::hexp(__nv_bfloat16)
  2. fast_tanh(bfloat16_t x) — scalar, CUTLASS_HOST_DEVICE, SM90+/CUDA 12+, uses tanh.approx.bf16 PTX
  3. fast_exp_op<Array<bfloat16_t, N>>CUTLASS_DEVICE, SM80+/CUDA 11+, ::h2exp(__nv_bfloat162) for N/2 pairs, scalar residual for odd N
  4. fast_tanh_op<Array<bfloat16_t, N>>CUTLASS_DEVICE, SM90+/CUDA 12+, tanh.approx.bf16x2 PTX for N/2 pairs, tanh.approx.bf16 for odd-N residual
  5. #include "cutlass/bfloat16.h" added to fast_math.h — required by the new specializations

Each specialization is inserted immediately after the corresponding half_t block. The half_t code is unchanged.

Instruction count

Theoretical (ISA-derived), per element:

Operation Before (float round-trip) After (PTX direct) Reduction
fast_tanh scalar 3 1
fast_tanh_op array 3 0.5
fast_exp scalar 3 ~1 † ~3× †
fast_exp_op array 3 ~0.5 ~6×

ex2.approx.bf16 / ex2.approx.bf16x2 is native on SM90+/CUDA 12.1+. On earlier toolchains hexp / h2exp expands to float round-trips and there is no speedup. The tanh figures are CUDA-version-independent — tanh.approx.bf16x2 is the only SM90 mechanism and is verified in the PTX output below.

Benchmark script

Standalone script — not committed to the repository. Compile and run on SM90+ hardware to measure actual throughput. The PTX path uses the patched specializations; the fallback path simulates pre-patch behaviour with explicit float casts.

// tools/bench_bf16_activations.cu
//
// Compile (SM90+, tanh + exp):
//   nvcc -arch=sm_90 -O3 -I include -o bench_bf16_activations \
//        tools/bench_bf16_activations.cu
//
// Compile (SM80+, exp only — tanh guard won't fire):
//   nvcc -arch=sm_80 -O3 -I include -o bench_bf16_activations \
//        tools/bench_bf16_activations.cu
//
// Run:
//   ./bench_bf16_activations [N_elements] [warmup_reps] [bench_reps]
//   ./bench_bf16_activations 1048576 10 100

#include <cstdio>
#include <cstdlib>
#include <cuda_runtime.h>
#include "cutlass/bfloat16.h"
#include "cutlass/array.h"
#include "cutlass/fast_math.h"

static constexpr int VEC = 8;
using Vec = cutlass::Array<cutlass::bfloat16_t, VEC>;

// PTX paths (patched specializations fire on SM90+/SM80+)
__global__ void tanh_ptx(Vec const *in, Vec *out, int n) {
  cutlass::fast_tanh_op<Vec> op;
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < n) out[i] = op(in[i]);
}

__global__ void exp_ptx(Vec const *in, Vec *out, int n) {
  cutlass::fast_exp_op<Vec> op;
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < n) out[i] = op(in[i]);
}

// Float round-trip paths (pre-patch behaviour — explicit float casts)
__global__ void tanh_f32(Vec const *in, Vec *out, int n) {
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < n) {
    Vec r;
    CUTLASS_PRAGMA_UNROLL
    for (int j = 0; j < VEC; ++j)
      r[j] = cutlass::bfloat16_t(::tanhf(float(in[i][j])));
    out[i] = r;
  }
}

__global__ void exp_f32(Vec const *in, Vec *out, int n) {
  int i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < n) {
    Vec r;
    CUTLASS_PRAGMA_UNROLL
    for (int j = 0; j < VEC; ++j)
      r[j] = cutlass::bfloat16_t(::expf(float(in[i][j])));
    out[i] = r;
  }
}

static float time_kernel(void (*k)(Vec const *, Vec *, int),
                         Vec const *in, Vec *out, int n, int reps) {
  dim3 blk(256), grd((n + 255) / 256);
  cudaEvent_t a, b;
  cudaEventCreate(&a);
  cudaEventCreate(&b);
  cudaEventRecord(a);
  for (int i = 0; i < reps; ++i) k<<<grd, blk>>>(in, out, n);
  cudaEventRecord(b);
  cudaEventSynchronize(b);
  float ms;
  cudaEventElapsedTime(&ms, a, b);
  cudaEventDestroy(a);
  cudaEventDestroy(b);
  return ms / reps;
}

int main(int argc, char **argv) {
  int N      = argc > 1 ? atoi(argv[1]) : 1 << 20;
  int warmup = argc > 2 ? atoi(argv[2]) : 10;
  int reps   = argc > 3 ? atoi(argv[3]) : 100;
  int n_vecs = N / VEC;

  Vec *d_in, *d_out;
  cudaMalloc(&d_in,  n_vecs * sizeof(Vec));
  cudaMalloc(&d_out, n_vecs * sizeof(Vec));

  dim3 blk(256), grd((n_vecs + 255) / 256);
  for (int i = 0; i < warmup; ++i) {
    tanh_ptx<<<grd, blk>>>(d_in, d_out, n_vecs);
    tanh_f32<<<grd, blk>>>(d_in, d_out, n_vecs);
    exp_ptx <<<grd, blk>>>(d_in, d_out, n_vecs);
    exp_f32 <<<grd, blk>>>(d_in, d_out, n_vecs);
  }
  cudaDeviceSynchronize();

  float tp_ms = time_kernel(tanh_ptx, d_in, d_out, n_vecs, reps);
  float tf_ms = time_kernel(tanh_f32, d_in, d_out, n_vecs, reps);
  float ep_ms = time_kernel(exp_ptx,  d_in, d_out, n_vecs, reps);
  float ef_ms = time_kernel(exp_f32,  d_in, d_out, n_vecs, reps);

  int dev; cudaGetDevice(&dev);
  cudaDeviceProp prop; cudaGetDeviceProperties(&prop, dev);
  printf("GPU:  %s\n", prop.name);
  printf("N:    %d elements  (Vec=%d, kernels operate on %d vectors)\n\n",
         N, VEC, n_vecs);
  printf("%-36s  %8s  %8s  %8s\n", "operation", "PTX (ms)", "f32 (ms)", "speedup");
  printf("%-36s  %8.3f  %8.3f  %7.2fx\n",
         "fast_tanh_op<Array<bfloat16_t,8>>", tp_ms, tf_ms, tf_ms / tp_ms);
  printf("%-36s  %8.3f  %8.3f  %7.2fx\n",
         "fast_exp_op<Array<bfloat16_t,8>>",  ep_ms, ef_ms, ef_ms / ep_ms);

  cudaFree(d_in);
  cudaFree(d_out);
  return 0;
}

Design choices

reinterpret_cast<__nv_bfloat16 const &>(x.storage) in the scalar fast_exp, not x.to_nv_bfloat16(). bfloat16_t::to_nv_bfloat16() is CUTLASS_DEVICE-only. Scalar fast_exp is CUTLASS_HOST_DEVICE — the reinterpret_cast produces the identical bit pattern and compiles on both host and device.

bfloat16_t::bitcast(bits) in the scalar fast_tanh. half_t uses "=h"(x.raw()) as a PTX output operand directly; bfloat16_t::raw() returns uint16_t by value, so that form doesn't compile. Local uint16_t bits + bitcast is the idiomatic equivalent and bitcast is CUTLASS_HOST_DEVICE.

uint16_t *out_raw in the odd-N residuals, not proxy assignment. reference::operator=(bfloat16_t x) in the sub-word Array specialization does reinterpret_cast<uint32_t const &>(x) — a 4-byte read from a 2-byte object. The existing fast_tanh_op<Array<half_t,N>> already avoids this by going through raw uint16_t * pointers; I mirror that pattern rather than going through the proxy.

SM80+/CUDA 11+ for exp, SM90+/CUDA 12+ for tanh. h2exp(__nv_bfloat162) is available from Ampere (SM80, same as __nv_bfloat162 support). tanh.approx.bf16 uses hardware acceleration from Hopper (SM90, PTX ISA 8.0 — confirmed in CUDA Math API docs). Separate #if guards per operation (not one combined guard) preserve the SM80 exp path on A100/A800 where the SM90 tanh guard would not fire.

::h2exp not h2expbf16. The CUDA Math API overloads h2exp for both __half2 and __nv_bfloat162. There is no separate h2expbf16 in the API.

Correctness

  • All four specializations fall through to the float path on host and below the arch threshold — identical to pre-patch behaviour.
  • __nv_bfloat162 reinterpret on Array<bfloat16_t,N>: Array<T,N,false> (defined in array_subbyte.h) uses uint32_t storage when sizeof_bits<T> * N is divisible by 32 — for 16-bit types, any even N. alignof(uint32_t) = 4 satisfies __nv_bfloat162's 4-byte alignment requirement. The pair loop only runs for even N ≥ 2; the odd-N residual uses a scalar path with no __nv_bfloat162 reinterpret.
  • N=0: N/2 == 0 and N%2 == 0; neither loop nor residual executes.
  • ::h2exp on __nv_bfloat162 is overloaded from the same function name as ::h2exp(__half2) — no separate bfloat16 variant exists or is needed.

Tests

New suite Fast_math_bfloat16 in test/unit/epilogue/thread/activation.cu — 11 test cases: 9 host tests (15 assertions) + 2 device tests. The bfloat16_t overload did not exist pre-patch; host builds silently convert to float via bfloat16_t::operator float() rather than dispatching the new specialization.

Test Kind Covers
fast_exp_zero host fast_exp(0.0f) → 1.0f ± 1 ULP
fast_exp_one host fast_exp(1.0f) → expf(1.0f) ± 2 ULP
fast_exp_neg_one host fast_exp(-1.0f) → expf(-1.0f) ± 2 ULP
fast_tanh_zero host fast_tanh(0.0f) → 0.0f ± 1 ULP
fast_tanh_one host fast_tanh(1.0f) → tanhf(1.0f) ± 2 ULP
fast_exp_op_array4_host host fast_exp_op<BF16x4> on {0,1,-1,2}, element-wise expf reference ± 2 ULP
fast_tanh_op_array4_host host fast_tanh_op<BF16x4> on {0,1,-1,2}, element-wise tanhf reference ± 2 ULP
fast_exp_op_array1_odd_residual host odd-N residual: N=1, exp path (host float fallback)
fast_tanh_op_array1_odd_residual host odd-N residual: N=1, tanh path (host float fallback)
device_fast_exp_op_array8_sm80 device fast_exp_op<BF16x8> on 128 elements, expf reference ± 2%; skips below SM80
device_fast_tanh_op_array8_sm90 device fast_tanh_op<BF16x8> on 128 elements, tanhf reference ± 2% + 1e-4 abs floor; skips below SM90

Host tests exercise the float fallback path (__CUDA_ARCH__ not defined at host compile time). Device tests compile to SM90 and use GTEST_SKIP() at runtime when the detected SM is below the required threshold — confirmed on P100 (SM 6.0), see evidence below.

Evidence

Verified on Kaggle (P100, SM 6.0, CUDA 12 toolchain, nvcc -arch=sm_90), kernel vittorialanzo/cutlass-bf16-fast-math-test v9. The Kaggle kernel runs standalone C++ verification programs that exercise the same algorithms as the GTest suite; they are not the CUTLASS GTest suite itself (which will run in CUTLASS CI against SM90+ hardware). Host unit tests run on CPU (float fallback path); PTX emission via --ptx is compile-time only.

Host unit tests (float fallback path)

=== Fast_math_bfloat16 host tests ===

  PASS  fast_exp_zero
  PASS  fast_exp_one
  PASS  fast_exp_neg_one
  PASS  fast_tanh_zero
  PASS  fast_tanh_one
  PASS  fast_exp_op_array4[0]
  PASS  fast_exp_op_array4[1]
  PASS  fast_exp_op_array4[2]
  PASS  fast_exp_op_array4[3]
  PASS  fast_tanh_op_array4[0]
  PASS  fast_tanh_op_array4[1]
  PASS  fast_tanh_op_array4[2]
  PASS  fast_tanh_op_array4[3]
  PASS  fast_exp_op_array1_odd_residual
  PASS  fast_tanh_op_array1_odd_residual

Results: 15 passed, 0 failed

SM90 PTX check (nvcc -arch=sm_90 --ptx)

  FOUND   'tanh.approx.bf16x2'  (array tanh x2)
  FOUND   'tanh.approx.bf16'  (scalar tanh)

No cvt.f32.bf16 in the PTX output. cvt.rn.bf16.f32 appears in the PTX file but only in the exp_k kernel (h2exp internals) — not in the tanh path.

The PTX file contains both tanh_k and exp_k kernels. Tanh-kernel lines only (tanh_k, N=8):

tanh.approx.bf16x2 %r6, %r7;
tanh.approx.bf16x2 %r8, %r9;
tanh.approx.bf16x2 %r10, %r11;
tanh.approx.bf16x2 %r12, %r13;

Device test graceful skip (SM 6.0 — P100)

=== Device: Tesla P100-PCIE-16GB (SM 6.0) ===

SKIP: SM90+ required for tanh.approx.bf16x2 path (this device: SM 6.0)
  fast_exp_op path requires SM80+ (CUDA 11+)
  fast_tanh_op path requires SM90+ (CUDA 12+)

Checklist

  • Tests added (11 test cases — 9 host tests with 15 assertions + 2 device tests with GTEST_SKIP() guards; bfloat16_t overload absent pre-patch so calls silently converted to float via implicit conversion rather than dispatching the new specialization)
  • No new dependencies
  • half_t specializations and existing tests unchanged
  • #include "cutlass/bfloat16.h" added to fast_math.h (required by new specializations)
  • PTX output verification (tanh.approx.bf16x2 present, cvt.f32.bf16 absent) — confirmed on Kaggle (P100, SM 6.0), nvcc -arch=sm_90

Agentic workflow

This contribution was produced with a multi-agent pipeline under my direction and supervision. I approved each stage at human checkpoints; the pipeline handled planning, implementation, testing, and adversarial review.

Stage 0 — Performance assessment (identifies the missing bf16 fast-math specializations):

flowchart TD
    S([STARTUP\nroster check · pre-flight]):::green --> O[ORCHESTRATOR\nmain thread · routing only]:::blue
    O --> R[RECON AGENT\nstatic analysis · 5 passes · hotspot index]:::blue
    R -->|HOTSPOT_INDEX| D[DISPATCHER\nmechanical signal routing]:::yellow
    D -->|ROUTING_MANIFEST| C[COMPLEXITY AGENT\nnested loops · O n2]:::red
    D -->|ROUTING_MANIFEST| M[MEMORY AGENT\nalloc in loops · GC pressure]:::red
    D -->|ROUTING_MANIFEST| IO[IO AGENT\nN+1 · blocking async · locks]:::red
    D -->|ROUTING_MANIFEST| DS[DATASTRUCTURES AGENT\nwrong DS · linear search]:::red
    D -->|ROUTING_MANIFEST| CA[CACHE AGENT\nredundant compute · regex in loop]:::red
    C -->|FINDINGS| SC[SCORING COLLECTOR\nmerge · deduplicate · priority_score · top 15]:::yellow
    M -->|FINDINGS| SC
    IO -->|FINDINGS| SC
    DS -->|FINDINGS| SC
    CA -->|FINDINGS| SC
    SC -->|SCORED_FINDINGS| PQ[PRE-QUALIFICATION GATE\ngit log · gh pr · grep\neliminate owned or intentional findings]:::purple
    PQ -->|cleared_findings| TR[TOT ROOT\nBranch A: correctness bugs\nBranch B: high-value speed\nBranch C: risk / low-priority]:::purple
    TR -->|branches A · B · C| AS[ASSESSMENT SYNTHESIZER\nread code · confirm or discard each finding]:::purple
    AS -->|ASSESSMENT_REPORT| HR([HUMAN REVIEW GATE\npipeline halts · human decides]):::green

    classDef green fill:#1a4a1a,stroke:#4CAF50,color:#4CAF50
    classDef blue fill:#1a2a4a,stroke:#5b9bd5,color:#9cc4e8
    classDef yellow fill:#3a3000,stroke:#d4a017,color:#d4c87a
    classDef red fill:#3a0000,stroke:#cc3333,color:#ff8888
    classDef purple fill:#2a1a4a,stroke:#9966cc,color:#cc99ff
Loading

Stage 1-3 — Contribution pipeline (plans, implements, reviews, delivers):

flowchart TD
    subgraph L1["Layer 1 — Planning"]
        PE[pattern-extractor] --> CE[coverage-engineer\nassessment mode]
        CE --> CP[contribution-planner]
        CP --> AR[architect-reviewer]
        AR -->|OBJECTIONS_RAISED| CP
        AR -->|NO_OBJECTIONS| impl
    end

    subgraph L2["Layer 2 — Implementation + Adversarial Loop"]
        impl([enter]) --> CA1[code-author\nimplementation pass]
        CA1 --> CA2[code-author\nregression tests]
        CA2 --> RE[readability-editor]
        RE --> VT[verbosity-trimmer]
        VT --> SL[scope-linter]
        SL --> SR[senior-reviewer\nATTACK]
        SL --> RT[red-team\nATTACK]
        SR --> SE[senior-engineer]
        RT --> SE
        SE --> TA[test-author]
        TA --> HM[hostile-maintainer\nRE_REVIEW]
        TA --> RT2[red-team\nRE_REVIEW]
        HM --> SE2[senior-engineer\nrevision]
        RT2 --> SE2
        SE2 --> MG[merge-gate]
        MG -->|APPROVE| pr
        MG -->|REVISE| SE
    end

    subgraph L3["Layer 3 — PR Delivery"]
        pr([enter]) --> PW[pr-writer]
        PW --> PRA[pr-adversary]
        PRA -->|REVISIONS_REQUESTED| PW
        PRA -->|PR_READY| LA[license-auditor]
        LA --> done([human approval])
    end
Loading

VittoriaLanzo and others added 2 commits May 15, 2026 23:04
Two new GTest cases in Fast_math_bfloat16:
- device_fast_exp_op_array8_sm80: runs fast_exp_op<Array<bfloat16_t,8>>
  on device via test_Epilogue_thread_activation; skips below SM80
- device_fast_tanh_op_array8_sm90: runs fast_tanh_op<Array<bfloat16_t,8>>
  on device; skips below SM90 (tanh.approx.bf16x2 path)
@VittoriaLanzo VittoriaLanzo marked this pull request as ready for review May 16, 2026 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant