From 0560fa7b0f04bced77a20cfa6434e71125e1ff84 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 22 Apr 2026 11:15:02 +0800 Subject: [PATCH 1/2] add shape to range/special (#15862) --- tinygrad/uop/ops.py | 8 +++++--- tinygrad/viz/serve.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ae2125b5a60cb..9e220b24762bb 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -211,8 +211,8 @@ def ptrdtype(self) -> PtrDType: def _shape(self) -> tuple[sint, ...]|None: match self.op: # late ops don't have shape - case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ + case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ + Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \ Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION: return None @@ -241,7 +241,9 @@ def _shape(self) -> tuple[sint, ...]|None: return self.src[0].shape[len(self.src[1:]):] # some ops init the shape - case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return () + case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return () + # TODO: VCONST should have the shape of the arg + case Ops.VCONST: return () case Ops.BUFFER: return (self.arg,) case Ops.BUFFER_VIEW: return (self.arg[0],) case Ops.CUSTOM_FUNCTION: return None diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index d2a056cb9b66b..4978423cfd035 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -113,7 +113,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: # always exclude DEVICE/CONST/UNIQUE if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u) - if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u) if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u) # exclude RESHAPE/EXPAND that only serve to broadcast a CONST if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u) From 87378331e816d8c9ed71328569efa08f6bcb50fc Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 22 Apr 2026 11:58:37 +0800 Subject: [PATCH 2/2] llama: fused mul quantize fp8 (#15863) --- examples/mlperf/models/flat_llama.py | 9 ++++ extra/amax/cast_amax.py | 64 +++++++++++++++++++++--- extra/amax/fused_mul_quantize_fp8.cpp | 71 +++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 8 deletions(-) create mode 100644 extra/amax/fused_mul_quantize_fp8.cpp diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 1c478e182d7f7..508f4ad580a64 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -140,6 +140,15 @@ def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tens x = x * attention_norm xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv) + if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1): + from extra.amax.cast_amax import fused_mul_quantize_fp8 + amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device) + x_fp8, x_inv_scale, new_amax_xqkv = fused_mul_quantize_fp8(x, attention_norm, amax_s, FP8_DTYPE) + xqkv, *ret = matmul(None, wqkv, w_inv_scale=s_qkv, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax_xqkv) + else: + x = x * attention_norm + xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv) + new_amaxs.extend(ret[:1]) saves.extend(ret[1:] + [xqkv]) xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim) diff --git a/extra/amax/cast_amax.py b/extra/amax/cast_amax.py index 6369b18144dbd..098337ba4654d 100644 --- a/extra/amax/cast_amax.py +++ b/extra/amax/cast_amax.py @@ -15,12 +15,19 @@ def _compile(cpp_name:str, n_elems:int, hidden:int): def _shard_shape(shape:tuple, axis:int, ndev:int) -> list: s = list(shape); s[axis] //= ndev; return s +def _scalar_amax(amax_buf:Tensor) -> Tensor: + if isinstance(amax_buf.device, tuple): + from examples.mlperf.models.flat_llama import _local_abs_max + return _local_abs_max(amax_buf).detach() + return amax_buf.max().detach() + +# ** fused silu*mul -> fp8 cast + amax (w13 layout) + @functools.cache def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp: hidden = xw13.shape[2] // 2 n_elems = xw13.shape[0] * xw13.shape[1] * hidden threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") - # read 2*N bf16 (xw13) + N bf16 (grad_x2) + 1 scalar; write 2*N bf16 (grad_xw13) mem = n_elems * 2 * 5 sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups, arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem))) @@ -33,7 +40,6 @@ def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state: hidden = xw13.shape[2] // 2 n_elems = xw13.shape[0] * xw13.shape[1] * hidden threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") - # read 2*N bf16 + 1 scalar, write N fp8 + NUM_WG bf16 mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2 sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups, arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem))) @@ -42,7 +48,6 @@ def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state: UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib))) def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp): - # kernel.src[1:] is (fp8_out, amax_buf, xw13, amax_state); only xw13 needs a grad _, _, xw13, amax_state = kernel.src[1:] device = xw13.device if isinstance(device, tuple): @@ -76,10 +81,53 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname) fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13) - # per-device scalar amax (no cross-device allreduce, matches _local_abs_max semantics) - if isinstance(amax_buf.device, tuple): - from examples.mlperf.models.flat_llama import _local_abs_max - new_amax = _local_abs_max(amax_buf).detach() - else: new_amax = amax_buf.max().detach() inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal() + return fp8_out, inv_scale, _scalar_amax(amax_buf) + +# ** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize) + +@functools.cache +def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp: + MBS, SEQ, HIDDEN = x.shape + n_elems = MBS * SEQ * HIDDEN + threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") + mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2 + sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups, + arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem))) + src, lib = _compile("fused_mul_quantize_fp8.cpp", n_elems, HIDDEN) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), + UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib))) + +def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp): + # inputs: (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight + _, _, x_u, weight_u, amax_state_u = kernel.src[1:] + device = x_u.device + grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16) + x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device) + scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8) + grad_scaled = grad_t.float() * scale + # grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate + grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16) + grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16) + return (None, None, grad_x.uop, grad_weight.uop, None) + +def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]: + # (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax). + assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16 + assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}" + MBS, SEQ, HIDDEN = x.shape + if isinstance(x.device, tuple): + axis, ndev = x.uop.axis, len(x.device) + assert axis in (0, 1), f"unsupported sharding axis={axis}" + fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=x.device).uop.multi(axis), device=x.device) + amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device) + dname = x.device[0].split(":")[0] + else: + fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device) + amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device) + dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device + fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname) + fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_mul_quantize_fp8_bwd) + new_amax = _scalar_amax(amax_buf) + inv_scale = (new_amax.float() + 1e-8) / FP8_MAX return fp8_out, inv_scale, new_amax diff --git a/extra/amax/fused_mul_quantize_fp8.cpp b/extra/amax/fused_mul_quantize_fp8.cpp new file mode 100644 index 0000000000000..d460942dca4a7 --- /dev/null +++ b/extra/amax/fused_mul_quantize_fp8.cpp @@ -0,0 +1,71 @@ +#include +#include +#include + +#ifndef N_ELEMS +#define N_ELEMS 67108864 +#endif +#ifndef HIDDEN +#define HIDDEN 4096 +#endif +#ifndef NUM_WG +#define NUM_WG 1024 +#endif +#ifndef THREADS_PER_WG +#define THREADS_PER_WG 256 +#endif + +constexpr int VEC = 8; +constexpr float FP8_MAX = 448.0f; + +static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC"); +static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC"); + +extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void +fused_mul_quantize_fp8( + __hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS + __hip_bfloat16* __restrict__ amax_buf, // bf16, NUM_WG + const __hip_bfloat16* __restrict__ x, // bf16, N_ELEMS + const __hip_bfloat16* __restrict__ weight, // bf16, HIDDEN (per-hidden scale) + const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar +{ + __shared__ float sdata[THREADS_PER_WG]; + + const int tid = threadIdx.x; + const int wg = blockIdx.x; + const int gid = wg * THREADS_PER_WG + tid; + const int stride_elems = NUM_WG * THREADS_PER_WG * VEC; + + const float scale = FP8_MAX / (static_cast(*amax_state) + 1e-8f); + float local_max = 0.0f; + + for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) { + const int h = base % HIDDEN; // 0..HIDDEN-VEC, 8-aligned (since base is 8-aligned and HIDDEN divides VEC) + float4 x_raw = *reinterpret_cast(&x[base]); + float4 w_raw = *reinterpret_cast(&weight[h]); + + const __hip_bfloat16 *xi = reinterpret_cast(&x_raw); + const __hip_bfloat16 *wi = reinterpret_cast(&w_raw); + + __hip_fp8_storage_t out[VEC]; + #pragma unroll + for (int i = 0; i < VEC; i++) { + const float val = static_cast(xi[i]) * static_cast(wi[i]); + local_max = fmaxf(local_max, fabsf(val)); + const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, val * scale)); + out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3); + } + + *reinterpret_cast(&fp8_out[base]) = *reinterpret_cast(out); + } + + // LDS tree-reduce per-WG amax + sdata[tid] = local_max; + __syncthreads(); + for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) { + if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]); + __syncthreads(); + } + + if (tid == 0) amax_buf[wg] = static_cast<__hip_bfloat16>(sdata[0]); +}