"""
Corrected Sinkhorn Backward Pass (fix.md Scheme 2)
When R satisfies R @ 1 = r and R^T @ 1 = c (r, c != 1),
the corrected linear system for beta is:
(diag(c) - R^T diag(r)^{-1} R) beta = s_c - R^T diag(r)^{-1} s_r
where s_r = (R ⊙ dR) @ 1, s_c = (R ⊙ dR)^T @ 1.
This file implements the corrected kernel and compares with:
1. PyTorch autograd
2. The original triton_reduce.py (assumes r=1, c=1)
"""
from icecream import ic
import torch
import triton
import triton.language as tl
from tqdm import trange
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
dtype = torch.float64
EPS = tl.constexpr(1e-10)
eps = 1e-10
# ─────────────────────────────────────────────────────────────────────────────
# Forward: standard Sinkhorn (same as triton_reduce.py)
# ─────────────────────────────────────────────────────────────────────────────
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True).clamp(min=eps)
R = R / R.sum(-1, keepdim=True).clamp(min=eps)
return R, P
# ─────────────────────────────────────────────────────────────────────────────
# Original matvec_S: S = I - R^T R (assumes r=1, c=1)
# ─────────────────────────────────────────────────────────────────────────────
@triton.jit
def matvec_S(R, x):
"""
S = I - R^T R, perform S @ x WITHOUT materializing RTR.
Computes: x - R^T (R x) using two matvecs.
R: (tilesize, n, n)
x: (tilesize, n, 1)
returns: (tilesize, n, 1)
"""
Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
RT = R.permute(0, 2, 1)
RTRx = tl.sum(RT * Rx.permute(0, 2, 1), axis=-1).expand_dims(-1)
return x - RTRx
# ─────────────────────────────────────────────────────────────────────────────
# Corrected matvec_S: S = diag(c) - R^T diag(r)^{-1} R
# ─────────────────────────────────────────────────────────────────────────────
@triton.jit
def matvec_S_corrected(R, x, r, c):
"""
Corrected S = diag(c) - R^T diag(r)^{-1} R.
Computes: diag(c) x - R^T (diag(r)^{-1} (R x))
R: (tilesize, n, n)
x: (tilesize, n, 1)
r: (tilesize, n, 1) row sums of R
c: (tilesize, n, 1) col sums of R
returns: (tilesize, n, 1)
"""
# R @ x -> (tilesize, n, 1)
Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
# diag(r)^{-1} @ Rx -> element-wise divide
Rx_scaled = Rx / r
# R^T @ Rx_scaled -> (tilesize, n, 1)
RT = R.permute(0, 2, 1)
RTRx = tl.sum(RT * Rx_scaled.permute(0, 2, 1), axis=-1).expand_dims(-1)
# diag(c) @ x - R^T diag(r)^{-1} R @ x
return c * x - RTRx
# ─────────────────────────────────────────────────────────────────────────────
# Corrected backward kernel
# ─────────────────────────────────────────────────────────────────────────────
@triton.autotune(
configs=[
triton.Config({"tilesize": tilesize}, num_stages=1, num_warps=num_warps)
for tilesize in [1, 2, 4, 8, 16, 32, 64]
for num_warps in [1, 2, 4, 8]
],
key=[],
)
@triton.jit
def sinkhorn_bwd_corrected_kernel(
seqlen,
out,
dout,
res,
out_stride_0,
out_stride_1,
out_stride_2,
dout_stride_0,
dout_stride_1,
dout_stride_2,
res_stride_0,
res_stride_1,
res_stride_2,
n_stream: tl.constexpr,
tilesize: tl.constexpr,
):
out_desc = tl.make_tensor_descriptor(
out,
shape=[seqlen, n_stream, n_stream],
strides=[out_stride_0, out_stride_1, out_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
dout_desc = tl.make_tensor_descriptor(
dout,
shape=[seqlen, n_stream, n_stream],
strides=[dout_stride_0, dout_stride_1, dout_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
res_desc = tl.make_tensor_descriptor(
res,
shape=[seqlen, n_stream, n_stream],
strides=[res_stride_0, res_stride_1, res_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
seq_off = tl.program_id(0) * tilesize
R = out_desc.load([seq_off, 0, 0]) # (tilesize, n, n)
RT = R.permute(0, 2, 1)
dR = dout_desc.load([seq_off, 0, 0])
# ── Compute row sums r and col sums c of R ──────────────────────────────
# r: (tilesize, n, 1), c: (tilesize, n, 1)
r = tl.sum(R, axis=-1).expand_dims(-1) # row sums
c = tl.sum(R, axis=-2).expand_dims(-1) # col sums
# ── Step 1: s_r = (R ⊙ dR) @ 1, s_c = (R ⊙ dR)^T @ 1 ────────────────
RdR = R * dR
s_r = tl.sum(RdR, axis=-1).expand_dims(-1) # (tilesize, n, 1)
s_c = tl.sum(RdR, axis=-2).expand_dims(-1) # (tilesize, n, 1)
# ── Step 2: b = s_c - R^T diag(r)^{-1} s_r ─────────────────────────────
s_r_scaled = s_r / (r + EPS) # diag(r)^{-1} s_r
RT_sr = tl.sum(RT * s_r_scaled.permute(0, 2, 1), axis=-1).expand_dims(-1)
b = s_c - RT_sr # (tilesize, n, 1)
# ── Step 3: CG to solve (diag(c) - R^T diag(r)^{-1} R) beta = b ────────
x = tl.zeros((tilesize, n_stream, 1), dtype=R.dtype)
r_cg = b # residual (using r_cg to avoid name clash with row sums r)
p = r_cg
r_normsq = tl.sum(r_cg * r_cg, axis=1, keep_dims=True)
for _ in range(n_stream):
Sp = matvec_S_corrected(R, p, r, c)
pSp = tl.sum(p * Sp, axis=1, keep_dims=True)
alpha = r_normsq / (pSp + EPS)
x += alpha * p
r_cg -= alpha * Sp
r_new_normsq = tl.sum(r_cg * r_cg, axis=1, keep_dims=True)
beta_cg = r_new_normsq / (r_normsq + EPS)
p = r_cg + beta_cg * p
r_normsq = r_new_normsq
# ── Step 4: u = diag(r)^{-1} (s_r - R beta), v = beta ─────────────────
# R @ x: (tilesize, n, n) x (tilesize, n, 1) -> (tilesize, n, 1)
Rx = tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
u = (s_r - Rx) / r # (tilesize, n, 1)
v = x # (tilesize, n, 1)
# ── Step 5: M_ij = u_i + v_j ────────────────────────────────────────────
vt = v.reshape(tilesize, 1, n_stream)
M_mat = u + vt # broadcast -> (tilesize, n, n)
# ── Step 6: grad = (dR - M) ⊙ R ─────────────────────────────────────────
res_tile = (dR - M_mat) * R
res_desc.store([seq_off, 0, 0], res_tile)
def sinkhorn_bwd_corrected(
out: torch.Tensor,
dout: torch.Tensor,
repeat: int = 1,
warmup: bool = True,
):
seqlen = out.size(0)
n_stream = out.size(1)
res = torch.empty_like(out)
def grid(META):
return (triton.cdiv(seqlen, META["tilesize"]), 1, 1)
def _run():
sinkhorn_bwd_corrected_kernel[grid](
seqlen,
out,
dout,
res,
out.stride(0),
out.stride(1),
out.stride(2),
dout.stride(0),
dout.stride(1),
dout.stride(2),
res.stride(0),
res.stride(1),
res.stride(2),
n_stream,
)
if warmup:
for _ in trange(4, desc="warmup"):
_run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(repeat):
_run()
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
if repeat > 1:
print(f"[corrected] Kernel execution time ({repeat=}): {elapsed_ms:.3f} ms")
print(f"[corrected] Average time per iteration: {elapsed_ms / repeat:.3f} ms")
return res
# ─────────────────────────────────────────────────────────────────────────────
# Original backward kernel (from triton_reduce.py, verbatim)
# ─────────────────────────────────────────────────────────────────────────────
@triton.autotune(
configs=[
triton.Config({"tilesize": tilesize}, num_stages=1, num_warps=num_warps)
for tilesize in [1, 2, 4, 8, 16, 32, 64]
for num_warps in [1, 2, 4, 8]
],
key=[],
)
@triton.jit
def sinkhorn_bwd_original_kernel(
seqlen,
out,
dout,
res,
out_stride_0,
out_stride_1,
out_stride_2,
dout_stride_0,
dout_stride_1,
dout_stride_2,
res_stride_0,
res_stride_1,
res_stride_2,
n_stream: tl.constexpr,
tilesize: tl.constexpr,
):
out_desc = tl.make_tensor_descriptor(
out,
shape=[seqlen, n_stream, n_stream],
strides=[out_stride_0, out_stride_1, out_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
dout_desc = tl.make_tensor_descriptor(
dout,
shape=[seqlen, n_stream, n_stream],
strides=[dout_stride_0, dout_stride_1, dout_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
res_desc = tl.make_tensor_descriptor(
res,
shape=[seqlen, n_stream, n_stream],
strides=[res_stride_0, res_stride_1, res_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
seq_off = tl.program_id(0) * tilesize
R = out_desc.load([seq_off, 0, 0])
RT = R.permute(0, 2, 1)
dR = dout_desc.load([seq_off, 0, 0])
RdR = R * dR
s_r = tl.sum(RdR, axis=-1).expand_dims(-1)
s_c = tl.sum(RdR, axis=-2).expand_dims(-1)
b = s_c - tl.sum(RT * s_r.permute(0, 2, 1), axis=-1).expand_dims(-1)
x = tl.zeros((tilesize, n_stream, 1), dtype=R.dtype)
r = b
p = r
r_normsq = tl.sum(r * r, axis=1, keep_dims=True)
for _ in range(n_stream):
Sp = matvec_S(R, p)
pSp = tl.sum(p * Sp, axis=1, keep_dims=True)
alpha = r_normsq / (pSp + EPS)
x += alpha * p
r -= alpha * Sp
r_new_normsq = tl.sum(r * r, axis=1, keep_dims=True)
beta = r_new_normsq / (r_normsq + EPS)
p = r + beta * p
r_normsq = r_new_normsq
u = s_r - tl.sum(R * x.permute(0, 2, 1), axis=-1).expand_dims(-1)
v = x
vt = v.reshape(tilesize, 1, n_stream)
M_mat = u + vt
res_tile = (dR - M_mat) * R
res_desc.store([seq_off, 0, 0], res_tile)
def sinkhorn_bwd_original(
out: torch.Tensor,
dout: torch.Tensor,
repeat: int = 1,
warmup: bool = True,
):
seqlen = out.size(0)
n_stream = out.size(1)
res = torch.empty_like(out)
def grid(META):
return (triton.cdiv(seqlen, META["tilesize"]), 1, 1)
def _run():
sinkhorn_bwd_original_kernel[grid](
seqlen,
out,
dout,
res,
out.stride(0),
out.stride(1),
out.stride(2),
dout.stride(0),
dout.stride(1),
dout.stride(2),
res.stride(0),
res.stride(1),
res.stride(2),
n_stream,
)
if warmup:
for _ in trange(4, desc="warmup (original)"):
_run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(repeat):
_run()
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
if repeat > 1:
print(f"[original] Kernel execution time ({repeat=}): {elapsed_ms:.3f} ms")
print(f"[original] Average time per iteration: {elapsed_ms / repeat:.3f} ms")
return res
# ─────────────────────────────────────────────────────────────────────────────
# Main: accuracy comparison
# ─────────────────────────────────────────────────────────────────────────────
def main():
seqlen = 65536
n_stream = 8
iters = 10
repeat = 512
print(f"\nRunning {iters} iters...")
device = torch.device("cuda")
dist = torch.distributions.uniform.Uniform(0.0, 4.0)
M = dist.sample((seqlen, n_stream, n_stream)).to(device)
M.requires_grad_(True)
# ── Forward ──────────────────────────────────────────────────────────────
R, P = sinkhorn_forward(M, iters)
loss_weight = torch.randn_like(R)
# ── Method A: Autograd (ground truth) ────────────────────────────────────
loss_a = (R * loss_weight).sum()
loss_a.backward()
grad_autograd = M.grad.detach().clone()
grad_R = loss_weight # dL/dR
# ── Method B: Original triton_reduce.py (assumes r=1, c=1) ───────────────
grad_original = sinkhorn_bwd_original(R.detach(), grad_R, repeat=repeat, warmup=True)
# ── Method C: Corrected (fix.md scheme 2, uses actual r, c) ─────────────
grad_corrected = sinkhorn_bwd_corrected(R.detach(), grad_R, repeat=repeat, warmup=True)
# ── Accuracy comparison ───────────────────────────────────────────────────
def format_list(ls):
return [f"{x:.2e}" for x in ls]
def print_metrics(name, g1, g2):
abs_diff = (g1 - g2).abs()
rel_diff = abs_diff / (g1.abs() + 1e-12)
MAE = abs_diff.mean(dim=(-1, -2)).tolist()
max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist()
mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist()
max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist()
print(f"\n[{name}]")
print(f"Max MAE = {max(MAE)}")
print(f"Max max_abs_diff = {max(max_abs_diff)}")
print(f"Max mean_rel_diff = {max(mean_rel_diff)}")
print(f"Max max_rel_diff = {max(max_rel_diff)}")
return max(MAE)
print("\nComparison of gradients dL/dM")
print("--------------------------------")
mae_orig = print_metrics("Original (r=1,c=1 assumption)", grad_autograd, grad_original)
mae_gen = print_metrics("Corrected (fix.md scheme 2)", grad_autograd, grad_corrected)
print("\nGrad (autograd) sample:\n", grad_autograd[0])
print("\nGrad (original) sample:\n", grad_original[0])
print("\nGrad (corrected) sample:\n", grad_corrected[0])
print(grad_autograd.dtype)
print(grad_corrected.dtype)
# ── Sanity check: row/col sums of R ──────────────────────────────────────
r_sums = R.sum(-1) # (seqlen, n)
c_sums = R.sum(-2) # (seqlen, n)
print(f"\nR row sums - mean={r_sums.mean():.4f}, std={r_sums.std():.4f}")
print(f"R col sums - mean={c_sums.mean():.4f}, std={c_sums.std():.4f}")
print("(If sums ≈ 1.0, original and generalized should agree)")
assert mae_gen < 5e-4, f"Corrected MAE too large: {mae_gen:.3e}"
print("\n✓ Accuracy check passed.")
if __name__ == "__main__":
main()
I have long been wondering if the backward of sinkhorn step can be computed like this:
triton code
driver code
However,this heavily rely on the convergence of forward pass, which may not be guaranteed during training