Skip to content

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002

Closed
yarongmu-google wants to merge 11 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping
Closed

[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU#2002
yarongmu-google wants to merge 11 commits intopytorch:mainfrom
yarongmu-google:fix-pallas-dtype-mapping

Conversation

@yarongmu-google
Copy link
Copy Markdown
Collaborator

@yarongmu-google yarongmu-google commented Apr 10, 2026

The kernels currently has 2 common issues that need support:

  1. Long types are not supported in Pallas/Mosaic (XLA does support it but Helion doesn't go through XLA).
  2. Directly indexing into vectors on HBM.
    Add CI workflow #2 is the bigger fix here. More about it:

The issue was that evaluating hl.load(logits_flat, [flat_indices]) maps to random reads from HBM, which TensorCores do not support. By changing the Helion code to apply the label == v_chunk_index boolean mask directly across the logits_rows which are sequentially streaming into VMEM, we eliminated the 1D sparse gather entirely.

The updated cross_entropy.py is now verified mathematically correct, fully functional, and autotunes on the TPU v7s on smaller shapes.

After this PR:

=================================================================
Benchmark Results
=================================================================
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.3826       1.10x          
torch                0.4208       1.00x (ref)    
=================================================================

Note: this PR depends on pytorch/pytorch#180252

norx1991 and others added 7 commits April 2, 2026 19:06
…errors and fix zero division in block size calculation
…py.py to avoid unaligned HBM gather

This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 10, 2026
@yarongmu-google yarongmu-google force-pushed the fix-pallas-dtype-mapping branch from 7836006 to 54c65a0 Compare April 13, 2026 22:24
Comment thread examples/cross_entropy.py
Comment on lines -50 to -52
# Flatten logits once at the beginning
logits_flat = logits.view(-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changig the kernel? Shouldn't we make the existing kernel work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making teh existing kernel work means we have to use fori_loop loop.

I avoided it with a dynamic gather (e.g., hl.load([flat_indices])) because it forces the kernel out of the emit_pipeline execution model and do an indirect, data-dependent read from HBM based on the label indices. TPUs are heavily optimized for large, contiguous memory bursts rather than sparse, 4-byte random accesses. By loading the full, contiguous rows of logits and using a boolean mask (chunk_logits * mask) to extract the target logit, we deliberately trade extremely cheap ALU operations for predictable, dense memory access. This allows the compiler to keep the kernel in emit_pipeline mode, maximizing HBM bandwidth utilization and overlapping our memory loads with the masking compute.

Cleaned up PR in #2019

Comment on lines +460 to +461
print("\n[DEBUG TRACE] Executing Pallas Kernel")
print("Original Order (Tensors):")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Copy Markdown
Collaborator Author

@yarongmu-google yarongmu-google Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this PR was messed up when syncing with upstream and contains extra fiels. I will abandon this and create a clean one based on the current upstream main: #2018

@yarongmu-google yarongmu-google marked this pull request as draft April 14, 2026 21:18
@yarongmu-google
Copy link
Copy Markdown
Collaborator Author

yarongmu-google commented Apr 14, 2026

Replaced by #2019

@yarongmu-google yarongmu-google closed this by deleting the head repository Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants