[Pallas] Reject 64-bit input tensors and fix tiling ZeroDivisionError#1950
[Pallas] Reject 64-bit input tensors and fix tiling ZeroDivisionError#1950
Conversation
2971771 to
8bffc25
Compare
8bffc25 to
2a79843
Compare
5531d8c to
096411a
Compare
7b49429 to
9dacbfa
Compare
9dacbfa to
b9cd8a2
Compare
8a9ef0e to
bc3a376
Compare
|
The changes lgtm, although I'm not sure if we really wanna do this. One could argue that if the user wrote int64/float64, we're expected to uphold that contract. Might be better to just throw an error, explicitly calling out that 64-bit types are not supported? |
Thanks for the quick reply. Several kernels use x64 (e.g., cross_entropy, as listed in #2010), so throwing errors mean we would have to either create an almost-duplicated example, or decide not to support it. Or do you think we should guard this under a flag? |
I personally think we should just use a mechanism similar to this one for these tests/examples. E.g. we can do My mental model is that if someone requested int64/float64 dtypes, they must've done so for a reason, and they believe that they need the extra bitwidth. It's likely for these use cases, falling back to int32 is not really useful. |
@AmesingFlank Thanks, this is a good point. I checked cross_entropy, and there doesn’t seem to be any real need for long there; it looks like it was mainly mirroring the standard PyTorch API. For cases like that, using LONG_INT_TYPE seems reasonable. For kernels where the data range or precision actually matters, I agree we should raise an explicit error instead of silently narrowing. So overall I agree that defining LONG_INT_TYPE is the better step forward here: it keeps the choice explicit in the example/test code and has a much smaller blast radius than changing runtime behavior. I will update this PR to reflect this. |
bc3a376 to
f322049
Compare
6c2e985 to
39cf2ce
Compare
2aef0e5 to
06f8157
Compare
XLA on TPU rejects 64-bit element types. Override PallasBackend.dtype_str to narrow int64→int32 and float64→float32 via a new pallas_narrow_dtype helper, so all generated code (convert_element_type, jnp.full, reductions) emits 32-bit types.
06f8157 to
080dda0
Compare
|
Summary
TPU does not natively support 64-bit element types. This PR:
Runtime (
runtime/__init__.py):_pallas_check_dtypesraisesTypeErrorif any tensor arg uses int64, uint64, or float64. Runs before the cache check in all three launchers so it cannot be bypassed.Block spec tiling (
backend.py): Cap bitwidth to 32 in_get_pallas_required_alignmentto prevent ZeroDivisionError when 64-bit dtypes appear at compile time (128 * (32 // 64)= 0).LONG_INT_TYPE(_testing.py): New constant (int32 on Pallas, int64 elsewhere) for examples/tests that use long integers but don't need 64-bit range. Updatedcross_entropyexample and test to use it.Note: internal dtype promotion (e.g.
torch.sumon int32 → int64,argmax→ int64) is handled by JAX, which silently narrows withoutjax_enable_x64.