Skip to content

Discuss the backward impl of mHC #2

@Da1sypetals

Description

@Da1sypetals

I have long been wondering if the backward of sinkhorn step can be computed like this:

triton code
"""
Corrected Sinkhorn Backward Pass (fix.md Scheme 2)

When R satisfies R @ 1 = r and R^T @ 1 = c (r, c != 1),
the corrected linear system for beta is:
    (diag(c) - R^T diag(r)^{-1} R) beta = s_c - R^T diag(r)^{-1} s_r

where s_r = (R ⊙ dR) @ 1,  s_c = (R ⊙ dR)^T @ 1.

This file implements the corrected kernel and compares with:
1. PyTorch autograd
2. The original triton_reduce.py (assumes r=1, c=1)
"""

from icecream import ic
import torch
import triton
import triton.language as tl
from tqdm import trange


# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: int | None):
    return torch.empty(size, device="cuda", dtype=torch.int8)


triton.set_allocator(alloc_fn)

dtype = torch.float64
EPS = tl.constexpr(1e-10)
eps = 1e-10


# ─────────────────────────────────────────────────────────────────────────────
# Forward: standard Sinkhorn (same as triton_reduce.py)
# ─────────────────────────────────────────────────────────────────────────────
def sinkhorn_forward(M, iters=20):
    P = torch.exp(M)
    R = P
    for _ in range(iters):
        R = R / R.sum(-2, keepdim=True).clamp(min=eps)
        R = R / R.sum(-1, keepdim=True).clamp(min=eps)
    return R, P


# ─────────────────────────────────────────────────────────────────────────────
# Original matvec_S: S = I - R^T R  (assumes r=1, c=1)
# ─────────────────────────────────────────────────────────────────────────────
@triton.jit
def matvec_S(R, x):
    """
    S = I - R^T R, perform S @ x WITHOUT materializing RTR.
    Computes: x - R^T (R x)  using two matvecs.
    R: (tilesize, n, n)
    x: (tilesize, n, 1)
    returns: (tilesize, n, 1)
    """
    Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
    RT = R.permute(0, 2, 1)
    RTRx = tl.sum(RT * Rx.permute(0, 2, 1), axis=-1).expand_dims(-1)
    return x - RTRx


# ─────────────────────────────────────────────────────────────────────────────
# Corrected matvec_S: S = diag(c) - R^T diag(r)^{-1} R
# ─────────────────────────────────────────────────────────────────────────────
@triton.jit
def matvec_S_corrected(R, x, r, c):
    """
    Corrected S = diag(c) - R^T diag(r)^{-1} R.
    Computes: diag(c) x - R^T (diag(r)^{-1} (R x))

    R: (tilesize, n, n)
    x: (tilesize, n, 1)
    r: (tilesize, n, 1)   row sums of R
    c: (tilesize, n, 1)   col sums of R
    returns: (tilesize, n, 1)
    """
    # R @ x -> (tilesize, n, 1)
    Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
    # diag(r)^{-1} @ Rx -> element-wise divide
    Rx_scaled = Rx / r 
    # R^T @ Rx_scaled -> (tilesize, n, 1)
    RT = R.permute(0, 2, 1)
    RTRx = tl.sum(RT * Rx_scaled.permute(0, 2, 1), axis=-1).expand_dims(-1)
    # diag(c) @ x - R^T diag(r)^{-1} R @ x
    return c * x - RTRx


# ─────────────────────────────────────────────────────────────────────────────
# Corrected backward kernel
# ─────────────────────────────────────────────────────────────────────────────
@triton.autotune(
    configs=[
        triton.Config({"tilesize": tilesize}, num_stages=1, num_warps=num_warps)
        for tilesize in [1, 2, 4, 8, 16, 32, 64]
        for num_warps in [1, 2, 4, 8]
    ],
    key=[],
)
@triton.jit
def sinkhorn_bwd_corrected_kernel(
    seqlen,
    out,
    dout,
    res,
    out_stride_0,
    out_stride_1,
    out_stride_2,
    dout_stride_0,
    dout_stride_1,
    dout_stride_2,
    res_stride_0,
    res_stride_1,
    res_stride_2,
    n_stream: tl.constexpr,
    tilesize: tl.constexpr,
):
    out_desc = tl.make_tensor_descriptor(
        out,
        shape=[seqlen, n_stream, n_stream],
        strides=[out_stride_0, out_stride_1, out_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )
    dout_desc = tl.make_tensor_descriptor(
        dout,
        shape=[seqlen, n_stream, n_stream],
        strides=[dout_stride_0, dout_stride_1, dout_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )
    res_desc = tl.make_tensor_descriptor(
        res,
        shape=[seqlen, n_stream, n_stream],
        strides=[res_stride_0, res_stride_1, res_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )

    seq_off = tl.program_id(0) * tilesize

    R = out_desc.load([seq_off, 0, 0])  # (tilesize, n, n)
    RT = R.permute(0, 2, 1)
    dR = dout_desc.load([seq_off, 0, 0])

    # ── Compute row sums r and col sums c of R ──────────────────────────────
    # r: (tilesize, n, 1),  c: (tilesize, n, 1)
    r = tl.sum(R, axis=-1).expand_dims(-1)  # row sums
    c = tl.sum(R, axis=-2).expand_dims(-1)  # col sums

    # ── Step 1: s_r = (R ⊙ dR) @ 1,  s_c = (R ⊙ dR)^T @ 1 ────────────────
    RdR = R * dR
    s_r = tl.sum(RdR, axis=-1).expand_dims(-1)  # (tilesize, n, 1)
    s_c = tl.sum(RdR, axis=-2).expand_dims(-1)  # (tilesize, n, 1)

    # ── Step 2: b = s_c - R^T diag(r)^{-1} s_r ─────────────────────────────
    s_r_scaled = s_r / (r + EPS)  # diag(r)^{-1} s_r
    RT_sr = tl.sum(RT * s_r_scaled.permute(0, 2, 1), axis=-1).expand_dims(-1)
    b = s_c - RT_sr  # (tilesize, n, 1)

    # ── Step 3: CG to solve (diag(c) - R^T diag(r)^{-1} R) beta = b ────────
    x = tl.zeros((tilesize, n_stream, 1), dtype=R.dtype)
    r_cg = b  # residual (using r_cg to avoid name clash with row sums r)
    p = r_cg
    r_normsq = tl.sum(r_cg * r_cg, axis=1, keep_dims=True)

    for _ in range(n_stream):
        Sp = matvec_S_corrected(R, p, r, c)
        pSp = tl.sum(p * Sp, axis=1, keep_dims=True)
        alpha = r_normsq / (pSp + EPS)

        x += alpha * p
        r_cg -= alpha * Sp

        r_new_normsq = tl.sum(r_cg * r_cg, axis=1, keep_dims=True)
        beta_cg = r_new_normsq / (r_normsq + EPS)

        p = r_cg + beta_cg * p
        r_normsq = r_new_normsq

    # ── Step 4: u = diag(r)^{-1} (s_r - R beta),  v = beta ─────────────────
    # R @ x: (tilesize, n, n) x (tilesize, n, 1) -> (tilesize, n, 1)
    Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
    u = (s_r - Rx) / r  # (tilesize, n, 1)
    v = x  # (tilesize, n, 1)

    # ── Step 5: M_ij = u_i + v_j ────────────────────────────────────────────
    vt = v.reshape(tilesize, 1, n_stream)
    M_mat = u + vt  # broadcast -> (tilesize, n, n)

    # ── Step 6: grad = (dR - M) ⊙ R ─────────────────────────────────────────
    res_tile = (dR - M_mat) * R

    res_desc.store([seq_off, 0, 0], res_tile)


def sinkhorn_bwd_corrected(
    out: torch.Tensor,
    dout: torch.Tensor,
    repeat: int = 1,
    warmup: bool = True,
):
    seqlen = out.size(0)
    n_stream = out.size(1)

    res = torch.empty_like(out)

    def grid(META):
        return (triton.cdiv(seqlen, META["tilesize"]), 1, 1)

    def _run():
        sinkhorn_bwd_corrected_kernel[grid](
            seqlen,
            out,
            dout,
            res,
            out.stride(0),
            out.stride(1),
            out.stride(2),
            dout.stride(0),
            dout.stride(1),
            dout.stride(2),
            res.stride(0),
            res.stride(1),
            res.stride(2),
            n_stream,
        )

    if warmup:
        for _ in trange(4, desc="warmup"):
            _run()
            torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start_event.record()
    for _ in range(repeat):
        _run()
    end_event.record()
    torch.cuda.synchronize()

    elapsed_ms = start_event.elapsed_time(end_event)
    if repeat > 1:
        print(f"[corrected] Kernel execution time ({repeat=}): {elapsed_ms:.3f} ms")
        print(f"[corrected] Average time per iteration: {elapsed_ms / repeat:.3f} ms")

    return res


# ─────────────────────────────────────────────────────────────────────────────
# Original backward kernel (from triton_reduce.py, verbatim)
# ─────────────────────────────────────────────────────────────────────────────
@triton.autotune(
    configs=[
        triton.Config({"tilesize": tilesize}, num_stages=1, num_warps=num_warps)
        for tilesize in [1, 2, 4, 8, 16, 32, 64]
        for num_warps in [1, 2, 4, 8]
    ],
    key=[],
)
@triton.jit
def sinkhorn_bwd_original_kernel(
    seqlen,
    out,
    dout,
    res,
    out_stride_0,
    out_stride_1,
    out_stride_2,
    dout_stride_0,
    dout_stride_1,
    dout_stride_2,
    res_stride_0,
    res_stride_1,
    res_stride_2,
    n_stream: tl.constexpr,
    tilesize: tl.constexpr,
):
    out_desc = tl.make_tensor_descriptor(
        out,
        shape=[seqlen, n_stream, n_stream],
        strides=[out_stride_0, out_stride_1, out_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )
    dout_desc = tl.make_tensor_descriptor(
        dout,
        shape=[seqlen, n_stream, n_stream],
        strides=[dout_stride_0, dout_stride_1, dout_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )
    res_desc = tl.make_tensor_descriptor(
        res,
        shape=[seqlen, n_stream, n_stream],
        strides=[res_stride_0, res_stride_1, res_stride_2],
        block_shape=[tilesize, n_stream, n_stream],
    )

    seq_off = tl.program_id(0) * tilesize

    R = out_desc.load([seq_off, 0, 0])
    RT = R.permute(0, 2, 1)
    dR = dout_desc.load([seq_off, 0, 0])

    RdR = R * dR
    s_r = tl.sum(RdR, axis=-1).expand_dims(-1)
    s_c = tl.sum(RdR, axis=-2).expand_dims(-1)

    b = s_c - tl.sum(RT * s_r.permute(0, 2, 1), axis=-1).expand_dims(-1)

    x = tl.zeros((tilesize, n_stream, 1), dtype=R.dtype)
    r = b
    p = r
    r_normsq = tl.sum(r * r, axis=1, keep_dims=True)

    for _ in range(n_stream):
        Sp = matvec_S(R, p)
        pSp = tl.sum(p * Sp, axis=1, keep_dims=True)
        alpha = r_normsq / (pSp + EPS)

        x += alpha * p
        r -= alpha * Sp

        r_new_normsq = tl.sum(r * r, axis=1, keep_dims=True)
        beta = r_new_normsq / (r_normsq + EPS)

        p = r + beta * p
        r_normsq = r_new_normsq

    u = s_r - tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
    v = x

    vt = v.reshape(tilesize, 1, n_stream)
    M_mat = u + vt

    res_tile = (dR - M_mat) * R
    res_desc.store([seq_off, 0, 0], res_tile)


def sinkhorn_bwd_original(
    out: torch.Tensor,
    dout: torch.Tensor,
    repeat: int = 1,
    warmup: bool = True,
):
    seqlen = out.size(0)
    n_stream = out.size(1)

    res = torch.empty_like(out)

    def grid(META):
        return (triton.cdiv(seqlen, META["tilesize"]), 1, 1)

    def _run():
        sinkhorn_bwd_original_kernel[grid](
            seqlen,
            out,
            dout,
            res,
            out.stride(0),
            out.stride(1),
            out.stride(2),
            dout.stride(0),
            dout.stride(1),
            dout.stride(2),
            res.stride(0),
            res.stride(1),
            res.stride(2),
            n_stream,
        )

    if warmup:
        for _ in trange(4, desc="warmup (original)"):
            _run()
            torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start_event.record()
    for _ in range(repeat):
        _run()
    end_event.record()
    torch.cuda.synchronize()

    elapsed_ms = start_event.elapsed_time(end_event)
    if repeat > 1:
        print(f"[original]     Kernel execution time ({repeat=}): {elapsed_ms:.3f} ms")
        print(f"[original]     Average time per iteration: {elapsed_ms / repeat:.3f} ms")

    return res


# ─────────────────────────────────────────────────────────────────────────────
# Main: accuracy comparison
# ─────────────────────────────────────────────────────────────────────────────
def main():
    seqlen = 65536
    n_stream = 8
    iters = 10
    repeat = 512

    print(f"\nRunning {iters} iters...")

    device = torch.device("cuda")
    dist = torch.distributions.uniform.Uniform(0.0, 4.0)
    M = dist.sample((seqlen, n_stream, n_stream)).to(device)
    M.requires_grad_(True)

    # ── Forward ──────────────────────────────────────────────────────────────
    R, P = sinkhorn_forward(M, iters)
    loss_weight = torch.randn_like(R)

    # ── Method A: Autograd (ground truth) ────────────────────────────────────
    loss_a = (R * loss_weight).sum()
    loss_a.backward()
    grad_autograd = M.grad.detach().clone()

    grad_R = loss_weight  # dL/dR

    # ── Method B: Original triton_reduce.py (assumes r=1, c=1) ───────────────
    grad_original = sinkhorn_bwd_original(R.detach(), grad_R, repeat=repeat, warmup=True)

    # ── Method C: Corrected (fix.md scheme 2, uses actual r, c) ─────────────
    grad_corrected = sinkhorn_bwd_corrected(R.detach(), grad_R, repeat=repeat, warmup=True)

    # ── Accuracy comparison ───────────────────────────────────────────────────
    def format_list(ls):
        return [f"{x:.2e}" for x in ls]

    def print_metrics(name, g1, g2):
        abs_diff = (g1 - g2).abs()
        rel_diff = abs_diff / (g1.abs() + 1e-12)

        MAE = abs_diff.mean(dim=(-1, -2)).tolist()
        max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist()
        mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist()
        max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist()

        print(f"\n[{name}]")
        print(f"Max MAE = {max(MAE)}")
        print(f"Max max_abs_diff = {max(max_abs_diff)}")
        print(f"Max mean_rel_diff = {max(mean_rel_diff)}")
        print(f"Max max_rel_diff = {max(max_rel_diff)}")
        return max(MAE)

    print("\nComparison of gradients dL/dM")
    print("--------------------------------")
    mae_orig = print_metrics("Original (r=1,c=1 assumption)", grad_autograd, grad_original)
    mae_gen = print_metrics("Corrected (fix.md scheme 2)", grad_autograd, grad_corrected)

    print("\nGrad (autograd) sample:\n", grad_autograd[0])
    print("\nGrad (original) sample:\n", grad_original[0])
    print("\nGrad (corrected) sample:\n", grad_corrected[0])
    print(grad_autograd.dtype)
    print(grad_corrected.dtype)

    # ── Sanity check: row/col sums of R ──────────────────────────────────────
    r_sums = R.sum(-1)  # (seqlen, n)
    c_sums = R.sum(-2)  # (seqlen, n)
    print(f"\nR row sums  - mean={r_sums.mean():.4f}, std={r_sums.std():.4f}")
    print(f"R col sums  - mean={c_sums.mean():.4f}, std={c_sums.std():.4f}")
    print("(If sums ≈ 1.0, original and generalized should agree)")

    assert mae_gen < 5e-4, f"Corrected MAE too large: {mae_gen:.3e}"
    print("\n✓ Accuracy check passed.")


if __name__ == "__main__":
    main()
driver code
"""
Compare the backward gradients of two Sinkhorn backward implementations:
1. tl_impl.py - TileLang implementation (layer-by-layer backprop)
2. tt_impl.py - Triton implementation (CG solver for corrected linear system)

Ground truth: PyTorch autograd
"""

import argparse
import torch
import sys

from tl_impl import _mhc_sinkhorn_bwd, _mhc_sinkhorn_fwd
from tt_impl import sinkhorn_bwd_corrected, sinkhorn_forward


def sinkhorn_forward_pytorch(M, iters=10, eps=1e-10):
    P = torch.softmax(M, dim=-1) + eps
    R = P / (P.sum(dim=-2, keepdim=True) + eps)
    for _ in range(iters - 1):
        R = R / (R.sum(dim=-1, keepdim=True) + eps)
        R = R / (R.sum(dim=-2, keepdim=True) + eps)
    return R


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--iter", type=int, default=10, help="Number of Sinkhorn iterations")
    args = parser.parse_args()

    hidden_size = 8
    num_tokens = 4096
    token_block_size = 4
    repeat = args.iter
    eps = 1e-10

    device = torch.device("cuda")
    torch.manual_seed(42)

    M = torch.randn(num_tokens, hidden_size, hidden_size, dtype=torch.float32, device=device)
    grad_output = torch.randn(num_tokens, hidden_size, hidden_size, dtype=torch.float32, device=device)

    print(f"Configuration: num_tokens={num_tokens}, hidden_size={hidden_size}, repeat={repeat}, eps={eps}")
    print(f"token_block_size={token_block_size}")
    print()

    # PyTorch autograd (softmax forward)
    M_ref = M.detach().clone().requires_grad_(True)
    R_ref = sinkhorn_forward_pytorch(M_ref, iters=repeat, eps=eps)
    (R_ref * grad_output).sum().backward()
    grad_autograd = M_ref.grad.detach().clone()

    # TileLang
    fwd_kernel = _mhc_sinkhorn_fwd(hidden_size, token_block_size, repeat, eps)
    R_tl = torch.empty(num_tokens, hidden_size, hidden_size, dtype=torch.float32, device=device)
    fwd_kernel(M.detach(), R_tl)
    torch.cuda.synchronize()

    bwd_kernel = _mhc_sinkhorn_bwd(hidden_size, token_block_size, repeat, eps)
    grad_tl = torch.empty(num_tokens, hidden_size, hidden_size, dtype=torch.float32, device=device)
    bwd_kernel(grad_output, M.detach(), grad_tl)
    torch.cuda.synchronize()

    # Triton
    M_tt = M.detach().clone().requires_grad_(True)
    R_tt, _ = sinkhorn_forward(M_tt, iters=repeat)
    (R_tt * grad_output).sum().backward()
    grad_autograd_tt = M_tt.grad.detach().clone()

    grad_tt = sinkhorn_bwd_corrected(R_tt.detach(), grad_output, repeat=1, warmup=False)
    torch.cuda.synchronize()

    # Comparison 1: TileLang vs PyTorch softmax
    print("=== TileLang vs PyTorch (softmax forward) ===")
    abs_diff = (grad_tl - grad_autograd).abs()
    rel_diff = abs_diff / (grad_autograd.abs() + 1e-12)
    print(f"  MAE  = {abs_diff.mean().item():.6e}")
    print(f"  MaxAE= {abs_diff.max().item():.6e}")
    print(f"  MRE  = {rel_diff.mean().item():.6e}")
    print(f"  MaxRE= {rel_diff.max().item():.6e}")
    print()

    # Comparison 2: Triton vs PyTorch exp
    print("=== Triton vs PyTorch (exp forward) ===")
    abs_diff = (grad_tt - grad_autograd_tt).abs()
    rel_diff = abs_diff / (grad_autograd_tt.abs() + 1e-12)
    print(f"  MAE  = {abs_diff.mean().item():.6e}")
    print(f"  MaxAE= {abs_diff.max().item():.6e}")
    print(f"  MRE  = {rel_diff.mean().item():.6e}")
    print(f"  MaxRE= {rel_diff.max().item():.6e}")
    print()

    print("Sample gradient (token 0):")
    print(f"  Autograd (softmax):\n{grad_autograd[0]}")
    print(f"  TileLang:          \n{grad_tl[0]}")
    print(f"  Autograd (exp):    \n{grad_autograd_tt[0]}")
    print(f"  Triton:            \n{grad_tt[0]}")


if __name__ == "__main__":
    main()

However,this heavily rely on the convergence of forward pass, which may not be guaranteed during training

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions