Add fused Triton kernel for VQ-codebook EMA update#1
Open
EdoardoBotta wants to merge 1 commit intomainfrom
Open
Add fused Triton kernel for VQ-codebook EMA update#1EdoardoBotta wants to merge 1 commit intomainfrom
EdoardoBotta wants to merge 1 commit intomainfrom
Conversation
Replaces the five-step PyTorch implementation of VQQuantizer._ema_update with a single Triton kernel pass (rectokens/kernels/ema_update.py) that: * Scatter-accumulates cluster_size and embed_sum directly from (x, codes) without materialising the (B, K) one-hot matrix, saving O(B·K) memory. * Fuses the active-only EMA blend, codebook refresh, dead-code counter increment, and dead-code replacement into one GPU kernel launch per codebook level. Grid shape (K,) — one thread block per codebook entry — avoids inter-block atomics while keeping D entirely in registers (D is tl.constexpr; Triton recompiles one variant per unique embedding dimension). A Python dispatch layer (rectokens/ops/ema_update.py) selects the Triton path on CUDA and falls back to the original pure-PyTorch logic on CPU, preserving full correctness on CPU-only machines. Tests in tests/test_ema_update_kernel.py cover: all-active, partially-inactive, dead-code restart, single sample, non-power-of-two batch, large codebook, large D, decay edge cases (0 and 1), and counter-reset semantics. https://claude.ai/code/session_018aemrE9iqLcVU8V84wHHDj
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replaces the five-step PyTorch implementation of VQQuantizer._ema_update
with a single Triton kernel pass (rectokens/kernels/ema_update.py) that:
without materialising the (B, K) one-hot matrix, saving O(B·K) memory.
increment, and dead-code replacement into one GPU kernel launch per
codebook level.
Grid shape (K,) — one thread block per codebook entry — avoids inter-block
atomics while keeping D entirely in registers (D is tl.constexpr; Triton
recompiles one variant per unique embedding dimension).
A Python dispatch layer (rectokens/ops/ema_update.py) selects the Triton
path on CUDA and falls back to the original pure-PyTorch logic on CPU,
preserving full correctness on CPU-only machines.
Tests in tests/test_ema_update_kernel.py cover: all-active, partially-inactive,
dead-code restart, single sample, non-power-of-two batch, large codebook,
large D, decay edge cases (0 and 1), and counter-reset semantics.
https://claude.ai/code/session_018aemrE9iqLcVU8V84wHHDj