[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002
[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002yarongmu-google wants to merge 11 commits intopytorch:mainfrom
Conversation
…errors and fix zero division in block size calculation
…py.py to avoid unaligned HBM gather This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging.
7836006 to
54c65a0
Compare
| # Flatten logits once at the beginning | ||
| logits_flat = logits.view(-1) | ||
|
|
There was a problem hiding this comment.
Why are we changig the kernel? Shouldn't we make the existing kernel work?
There was a problem hiding this comment.
Making teh existing kernel work means we have to use fori_loop loop.
I avoided it with a dynamic gather (e.g., hl.load([flat_indices])) because it forces the kernel out of the emit_pipeline execution model and do an indirect, data-dependent read from HBM based on the label indices. TPUs are heavily optimized for large, contiguous memory bursts rather than sparse, 4-byte random accesses. By loading the full, contiguous rows of logits and using a boolean mask (chunk_logits * mask) to extract the target logit, we deliberately trade extremely cheap ALU operations for predictable, dense memory access. This allows the compiler to keep the kernel in emit_pipeline mode, maximizing HBM bandwidth utilization and overlapping our memory loads with the masking compute.
Cleaned up PR in #2019
| print("\n[DEBUG TRACE] Executing Pallas Kernel") | ||
| print("Original Order (Tensors):") |
There was a problem hiding this comment.
Sorry this PR was messed up when syncing with upstream and contains extra fiels. I will abandon this and create a clean one based on the current upstream main: #2018
|
Replaced by #2019 |
The kernels currently has 2 common issues that need support:
Add CI workflow #2 is the bigger fix here. More about it:
The issue was that evaluating hl.load(logits_flat, [flat_indices]) maps to random reads from HBM, which TensorCores do not support. By changing the Helion code to apply the label == v_chunk_index boolean mask directly across the logits_rows which are sequentially streaming into VMEM, we eliminated the 1D sparse gather entirely.
The updated cross_entropy.py is now verified mathematically correct, fully functional, and autotunes on the TPU v7s on smaller shapes.
After this PR:
Note: this PR depends on pytorch/pytorch#180252