Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions examples/mlperf/models/flat_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,20 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax

def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None) -> tuple[Tensor,...]:
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
return (x @ w.T,)
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if x_fp8 is None: x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale, x_new_amax, x_fp8, w

def matmul_fp8_precomputed(x_fp8:Tensor, x_inv_scale:Tensor, x_new_amax:Tensor, w:Tensor, w_inv_scale:Tensor) -> tuple[Tensor,...]:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_inv_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
return x_fp8.dot(w.T, dtype=dtypes.float) * x_inv_scale * w_inv_scale, x_new_amax, x_fp8, w

def _rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
x = x_in.float()
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
Expand Down Expand Up @@ -180,10 +175,14 @@ def feed_forward(self, x:Tensor, ffn_norm:Tensor, w13:Tensor, w2:Tensor,
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [x_w13])

x_w1 = x_w13[..., :self.hidden_dim]
x_w3 = x_w13[..., self.hidden_dim:]

out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
if FP8 and getenv("FUSED_SILU_W13", 1):
from extra.amax.cast_amax import fused_quantize_fp8_w13
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
else:
x_w1, x_w3 = x_w13[..., :self.hidden_dim], x_w13[..., self.hidden_dim:]
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [out])
return (out, *new_amaxs, *saves)
Expand Down
85 changes: 85 additions & 0 deletions extra/amax/cast_amax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler

FP8_MAX = 448.0
NUM_WG, THREADS_PER_WG = 1024, 256

def _compile(cpp_name:str, n_elems:int, hidden:int):
src = (pathlib.Path(__file__).parent/cpp_name).read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)

def _shard_shape(shape:tuple, axis:int, ndev:int) -> list:
s = list(shape); s[axis] //= ndev; return s

@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 (xw13) + N bf16 (grad_x2) + 1 scalar; write 2*N bf16 (grad_xw13)
mem = n_elems * 2 * 5
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
src, lib = _compile("cast_amax_bwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))

@functools.cache
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 + 1 scalar, write N fp8 + NUM_WG bf16
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
src, lib = _compile("cast_amax_fwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))

def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
# kernel.src[1:] is (fp8_out, amax_buf, xw13, amax_state); only xw13 needs a grad
_, _, xw13, amax_state = kernel.src[1:]
device = xw13.device
if isinstance(device, tuple):
axis, ndev = xw13.axis, len(device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
grad_xw13 = Tensor(Tensor.invalid(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
else:
grad_xw13 = Tensor.invalid(*xw13.shape, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t, Tensor(amax_state, device=device), fxn=fxn)
return (None, None, grad_xw13.uop, None)

def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax).
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
HIDDEN = H2 // 2
if isinstance(xw13.device, tuple):
axis, ndev = xw13.uop.axis, len(xw13.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalid(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device)
amax_buf = Tensor(Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device)
dname = xw13.device[0].split(":")[0]
else:
fp8_out = Tensor.invalid(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
amax_buf = Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
# per-device scalar amax (no cross-device allreduce, matches _local_abs_max semantics)
if isinstance(amax_buf.device, tuple):
from examples.mlperf.models.flat_llama import _local_abs_max
new_amax = _local_abs_max(amax_buf).detach()
else: new_amax = amax_buf.max().detach()
inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal()
return fp8_out, inv_scale, new_amax
68 changes: 68 additions & 0 deletions extra/amax/cast_amax_bwd_w13.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>

#ifndef N_ELEMS
#define N_ELEMS 234881024
#endif
#ifndef HIDDEN
#define HIDDEN 14336
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif

constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;

static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC");

extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_silu_mul_bwd_w13(
__hip_bfloat16* __restrict__ grad_xw13_out, // bf16, 2*N_ELEMS (interleaved layout)
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS (interleaved)
const __hip_bfloat16* __restrict__ grad_x2, // bf16, N_ELEMS
const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar
{
const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;

const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);

for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
const int outer = base / HIDDEN;
const int inner = base % HIDDEN;
const int xw1_off = outer * 2 * HIDDEN + inner;
const int xw3_off = xw1_off + HIDDEN;

float4 x1_raw = *reinterpret_cast<const float4*>(&xw13[xw1_off]);
float4 x3_raw = *reinterpret_cast<const float4*>(&xw13[xw3_off]);
float4 g_raw = *reinterpret_cast<const float4*>(&grad_x2[base]);

const __hip_bfloat16 *x1 = reinterpret_cast<const __hip_bfloat16*>(&x1_raw);
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);
const __hip_bfloat16 *gv = reinterpret_cast<const __hip_bfloat16*>(&g_raw);

__hip_bfloat16 out1[VEC], out3[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float f1 = static_cast<float>(x1[i]);
const float f3 = static_cast<float>(x3[i]);
const float fg = static_cast<float>(gv[i]);
const float sig = 1.0f / (1.0f + __expf(-f1));
const float silu = f1 * sig;
const float silu_prime = sig + silu * (1.0f - sig);
const float gs = fg * scale;
out1[i] = static_cast<__hip_bfloat16>(gs * silu_prime * f3);
out3[i] = static_cast<__hip_bfloat16>(gs * silu);
}

*reinterpret_cast<float4*>(&grad_xw13_out[xw1_off]) = *reinterpret_cast<float4*>(out1);
*reinterpret_cast<float4*>(&grad_xw13_out[xw3_off]) = *reinterpret_cast<float4*>(out3);
}
}
79 changes: 79 additions & 0 deletions extra/amax/cast_amax_fwd_w13.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>

#ifndef N_ELEMS
#define N_ELEMS 234881024
#endif
#ifndef HIDDEN
#define HIDDEN 14336
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif

constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;

static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC (so VEC loads don't straddle block boundary)");

extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_silu_mul_cast_amax_w13(
__hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS
__hip_bfloat16* __restrict__ amax_buf, // bf16, NUM_WG (per-WG amaxes)
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS
const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar
{
__shared__ float sdata[THREADS_PER_WG];

const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;

const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
float local_max = 0.0f;

// grid-stride over 8-element groups
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
// interleaved xw13 layout: xw1 and xw3 are not contiguous halves
const int outer = base / HIDDEN;
const int inner = base % HIDDEN;
const int xw1_off = outer * 2 * HIDDEN + inner;
const int xw3_off = xw1_off + HIDDEN;

float4 x1_raw = *reinterpret_cast<const float4*>(&xw13[xw1_off]);
float4 x3_raw = *reinterpret_cast<const float4*>(&xw13[xw3_off]);

const __hip_bfloat16 *x1 = reinterpret_cast<const __hip_bfloat16*>(&x1_raw);
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);

__hip_fp8_storage_t out[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float f1 = static_cast<float>(x1[i]);
const float f3 = static_cast<float>(x3[i]);
const float silu = f1 / (1.0f + __expf(-f1));
const float x2 = silu * f3;
local_max = fmaxf(local_max, fabsf(x2));
const float x_scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, x2 * scale));
out[i] = __hip_cvt_float_to_fp8(x_scaled, __HIP_SATFINITE, __HIP_E4M3);
}

*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
}

// LDS tree reduction: per-workgroup amax
sdata[tid] = local_max;
__syncthreads();
for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
__syncthreads();
}

if (tid == 0) amax_buf[wg] = static_cast<__hip_bfloat16>(sdata[0]);
}
19 changes: 5 additions & 14 deletions test/unit/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_assign_changes_alt(self, realize=False):
self.assertNotEqual(a.item(), b.item())
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)

@unittest.skip("assign to contiguous shouldn't change the base buffer")
def test_assign_changes_buffer_alt(self):
a, b = [Tensor(Tensor(0).contiguous().realize().uop.buf_uop) for _ in range(2)]
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
Expand Down Expand Up @@ -507,17 +506,6 @@ def test_assign_bitcast_different_size(self):
# TODO: broken now
np.testing.assert_equal(a.numpy(), [0]*8)

@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
a.realize()
oba1 = a.uop.base.output_buffer
a.assign(a.cast(dtypes.int32).realize())
a.realize()
oba2 = a.uop.base.output_buffer
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))

def test_assign_dtype_mismatch(self):
# assign should not implicitly cast dtypes - this can lose precision
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
Expand Down Expand Up @@ -684,7 +672,6 @@ def test_read_between_writes(self):
self.assertEqual(r1.item(), 4)
self.assertEqual(r2.item(), 8)

@unittest.skip("TODO: this is broken")
def test_write_read_write_chain(self):
"""Write, read, write chain - middle read must complete before second write."""
buf = Tensor.zeros(4).contiguous().realize()
Expand All @@ -694,7 +681,11 @@ def test_write_read_write_chain(self):
final_sum = buf.sum() # lazy read, should be 20
# Realize in "wrong" order - final first
self.assertEqual(final_sum.realize().item(), 20)
self.assertEqual(mid_sum.realize().item(), 12)
try:
self.assertEqual(mid_sum.realize().item(), 12)
except AssertionError:
# TODO: this is wrong
self.assertEqual(mid_sum.realize().item(), 20)

def test_slice_read_then_full_write(self):
"""Read from slice, then overwrite full buffer - WAR dependency works for full buffer assigns."""
Expand Down
Loading