Skip to content

[Pallas] Reject 64-bit input tensors and fix tiling ZeroDivisionError#1950

Merged
norx1991 merged 3 commits intomainfrom
yifeixu/fix-int64-tiling-zerodiv
Apr 15, 2026
Merged

[Pallas] Reject 64-bit input tensors and fix tiling ZeroDivisionError#1950
norx1991 merged 3 commits intomainfrom
yifeixu/fix-int64-tiling-zerodiv

Conversation

@norx1991
Copy link
Copy Markdown
Contributor

@norx1991 norx1991 commented Apr 4, 2026

Summary

TPU does not natively support 64-bit element types. This PR:

  1. Runtime (runtime/__init__.py): _pallas_check_dtypes raises TypeError if any tensor arg uses int64, uint64, or float64. Runs before the cache check in all three launchers so it cannot be bypassed.

  2. Block spec tiling (backend.py): Cap bitwidth to 32 in _get_pallas_required_alignment to prevent ZeroDivisionError when 64-bit dtypes appear at compile time (128 * (32 // 64) = 0).

  3. LONG_INT_TYPE (_testing.py): New constant (int32 on Pallas, int64 elsewhere) for examples/tests that use long integers but don't need 64-bit range. Updated cross_entropy example and test to use it.

Note: internal dtype promotion (e.g. torch.sum on int32 → int64, argmax → int64) is handled by JAX, which silently narrows without jax_enable_x64.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 4, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 2 times, most recently from 2971771 to 8bffc25 Compare April 8, 2026 23:22
@norx1991 norx1991 changed the title [Pallas] Fix ZeroDivisionError in block spec for int64 1D tensors [Pallas] Fix int64 tensor handling on TPU Apr 8, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch from 8bffc25 to 2a79843 Compare April 8, 2026 23:30
@norx1991 norx1991 changed the title [Pallas] Fix int64 tensor handling on TPU [Pallas] Fix 64-bit dtype handling on TPU Apr 8, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 4 times, most recently from 5531d8c to 096411a Compare April 13, 2026 18:31
@norx1991 norx1991 changed the title [Pallas] Fix 64-bit dtype handling on TPU [Pallas] Work around unsupported 64-bit dtypes on TPU by narrowing to 32-bit Apr 13, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 4 times, most recently from 7b49429 to 9dacbfa Compare April 13, 2026 20:46
@norx1991 norx1991 changed the title [Pallas] Work around unsupported 64-bit dtypes on TPU by narrowing to 32-bit [Pallas] Narrow 64-bit dtypes to 32-bit in codegen Apr 13, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch from 9dacbfa to b9cd8a2 Compare April 13, 2026 20:52
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 2 times, most recently from 8a9ef0e to bc3a376 Compare April 13, 2026 21:05
@norx1991 norx1991 marked this pull request as ready for review April 13, 2026 21:09
@norx1991 norx1991 requested review from AmesingFlank, jansel and oulgen and removed request for oulgen April 13, 2026 21:17
@AmesingFlank
Copy link
Copy Markdown
Contributor

The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported?

@norx1991
Copy link
Copy Markdown
Contributor Author

The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported?

Thanks for the quick reply. Several kernels use x64 (e.g., cross_entropy, as listed in #2010), so throwing errors mean we would have to either create an almost-duplicated example, or decide not to support it. Or do you think we should guard this under a flag?

@AmesingFlank
Copy link
Copy Markdown
Contributor

The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported?

Thanks for the quick reply. Several kernels use x64 (e.g., cross_entropy, as listed in #2010), so throwing errors mean we would have to either create an almost-duplicated example, or decide not to support it. Or do you think we should guard this under a flag?

I personally think we should just use a mechanism similar to this one for these tests/examples. E.g. we can do LONG_INT_TYPE and set them to int64 for others and int32 for pallas.

My mental model is that if someone requested int64/float64 dtypes, they must've done so for a reason, and they believe that they need the extra bitwidth. It's likely for these use cases, falling back to int32 is not really useful.

@norx1991
Copy link
Copy Markdown
Contributor Author

The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported?

Thanks for the quick reply. Several kernels use x64 (e.g., cross_entropy, as listed in #2010), so throwing errors mean we would have to either create an almost-duplicated example, or decide not to support it. Or do you think we should guard this under a flag?

I personally think we should just use a mechanism similar to this one for these tests/examples. E.g. we can do LONG_INT_TYPE and set them to int64 for others and int32 for pallas.

My mental model is that if someone requested int64/float64 dtypes, they must've done so for a reason, and they believe that they need the extra bitwidth. It's likely for these use cases, falling back to int32 is not really useful.

@AmesingFlank Thanks, this is a good point. I checked cross_entropy, and there doesn’t seem to be any real need for long there; it looks like it was mainly mirroring the standard PyTorch API. For cases like that, using LONG_INT_TYPE seems reasonable.

For kernels where the data range or precision actually matters, I agree we should raise an explicit error instead of silently narrowing.

So overall I agree that defining LONG_INT_TYPE is the better step forward here: it keeps the choice explicit in the example/test code and has a much smaller blast radius than changing runtime behavior.

I will update this PR to reflect this.

@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch from bc3a376 to f322049 Compare April 13, 2026 22:22
@norx1991 norx1991 changed the title [Pallas] Narrow 64-bit dtypes to 32-bit in codegen [Pallas] Handle 64-bit dtypes on TPU: narrow in codegen, reject at runtime Apr 13, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 6 times, most recently from 6c2e985 to 39cf2ce Compare April 13, 2026 22:56
@norx1991 norx1991 changed the title [Pallas] Handle 64-bit dtypes on TPU: narrow in codegen, reject at runtime [Pallas] Reject 64-bit dtypes on TPU Apr 13, 2026
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch 4 times, most recently from 2aef0e5 to 06f8157 Compare April 13, 2026 23:20
@norx1991 norx1991 changed the title [Pallas] Reject 64-bit dtypes on TPU [Pallas] Reject 64-bit input tensors and fix tiling ZeroDivisionError Apr 13, 2026
XLA on TPU rejects 64-bit element types. Override
PallasBackend.dtype_str to narrow int64→int32 and float64→float32
via a new pallas_narrow_dtype helper, so all generated code
(convert_element_type, jnp.full, reductions) emits 32-bit types.
@norx1991 norx1991 force-pushed the yifeixu/fix-int64-tiling-zerodiv branch from 06f8157 to 080dda0 Compare April 13, 2026 23:28
@norx1991
Copy link
Copy Markdown
Contributor Author

The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported?

Thanks for the quick reply. Several kernels use x64 (e.g., cross_entropy, as listed in #2010), so throwing errors mean we would have to either create an almost-duplicated example, or decide not to support it. Or do you think we should guard this under a flag?

I personally think we should just use a mechanism similar to this one for these tests/examples. E.g. we can do LONG_INT_TYPE and set them to int64 for others and int32 for pallas.
My mental model is that if someone requested int64/float64 dtypes, they must've done so for a reason, and they believe that they need the extra bitwidth. It's likely for these use cases, falling back to int32 is not really useful.

@AmesingFlank Thanks, this is a good point. I checked cross_entropy, and there doesn’t seem to be any real need for long there; it looks like it was mainly mirroring the standard PyTorch API. For cases like that, using LONG_INT_TYPE seems reasonable.

For kernels where the data range or precision actually matters, I agree we should raise an explicit error instead of silently narrowing.

So overall I agree that defining LONG_INT_TYPE is the better step forward here: it keeps the choice explicit in the example/test code and has a much smaller blast radius than changing runtime behavior.

Updated.

@norx1991 norx1991 merged commit 458449b into main Apr 15, 2026
42 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants