Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/mlperf/models/flat_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def attention(self, x:Tensor, freqs_cis:Tensor, attention_norm:Tensor, wqkv:Tens
x = x * attention_norm

xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)
if FP8 and getenv("FUSED_NORM_MUL_QUANTIZE", 1):
from extra.amax.cast_amax import fused_mul_quantize_fp8
amax_s = amax_xqkv if amax_xqkv is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x.device)
x_fp8, x_inv_scale, new_amax_xqkv = fused_mul_quantize_fp8(x, attention_norm, amax_s, FP8_DTYPE)
xqkv, *ret = matmul(None, wqkv, w_inv_scale=s_qkv, x_fp8=x_fp8, x_scale=x_inv_scale, x_new_amax=new_amax_xqkv)
else:
x = x * attention_norm
xqkv, *ret = matmul(x, wqkv, amax_x=amax_xqkv, w_inv_scale=s_qkv)

new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [xqkv])
xqkv = xqkv.reshape(bsz, seqlen, self.n_kv_heads, self.n_rep + 2, self.head_dim)
Expand Down
64 changes: 56 additions & 8 deletions extra/amax/cast_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@ def _compile(cpp_name:str, n_elems:int, hidden:int):
def _shard_shape(shape:tuple, axis:int, ndev:int) -> list:
s = list(shape); s[axis] //= ndev; return s

def _scalar_amax(amax_buf:Tensor) -> Tensor:
if isinstance(amax_buf.device, tuple):
from examples.mlperf.models.flat_llama import _local_abs_max
return _local_abs_max(amax_buf).detach()
return amax_buf.max().detach()

# ** fused silu*mul -> fp8 cast + amax (w13 layout)

@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 (xw13) + N bf16 (grad_x2) + 1 scalar; write 2*N bf16 (grad_xw13)
mem = n_elems * 2 * 5
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
Expand All @@ -33,7 +40,6 @@ def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 + 1 scalar, write N fp8 + NUM_WG bf16
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
Expand All @@ -42,7 +48,6 @@ def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))

def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
# kernel.src[1:] is (fp8_out, amax_buf, xw13, amax_state); only xw13 needs a grad
_, _, xw13, amax_state = kernel.src[1:]
device = xw13.device
if isinstance(device, tuple):
Expand Down Expand Up @@ -76,10 +81,53 @@ def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[T
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
# per-device scalar amax (no cross-device allreduce, matches _local_abs_max semantics)
if isinstance(amax_buf.device, tuple):
from examples.mlperf.models.flat_llama import _local_abs_max
new_amax = _local_abs_max(amax_buf).detach()
else: new_amax = amax_buf.max().detach()
inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal()
return fp8_out, inv_scale, _scalar_amax(amax_buf)

# ** fused (x * weight) -> fp8 cast + amax (norm-mul-quantize)

@functools.cache
def _custom_mul_quantize_fp8(fp8_out:UOp, amax_buf:UOp, x:UOp, weight:UOp, amax_state:UOp, dname:str) -> UOp:
MBS, SEQ, HIDDEN = x.shape
n_elems = MBS * SEQ * HIDDEN
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 + HIDDEN * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, x.base, weight.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_mul_quantize_fp8_{n_elems}_h{HIDDEN}", estimates=Estimates(ops=3*n_elems, mem=mem)))
src, lib = _compile("fused_mul_quantize_fp8.cpp", n_elems, HIDDEN)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))

def _fused_mul_quantize_fp8_bwd(gradient:UOp, kernel:UOp):
# inputs: (fp8_out, amax_buf, x, weight, amax_state); grads for x and weight
_, _, x_u, weight_u, amax_state_u = kernel.src[1:]
device = x_u.device
grad_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
x_t, weight_t = Tensor(x_u, device=device), Tensor(weight_u, device=device)
scale = FP8_MAX / (Tensor(amax_state_u, device=device).float() + 1e-8)
grad_scaled = grad_t.float() * scale
# grad_x stays bf16 to avoid CSE materializing a (MBS, SEQ, HIDDEN) fp32 intermediate
grad_x = (grad_scaled * weight_t.float()).cast(dtypes.bfloat16)
grad_weight = (grad_scaled * x_t.float()).sum(axis=(0, 1)).cast(dtypes.bfloat16)
return (None, None, grad_x.uop, grad_weight.uop, None)

def fused_mul_quantize_fp8(x:Tensor, weight:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# (x * weight) -> fp8 + amax, delayed scaling. Returns (fp8, inv_scale, new_amax).
assert x.dtype == dtypes.bfloat16 and weight.dtype == dtypes.bfloat16
assert x.shape[-1] == weight.shape[-1], f"HIDDEN mismatch: x={x.shape}, weight={weight.shape}"
MBS, SEQ, HIDDEN = x.shape
if isinstance(x.device, tuple):
axis, ndev = x.uop.axis, len(x.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalids(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=x.device).uop.multi(axis), device=x.device)
amax_buf = Tensor(Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device).uop.multi(0), device=x.device)
dname = x.device[0].split(":")[0]
else:
fp8_out = Tensor.invalids(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=x.device)
amax_buf = Tensor.invalids(NUM_WG, dtype=dtypes.bfloat16, device=x.device)
dname = x.device.split(":")[0] if isinstance(x.device, str) else x.device
fxn = functools.partial(_custom_mul_quantize_fp8, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, x, weight, amax_state, fxn=fxn, grad_fxn=_fused_mul_quantize_fp8_bwd)
new_amax = _scalar_amax(amax_buf)
inv_scale = (new_amax.float() + 1e-8) / FP8_MAX
return fp8_out, inv_scale, new_amax
71 changes: 71 additions & 0 deletions extra/amax/fused_mul_quantize_fp8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>

#ifndef N_ELEMS
#define N_ELEMS 67108864
#endif
#ifndef HIDDEN
#define HIDDEN 4096
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif

constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;

static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC");

extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_mul_quantize_fp8(
__hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS
__hip_bfloat16* __restrict__ amax_buf, // bf16, NUM_WG
const __hip_bfloat16* __restrict__ x, // bf16, N_ELEMS
const __hip_bfloat16* __restrict__ weight, // bf16, HIDDEN (per-hidden scale)
const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar
{
__shared__ float sdata[THREADS_PER_WG];

const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;

const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
float local_max = 0.0f;

for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
const int h = base % HIDDEN; // 0..HIDDEN-VEC, 8-aligned (since base is 8-aligned and HIDDEN divides VEC)
float4 x_raw = *reinterpret_cast<const float4*>(&x[base]);
float4 w_raw = *reinterpret_cast<const float4*>(&weight[h]);

const __hip_bfloat16 *xi = reinterpret_cast<const __hip_bfloat16*>(&x_raw);
const __hip_bfloat16 *wi = reinterpret_cast<const __hip_bfloat16*>(&w_raw);

__hip_fp8_storage_t out[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float val = static_cast<float>(xi[i]) * static_cast<float>(wi[i]);
local_max = fmaxf(local_max, fabsf(val));
const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, val * scale));
out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3);
}

*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
}

// LDS tree-reduce per-WG amax
sdata[tid] = local_max;
__syncthreads();
for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
__syncthreads();
}

if (tid == 0) amax_buf[wg] = static_cast<__hip_bfloat16>(sdata[0]);
}
8 changes: 5 additions & 3 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def ptrdtype(self) -> PtrDType:
def _shape(self) -> tuple[sint, ...]|None:
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None

Expand Down Expand Up @@ -241,7 +241,9 @@ def _shape(self) -> tuple[sint, ...]|None:
return self.src[0].shape[len(self.src[1:]):]

# some ops init the shape
case Ops.CONST | Ops.VCONST | Ops.DEFINE_VAR | Ops.BIND: return ()
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
# TODO: VCONST should have the shape of the arg
case Ops.VCONST: return ()
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.CUSTOM_FUNCTION: return None
Expand Down
1 change: 0 additions & 1 deletion tinygrad/viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
Expand Down
Loading