Skip to content

Add fused Triton kernel for VQ-codebook EMA update#1

Open
EdoardoBotta wants to merge 1 commit intomainfrom
claude/triton-ema-update-kernel-dSKPp
Open

Add fused Triton kernel for VQ-codebook EMA update#1
EdoardoBotta wants to merge 1 commit intomainfrom
claude/triton-ema-update-kernel-dSKPp

Conversation

@EdoardoBotta
Copy link
Copy Markdown
Owner

Replaces the five-step PyTorch implementation of VQQuantizer._ema_update
with a single Triton kernel pass (rectokens/kernels/ema_update.py) that:

  • Scatter-accumulates cluster_size and embed_sum directly from (x, codes)
    without materialising the (B, K) one-hot matrix, saving O(B·K) memory.
  • Fuses the active-only EMA blend, codebook refresh, dead-code counter
    increment, and dead-code replacement into one GPU kernel launch per
    codebook level.

Grid shape (K,) — one thread block per codebook entry — avoids inter-block
atomics while keeping D entirely in registers (D is tl.constexpr; Triton
recompiles one variant per unique embedding dimension).

A Python dispatch layer (rectokens/ops/ema_update.py) selects the Triton
path on CUDA and falls back to the original pure-PyTorch logic on CPU,
preserving full correctness on CPU-only machines.

Tests in tests/test_ema_update_kernel.py cover: all-active, partially-inactive,
dead-code restart, single sample, non-power-of-two batch, large codebook,
large D, decay edge cases (0 and 1), and counter-reset semantics.

https://claude.ai/code/session_018aemrE9iqLcVU8V84wHHDj

Replaces the five-step PyTorch implementation of VQQuantizer._ema_update
with a single Triton kernel pass (rectokens/kernels/ema_update.py) that:

* Scatter-accumulates cluster_size and embed_sum directly from (x, codes)
  without materialising the (B, K) one-hot matrix, saving O(B·K) memory.
* Fuses the active-only EMA blend, codebook refresh, dead-code counter
  increment, and dead-code replacement into one GPU kernel launch per
  codebook level.

Grid shape (K,) — one thread block per codebook entry — avoids inter-block
atomics while keeping D entirely in registers (D is tl.constexpr; Triton
recompiles one variant per unique embedding dimension).

A Python dispatch layer (rectokens/ops/ema_update.py) selects the Triton
path on CUDA and falls back to the original pure-PyTorch logic on CPU,
preserving full correctness on CPU-only machines.

Tests in tests/test_ema_update_kernel.py cover: all-active, partially-inactive,
dead-code restart, single sample, non-power-of-two batch, large codebook,
large D, decay edge cases (0 and 1), and counter-reset semantics.

https://claude.ai/code/session_018aemrE9iqLcVU8V84wHHDj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants