diff --git a/kernels/fused_split_gdr_update_ksplit2_flyc.py b/kernels/fused_split_gdr_update_ksplit2_flyc.py new file mode 100644 index 00000000..1f7af39f --- /dev/null +++ b/kernels/fused_split_gdr_update_ksplit2_flyc.py @@ -0,0 +1,688 @@ +"""FlyDSL fused split-GDR update-forward kernel (ksplit2).""" + +from flydsl.dialects.ext import flir, arith, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext import memref, rocdl +from flydsl.runtime.device import get_rocm_arch +from flydsl.utils import SmemAllocator +import _mlir.extras.types as T + + +KERNEL_NAME = "fused_split_gdr_update_ksplit2_flyc_kernel" +SOFTPLUS_BETA = 1.0 +SOFTPLUS_THRESHOLD = 20.0 + + +def build_fused_split_gdr_update_ksplit2_flyc_module( + B: int, + T_seq: int, + H: int, + HV: int, + K: int, + V: int, + N_STATE: int, + key_dim: int | None = None, + value_dim: int | None = None, + dtype_str: str = "f32", + BV: int = 64, + softplus_beta: float = SOFTPLUS_BETA, + softplus_threshold: float = SOFTPLUS_THRESHOLD, + use_qk_l2norm_in_kernel: bool = False, +): + """Build split-GDR ksplit2 FlyDSL module.""" + if dtype_str not in {"f32", "bf16"}: + raise NotImplementedError("Supported dtype_str values are 'f32' and 'bf16'.") + if HV % H != 0: + raise ValueError("HV must be divisible by H for hv->h mapping.") + if K % 4 != 0: + raise ValueError("K must be divisible by 4 for swizzled state layout.") + if BV <= 0 or BV % 2 != 0: + raise ValueError("BV must be positive and divisible by 2 for ksplit2 mapping.") + + block_threads = min(BV, 2 * V) + if block_threads <= 0 or block_threads % 2 != 0: + raise ValueError("block_threads must be positive and divisible by 2.") + + if key_dim is None: + key_dim = H * K + if value_dim is None: + value_dim = HV * V + mixed_dim = 2 * key_dim + value_dim + + gpu_arch = get_rocm_arch() + compute_type = T.f32() + elem_type = T.f32() if dtype_str == "f32" else T.bf16() + state_elem_type = T.f32() + scale = K ** (-0.5) + BT = B * T_seq + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + _asv = arith.as_value + _extf = flir.arith.extf + + def _load_mixed_qkv_fp32(mixed_qkv, i_n, col, t_idx): + val = memref.load(mixed_qkv, [_asv(i_n), _asv(col), _asv(t_idx)]) + if dtype_str != "f32": + val = _extf(compute_type, _asv(val)) + return val + + def _load_state_lane(initial_state_source, state_idx, i_hv, kg, iv_safe, lane_comp): + return memref.load( + initial_state_source, + [ + _asv(state_idx), + _asv(i_hv), + _asv(kg), + _asv(iv_safe), + _asv(arith.index(lane_comp)), + ], + ) + + def _store_state_lane( + initial_state_source, value, state_idx, i_hv, kg, iv_safe, lane_comp + ): + memref.store( + _asv(value), + initial_state_source, + [ + _asv(state_idx), + _asv(i_hv), + _asv(kg), + _asv(iv_safe), + _asv(arith.index(lane_comp)), + ], + ) + + def _add4(v0, v1, v2, v3, fm_fast): + sum01 = flir.arith.AddFOp(_asv(v0), _asv(v1), fastmath=fm_fast).result + sum23 = flir.arith.AddFOp(_asv(v2), _asv(v3), fastmath=fm_fast).result + return flir.arith.AddFOp(_asv(sum01), _asv(sum23), fastmath=fm_fast).result + + def _dot4(a0, a1, a2, a3, b0, b1, b2, b3, fm_fast): + prod0 = flir.arith.MulFOp(_asv(a0), _asv(b0), fastmath=fm_fast).result + prod1 = flir.arith.MulFOp(_asv(a1), _asv(b1), fastmath=fm_fast).result + prod2 = flir.arith.MulFOp(_asv(a2), _asv(b2), fastmath=fm_fast).result + prod3 = flir.arith.MulFOp(_asv(a3), _asv(b3), fastmath=fm_fast).result + return _add4(prod0, prod1, prod2, prod3, fm_fast) + + def _sumsq4(v0, v1, v2, v3, fm_fast): + sq0 = flir.arith.MulFOp(_asv(v0), _asv(v0), fastmath=fm_fast).result + sq1 = flir.arith.MulFOp(_asv(v1), _asv(v1), fastmath=fm_fast).result + sq2 = flir.arith.MulFOp(_asv(v2), _asv(v2), fastmath=fm_fast).result + sq3 = flir.arith.MulFOp(_asv(v3), _asv(v3), fastmath=fm_fast).result + return _add4(sq0, sq1, sq2, sq3, fm_fast) + + def _smem_load4(smem_kq, idx0): + idx1 = arith.as_value( + flir.arith.AddIOp(arith.as_value(idx0), arith.as_value(arith.index(1))).result + ) + idx2 = arith.as_value( + flir.arith.AddIOp(arith.as_value(idx0), arith.as_value(arith.index(2))).result + ) + idx3 = arith.as_value( + flir.arith.AddIOp(arith.as_value(idx0), arith.as_value(arith.index(3))).result + ) + v0 = smem_kq.load([arith.as_value(idx0)]) + v1 = smem_kq.load([arith.as_value(idx1)]) + v2 = smem_kq.load([arith.as_value(idx2)]) + v3 = smem_kq.load([arith.as_value(idx3)]) + return v0, v1, v2, v3 + + class _FusedSplitGdrUpdateKSplit2(flir.MlirModule): + GPU_MODULE_NAME = f"fused_split_gdr_update_ksplit2_flyc_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + _state["smem_kq"] = allocator.allocate_array(compute_type, K) + allocator.finalize() + + @flir.kernel + def fused_split_gdr_update_ksplit2_flyc_kernel( + self: flir.T.i64, + mixed_qkv: lambda: T.memref(B, mixed_dim, T_seq, elem_type), + A_log: lambda: T.memref(HV, T.f32()), + a: lambda: T.memref(BT, HV, elem_type), + dt_bias: lambda: T.memref(HV, elem_type), + b_gate: lambda: T.memref(BT, HV, elem_type), + initial_state_source: lambda: T.memref(N_STATE, HV, K // 4, V, 4, state_elem_type), + initial_state_indices: lambda: T.memref(B, T.i32()), + o: lambda: T.memref(B, T_seq, HV, V, elem_type), + ): + # Kernel math per timestep: + # 1) g = -exp(A_log) * softplus(a + dt_bias), beta = sigmoid(b_gate) + # 2) v <- (v - * exp(g)) * beta + # 3) h <- h * exp(g) + k * v + # 4) o <- + i_v_tile = flir.const_index(flir.block_idx("y")) + i_nh = flir.const_index(flir.block_idx("z")) + c_hv = arith.as_value(arith.index(HV)) + i_n = arith.as_value(flir.arith.DivUIOp(arith.as_value(i_nh), c_hv).result) + i_hv = arith.as_value(flir.arith.RemUIOp(arith.as_value(i_nh), c_hv).result) + c_hv_per_h = arith.as_value(arith.index(HV // H)) + i_h = arith.as_value( + flir.arith.DivUIOp(arith.as_value(i_hv), c_hv_per_h).result + ) + + tid = flir.const_index(flir.thread_idx("x")) + c_bv = arith.as_value(arith.index(block_threads // 2)) + iv_base = arith.as_value( + flir.arith.MulIOp(arith.as_value(i_v_tile), c_bv).result + ) + lane_i32 = arith.as_value( + flir.arith.IndexCastOp(T.i32(), arith.as_value(tid)).result + ) + c_one_i32 = arith.constant(1, type=T.i32()) + v_idx_i32 = flir.arith.ShRUIOp( + arith.as_value(lane_i32), arith.as_value(c_one_i32) + ).result + v_idx = arith.as_value( + flir.arith.IndexCastOp(T.index(), arith.as_value(v_idx_i32)).result + ) + iv_global = arith.as_value( + flir.arith.AddIOp(arith.as_value(iv_base), arith.as_value(v_idx)).result + ) + c_V = arith.as_value(arith.index(V)) + c_zero_idx = arith.as_value(arith.index(0)) + iv_valid = arith.CmpIOp( + arith.CmpIPredicate.ult, + arith.as_value(iv_global), + arith.as_value(c_V), + ) + iv_safe = arith.select( + arith.as_value(iv_valid), + arith.as_value(iv_global), + arith.as_value(c_zero_idx), + ) + lane_lsb = arith.andi(arith.as_value(lane_i32), arith.as_value(c_one_i32)) + c_zero_i32 = arith.constant(0, type=T.i32()) + c_two_i32 = arith.constant(2, type=T.i32()) + is_primary_lane = arith.CmpIOp( + arith.CmpIPredicate.eq, + arith.as_value(lane_lsb), + arith.as_value(c_zero_i32), + ) + partner_lane_i32 = arith.xori(arith.as_value(lane_i32), arith.as_value(c_one_i32)) + partner_lane_bytes = flir.arith.ShLIOp( + arith.as_value(partner_lane_i32), arith.as_value(c_two_i32) + ).result + + comp = compute_type + c_log2e = arith.constant(1.4426950408889634, type=comp) + c_ln2 = arith.constant(0.6931471805599453, type=comp) + c_scale = arith.constant(scale, type=comp) + c_neg_one = arith.constant(-1.0, type=comp) + c_one = arith.constant(1.0, type=comp) + c_softplus_beta = arith.constant(softplus_beta, type=comp) + c_inv_softplus_beta = arith.constant(1.0 / softplus_beta, type=comp) + c_softplus_threshold = arith.constant(softplus_threshold, type=comp) + c_eps = arith.constant(1e-6, type=comp) + fm_fast = flir.arith.FastMathFlags.fast + c_zero_f32 = arith.constant(0.0, type=comp) + c_T = arith.as_value(arith.index(T_seq)) + c_K_idx = arith.as_value(arith.index(K)) + c_V_idx = arith.as_value(arith.index(V)) + c_key_dim = arith.as_value(arith.index(key_dim)) + c_v_base = arith.as_value(arith.index(2 * key_dim)) + c_q_dim_off = arith.as_value( + flir.arith.MulIOp(arith.as_value(i_h), arith.as_value(c_K_idx)).result + ) + c_k_dim_off = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_key_dim), arith.as_value(c_q_dim_off)).result + ) + c_v_h_off = arith.as_value( + flir.arith.MulIOp(arith.as_value(i_hv), arith.as_value(c_V_idx)).result + ) + c_v_dim_off = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_v_base), arith.as_value(c_v_h_off)).result + ) + c_k_half = arith.as_value(arith.index(K // 2)) + k_base = arith.select( + arith.as_value(is_primary_lane), + arith.as_value(arith.index(0)), + arith.as_value(c_k_half), + ) + kg_start = arith.as_value( + flir.arith.DivUIOp(arith.as_value(k_base), arith.as_value(arith.index(4))).result + ) + + base_ptr = allocator.get_base() + smem_kq = _state["smem_kq"](base_ptr) + + state_idx_i32 = memref.load(initial_state_indices, [arith.as_value(i_n)]) + state_idx_nonneg = arith.CmpIOp( + arith.CmpIPredicate.sge, + arith.as_value(state_idx_i32), + arith.as_value(c_zero_i32), + ) + state_idx_i32_safe = arith.select( + arith.as_value(state_idx_nonneg), + arith.as_value(state_idx_i32), + arith.as_value(c_zero_i32), + ) + state_idx = arith.as_value( + flir.arith.IndexCastOp(T.index(), arith.as_value(state_idx_i32_safe)).result + ) + + h_vec_x = [c_zero_f32 for _ in range_constexpr(K // 8)] + h_vec_y = [c_zero_f32 for _ in range_constexpr(K // 8)] + h_vec_z = [c_zero_f32 for _ in range_constexpr(K // 8)] + h_vec_w = [c_zero_f32 for _ in range_constexpr(K // 8)] + h_planes = [h_vec_x, h_vec_y, h_vec_z, h_vec_w] + if state_idx_nonneg: + if iv_valid: + for kg_local in range_constexpr(K // 8): + kg = arith.as_value( + flir.arith.AddIOp( + arith.as_value(kg_start), arith.as_value(arith.index(kg_local)) + ).result + ) + for lane_comp in range_constexpr(4): + h_planes[lane_comp][kg_local] = _load_state_lane( + initial_state_source, + state_idx, + i_hv, + kg, + iv_safe, + lane_comp, + ) + + a_log_val = memref.load(A_log, [arith.as_value(i_hv)]) + dt_bias_val = memref.load(dt_bias, [arith.as_value(i_hv)]) + if dtype_str != "f32": + dt_bias_val = flir.arith.extf(comp, arith.as_value(dt_bias_val)) + neg_exp_a = flir.arith.MulFOp( + arith.as_value(c_neg_one), + arith.as_value( + flir.math.exp2( + arith.as_value( + flir.arith.MulFOp( + arith.as_value(a_log_val), arith.as_value(c_log2e), fastmath=fm_fast + ).result + ), + fastmath=fm_fast, + ) + ), + fastmath=fm_fast, + ).result + + m5 = flir.arith.MulIOp(arith.as_value(i_n), c_T).result + for t in range_constexpr(T_seq): + t_idx = arith.as_value(arith.index(t)) + row_ab = arith.as_value( + flir.arith.AddIOp(arith.as_value(m5), t_idx).result + ) + + # Phase 0: load current V lane and compute scalar gates (g, beta). + k_inv_norm = c_one + v_col = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_v_dim_off), arith.as_value(iv_safe)).result + ) + v_val = _load_mixed_qkv_fp32(mixed_qkv, i_n, v_col, t_idx) + v_val = arith.select( + arith.as_value(iv_valid), + arith.as_value(v_val), + arith.as_value(c_zero_f32), + ) + + a_val = memref.load(a, [arith.as_value(row_ab), arith.as_value(i_hv)]) + if dtype_str != "f32": + a_val = flir.arith.extf(comp, arith.as_value(a_val)) + x = flir.arith.AddFOp( + arith.as_value(a_val), arith.as_value(dt_bias_val), fastmath=fm_fast + ).result + beta_x = flir.arith.MulFOp( + arith.as_value(c_softplus_beta), arith.as_value(x), fastmath=fm_fast + ).result + x_log2e = flir.arith.MulFOp( + arith.as_value(beta_x), arith.as_value(c_log2e), fastmath=fm_fast + ).result + exp_x = flir.math.exp2(arith.as_value(x_log2e), fastmath=fm_fast) + one_plus_exp_x = flir.arith.AddFOp( + arith.as_value(c_one), arith.as_value(exp_x), fastmath=fm_fast + ).result + log2_one_plus_exp_x = flir.math.log2( + arith.as_value(one_plus_exp_x), fastmath=fm_fast + ) + ln_one_plus_exp_x = flir.arith.MulFOp( + arith.as_value(log2_one_plus_exp_x), + arith.as_value(c_ln2), + fastmath=fm_fast, + ).result + softplus_x = flir.arith.MulFOp( + arith.as_value(c_inv_softplus_beta), + arith.as_value(ln_one_plus_exp_x), + fastmath=fm_fast, + ).result + use_exp_branch = arith.CmpFOp( + arith.CmpFPredicate.OLE, + arith.as_value(beta_x), + arith.as_value(c_softplus_threshold), + ).result + softplus_x = arith.select( + use_exp_branch, arith.as_value(softplus_x), arith.as_value(x) + ) + exp_a_log_softplus = flir.arith.MulFOp( + arith.as_value(neg_exp_a), arith.as_value(softplus_x), fastmath=fm_fast + ).result + g_val = exp_a_log_softplus + + b_val = memref.load(b_gate, [arith.as_value(row_ab), arith.as_value(i_hv)]) + if dtype_str != "f32": + b_val = flir.arith.extf(comp, arith.as_value(b_val)) + neg_b = flir.arith.MulFOp( + arith.as_value(c_neg_one), arith.as_value(b_val), fastmath=fm_fast + ).result + neg_b_log2e = flir.arith.MulFOp( + arith.as_value(neg_b), arith.as_value(c_log2e), fastmath=fm_fast + ).result + exp_neg_b = flir.math.exp2(arith.as_value(neg_b_log2e), fastmath=fm_fast) + one_plus_exp_neg_b = flir.arith.AddFOp( + arith.as_value(c_one), arith.as_value(exp_neg_b), fastmath=fm_fast + ).result + beta_val = flir.arith.DivFOp( + arith.as_value(c_one), + arith.as_value(one_plus_exp_neg_b), + fastmath=fm_fast, + ).result + + g_log2e = flir.arith.MulFOp( + arith.as_value(g_val), arith.as_value(c_log2e), fastmath=fm_fast + ).result + exp_g = flir.math.exp2(arith.as_value(g_log2e), fastmath=fm_fast) + + # Phase 1: cooperative K load -> reduce (h·k) and optional K norm. + k_pair_base = arith.as_value( + flir.arith.MulIOp(arith.as_value(tid), arith.as_value(arith.index(2))).result + ) + k_pair_valid0 = arith.CmpIOp( + arith.CmpIPredicate.ult, + arith.as_value(k_pair_base), + arith.as_value(c_K_idx), + ) + if k_pair_valid0: + k_col0 = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_k_dim_off), arith.as_value(k_pair_base)).result + ) + k0 = _load_mixed_qkv_fp32(mixed_qkv, i_n, k_col0, t_idx) + smem_kq.store(k0, [arith.as_value(k_pair_base)]) + k_pair_base1 = arith.as_value( + flir.arith.AddIOp(arith.as_value(k_pair_base), arith.as_value(arith.index(1))).result + ) + k_pair_valid1 = arith.CmpIOp( + arith.CmpIPredicate.ult, + arith.as_value(k_pair_base1), + arith.as_value(c_K_idx), + ) + if k_pair_valid1: + k_col1 = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_k_dim_off), arith.as_value(k_pair_base1)).result + ) + k1 = _load_mixed_qkv_fp32(mixed_qkv, i_n, k_col1, t_idx) + smem_kq.store(k1, [arith.as_value(k_pair_base1)]) + gpu.barrier() + + acc_hk = arith.constant(0.0, type=comp) + k_sq_local = arith.constant(0.0, type=comp) + k_vecs = [[] for _ in range_constexpr(4)] + for kg_local in range_constexpr(K // 8): + k_idx0 = arith.as_value( + flir.arith.AddIOp( + arith.as_value(k_base), + arith.as_value(arith.index(kg_local * 4)), + ).result + ) + k0, k1, k2, k3 = _smem_load4(smem_kq, k_idx0) + k_vals = [k0, k1, k2, k3] + for lane_comp in range_constexpr(4): + k_vecs[lane_comp].append(k_vals[lane_comp]) + sum0123 = _dot4( + h_planes[0][kg_local], + h_planes[1][kg_local], + h_planes[2][kg_local], + h_planes[3][kg_local], + k0, + k1, + k2, + k3, + fm_fast, + ) + acc_hk = flir.arith.AddFOp( + arith.as_value(acc_hk), arith.as_value(sum0123), fastmath=fm_fast + ).result + if use_qk_l2norm_in_kernel: + sq0123 = _sumsq4(k0, k1, k2, k3, fm_fast) + k_sq_local = flir.arith.AddFOp( + arith.as_value(k_sq_local), arith.as_value(sq0123), fastmath=fm_fast + ).result + acc_hk_peer = rocdl.ds_bpermute( + T.i32(), + arith.as_value(partner_lane_bytes), + arith.as_value(flir.arith.bitcast(T.i32(), arith.as_value(acc_hk))), + ) + acc_hk_peer = flir.arith.bitcast(comp, arith.as_value(acc_hk_peer)) + acc_hk = flir.arith.AddFOp( + arith.as_value(acc_hk), arith.as_value(acc_hk_peer), fastmath=fm_fast + ).result + if use_qk_l2norm_in_kernel: + k_sq_peer = rocdl.ds_bpermute( + T.i32(), + arith.as_value(partner_lane_bytes), + arith.as_value(flir.arith.bitcast(T.i32(), arith.as_value(k_sq_local))), + ) + k_sq_peer = flir.arith.bitcast(comp, arith.as_value(k_sq_peer)) + k_sq = flir.arith.AddFOp( + arith.as_value(k_sq_local), arith.as_value(k_sq_peer), fastmath=fm_fast + ).result + k_inv_norm = flir.math.rsqrt( + arith.as_value( + flir.arith.AddFOp( + arith.as_value(k_sq), arith.as_value(c_eps), fastmath=fm_fast + ).result + ), + fastmath=fm_fast, + ) + decayed_dot = flir.arith.MulFOp( + arith.as_value(acc_hk), arith.as_value(exp_g), fastmath=fm_fast + ).result + if use_qk_l2norm_in_kernel: + decayed_dot = flir.arith.MulFOp( + arith.as_value(decayed_dot), arith.as_value(k_inv_norm), fastmath=fm_fast + ).result + v_val = flir.arith.SubFOp( + arith.as_value(v_val), arith.as_value(decayed_dot) + ).result + v_val = flir.arith.MulFOp( + arith.as_value(v_val), arith.as_value(beta_val), fastmath=fm_fast + ).result + + # Phase 2: recurrent state update h = h*exp(g) + k*v. + for kg_local in range_constexpr(K // 8): + k_vals = [ + k_vecs[0][kg_local], + k_vecs[1][kg_local], + k_vecs[2][kg_local], + k_vecs[3][kg_local], + ] + if use_qk_l2norm_in_kernel: + for lane_comp in range_constexpr(4): + k_vals[lane_comp] = flir.arith.MulFOp( + arith.as_value(k_vals[lane_comp]), + arith.as_value(k_inv_norm), + fastmath=fm_fast, + ).result + for lane_comp in range_constexpr(4): + kv = flir.arith.MulFOp( + arith.as_value(k_vals[lane_comp]), + arith.as_value(v_val), + fastmath=fm_fast, + ).result + h_decay = flir.arith.MulFOp( + arith.as_value(h_planes[lane_comp][kg_local]), + arith.as_value(exp_g), + fastmath=fm_fast, + ).result + h_planes[lane_comp][kg_local] = flir.arith.AddFOp( + arith.as_value(h_decay), + arith.as_value(kv), + fastmath=fm_fast, + ).result + + # Phase 3: cooperative Q load -> reduce (h·q) and optional Q norm. + if k_pair_valid0: + q_col0 = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_q_dim_off), arith.as_value(k_pair_base)).result + ) + q0 = _load_mixed_qkv_fp32(mixed_qkv, i_n, q_col0, t_idx) + smem_kq.store(q0, [arith.as_value(k_pair_base)]) + if k_pair_valid1: + q_col1 = arith.as_value( + flir.arith.AddIOp(arith.as_value(c_q_dim_off), arith.as_value(k_pair_base1)).result + ) + q1 = _load_mixed_qkv_fp32(mixed_qkv, i_n, q_col1, t_idx) + smem_kq.store(q1, [arith.as_value(k_pair_base1)]) + gpu.barrier() + + o_acc = arith.constant(0.0, type=comp) + q_sq_local = arith.constant(0.0, type=comp) + for kg_local in range_constexpr(K // 8): + q_idx0 = arith.as_value( + flir.arith.AddIOp( + arith.as_value(k_base), + arith.as_value(arith.index(kg_local * 4)), + ).result + ) + q0, q1, q2, q3 = _smem_load4(smem_kq, q_idx0) + q0s = flir.arith.MulFOp( + arith.as_value(q0), arith.as_value(c_scale), fastmath=fm_fast + ).result + q1s = flir.arith.MulFOp( + arith.as_value(q1), arith.as_value(c_scale), fastmath=fm_fast + ).result + q2s = flir.arith.MulFOp( + arith.as_value(q2), arith.as_value(c_scale), fastmath=fm_fast + ).result + q3s = flir.arith.MulFOp( + arith.as_value(q3), arith.as_value(c_scale), fastmath=fm_fast + ).result + sum0123 = _dot4( + h_planes[0][kg_local], + h_planes[1][kg_local], + h_planes[2][kg_local], + h_planes[3][kg_local], + q0s, + q1s, + q2s, + q3s, + fm_fast, + ) + o_acc = flir.arith.AddFOp( + arith.as_value(o_acc), arith.as_value(sum0123), fastmath=fm_fast + ).result + if use_qk_l2norm_in_kernel: + sq0123 = _sumsq4(q0, q1, q2, q3, fm_fast) + q_sq_local = flir.arith.AddFOp( + arith.as_value(q_sq_local), arith.as_value(sq0123), fastmath=fm_fast + ).result + o_acc_peer = rocdl.ds_bpermute( + T.i32(), + arith.as_value(partner_lane_bytes), + arith.as_value(flir.arith.bitcast(T.i32(), arith.as_value(o_acc))), + ) + o_acc_peer = flir.arith.bitcast(comp, arith.as_value(o_acc_peer)) + o_acc = flir.arith.AddFOp( + arith.as_value(o_acc), arith.as_value(o_acc_peer), fastmath=fm_fast + ).result + q_inv_norm = c_one + if use_qk_l2norm_in_kernel: + q_sq_peer = rocdl.ds_bpermute( + T.i32(), + arith.as_value(partner_lane_bytes), + arith.as_value(flir.arith.bitcast(T.i32(), arith.as_value(q_sq_local))), + ) + q_sq_peer = flir.arith.bitcast(comp, arith.as_value(q_sq_peer)) + q_sq = flir.arith.AddFOp( + arith.as_value(q_sq_local), arith.as_value(q_sq_peer), fastmath=fm_fast + ).result + q_inv_norm = flir.math.rsqrt( + arith.as_value( + flir.arith.AddFOp( + arith.as_value(q_sq), arith.as_value(c_eps), fastmath=fm_fast + ).result + ), + fastmath=fm_fast, + ) + o_acc = flir.arith.MulFOp( + arith.as_value(o_acc), arith.as_value(q_inv_norm), fastmath=fm_fast + ).result + out_elem = ( + o_acc + if dtype_str == "f32" + else flir.arith.truncf(elem_type, arith.as_value(o_acc)) + ) + # Phase 4: write output for primary lane only. + if iv_valid: + if is_primary_lane: + memref.store( + arith.as_value(out_elem), + o, + [ + arith.as_value(i_n), + arith.as_value(t_idx), + arith.as_value(i_hv), + arith.as_value(iv_safe), + ], + ) + + if iv_valid: + if state_idx_nonneg: + for kg_local in range_constexpr(K // 8): + kg = arith.as_value( + flir.arith.AddIOp( + arith.as_value(kg_start), arith.as_value(arith.index(kg_local)) + ).result + ) + for lane_comp in range_constexpr(4): + _store_state_lane( + initial_state_source, + h_planes[lane_comp][kg_local], + state_idx, + i_hv, + kg, + iv_safe, + lane_comp, + ) + + @flir.jit + def __call__( + self: flir.T.i64, + mixed_qkv: lambda: T.memref(B, mixed_dim, T_seq, elem_type), + A_log: lambda: T.memref(HV, T.f32()), + a: lambda: T.memref(BT, HV, elem_type), + dt_bias: lambda: T.memref(HV, elem_type), + b_gate: lambda: T.memref(BT, HV, elem_type), + initial_state_source: lambda: T.memref(N_STATE, HV, K // 4, V, 4, state_elem_type), + initial_state_indices: lambda: T.memref(B, T.i32()), + o: lambda: T.memref(B, T_seq, HV, V, elem_type), + ): + c1 = flir.const_index(1) + gx = flir.const_index(1) + gy = flir.const_index((V + (block_threads // 2) - 1) // (block_threads // 2)) + gz = flir.const_index(B * HV) + bx = flir.const_index(block_threads) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(gx, gy, gz), + block_size=(bx, c1, c1), + kernel_operands=[ + mixed_qkv, + A_log, + a, + dt_bias, + b_gate, + initial_state_source, + initial_state_indices, + o, + ], + ) + + return _FusedSplitGdrUpdateKSplit2() diff --git a/kernels/split_gdr_decode_hip.hip b/kernels/split_gdr_decode_hip.hip new file mode 100644 index 00000000..516ee824 --- /dev/null +++ b/kernels/split_gdr_decode_hip.hip @@ -0,0 +1,1586 @@ +/** + * Fused Split GDR (Gating Delta Rule) Decode Kernel - HIP Implementation. + * + * This file implements the HIP version of fused_split_gdr_update_kernel_v3 + * from fused_sigmoid_gating_recurrent.py. + * + * The kernel fuses: + * 1. Splitting mixed QKV tensor into Q, K, V + * 2. Sigmoid gating delta rule state update + * 3. Output computation + * + * Hardware target: AMD MI series GPU + * - 80 CUs + * - 4 SIMDs per CU + * - 64 threads per wavefront + * - Max 8 wavefronts per CU (= 2 WF/SIMD) + * - 512 VGPRs per SIMD + * + * ===================================================================== + * Swizzled State Layout: (N_states, HV, K/4, V, 4) + * ===================================================================== + * + * This layout enables BOTH cross-thread coalescing AND per-thread float4: + * + * Original (K, V): + * Thread 0 accesses: [k=0,v=0], [k=1,v=0], [k=2,v=0]... + * → stride = V floats = 512 bytes between consecutive k ❌ + * + * Swizzled (K/4, V, 4): + * Memory: [v=0,k=0:3][v=1,k=0:3][v=2,k=0:3]...[v=63,k=0:3]... + * Thread 0 loads float4 from [kg=0,v=0,0:3] → k=0,1,2,3 ✅ + * Thread 1 loads float4 from [kg=0,v=1,0:3] → k=0,1,2,3 ✅ + * All 64 threads access consecutive 1024 bytes ✅ + * + * Benefits: + * - Cross-thread coalescing: 64 threads × 16 bytes = 1024 bytes + * - Per-thread vectorization: float4 load = 4 floats in 1 instruction + * - 4x fewer memory load instructions + * + * ===================================================================== + */ + +#include +#include +#include + +#include +#include + +#include +#include + +// --------------------------------------------------------------------------- +// Helper device functions +// --------------------------------------------------------------------------- + +__device__ __forceinline__ float hip_softplus( + float x, float beta, float inv_beta, float threshold) +{ + float beta_x = beta * x; + return (beta_x <= threshold) ? inv_beta * logf(1.0f + expf(beta_x)) : x; +} + +__device__ __forceinline__ float hip_sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +// --------------------------------------------------------------------------- +// Vec2 bf16 load helper - loads 2 bf16 values as one uint32_t +// --------------------------------------------------------------------------- + +__device__ __forceinline__ void load_bf16x2( + float* __restrict__ smem, + const __hip_bfloat16* __restrict__ src, + int base, int K, int BK) +{ + if (base + 1 < K) { + uint32_t packed = *reinterpret_cast(src + base); + const __hip_bfloat16* v = reinterpret_cast(&packed); + smem[base] = static_cast(v[0]); + smem[base + 1] = static_cast(v[1]); + } else if (base < K) { + smem[base] = static_cast(src[base]); + if (base + 1 < BK) smem[base + 1] = 0.0f; + } else { + if (base < BK) smem[base] = 0.0f; + if (base + 1 < BK) smem[base + 1] = 0.0f; + } +} + +// --------------------------------------------------------------------------- +// Kernel: fused_split_gdr_update_kernel +// --------------------------------------------------------------------------- +// +// Template parameters: +// BK - tile size along K dimension (= head_dim, power of 2, must be divisible by 4) +// BV - block size / V-tile width (= 64, matches wavefront size) +// +// Grid: (NK, cdiv(V, BV), B * HV) NK always 1 +// Block: BV threads = 1 wavefront +// +// State layout: SWIZZLED (N_states, HV, K/4, V, 4) +// - Enables float4 vectorized loads with cross-thread coalescing +// --------------------------------------------------------------------------- + +template +__global__ __launch_bounds__(BV) +void fused_split_gdr_update_kernel( + // ----- Input: mixed QKV tensor (B, dim, T) ----- + const __hip_bfloat16* __restrict__ mixed_qkv, + // ----- Gating inputs ----- + const float* __restrict__ A_log, // (HV,) + const __hip_bfloat16* __restrict__ a, // (B*T, HV) + const __hip_bfloat16* __restrict__ dt_bias, // (HV,) + float softplus_beta, + float softplus_threshold, + const __hip_bfloat16* __restrict__ b_gate, // (B*T, HV) + // ----- Output ----- + __hip_bfloat16* __restrict__ o, // (B, T, HV, V) + // ----- State: SWIZZLED (N_states, HV, K/4, V, 4) ----- + float* __restrict__ h0_source, + const int32_t* __restrict__ h0_indices, // (B,) + // ----- Dimensions ----- + int T, + int key_dim, + int value_dim, + // Strides for mixed_qkv + int64_t stride_x_batch, + int64_t stride_x_dim, + int64_t stride_x_seq, + // Strides for output + int64_t stride_o_batch, + int64_t stride_o_seq, + int64_t stride_o_head, + int64_t stride_o_dim, + // Meta dimensions + int B, + int H, // num_heads_qk + int HV, // num_heads_v + int K, // head_dim (key) + int V_dim, // head_dim (value) + float scale +) { + // Compile-time check: BK must be divisible by 4 for float4 + static_assert(BK % 4 == 0, "BK must be divisible by 4 for float4 vectorization"); + constexpr int BK4 = BK / 4; // Number of float4 groups + + // ================================================================ + // 1. Thread / block indexing + // ================================================================ + const int i_v = blockIdx.y; // V-tile index (each tile = BV columns) + const int i_nh = blockIdx.z; // flattened (batch, head_v) + const int i_n = i_nh / HV; // batch index + const int i_hv = i_nh % HV; // value-head index + + const int GROUP_SIZE = HV / H; + const int i_h = i_hv / GROUP_SIZE; // corresponding QK head + + const int lane = threadIdx.x; // [0, BV) = [0, 64) + + // Global V column this thread owns + const int v_col = i_v * BV + lane; + const bool valid_v = (v_col < V_dim); + + // ================================================================ + // 2. LDS for cooperative K / Q vector sharing + // ================================================================ + __shared__ float smem[BK]; // BK floats for K or Q + + // ================================================================ + // 3. Per-thread hidden-state column (BK registers, stored as float4) + // ================================================================ + float4 h_vec[BK4]; + #pragma unroll + for (int i = 0; i < BK4; i++) { + h_vec[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } + + // ================================================================ + // 4. Load initial state from SWIZZLED layout: h0[idx, hv, kg, v_col, 0:4] + // Layout: (N_states, HV, K/4, V, 4) + // Address: base + idx*(HV*K4*V*4) + hv*(K4*V*4) + kg*(V*4) + v*4 + k_inner + // + // For float4 load: base + kg*(V*4) + v_col*4 + // All 64 threads loading kg=0: addresses are consecutive (coalesced!) + // ================================================================ + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + const float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + + // float4 vectorized loads - each thread loads from consecutive address + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + if (kg < K4) { + // h0_base[kg * V_dim + v_col] is a float4 + h_vec[kg] = h0_base[kg * V_dim + v_col]; + } + } + } + } + + // ================================================================ + // 5. Pre-compute loop-invariant gating values + // ================================================================ + const float neg_exp_A_log = -expf(A_log[i_hv]); + const float b_dt_bias = static_cast(dt_bias[i_hv]); + const float inv_softplus_beta = 1.0f / softplus_beta; + + // ================================================================ + // 6. Dimension offsets inside the mixed_qkv dim axis + // ================================================================ + const int q_dim_off = i_h * K; // Q start in dim + const int k_dim_off = key_dim + i_h * K; // K start in dim + const int v_dim_off = 2 * key_dim + i_hv * V_dim; // V start in dim + + // ================================================================ + // 7. Base pointers (per-batch) + // ================================================================ + const __hip_bfloat16* x_base = + mixed_qkv + static_cast(i_n) * stride_x_batch; + + const __hip_bfloat16* p_a = + a + static_cast(i_n) * T * HV + i_hv; + const __hip_bfloat16* p_b = + b_gate + static_cast(i_n) * T * HV + i_hv; + + __hip_bfloat16* p_o = + o + static_cast(i_n) * stride_o_batch + + static_cast(i_hv) * stride_o_head; + + // Check if vec2 load is possible (decode: stride_x_dim == 1) + const bool use_vec2 = (stride_x_dim == 1); + + // ================================================================ + // 8. Main timestep loop + // ================================================================ + for (int t = 0; t < T; t++) { + const __hip_bfloat16* x_t = + x_base + static_cast(t) * stride_x_seq; + + // ------------------------------------------------------------ + // 8a. Cooperatively load K[BK] into LDS + // ------------------------------------------------------------ + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + k_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(k_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // ------------------------------------------------------------ + // 8b. Load this thread's V element + // ------------------------------------------------------------ + float v_local = 0.0f; + if (valid_v) { + v_local = static_cast( + x_t[static_cast(v_dim_off + v_col) * stride_x_dim]); + } + + // ------------------------------------------------------------ + // 8c. Load gating scalars for this timestep + // ------------------------------------------------------------ + const float a_t = static_cast(p_a[static_cast(t) * HV]); + const float b_t = static_cast(p_b[static_cast(t) * HV]); + + // g = -exp(A_log) * softplus(a_t + dt_bias) + const float sp = hip_softplus( + a_t + b_dt_bias, softplus_beta, inv_softplus_beta, softplus_threshold); + const float g = neg_exp_A_log * sp; + const float exp_g = expf(g); + + // beta = sigmoid(b_t) + const float beta = hip_sigmoid(b_t); + + // ------------------------------------------------------------ + // 8d. K-vector L2 norm (warp-level reduction via __shfl_xor) + // ------------------------------------------------------------ + float k_inv_norm = 1.0f; + if constexpr (USE_QK_L2NORM) { + float k_sq = 0.0f; + for (int i = lane; i < BK; i += BV) { + k_sq += smem[i] * smem[i]; + } + for (int off = 32; off >= 1; off >>= 1) { + k_sq += __shfl_xor(k_sq, off, 64); + } + k_inv_norm = rsqrtf(k_sq + 1e-6f); + } + + // ------------------------------------------------------------ + // 8e. Decay: h *= exp(g) + // ------------------------------------------------------------ + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + h_vec[kg].x *= exp_g; + h_vec[kg].y *= exp_g; + h_vec[kg].z *= exp_g; + h_vec[kg].w *= exp_g; + } + + // ------------------------------------------------------------ + // 8f. Delta rule: v -= dot(h, k_normed) + // ------------------------------------------------------------ + { + float dot_hk = 0.0f; + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + const int k0 = kg * 4; + dot_hk += h_vec[kg].x * smem[k0]; + dot_hk += h_vec[kg].y * smem[k0 + 1]; + dot_hk += h_vec[kg].z * smem[k0 + 2]; + dot_hk += h_vec[kg].w * smem[k0 + 3]; + } + v_local -= dot_hk * k_inv_norm; + } + + // ------------------------------------------------------------ + // 8g. Beta gating: v *= beta + // ------------------------------------------------------------ + v_local *= beta; + + // ------------------------------------------------------------ + // 8h. State update: h[i] += k_normed[i] * v + // ------------------------------------------------------------ + { + const float kv = k_inv_norm * v_local; + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + const int k0 = kg * 4; + h_vec[kg].x += smem[k0] * kv; + h_vec[kg].y += smem[k0 + 1] * kv; + h_vec[kg].z += smem[k0 + 2] * kv; + h_vec[kg].w += smem[k0 + 3] * kv; + } + } + + // ------------------------------------------------------------ + // 8i. Load Q[BK] into LDS + // ------------------------------------------------------------ + __syncthreads(); + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + q_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(q_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // ------------------------------------------------------------ + // 8j. Q-vector L2 norm + // ------------------------------------------------------------ + float q_inv_norm = 1.0f; + if constexpr (USE_QK_L2NORM) { + float q_sq = 0.0f; + for (int i = lane; i < BK; i += BV) { + q_sq += smem[i] * smem[i]; + } + for (int off = 32; off >= 1; off >>= 1) { + q_sq += __shfl_xor(q_sq, off, 64); + } + q_inv_norm = rsqrtf(q_sq + 1e-6f); + } + + // ------------------------------------------------------------ + // 8k. Output: o = dot(h, q_normed) * scale + // ------------------------------------------------------------ + { + float o_local = 0.0f; + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + const int k0 = kg * 4; + o_local += h_vec[kg].x * smem[k0]; + o_local += h_vec[kg].y * smem[k0 + 1]; + o_local += h_vec[kg].z * smem[k0 + 2]; + o_local += h_vec[kg].w * smem[k0 + 3]; + } + o_local *= q_inv_norm * scale; + + if (valid_v) { + p_o[static_cast(t) * stride_o_seq + + static_cast(v_col) * stride_o_dim] = + static_cast<__hip_bfloat16>(o_local); + } + } + } // end timestep loop + + // ================================================================ + // 9. Store final state to SWIZZLED layout + // ================================================================ + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + + // float4 vectorized stores + #pragma unroll + for (int kg = 0; kg < BK4; kg++) { + if (kg < K4) { + h0_base[kg * V_dim + v_col] = h_vec[kg]; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Kernel: fused_split_gdr_update_kernel_ksplit2 +// +// K-split=2: each V column served by 2 adjacent lanes. +// lane 2i: h[0 : BK/2-1, v_col=i] +// lane 2i+1: h[BK/2 : BK-1, v_col=i] +// Reduction: single __shfl_xor(..., 1, 64). +// +// Grid: (1, cdiv(V, BV/2), B * HV) +// Block: BV threads = 1 wavefront +// State layout: SWIZZLED (N_states, HV, K/4, V, 4) +// --------------------------------------------------------------------------- + +template +__global__ +__launch_bounds__(BV) +void fused_split_gdr_update_kernel_ksplit2( + const __hip_bfloat16* __restrict__ mixed_qkv, + const float* __restrict__ A_log, + const __hip_bfloat16* __restrict__ a, + const __hip_bfloat16* __restrict__ dt_bias, + float softplus_beta, + float softplus_threshold, + const __hip_bfloat16* __restrict__ b_gate, + __hip_bfloat16* __restrict__ o, + float* __restrict__ h0_source, + const int32_t* __restrict__ h0_indices, + int T, + int key_dim, + int value_dim, + int64_t stride_x_batch, + int64_t stride_x_dim, + int64_t stride_x_seq, + int64_t stride_o_batch, + int64_t stride_o_seq, + int64_t stride_o_head, + int64_t stride_o_dim, + int B, + int H, + int HV, + int K, + int V_dim, + float scale +) { + static_assert(BK % 4 == 0, "BK must be divisible by 4"); + static_assert(BK >= 8, "BK must be >= 8 for K_SPLIT=2"); + + constexpr int BK_HALF = BK / 2; + constexpr int BK4_HALF = BK_HALF / 4; + constexpr int BV_OUT = BV / 2; + + const int i_v = blockIdx.y; + const int i_nh = blockIdx.z; + const int i_n = i_nh / HV; + const int i_hv = i_nh % HV; + + const int GROUP_SIZE = HV / H; + const int i_h = i_hv / GROUP_SIZE; + + const int lane = threadIdx.x; + + const int v_idx = lane >> 1; + const int k_split = lane & 1; + const int v_col = i_v * BV_OUT + v_idx; + const bool valid_v = (v_col < V_dim); + const int k_start = k_split * BK_HALF; + + __shared__ float smem[BK]; + + float4 h_vec[BK4_HALF]; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + h_vec[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } + + // Load initial state from SWIZZLED layout + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + const float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + const int kg_start = k_start / 4; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + const int kg = kg_start + i; + if (kg < K4) { + h_vec[i] = h0_base[kg * V_dim + v_col]; + } + } + } + } + + const float neg_exp_A_log = -expf(A_log[i_hv]); + const float b_dt_bias = static_cast(dt_bias[i_hv]); + const float inv_softplus_beta = 1.0f / softplus_beta; + + const int q_dim_off = i_h * K; + const int k_dim_off = key_dim + i_h * K; + const int v_dim_off = 2 * key_dim + i_hv * V_dim; + + const __hip_bfloat16* x_base = + mixed_qkv + static_cast(i_n) * stride_x_batch; + + const __hip_bfloat16* p_a = + a + static_cast(i_n) * T * HV + i_hv; + const __hip_bfloat16* p_b = + b_gate + static_cast(i_n) * T * HV + i_hv; + + __hip_bfloat16* p_o = + o + static_cast(i_n) * stride_o_batch + + static_cast(i_hv) * stride_o_head; + + const bool use_vec2 = (stride_x_dim == 1); + + // Main timestep loop + for (int t = 0; t < T; t++) { + const __hip_bfloat16* x_t = + x_base + static_cast(t) * stride_x_seq; + + // Cooperatively load K[BK] into LDS (all 64 threads participate) + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + k_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(k_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // Load V element (both partner threads load same value) + float v_local = 0.0f; + if (valid_v) { + v_local = static_cast( + x_t[static_cast(v_dim_off + v_col) * stride_x_dim]); + } + + // Gating scalars + const float a_t = static_cast(p_a[static_cast(t) * HV]); + const float b_t = static_cast(p_b[static_cast(t) * HV]); + + const float sp = hip_softplus( + a_t + b_dt_bias, softplus_beta, inv_softplus_beta, softplus_threshold); + const float g = neg_exp_A_log * sp; + const float exp_g = expf(g); + const float beta = hip_sigmoid(b_t); + + // Phase 1: Fused K-L2-norm + Delta dot (on UN-DECAYED h_vec) + float k_inv_norm = 1.0f; + { + float dot_partial = 0.0f; + float k_sq = 0.0f; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + const int k0 = k_start + i * 4; + const float s0 = smem[k0], s1 = smem[k0 + 1]; + const float s2 = smem[k0 + 2], s3 = smem[k0 + 3]; + dot_partial += h_vec[i].x * s0 + h_vec[i].y * s1 + + h_vec[i].z * s2 + h_vec[i].w * s3; + if constexpr (USE_QK_L2NORM) { + k_sq += s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3; + } + } + dot_partial += __shfl_xor(dot_partial, 1, 64); + if constexpr (USE_QK_L2NORM) { + k_sq += __shfl_xor(k_sq, 1, 64); + k_inv_norm = rsqrtf(k_sq + 1e-6f); + } + v_local -= dot_partial * exp_g * k_inv_norm; + } + + v_local *= beta; + + // Phase 2: Fused Decay + State update + { + const float kv = k_inv_norm * v_local; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + const int k0 = k_start + i * 4; + h_vec[i].x = h_vec[i].x * exp_g + smem[k0] * kv; + h_vec[i].y = h_vec[i].y * exp_g + smem[k0 + 1] * kv; + h_vec[i].z = h_vec[i].z * exp_g + smem[k0 + 2] * kv; + h_vec[i].w = h_vec[i].w * exp_g + smem[k0 + 3] * kv; + } + } + + // Load Q into LDS + __syncthreads(); + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + q_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(q_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // Phase 3: Fused Q-L2-norm + Output dot + { + float out_partial = 0.0f; + float q_sq = 0.0f; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + const int k0 = k_start + i * 4; + const float s0 = smem[k0], s1 = smem[k0 + 1]; + const float s2 = smem[k0 + 2], s3 = smem[k0 + 3]; + out_partial += h_vec[i].x * s0 + h_vec[i].y * s1 + + h_vec[i].z * s2 + h_vec[i].w * s3; + if constexpr (USE_QK_L2NORM) { + q_sq += s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3; + } + } + out_partial += __shfl_xor(out_partial, 1, 64); + float q_inv_norm = 1.0f; + if constexpr (USE_QK_L2NORM) { + q_sq += __shfl_xor(q_sq, 1, 64); + q_inv_norm = rsqrtf(q_sq + 1e-6f); + } + float o_local = out_partial * q_inv_norm * scale; + + if (k_split == 0 && valid_v) { + p_o[static_cast(t) * stride_o_seq + + static_cast(v_col) * stride_o_dim] = + static_cast<__hip_bfloat16>(o_local); + } + } + } // end timestep loop + + // Store final state to SWIZZLED layout + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + const int kg_start = k_start / 4; + #pragma unroll + for (int i = 0; i < BK4_HALF; i++) { + const int kg = kg_start + i; + if (kg < K4) { + h0_base[kg * V_dim + v_col] = h_vec[i]; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Kernel: fused_split_gdr_update_kernel_ksplit4 +// +// K-split=4: each V column served by 4 adjacent lanes. +// lane 4i+j: h[j*BK/4 : (j+1)*BK/4-1, v_col=i] (j=0,1,2,3) +// Reduction: __shfl_xor(...,1) then __shfl_xor(...,2). +// +// Grid: (1, cdiv(V, BV/4), B * HV) +// Block: BV threads = 1 wavefront +// State layout: SWIZZLED (N_states, HV, K/4, V, 4) +// --------------------------------------------------------------------------- + +template +__global__ __launch_bounds__(BV) +void fused_split_gdr_update_kernel_ksplit4( + const __hip_bfloat16* __restrict__ mixed_qkv, + const float* __restrict__ A_log, + const __hip_bfloat16* __restrict__ a, + const __hip_bfloat16* __restrict__ dt_bias, + float softplus_beta, + float softplus_threshold, + const __hip_bfloat16* __restrict__ b_gate, + __hip_bfloat16* __restrict__ o, + float* __restrict__ h0_source, + const int32_t* __restrict__ h0_indices, + int T, + int key_dim, + int value_dim, + int64_t stride_x_batch, + int64_t stride_x_dim, + int64_t stride_x_seq, + int64_t stride_o_batch, + int64_t stride_o_seq, + int64_t stride_o_head, + int64_t stride_o_dim, + int B, + int H, + int HV, + int K, + int V_dim, + float scale +) { + static_assert(BK % 4 == 0, "BK must be divisible by 4"); + static_assert(BK >= 16, "BK must be >= 16 for K_SPLIT=4"); + + constexpr int BK_QTR = BK / 4; + constexpr int BK4_QTR = BK_QTR / 4; + constexpr int BV_OUT = BV / 4; + + const int i_v = blockIdx.y; + const int i_nh = blockIdx.z; + const int i_n = i_nh / HV; + const int i_hv = i_nh % HV; + const int i_h = i_hv / (HV / H); + const int lane = threadIdx.x; + + const int v_idx = lane >> 2; + const int k_split = lane & 3; + const int v_col = i_v * BV_OUT + v_idx; + const bool valid_v = (v_col < V_dim); + const int k_start = k_split * BK_QTR; + + __shared__ float smem[BK]; + + float4 h_vec[BK4_QTR]; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + h_vec[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } + + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + const float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + const int kg_start = k_start / 4; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + const int kg = kg_start + i; + if (kg < K4) { + h_vec[i] = h0_base[kg * V_dim + v_col]; + } + } + } + } + + const float neg_exp_A_log = -expf(A_log[i_hv]); + const float b_dt_bias = static_cast(dt_bias[i_hv]); + const float inv_softplus_beta = 1.0f / softplus_beta; + + const int q_dim_off = i_h * K; + const int k_dim_off = key_dim + i_h * K; + const int v_dim_off = 2 * key_dim + i_hv * V_dim; + + const __hip_bfloat16* x_base = + mixed_qkv + static_cast(i_n) * stride_x_batch; + const __hip_bfloat16* p_a = + a + static_cast(i_n) * T * HV + i_hv; + const __hip_bfloat16* p_b = + b_gate + static_cast(i_n) * T * HV + i_hv; + __hip_bfloat16* p_o = + o + static_cast(i_n) * stride_o_batch + + static_cast(i_hv) * stride_o_head; + + const bool use_vec2 = (stride_x_dim == 1); + + for (int t = 0; t < T; t++) { + const __hip_bfloat16* x_t = + x_base + static_cast(t) * stride_x_seq; + + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + k_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(k_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + float v_local = 0.0f; + if (valid_v) { + v_local = static_cast( + x_t[static_cast(v_dim_off + v_col) * stride_x_dim]); + } + + const float a_t = static_cast(p_a[static_cast(t) * HV]); + const float b_t = static_cast(p_b[static_cast(t) * HV]); + + const float sp = hip_softplus( + a_t + b_dt_bias, softplus_beta, inv_softplus_beta, softplus_threshold); + const float g = neg_exp_A_log * sp; + const float exp_g = expf(g); + const float beta = hip_sigmoid(b_t); + + // Phase 1: Fused K-L2-norm + Delta dot (on UN-DECAYED h_vec) + // dot(h*exp_g, k) == exp_g * dot(h, k), so we skip the + // separate decay pass and fold it into the scalar factor. + float k_inv_norm = 1.0f; + { + float dot_partial = 0.0f; + float k_sq = 0.0f; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + const int k0 = k_start + i * 4; + const float s0 = smem[k0], s1 = smem[k0 + 1]; + const float s2 = smem[k0 + 2], s3 = smem[k0 + 3]; + dot_partial += h_vec[i].x * s0 + h_vec[i].y * s1 + + h_vec[i].z * s2 + h_vec[i].w * s3; + if constexpr (USE_QK_L2NORM) { + k_sq += s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3; + } + } + dot_partial += __shfl_xor(dot_partial, 1, 64); + dot_partial += __shfl_xor(dot_partial, 2, 64); + if constexpr (USE_QK_L2NORM) { + k_sq += __shfl_xor(k_sq, 1, 64); + k_sq += __shfl_xor(k_sq, 2, 64); + k_inv_norm = rsqrtf(k_sq + 1e-6f); + } + v_local -= dot_partial * exp_g * k_inv_norm; + } + + v_local *= beta; + + // Phase 2: Fused Decay + State update + // h = h * exp_g + k * (k_inv_norm * v) + { + const float kv = k_inv_norm * v_local; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + const int k0 = k_start + i * 4; + h_vec[i].x = h_vec[i].x * exp_g + smem[k0] * kv; + h_vec[i].y = h_vec[i].y * exp_g + smem[k0 + 1] * kv; + h_vec[i].z = h_vec[i].z * exp_g + smem[k0 + 2] * kv; + h_vec[i].w = h_vec[i].w * exp_g + smem[k0 + 3] * kv; + } + } + + // Load Q into smem + __syncthreads(); + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + q_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(q_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // Phase 3: Fused Q-L2-norm + Output dot + { + float out_partial = 0.0f; + float q_sq = 0.0f; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + const int k0 = k_start + i * 4; + const float s0 = smem[k0], s1 = smem[k0 + 1]; + const float s2 = smem[k0 + 2], s3 = smem[k0 + 3]; + out_partial += h_vec[i].x * s0 + h_vec[i].y * s1 + + h_vec[i].z * s2 + h_vec[i].w * s3; + if constexpr (USE_QK_L2NORM) { + q_sq += s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3; + } + } + out_partial += __shfl_xor(out_partial, 1, 64); + out_partial += __shfl_xor(out_partial, 2, 64); + float q_inv_norm = 1.0f; + if constexpr (USE_QK_L2NORM) { + q_sq += __shfl_xor(q_sq, 1, 64); + q_sq += __shfl_xor(q_sq, 2, 64); + q_inv_norm = rsqrtf(q_sq + 1e-6f); + } + float o_local = out_partial * q_inv_norm * scale; + + if (k_split == 0 && valid_v) { + p_o[static_cast(t) * stride_o_seq + + static_cast(v_col) * stride_o_dim] = + static_cast<__hip_bfloat16>(o_local); + } + } + } + + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const int K4 = K / 4; + float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K4 * V_dim * 4 + + static_cast(i_hv) * K4 * V_dim * 4 + ); + const int kg_start = k_start / 4; + #pragma unroll + for (int i = 0; i < BK4_QTR; i++) { + const int kg = kg_start + i; + if (kg < K4) { + h0_base[kg * V_dim + v_col] = h_vec[i]; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Kernel: fused_split_gdr_update_kernel_vsplit +// +// V-split: each thread handles 4 contiguous V columns and BK/8 K elements. +// Thread layout: 8 K-threads x 8 V-threads = 64 threads (1 wavefront) +// k_thread = lane & 7 (0..7) +// v_thread = lane >> 3 (0..7) +// Each thread owns h[k_chunk, v_group:v_group+4] = float4[BK/8] +// +// No LDS: K/Q loaded directly from global memory; L1 cache provides +// natural broadcast to 8 V-threads sharing the same K addresses. +// +// Grid: (1, cdiv(V, 32), B * HV) +// Block: 64 threads = 1 wavefront +// State layout: Standard (N_states, HV, K, V) = (N_states, HV, K, V/4, 4) +// - float4 = 4 contiguous V values for one K position (zero-copy view) +// --------------------------------------------------------------------------- + +template +__global__ __launch_bounds__(BV) +void fused_split_gdr_update_kernel_vsplit( + const __hip_bfloat16* __restrict__ mixed_qkv, + const float* __restrict__ A_log, + const __hip_bfloat16* __restrict__ a, + const __hip_bfloat16* __restrict__ dt_bias, + float softplus_beta, + float softplus_threshold, + const __hip_bfloat16* __restrict__ b_gate, + __hip_bfloat16* __restrict__ o, + float* __restrict__ h0_source, + const int32_t* __restrict__ h0_indices, + int T, + int key_dim, + int value_dim, + int64_t stride_x_batch, + int64_t stride_x_dim, + int64_t stride_x_seq, + int64_t stride_o_batch, + int64_t stride_o_seq, + int64_t stride_o_head, + int64_t stride_o_dim, + int B, + int H, + int HV, + int K, + int V_dim, + float scale +) { + static_assert(BV == 64, "vsplit requires BV=64"); + static_assert(BK >= 8 && BK % 8 == 0, "BK must be >= 8 and divisible by 8"); + + constexpr int K_THREADS = 8; + constexpr int V_THREADS = BV / K_THREADS; // 8 + constexpr int V_PER_THREAD = 4; + constexpr int K_PER_THREAD = BK / K_THREADS; // BK=128 -> 16 + constexpr int BV_OUT = V_THREADS * V_PER_THREAD; // 32 + + const int i_v = blockIdx.y; + const int i_nh = blockIdx.z; + const int i_n = i_nh / HV; + const int i_hv = i_nh % HV; + + const int GROUP_SIZE = HV / H; + const int i_h = i_hv / GROUP_SIZE; + + const int lane = threadIdx.x; + const int k_thread = lane & (K_THREADS - 1); + const int v_thread = lane >> 3; + + const int k_start = k_thread * K_PER_THREAD; + const int v_group = i_v * V_THREADS + v_thread; + const int v_col_start = v_group * V_PER_THREAD; + const bool valid_v = (v_col_start + V_PER_THREAD - 1 < V_dim); + + // h_vec[i] = (v0, v1, v2, v3) for K position k_start+i + float4 h_vec[K_PER_THREAD]; + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + h_vec[i] = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } + + // Load initial state from standard layout [N, HV, K, V] + // Interpreted as [N, HV, K, V/4, 4] — float4 = 4 consecutive V values + const int V4 = V_dim / 4; + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + const float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K * V_dim + + static_cast(i_hv) * K * V_dim + ); + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + const int k = k_start + i; + if (k < K) { + h_vec[i] = h0_base[static_cast(k) * V4 + v_group]; + } + } + } + } + + const float neg_exp_A_log = -expf(A_log[i_hv]); + const float b_dt_bias = static_cast(dt_bias[i_hv]); + const float inv_softplus_beta = 1.0f / softplus_beta; + + const int q_dim_off = i_h * K; + const int k_dim_off = key_dim + i_h * K; + const int v_dim_off = 2 * key_dim + i_hv * V_dim; + + const __hip_bfloat16* x_base = + mixed_qkv + static_cast(i_n) * stride_x_batch; + + const __hip_bfloat16* p_a = + a + static_cast(i_n) * T * HV + i_hv; + const __hip_bfloat16* p_b = + b_gate + static_cast(i_n) * T * HV + i_hv; + + __hip_bfloat16* p_o = + o + static_cast(i_n) * stride_o_batch + + static_cast(i_hv) * stride_o_head; + + __shared__ float smem[BK]; + const bool use_vec2 = (stride_x_dim == 1); + + for (int t = 0; t < T; t++) { + const __hip_bfloat16* x_t = + x_base + static_cast(t) * stride_x_seq; + + // Cooperatively load K[BK] into LDS (all 64 threads) + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + k_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(k_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // Load 4 V values for this thread's V-group + float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f; + if (valid_v) { + v0 = static_cast( + x_t[static_cast(v_dim_off + v_col_start) * stride_x_dim]); + v1 = static_cast( + x_t[static_cast(v_dim_off + v_col_start + 1) * stride_x_dim]); + v2 = static_cast( + x_t[static_cast(v_dim_off + v_col_start + 2) * stride_x_dim]); + v3 = static_cast( + x_t[static_cast(v_dim_off + v_col_start + 3) * stride_x_dim]); + } + + const float a_t = static_cast(p_a[static_cast(t) * HV]); + const float b_t = static_cast(p_b[static_cast(t) * HV]); + + const float sp = hip_softplus( + a_t + b_dt_bias, softplus_beta, inv_softplus_beta, softplus_threshold); + const float g = neg_exp_A_log * sp; + const float exp_g = expf(g); + const float beta = hip_sigmoid(b_t); + + // Phase 1: K-L2-norm + Delta dot (K in smem, 1 value per iteration) + float k_inv_norm = 1.0f; + { + float d0 = 0.0f, d1 = 0.0f, d2 = 0.0f, d3 = 0.0f; + float k_sq = 0.0f; + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + const float kv = smem[k_start + i]; + d0 = fmaf(h_vec[i].x, kv, d0); + d1 = fmaf(h_vec[i].y, kv, d1); + d2 = fmaf(h_vec[i].z, kv, d2); + d3 = fmaf(h_vec[i].w, kv, d3); + if constexpr (USE_QK_L2NORM) { + k_sq = fmaf(kv, kv, k_sq); + } + } + d0 += __shfl_xor(d0, 1, 64); d0 += __shfl_xor(d0, 2, 64); d0 += __shfl_xor(d0, 4, 64); + d1 += __shfl_xor(d1, 1, 64); d1 += __shfl_xor(d1, 2, 64); d1 += __shfl_xor(d1, 4, 64); + d2 += __shfl_xor(d2, 1, 64); d2 += __shfl_xor(d2, 2, 64); d2 += __shfl_xor(d2, 4, 64); + d3 += __shfl_xor(d3, 1, 64); d3 += __shfl_xor(d3, 2, 64); d3 += __shfl_xor(d3, 4, 64); + if constexpr (USE_QK_L2NORM) { + k_sq += __shfl_xor(k_sq, 1, 64); + k_sq += __shfl_xor(k_sq, 2, 64); + k_sq += __shfl_xor(k_sq, 4, 64); + k_inv_norm = rsqrtf(k_sq + 1e-6f); + } + const float factor = exp_g * k_inv_norm; + v0 -= d0 * factor; + v1 -= d1 * factor; + v2 -= d2 * factor; + v3 -= d3 * factor; + } + + v0 *= beta; v1 *= beta; v2 *= beta; v3 *= beta; + + // Phase 2: Decay + State update (K still in smem) + { + const float kv0 = k_inv_norm * v0, kv1 = k_inv_norm * v1; + const float kv2 = k_inv_norm * v2, kv3 = k_inv_norm * v3; + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + const float kval = smem[k_start + i]; + h_vec[i].x = fmaf(h_vec[i].x, exp_g, kval * kv0); + h_vec[i].y = fmaf(h_vec[i].y, exp_g, kval * kv1); + h_vec[i].z = fmaf(h_vec[i].z, exp_g, kval * kv2); + h_vec[i].w = fmaf(h_vec[i].w, exp_g, kval * kv3); + } + } + + // Load Q into LDS (reuse smem) + __syncthreads(); + if (use_vec2) { + const int base = 2 * lane; + if (base < BK) { + load_bf16x2(smem, x_t + q_dim_off, base, K, BK); + } + } else { + for (int i = lane; i < BK; i += BV) { + smem[i] = (i < K) + ? static_cast( + x_t[static_cast(q_dim_off + i) * stride_x_dim]) + : 0.0f; + } + } + __syncthreads(); + + // Phase 3: Q-L2-norm + Output dot + { + float o0 = 0.0f, o1 = 0.0f, o2 = 0.0f, o3 = 0.0f; + float q_sq = 0.0f; + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + const float qv = smem[k_start + i]; + o0 = fmaf(h_vec[i].x, qv, o0); + o1 = fmaf(h_vec[i].y, qv, o1); + o2 = fmaf(h_vec[i].z, qv, o2); + o3 = fmaf(h_vec[i].w, qv, o3); + if constexpr (USE_QK_L2NORM) { + q_sq = fmaf(qv, qv, q_sq); + } + } + o0 += __shfl_xor(o0, 1, 64); o0 += __shfl_xor(o0, 2, 64); o0 += __shfl_xor(o0, 4, 64); + o1 += __shfl_xor(o1, 1, 64); o1 += __shfl_xor(o1, 2, 64); o1 += __shfl_xor(o1, 4, 64); + o2 += __shfl_xor(o2, 1, 64); o2 += __shfl_xor(o2, 2, 64); o2 += __shfl_xor(o2, 4, 64); + o3 += __shfl_xor(o3, 1, 64); o3 += __shfl_xor(o3, 2, 64); o3 += __shfl_xor(o3, 4, 64); + float q_inv_norm = 1.0f; + if constexpr (USE_QK_L2NORM) { + q_sq += __shfl_xor(q_sq, 1, 64); + q_sq += __shfl_xor(q_sq, 2, 64); + q_sq += __shfl_xor(q_sq, 4, 64); + q_inv_norm = rsqrtf(q_sq + 1e-6f); + } + if (k_thread == 0 && valid_v) { + const float s = q_inv_norm * scale; + const int64_t o_off = static_cast(t) * stride_o_seq; + p_o[o_off + static_cast(v_col_start) * stride_o_dim] = + static_cast<__hip_bfloat16>(o0 * s); + p_o[o_off + static_cast(v_col_start + 1) * stride_o_dim] = + static_cast<__hip_bfloat16>(o1 * s); + p_o[o_off + static_cast(v_col_start + 2) * stride_o_dim] = + static_cast<__hip_bfloat16>(o2 * s); + p_o[o_off + static_cast(v_col_start + 3) * stride_o_dim] = + static_cast<__hip_bfloat16>(o3 * s); + } + } + } // end timestep loop + + // Store final state to standard layout [N, HV, K, V] + if constexpr (USE_INITIAL_STATE) { + const int32_t idx = h0_indices[i_n]; + if (idx >= 0 && valid_v) { + float4* h0_base = reinterpret_cast( + h0_source + + static_cast(idx) * HV * K * V_dim + + static_cast(i_hv) * K * V_dim + ); + #pragma unroll + for (int i = 0; i < K_PER_THREAD; i++) { + const int k = k_start + i; + if (k < K) { + h0_base[static_cast(k) * V4 + v_group] = h_vec[i]; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Host wrappers: fused_split_gdr_update_ksplit2 / ksplit4 / vsplit +// --------------------------------------------------------------------------- + +// Common launch macro for ksplit wrappers +#define LAUNCH_KS(KNAME, BK_CT, USE_INIT, USE_L2) \ + hipLaunchKernelGGL(( KNAME) \ + , dim3(grid), dim3(block), 0, stream, \ + reinterpret_cast(mixed_qkv.data_ptr()), \ + A_log.data_ptr(), \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(dt_bias.data_ptr()), \ + softplus_beta, softplus_threshold, \ + reinterpret_cast(b_gate.data_ptr()), \ + reinterpret_cast<__hip_bfloat16*>(o.data_ptr()), \ + use_initial_state \ + ? initial_state_source.data_ptr() : nullptr, \ + initial_state_indices.data_ptr(), \ + T, key_dim, value_dim, \ + stride_x_batch, stride_x_dim, stride_x_seq, \ + stride_o_batch, stride_o_seq, stride_o_head, stride_o_dim, \ + B, H, HV, K, V, scale \ + ) + +#define DISPATCH_KS_BOOL(KNAME, BK_CT) \ + if (use_initial_state && use_qk_l2norm_in_kernel) { \ + LAUNCH_KS(KNAME, BK_CT, true, true); \ + } else if (use_initial_state && !use_qk_l2norm_in_kernel) { \ + LAUNCH_KS(KNAME, BK_CT, true, false); \ + } else if (!use_initial_state && use_qk_l2norm_in_kernel) { \ + LAUNCH_KS(KNAME, BK_CT, false, true); \ + } else { \ + LAUNCH_KS(KNAME, BK_CT, false, false); \ + } + +#define DISPATCH_KS_BK(KNAME) \ + if (bk_runtime == 128) { DISPATCH_KS_BOOL(KNAME, 128); } \ + else if (bk_runtime == 64) { DISPATCH_KS_BOOL(KNAME, 64); } \ + else if (bk_runtime == 256) { DISPATCH_KS_BOOL(KNAME, 256); } \ + else if (bk_runtime == 32) { DISPATCH_KS_BOOL(KNAME, 32); } \ + else { TORCH_CHECK(false, "Unsupported BK: ", bk_runtime); } + +#define DEFINE_KSPLIT_WRAPPER(FUNC_NAME, KERNEL_NAME, KS_VAL) \ +torch::Tensor FUNC_NAME( \ + torch::Tensor mixed_qkv, torch::Tensor A_log, torch::Tensor a, \ + torch::Tensor dt_bias, torch::Tensor b_gate, \ + torch::Tensor initial_state_source, torch::Tensor initial_state_indices, \ + int key_dim, int value_dim, int num_heads_qk, int num_heads_v, \ + int head_dim, float softplus_beta, float softplus_threshold, float scale, \ + bool use_qk_l2norm_in_kernel, c10::optional output \ +) { \ + TORCH_CHECK(mixed_qkv.dim() == 3, "mixed_qkv must be 3-D (B, dim, T)"); \ + const int B = mixed_qkv.size(0), dim = mixed_qkv.size(1), \ + T = mixed_qkv.size(2); \ + const int H = num_heads_qk, HV = num_heads_v; \ + const int K = head_dim, V = head_dim; \ + TORCH_CHECK(dim == 2 * key_dim + value_dim, "mixed_qkv dim mismatch"); \ + TORCH_CHECK(K % 4 == 0, "head_dim must be divisible by 4"); \ + int bk_runtime = 1; \ + while (bk_runtime < K) bk_runtime <<= 1; \ + TORCH_CHECK((K + bk_runtime - 1) / bk_runtime == 1, "NK > 1 unsupported");\ + torch::Tensor o; \ + if (output.has_value()) { \ + o = output.value(); \ + TORCH_CHECK(o.size(0)==B && o.size(1)==T && o.size(2)==HV && o.size(3)==V,\ + "Output shape mismatch"); \ + } else { o = torch::empty({B, T, HV, V}, mixed_qkv.options()); } \ + bool use_initial_state = initial_state_source.defined() && \ + initial_state_source.numel() > 0; \ + constexpr int BV_VAL = 64; \ + constexpr int BV_OUT = BV_VAL / KS_VAL; \ + dim3 grid(1, (V + BV_OUT - 1) / BV_OUT, B * HV); \ + dim3 block(BV_VAL); \ + const int64_t stride_x_batch = mixed_qkv.stride(0); \ + const int64_t stride_x_dim = mixed_qkv.stride(1); \ + const int64_t stride_x_seq = mixed_qkv.stride(2); \ + const int64_t stride_o_batch = o.stride(0); \ + const int64_t stride_o_seq = o.stride(1); \ + const int64_t stride_o_head = o.stride(2); \ + const int64_t stride_o_dim = o.stride(3); \ + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); \ + DISPATCH_KS_BK(KERNEL_NAME); \ + return o; \ +} + +DEFINE_KSPLIT_WRAPPER(fused_split_gdr_update_ksplit2, + fused_split_gdr_update_kernel_ksplit2, 2) +DEFINE_KSPLIT_WRAPPER(fused_split_gdr_update_ksplit4, + fused_split_gdr_update_kernel_ksplit4, 4) +DEFINE_KSPLIT_WRAPPER(fused_split_gdr_update_vsplit, + fused_split_gdr_update_kernel_vsplit, 2) + +#undef DEFINE_KSPLIT_WRAPPER +#undef DISPATCH_KS_BK +#undef DISPATCH_KS_BOOL +#undef LAUNCH_KS + +// --------------------------------------------------------------------------- +// Host wrapper: fused_split_gdr_update +// --------------------------------------------------------------------------- +// +// NOTE: State tensor must be in SWIZZLED layout (N_states, HV, K/4, V, 4) +// Use state.reshape(N, HV, K//4, 4, V).permute(0,1,2,4,3).reshape(N, HV, K//4, V, 4) +// to convert from standard (N, HV, K, V) layout. + +torch::Tensor fused_split_gdr_update( + torch::Tensor mixed_qkv, // (B, dim, T), bfloat16 + torch::Tensor A_log, // (HV,), float32 + torch::Tensor a, // (B*T, HV), bfloat16 + torch::Tensor dt_bias, // (HV,), bfloat16 + torch::Tensor b_gate, // (B*T, HV), bfloat16 + torch::Tensor initial_state_source, // (N_states, HV, K/4, V, 4), float32 - SWIZZLED! + torch::Tensor initial_state_indices, // (B,), int32 + int key_dim, + int value_dim, + int num_heads_qk, + int num_heads_v, + int head_dim, + float softplus_beta, + float softplus_threshold, + float scale, + bool use_qk_l2norm_in_kernel, + c10::optional output // optional pre-allocated output +) { + // --- Dimension extraction --- + TORCH_CHECK(mixed_qkv.dim() == 3, "mixed_qkv must be 3-D (B, dim, T)"); + const int B = mixed_qkv.size(0); + const int dim = mixed_qkv.size(1); + const int T = mixed_qkv.size(2); + + const int H = num_heads_qk; + const int HV = num_heads_v; + const int K = head_dim; + const int V = head_dim; + + TORCH_CHECK(dim == 2 * key_dim + value_dim, + "mixed_qkv dim mismatch: expected ", 2 * key_dim + value_dim, + " got ", dim); + TORCH_CHECK(K % 4 == 0, "head_dim must be divisible by 4 for swizzled layout"); + + // --- BK: next power of 2 >= K (compile-time for template) --- + int bk_runtime = 1; + while (bk_runtime < K) bk_runtime <<= 1; + + const int NK = (K + bk_runtime - 1) / bk_runtime; + TORCH_CHECK(NK == 1, "NK > 1 is not supported yet"); + + // --- Output allocation --- + torch::Tensor o; + if (output.has_value()) { + o = output.value(); + TORCH_CHECK(o.size(0) == B && o.size(1) == T && + o.size(2) == HV && o.size(3) == V, + "Output shape mismatch"); + } else { + o = torch::empty({B, T, HV, V}, mixed_qkv.options()); + } + + // --- Determine whether initial state is used --- + bool use_initial_state = initial_state_source.defined() && + initial_state_source.numel() > 0; + + // --- Configuration: BV=64 (1 wavefront per block) --- + constexpr int BV_VAL = 64; + + // --- Grid: (NK, cdiv(V, BV), B * HV) --- + const int grid_x = NK; + const int grid_y = (V + BV_VAL - 1) / BV_VAL; + const int grid_z = B * HV; + dim3 grid(grid_x, grid_y, grid_z); + + // Block = BV threads = 1 wavefront + dim3 block(BV_VAL); + + // --- Strides --- + const int64_t stride_x_batch = mixed_qkv.stride(0); + const int64_t stride_x_dim = mixed_qkv.stride(1); + const int64_t stride_x_seq = mixed_qkv.stride(2); + const int64_t stride_o_batch = o.stride(0); + const int64_t stride_o_seq = o.stride(1); + const int64_t stride_o_head = o.stride(2); + const int64_t stride_o_dim = o.stride(3); + + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + // --- Kernel dispatch --- + #define LAUNCH_KERNEL(BK_CT, USE_INIT, USE_L2) \ + hipLaunchKernelGGL(( fused_split_gdr_update_kernel) \ + , dim3(grid), dim3(block), 0, stream, \ + reinterpret_cast(mixed_qkv.data_ptr()), \ + A_log.data_ptr(), \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(dt_bias.data_ptr()), \ + softplus_beta, \ + softplus_threshold, \ + reinterpret_cast(b_gate.data_ptr()), \ + reinterpret_cast<__hip_bfloat16*>(o.data_ptr()), \ + use_initial_state \ + ? initial_state_source.data_ptr() : nullptr, \ + initial_state_indices.data_ptr(), \ + T, key_dim, value_dim, \ + stride_x_batch, stride_x_dim, stride_x_seq, \ + stride_o_batch, stride_o_seq, stride_o_head, stride_o_dim, \ + B, H, HV, K, V, scale \ + ) + + #define DISPATCH_BK(BK_CT) \ + if (use_initial_state && use_qk_l2norm_in_kernel) { \ + LAUNCH_KERNEL(BK_CT, true, true); \ + } else if (use_initial_state && !use_qk_l2norm_in_kernel) { \ + LAUNCH_KERNEL(BK_CT, true, false); \ + } else if (!use_initial_state && use_qk_l2norm_in_kernel) { \ + LAUNCH_KERNEL(BK_CT, false, true); \ + } else { \ + LAUNCH_KERNEL(BK_CT, false, false); \ + } + + if (bk_runtime == 128) { + DISPATCH_BK(128); + } else if (bk_runtime == 64) { + DISPATCH_BK(64); + } else if (bk_runtime == 256) { + DISPATCH_BK(256); + } else if (bk_runtime == 32) { + DISPATCH_BK(32); + } else { + TORCH_CHECK(false, "Unsupported BK value: ", bk_runtime, + ". Supported: 32, 64, 128, 256"); + } + + #undef DISPATCH_BK + #undef LAUNCH_KERNEL + + return o; +} + +// --------------------------------------------------------------------------- +// PyBind11 / Torch extension binding +// --------------------------------------------------------------------------- + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "fused_split_gdr_update", + &fused_split_gdr_update, + "Fused Split GDR (Gating Delta Rule) decode update (HIP)\n" + "NOTE: State must be in SWIZZLED layout (N_states, HV, K/4, V, 4)", + py::arg("mixed_qkv"), + py::arg("A_log"), + py::arg("a"), + py::arg("dt_bias"), + py::arg("b_gate"), + py::arg("initial_state_source"), + py::arg("initial_state_indices"), + py::arg("key_dim"), + py::arg("value_dim"), + py::arg("num_heads_qk"), + py::arg("num_heads_v"), + py::arg("head_dim"), + py::arg("softplus_beta") = 1.0f, + py::arg("softplus_threshold") = 20.0f, + py::arg("scale") = -1.0f, + py::arg("use_qk_l2norm_in_kernel") = true, + py::arg("output") = c10::nullopt + ); + m.def( + "fused_split_gdr_update_ksplit2", + &fused_split_gdr_update_ksplit2, + "Fused Split GDR decode update with K-split=2 (HIP)\n" + "K dimension split across 2 threads per V column.\n" + "NOTE: State must be in SWIZZLED layout (N_states, HV, K/4, V, 4)", + py::arg("mixed_qkv"), + py::arg("A_log"), + py::arg("a"), + py::arg("dt_bias"), + py::arg("b_gate"), + py::arg("initial_state_source"), + py::arg("initial_state_indices"), + py::arg("key_dim"), + py::arg("value_dim"), + py::arg("num_heads_qk"), + py::arg("num_heads_v"), + py::arg("head_dim"), + py::arg("softplus_beta") = 1.0f, + py::arg("softplus_threshold") = 20.0f, + py::arg("scale") = -1.0f, + py::arg("use_qk_l2norm_in_kernel") = true, + py::arg("output") = c10::nullopt + ); + m.def( + "fused_split_gdr_update_ksplit4", + &fused_split_gdr_update_ksplit4, + "Fused Split GDR decode update with K-split=4 (HIP)\n" + "K dimension split across 4 threads per V column.\n" + "NOTE: State must be in SWIZZLED layout (N_states, HV, K/4, V, 4)", + py::arg("mixed_qkv"), + py::arg("A_log"), + py::arg("a"), + py::arg("dt_bias"), + py::arg("b_gate"), + py::arg("initial_state_source"), + py::arg("initial_state_indices"), + py::arg("key_dim"), + py::arg("value_dim"), + py::arg("num_heads_qk"), + py::arg("num_heads_v"), + py::arg("head_dim"), + py::arg("softplus_beta") = 1.0f, + py::arg("softplus_threshold") = 20.0f, + py::arg("scale") = -1.0f, + py::arg("use_qk_l2norm_in_kernel") = true, + py::arg("output") = c10::nullopt + ); + m.def( + "fused_split_gdr_update_vsplit", + &fused_split_gdr_update_vsplit, + "Fused Split GDR decode update with V-split (HIP)\n" + "8x8 thread layout: 8 K-threads x 8 V-threads, no LDS.\n" + "NOTE: State must be in VSPLIT layout (N_states, HV, V/4, K, 4)", + py::arg("mixed_qkv"), + py::arg("A_log"), + py::arg("a"), + py::arg("dt_bias"), + py::arg("b_gate"), + py::arg("initial_state_source"), + py::arg("initial_state_indices"), + py::arg("key_dim"), + py::arg("value_dim"), + py::arg("num_heads_qk"), + py::arg("num_heads_v"), + py::arg("head_dim"), + py::arg("softplus_beta") = 1.0f, + py::arg("softplus_threshold") = 20.0f, + py::arg("scale") = -1.0f, + py::arg("use_qk_l2norm_in_kernel") = true, + py::arg("output") = c10::nullopt + ); +} diff --git a/kernels/split_gdr_triton.py b/kernels/split_gdr_triton.py new file mode 100644 index 00000000..ea10c423 --- /dev/null +++ b/kernels/split_gdr_triton.py @@ -0,0 +1,212 @@ +"""Triton fused sigmoid-gating delta-rule update kernel.""" + +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit(do_not_specialize=["T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + dt_bias, + softplus_beta, + softplus_threshold, + q, + k, + v, + b, + o, + h0_source, + h0_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_KDA: tl.constexpr, +): + """Fused kernel combining sigmoid gating and recurrent delta-rule update.""" + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all_tokens = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all_tokens = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all_tokens + bos) * HV + i_hv) * V + o_v + + p_A_log = A_log + i_hv + if IS_KDA: + p_a = a + (bos * HV + i_hv) * K + o_k + p_dt_bias = dt_bias + i_hv * K + o_k + else: + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = ( + h0_source + + idx * HV * K * V + + i_hv * K * V + + o_k[:, None] * V + + o_v[None, :] + ) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) + + b_A_log = tl.load(p_A_log).to(tl.float32) + b_a = tl.load(p_a).to(tl.float32) + b_dt_bias = tl.load(p_dt_bias).to(tl.float32) + + x = b_a + b_dt_bias + beta_x = softplus_beta * x + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6)) + + b_q = b_q * scale + + if IS_KDA: + b_h *= tl.exp(b_g[:, None]) + else: + b_h *= tl.exp(b_g) + + b_v -= tl.sum(b_h * b_k[:, None], 0) + b_v *= b_beta + b_h += b_k[:, None] * b_v[None, :] + + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_b += HV + p_a += HV + + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n) + if idx >= 0: + p_h0 = ( + h0_source + + idx * HV * K * V + + i_hv * K * V + + o_k[:, None] * V + + o_v[None, :] + ) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + +def fused_sigmoid_gating_delta_rule_update( + o: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float, + softplus_threshold: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + scale: Optional[float] = None, + use_qk_l2norm_in_kernel: bool = True, + cu_seqlens: Optional[torch.Tensor] = None, + is_kda: bool = False, +): + """Launch fused Triton implementation for sigmoid-gating delta-rule update.""" + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + + num_stages = 3 + num_warps = 1 + + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + + grid = (NK, NV, N * HV) + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + o=o, + h0_source=initial_state_source, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state_source is not None, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_VARLEN=cu_seqlens is not None, + IS_KDA=is_kda, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/tests/kernels/run_test_split_gdr_ksplit2.sh b/tests/kernels/run_test_split_gdr_ksplit2.sh new file mode 100755 index 00000000..700a3537 --- /dev/null +++ b/tests/kernels/run_test_split_gdr_ksplit2.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +cd "${REPO_ROOT}" + +DEFAULT_TARGET="tests/kernels/test_fused_split_gdr_update_ksplit2_flyc.py::test_split_gdr_ksplit2_correctness_and_perf" +TARGET="${1:-${DEFAULT_TARGET}}" + +if [[ $# -gt 0 ]]; then + shift +fi + +# Default behavior: +# - kill stale pytest processes for this test file +# - keep a stable torch extension cache for faster reruns +# - allow opt-in isolated cache via ISOLATE_EXT=1 +if [[ "${CLEAN_STALE_PYTEST:-1}" == "1" ]]; then + pkill -f "pytest tests/kernels/test_fused_split_gdr_update_ksplit2_flyc.py" 2>/dev/null || true +fi + +if [[ "${ISOLATE_EXT:-0}" == "1" ]]; then + EXT_DIR="/tmp/torch_ext_split_gdr_$(date +%s)" +else + EXT_DIR="${TORCH_EXTENSIONS_DIR:-/tmp/torch_ext_split_gdr}" +fi + +FAULT_TIMEOUT="${FAULT_TIMEOUT:-180}" + +echo "[run_test_split_gdr_ksplit2] repo: ${REPO_ROOT}" +echo "[run_test_split_gdr_ksplit2] target: ${TARGET}" +echo "[run_test_split_gdr_ksplit2] TORCH_EXTENSIONS_DIR: ${EXT_DIR}" +echo "[run_test_split_gdr_ksplit2] faulthandler_timeout: ${FAULT_TIMEOUT}" + +exec env \ + TORCH_EXTENSIONS_DIR="${EXT_DIR}" \ + PYTHONFAULTHANDLER=1 \ + HIP_VISIBLE_DEVICES=7 rocprofv3 --hip-runtime-trace --kernel-trace --output-format csv pftrace -d trace -o flydsl_gdr_ksplit2_0311 --stats -- python3 -m pytest "${TARGET}" -v -s -o "faulthandler_timeout=${FAULT_TIMEOUT}" "$@" + # HIP_VISIBLE_DEVICES=7 rocprofv3 --hip-runtime-trace --kernel-trace --output-format csv pftrace -d trace -o flydsl_gdr_ksplit2_0311 --stats -- python3 -m pytest "${TARGET}" -v -s -o "faulthandler_timeout=${FAULT_TIMEOUT}" "$@" + diff --git a/tests/kernels/test_fused_split_gdr_update_ksplit2_flyc.py b/tests/kernels/test_fused_split_gdr_update_ksplit2_flyc.py new file mode 100644 index 00000000..fab2d64a --- /dev/null +++ b/tests/kernels/test_fused_split_gdr_update_ksplit2_flyc.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python3 +"""Correctness and smoke perf tests for FlyDSL split-GDR ksplit2 kernel.""" + +import math +import os +import sys +from pathlib import Path + +import pytest +import torch +from torch.utils.cpp_extension import load + +_repo = Path(__file__).resolve().parents[2] +_embedded = _repo / "build" / "python_packages" / "rocdsl" +if _embedded.exists(): + os.environ.setdefault("ROCDSL_USE_EMBEDDED_MLIR", "1") + sys.path.insert(0, str(_embedded)) +_src_py = _repo / "python" +if _src_py.exists(): + sys.path.insert(0, str(_src_py)) +sys.path.insert(0, str(_repo)) + +import flydsl + +from kernels.fused_split_gdr_update_ksplit2_flyc import ( + build_fused_split_gdr_update_ksplit2_flyc_module, +) +from kernels.split_gdr_triton import fused_sigmoid_gating_delta_rule_update + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available.", allow_module_level=True) + + +_HIP_SRC = Path("/sgl-workspace/sglang/python/sglang/srt/layers/attention/fla/split_gdr_decode_hip.hip") +_split_gdr_hip = None + + +def _get_hip_module(): + """JIT-compile and cache HIP split-GDR extension.""" + global _split_gdr_hip + if _split_gdr_hip is None: + if not _HIP_SRC.exists(): + pytest.skip(f"HIP source not found: {_HIP_SRC}", allow_module_level=True) + _split_gdr_hip = load( + name="split_gdr_hip_flydsl", + sources=[str(_HIP_SRC)], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3"], + verbose=False, + ) + return _split_gdr_hip + + +# --------------------------------------------------------------------------- +# State layout conversion utilities +# --------------------------------------------------------------------------- + +def to_swizzled_layout(state: torch.Tensor) -> torch.Tensor: + """ + Convert state from standard (N, HV, K, V) to swizzled (N, HV, K/4, V, 4) layout. + + Swizzled layout enables float4 vectorized loads with cross-thread coalescing: + - Each thread loads 4 consecutive K values as float4 + - All 64 threads access consecutive addresses (1024 bytes per load) + + Memory layout transformation: + Standard: h[n, hv, k, v] at address n*HV*K*V + hv*K*V + k*V + v + Swizzled: h[n, hv, kg, v, k4] at address n*HV*KG*V*4 + hv*KG*V*4 + kg*V*4 + v*4 + k4 + where kg = k // 4, k4 = k % 4 + """ + N, HV, K, V = state.shape + assert K % 4 == 0, f"K ({K}) must be divisible by 4 for swizzled layout" + + # Reshape: (N, HV, K, V) -> (N, HV, K/4, 4, V) + state = state.reshape(N, HV, K // 4, 4, V) + # Permute: (N, HV, K/4, 4, V) -> (N, HV, K/4, V, 4) + state = state.permute(0, 1, 2, 4, 3) + # Make contiguous + return state.contiguous() + + +def from_swizzled_layout(state: torch.Tensor) -> torch.Tensor: + """ + Convert state from swizzled (N, HV, K/4, V, 4) to standard (N, HV, K, V) layout. + """ + N, HV, K4, V, four = state.shape + assert four == 4, f"Last dimension must be 4, got {four}" + K = K4 * 4 + + # Permute: (N, HV, K/4, V, 4) -> (N, HV, K/4, 4, V) + state = state.permute(0, 1, 2, 4, 3) + # Reshape: (N, HV, K/4, 4, V) -> (N, HV, K, V) + state = state.reshape(N, HV, K, V) + # Make contiguous + return state.contiguous() + + +# --------------------------------------------------------------------------- +# Pure PyTorch CPU reference implementation +# --------------------------------------------------------------------------- + +def split_gdr_reference( + mixed_qkv: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + initial_state_source: torch.Tensor, + initial_state_indices: torch.Tensor, + key_dim: int, + value_dim: int, + num_heads_qk: int, + num_heads_v: int, + head_dim: int, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + scale: float = None, + use_qk_l2norm_in_kernel: bool = True, +): + """ + Pure PyTorch reference for fused_split_gdr_update_kernel_v3. + + Mirrors the Triton kernel logic step-by-step on CPU in float32. + + Args: + mixed_qkv: (B, dim, T), bfloat16 — concatenated Q, K, V along dim axis. + Q region: [0, key_dim), K region: [key_dim, 2*key_dim), + V region: [2*key_dim, 2*key_dim + value_dim). + A_log: (HV,), float32 + a: (B*T, HV), bfloat16 — time-variant gating parameter + dt_bias: (HV,), bfloat16 + b: (B*T, HV), bfloat16 — beta gating parameter + initial_state_source: (N_states, HV, K, V), float32 + initial_state_indices: (B,), int32 + key_dim: total key dimension = num_heads_qk * head_dim + value_dim: total value dimension = num_heads_v * head_dim + num_heads_qk: number of QK heads (H) + num_heads_v: number of V heads (HV) + head_dim: per-head dimension (K = V = head_dim) + + Returns: + output: (B, T, HV, V), same dtype as mixed_qkv + """ + B, dim, T = mixed_qkv.shape + H = num_heads_qk + HV = num_heads_v + K = head_dim + V = head_dim + GROUP_SIZE = HV // H + + if scale is None: + scale = K ** -0.5 + + # Work in float32 on CPU + mixed_qkv_f = mixed_qkv.float().cpu() + A_log_f = A_log.float().cpu() + dt_bias_f = dt_bias.float().cpu() + a_f = a.float().cpu() # (B*T, HV) + b_f = b.float().cpu() # (B*T, HV) + + # Reshape a, b to (B, T, HV) + a_f = a_f.view(B, T, HV) + b_f = b_f.view(B, T, HV) + + # Clone initial states — indexed by initial_state_indices + # h: (B, HV, K, V) in float32 + h = torch.zeros(B, HV, K, V, dtype=torch.float32) + indices = initial_state_indices.cpu() + for n in range(B): + idx = indices[n].item() + if idx >= 0: + h[n] = initial_state_source[idx].float().cpu() + + # Split mixed_qkv along dim axis + # Q: (B, key_dim, T), K_tensor: (B, key_dim, T), V_tensor: (B, value_dim, T) + Q_all = mixed_qkv_f[:, :key_dim, :] # (B, key_dim, T) + K_all = mixed_qkv_f[:, key_dim:2*key_dim, :] # (B, key_dim, T) + V_all = mixed_qkv_f[:, 2*key_dim:, :] # (B, value_dim, T) + + output = torch.zeros(B, T, HV, V, dtype=torch.float32) + + for t in range(T): + for hv in range(HV): + i_h = hv // GROUP_SIZE # corresponding QK head + + # Extract per-head Q, K, V for this timestep + q_vec = Q_all[:, i_h * K:(i_h + 1) * K, t] # (B, K) + k_vec = K_all[:, i_h * K:(i_h + 1) * K, t] # (B, K) + v_vec = V_all[:, hv * V:(hv + 1) * V, t] # (B, V) + + # Gating parameters for this timestep and head + a_t = a_f[:, t, hv] # (B,) + b_t = b_f[:, t, hv] # (B,) + + # g = -exp(A_log[hv]) * softplus(a_t + dt_bias[hv]) + x = a_t + dt_bias_f[hv] # (B,) + beta_x = softplus_beta * x + softplus_x = torch.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * torch.log(1.0 + torch.exp(beta_x)), + x, + ) + g = -torch.exp(A_log_f[hv]) * softplus_x # (B,) + + # beta = sigmoid(b_t) + beta = torch.sigmoid(b_t) # (B,) + + # L2 normalization + if use_qk_l2norm_in_kernel: + q_vec = q_vec / (torch.sqrt(torch.sum(q_vec * q_vec, dim=-1, keepdim=True) + 1e-6)) + k_vec = k_vec / (torch.sqrt(torch.sum(k_vec * k_vec, dim=-1, keepdim=True) + 1e-6)) + + # Scale query + q_vec = q_vec * scale # (B, K) + + # h *= exp(g) — decay + h[:, hv, :, :] *= torch.exp(g).unsqueeze(-1).unsqueeze(-1) # (B, K, V) + + # v -= sum(h * k[:, :, None], dim=K) — delta rule + v_vec = v_vec - torch.einsum('bkv,bk->bv', h[:, hv, :, :], k_vec) # (B, V) + + # v *= beta — beta gating + v_vec = v_vec * beta.unsqueeze(-1) # (B, V) + + # h += k[:, :, None] * v[:, None, :] — state update + h[:, hv, :, :] += torch.einsum('bk,bv->bkv', k_vec, v_vec) # (B, K, V) + + # o = sum(h * q[:, :, None], dim=K) — output + o_vec = torch.einsum('bkv,bk->bv', h[:, hv, :, :], q_vec) # (B, V) + output[:, t, hv, :] = o_vec + + # Write final state back to initial_state_source (in-place, on CPU copy) + for n in range(B): + idx = indices[n].item() + if idx >= 0: + initial_state_source[idx] = h[n].to(initial_state_source.dtype).to(initial_state_source.device) + + return output.to(mixed_qkv.dtype).to(mixed_qkv.device) + + +def create_inputs( + batch_size: int, + seqlen: int, + num_heads_qk: int, + num_heads_v: int, + head_dim: int, + device: str, + dtype: torch.dtype, +): + """Create test inputs (identical to TestFusedSplitGDRUpdateOpt.create_inputs).""" + key_dim = num_heads_qk * head_dim + value_dim = num_heads_v * head_dim + dim = 2 * key_dim + value_dim + + # mixed_qkv: (batch, dim, seqlen) + mixed_qkv = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype) + + # Gating parameters — A_log must be float32 + A_log = torch.randn(num_heads_v, device=device, dtype=torch.float32) + dt_bias = torch.randn(num_heads_v, device=device, dtype=dtype) + + # Time-variant gating: (batch * seqlen, num_heads_v) + a = torch.randn(batch_size * seqlen, num_heads_v, device=device, dtype=dtype) + b = torch.randn(batch_size * seqlen, num_heads_v, device=device, dtype=dtype) + + # SSM state must be float32 + # Shape: (batch + padding, num_heads_v, head_dim, head_dim) + ssm_state = torch.randn( + batch_size + 10, num_heads_v, head_dim, head_dim, + device=device, dtype=torch.float32, + ) + ssm_state_indices = torch.arange(batch_size, device=device, dtype=torch.int32) + + return { + "mixed_qkv": mixed_qkv, + "A_log": A_log, + "a": a, + "dt_bias": dt_bias, + "b": b, + "ssm_state": ssm_state, + "ssm_state_indices": ssm_state_indices, + "key_dim": key_dim, + "value_dim": value_dim, + } + +@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("seqlen", [1]) +@pytest.mark.parametrize("num_heads_qk", [4]) +@pytest.mark.parametrize("num_heads_v", [8]) +@pytest.mark.parametrize("head_dim", [128]) +def test_split_gdr_ksplit2_correctness_and_perf( + ctx, + batch_size, + seqlen, + num_heads_qk, + num_heads_v, + head_dim, +): + """Test correctness of ksplit2 kernel against torch CPU reference.""" + torch.manual_seed(42) + device = "cuda" + dtype = torch.bfloat16 + + inputs = create_inputs( + batch_size, seqlen, num_heads_qk, num_heads_v, head_dim, device, dtype, + ) + + key_dim = inputs["key_dim"] + value_dim = inputs["value_dim"] + + softplus_beta = 1.0 + softplus_threshold = 20.0 + scale = head_dim ** -0.5 + use_qk_l2norm_in_kernel = True + + common_scalar_args = { + "key_dim": key_dim, + "value_dim": value_dim, + "num_heads_qk": num_heads_qk, + "num_heads_v": num_heads_v, + "head_dim": head_dim, + "softplus_beta": softplus_beta, + "softplus_threshold": softplus_threshold, + "scale": scale, + "use_qk_l2norm_in_kernel": use_qk_l2norm_in_kernel, + } + common_tensor_args = { + "mixed_qkv": inputs["mixed_qkv"], + "A_log": inputs["A_log"], + "a": inputs["a"], + "dt_bias": inputs["dt_bias"], + "b_gate": inputs["b"], + "initial_state_indices": inputs["ssm_state_indices"], + } + + # ---- Reference: pure PyTorch CPU ---- + ssm_state_ref = inputs["ssm_state"].clone() + output_ref = split_gdr_reference( + mixed_qkv=inputs["mixed_qkv"], + A_log=inputs["A_log"], + a=inputs["a"], + dt_bias=inputs["dt_bias"], + b=inputs["b"], + initial_state_source=ssm_state_ref, + initial_state_indices=inputs["ssm_state_indices"], + key_dim=key_dim, + value_dim=value_dim, + num_heads_qk=num_heads_qk, + num_heads_v=num_heads_v, + head_dim=head_dim, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + # ---- ksplit2 kernel under test ---- + hip_mod = _get_hip_module() + ssm_state_hip = inputs["ssm_state"].clone() + ssm_state_swizzled = to_swizzled_layout(ssm_state_hip) + + output_hip = hip_mod.fused_split_gdr_update_ksplit2( + **common_tensor_args, + initial_state_source=ssm_state_swizzled, + **common_scalar_args, + ) + + ssm_state_hip_final = from_swizzled_layout(ssm_state_swizzled) + + # ---- FlyDSL build module precision check ---- + fly_module = build_fused_split_gdr_update_ksplit2_flyc_module( + B=batch_size, + T_seq=seqlen, + H=num_heads_qk, + HV=num_heads_v, + K=head_dim, + V=head_dim, + N_STATE=inputs["ssm_state"].shape[0], + dtype_str=("f32" if dtype == torch.float32 else "bf16"), + BV=64, + use_qk_l2norm_in_kernel=common_scalar_args["use_qk_l2norm_in_kernel"], + ) + fly_exe = flydsl.compile(fly_module) + ssm_state_fly = inputs["ssm_state"].clone() + ssm_state_swizzled_fly = to_swizzled_layout(ssm_state_fly) + output_fly = torch.empty_like(output_hip) + fly_exe( + common_tensor_args["mixed_qkv"], + common_tensor_args["A_log"], + common_tensor_args["a"], + common_tensor_args["dt_bias"], + common_tensor_args["b_gate"], + ssm_state_swizzled_fly, + common_tensor_args["initial_state_indices"], + output_fly, + ) + torch.cuda.synchronize() + ssm_state_fly_final = from_swizzled_layout(ssm_state_swizzled_fly) + + # ---- Triton fused_sigmoid_gating_delta_rule_update precision check ---- + q_triton = ( + inputs["mixed_qkv"][:, :key_dim, :] + .permute(0, 2, 1) + .reshape(batch_size, seqlen, num_heads_qk, head_dim) + .contiguous() + ) + k_triton = ( + inputs["mixed_qkv"][:, key_dim:2 * key_dim, :] + .permute(0, 2, 1) + .reshape(batch_size, seqlen, num_heads_qk, head_dim) + .contiguous() + ) + v_triton = ( + inputs["mixed_qkv"][:, 2 * key_dim:, :] + .permute(0, 2, 1) + .reshape(batch_size, seqlen, num_heads_v, head_dim) + .contiguous() + ) + a_triton = inputs["a"].reshape(batch_size, seqlen, num_heads_v).contiguous() + b_triton = inputs["b"].reshape(batch_size, seqlen, num_heads_v).contiguous() + ssm_state_triton = inputs["ssm_state"].clone() + output_triton = torch.empty_like(output_hip) + fused_sigmoid_gating_delta_rule_update( + output_triton, + A_log=inputs["A_log"], + a=a_triton, + dt_bias=inputs["dt_bias"], + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q_triton, + k=k_triton, + v=v_triton, + b=b_triton, + initial_state_source=ssm_state_triton, + initial_state_indices=common_tensor_args["initial_state_indices"], + scale=scale, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=None, + ) + torch.cuda.synchronize() + + stats_tensors = { + "output_ref": output_ref, + "output_hip": output_hip, + "output_fly": output_fly, + "output_triton": output_triton, + "ssm_state_ref": ssm_state_ref, + "ssm_state_hip_final": ssm_state_hip_final, + "ssm_state_fly_final": ssm_state_fly_final, + "ssm_state_triton": ssm_state_triton, + } + for name, tensor in stats_tensors.items(): + tensor_f32 = tensor.float() + print(f"{name} min/max: {tensor_f32.min().item():.6f} / {tensor_f32.max().item():.6f}") + + diff_checks = [ + ("Output max diff (hip vs ref)", output_ref, output_hip, "Hip output vs ref"), + ("State max diff (hip vs ref)", ssm_state_ref, ssm_state_hip_final, "Hip state vs ref"), + ("Output max diff (fly vs ref)", output_ref, output_fly, "Fly output vs ref"), + ("State max diff (fly vs ref)", ssm_state_ref, ssm_state_fly_final, "Fly state vs ref"), + ("Output max diff (fly vs hip)", output_hip, output_fly, "Fly output vs hip"), + ("State max diff (fly vs hip)", ssm_state_hip_final, ssm_state_fly_final, "Fly state vs hip"), + ("Output max diff (triton vs ref)", output_ref, output_triton, "Triton output vs ref"), + ("State max diff (triton vs ref)", ssm_state_ref, ssm_state_triton, "Triton state vs ref"), + ] + diff_results = {} + for display_name, lhs, rhs, err_name in diff_checks: + diff = (lhs - rhs).abs().max().item() + diff_results[display_name] = (diff, err_name) + + print(f"\n{'='*70}") + print(f"Split GDR ksplit2 Correctness: batch={batch_size}, seqlen={seqlen}") + print(f" heads_qk={num_heads_qk}, heads_v={num_heads_v}, head_dim={head_dim}") + print(f"{'='*70}") + for display_name, _lhs, _rhs, _err_name in diff_checks: + print(f" {display_name}: {diff_results[display_name][0]:.6f}") + print(f"{'='*70}") + + for display_name, _lhs, _rhs, _err_name in diff_checks: + diff, err_name = diff_results[display_name] + assert diff < 1e-3, f"{err_name} diff too large: {diff}" + + # ---- Performance check: HIP vs FlyDSL vs Triton ---- + warmup = 10 + num_iters = 1000 + state_swz_template = to_swizzled_layout(inputs["ssm_state"]) + out_template = torch.empty_like(output_hip) + + def _benchmark_us(run_fn): + for _ in range(warmup): + run_fn() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(num_iters): + run_fn() + end_evt.record() + torch.cuda.synchronize() + return (start_evt.elapsed_time(end_evt) * 1000.0) / num_iters + + def _run_once(backend: str): + if backend == "fly": + state_swz = state_swz_template.clone() + out = out_template.clone() + fly_exe( + common_tensor_args["mixed_qkv"], + common_tensor_args["A_log"], + common_tensor_args["a"], + common_tensor_args["dt_bias"], + common_tensor_args["b_gate"], + state_swz, + common_tensor_args["initial_state_indices"], + out, + ) + return + + if backend == "hip": + state_swz = state_swz_template.clone() + _ = hip_mod.fused_split_gdr_update_ksplit2( + **common_tensor_args, + initial_state_source=state_swz, + **common_scalar_args, + ) + return + + if backend == "triton": + state_triton = inputs["ssm_state"].clone() + out = out_template.clone() + fused_sigmoid_gating_delta_rule_update( + out, + A_log=inputs["A_log"], + a=a_triton, + dt_bias=inputs["dt_bias"], + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q_triton, + k=k_triton, + v=v_triton, + b=b_triton, + initial_state_source=state_triton, + initial_state_indices=common_tensor_args["initial_state_indices"], + scale=scale, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=None, + ) + return + + raise ValueError(f"Unknown backend: {backend}") + + fly_us = _benchmark_us(lambda: _run_once("fly")) + hip_us = _benchmark_us(lambda: _run_once("hip")) + triton_us = _benchmark_us(lambda: _run_once("triton")) + speed_ratio = hip_us / fly_us if fly_us > 0 else float("inf") + hip_vs_triton = hip_us / triton_us if triton_us > 0 else float("inf") + print(f" Perf warmup/loop: {warmup}/{num_iters}") + print(f" FlyDSL time: {fly_us:.2f} us") + print(f" HIP time: {hip_us:.2f} us") + print(f" Triton time: {triton_us:.2f} us") + print(f" HIP/FlyDSL: {speed_ratio:.3f}x") + print(f" HIP/Triton: {hip_vs_triton:.3f}x") + print(f" PASS — ksplit2 correctness test passed!")