[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2019
[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2019yarongmu-google wants to merge 6 commits intopytorch:mainfrom
Conversation
|
FYI, there is #1950 for the long type part. |
Thanks. Any idea why those type mapping classes were not updated? Were those actially not needed? |
The idea is that if a long type is not really needed on TPU, the user can use the newly added LONG_INT_TYPE. If they are really needed, then casting will not help. Therefore, the type mapping does not need to be updated. We would reject 64 bit data type directly. |
…errors and fix zero division in block size calculation
…py.py to avoid unaligned HBM gather This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging.
0374f54 to
0c1a4b3
Compare
|
I'd recommend creating a new example instead of modifying the existing one. Ideally, we would make the compiler smart enough so that even for the original example, the compiler could generate masked indexings which works on TPU, so there's value in keeping that example around |
gotcha. Reverted. |
b4c4ad8 to
29ec3ba
Compare
The kernels currently has 2 common issues that need support:
Add CI workflow #2 is the bigger fix here. More about it:
The issue was that evaluating hl.load(logits_flat, [flat_indices]) maps to random reads from HBM, which TensorCores do not support. By changing the Helion code to apply the label == v_chunk_index boolean mask directly across the logits_rows which are sequentially streaming into VMEM, we eliminated the 1D sparse gather entirely.
The updated cross_entropy.py is now verified mathematically correct, fully functional, and autotunes on the TPU v7s on smaller shapes.
After this PR:
Note: this PR depends on pytorch/pytorch#180252