Skip to content

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2019

Open
yarongmu-google wants to merge 6 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping-clean
Open

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2019
yarongmu-google wants to merge 6 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping-clean

Conversation

@yarongmu-google
Copy link
Copy Markdown
Collaborator

The kernels currently has 2 common issues that need support:

  1. Long types are not supported in Pallas/Mosaic (XLA does support it but Helion doesn't go through XLA).
  2. Directly indexing into vectors on HBM.
    Add CI workflow #2 is the bigger fix here. More about it:

The issue was that evaluating hl.load(logits_flat, [flat_indices]) maps to random reads from HBM, which TensorCores do not support. By changing the Helion code to apply the label == v_chunk_index boolean mask directly across the logits_rows which are sequentially streaming into VMEM, we eliminated the 1D sparse gather entirely.

The updated cross_entropy.py is now verified mathematically correct, fully functional, and autotunes on the TPU v7s on smaller shapes.

After this PR:

=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.3826       1.10x          
torch                0.4208       1.00x (ref)    
=================================================================

Note: this PR depends on pytorch/pytorch#180252

@norx1991
Copy link
Copy Markdown
Contributor

FYI, there is #1950 for the long type part.

@yarongmu-google
Copy link
Copy Markdown
Collaborator Author

FYI, there is #1950 for the long type part.

Thanks. Any idea why those type mapping classes were not updated? Were those actially not needed?

@norx1991
Copy link
Copy Markdown
Contributor

norx1991 commented Apr 15, 2026

FYI, there is #1950 for the long type part.

Thanks. Any idea why those type mapping classes were not updated? Were those actially not needed?

The idea is that if a long type is not really needed on TPU, the user can use the newly added LONG_INT_TYPE. If they are really needed, then casting will not help. Therefore, the type mapping does not need to be updated. We would reject 64 bit data type directly.

@yarongmu-google yarongmu-google marked this pull request as ready for review April 15, 2026 18:10
…errors and fix zero division in block size calculation
…py.py to avoid unaligned HBM gather

This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging.
@yarongmu-google yarongmu-google force-pushed the fix-pallas-dtype-mapping-clean branch from 0374f54 to 0c1a4b3 Compare April 16, 2026 00:21
@AmesingFlank
Copy link
Copy Markdown
Contributor

I'd recommend creating a new example instead of modifying the existing one. Ideally, we would make the compiler smart enough so that even for the original example, the compiler could generate masked indexings which works on TPU, so there's value in keeping that example around

@yarongmu-google
Copy link
Copy Markdown
Collaborator Author

FYI, there is #1950 for the long type part.

Thanks. Any idea why those type mapping classes were not updated? Were those actially not needed?

The idea is that if a long type is not really needed on TPU, the user can use the newly added LONG_INT_TYPE. If they are really needed, then casting will not help. Therefore, the type mapping does not need to be updated. We would reject 64 bit data type directly.

gotcha. Reverted.

@yarongmu-google yarongmu-google force-pushed the fix-pallas-dtype-mapping-clean branch from b4c4ad8 to 29ec3ba Compare April 16, 2026 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants