-
Notifications
You must be signed in to change notification settings - Fork 144
[TPU][Pallas]Fix example/cross_entropy.py on Pallas TPU #2002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e40e4e7
97b35b3
fdbd20b
8201686
9878a8f
d0157fd
a978600
54a1bdc
7c6c52e
bc936aa
54c65a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -455,6 +455,28 @@ | |
| out_ref[...] = in_ref[...] # type: ignore[index] | ||
| original_order[orig_pos] = out_ref | ||
| extra_refs = refs[n_tensor_inputs + len(_output_indices) :] | ||
|
|
||
| # --- DEBUG TRACE --- | ||
| print("\n[DEBUG TRACE] Executing Pallas Kernel") | ||
| print("Original Order (Tensors):") | ||
|
Comment on lines
+460
to
+461
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry this PR was messed up when syncing with upstream and contains extra fiels. I will abandon this and create a clean one based on the current upstream main: #2018 |
||
| for i, ref in enumerate(original_order): | ||
| if hasattr(ref, "shape"): | ||
| print( | ||
| f" Arg {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" | ||
| ) | ||
| else: | ||
| print(f" Arg {i}: Value={ref}") | ||
| print("Extra Refs (Scratches/VMEM):") | ||
| for i, ref in enumerate(extra_refs): | ||
| if hasattr(ref, "shape"): | ||
| print( | ||
| f" Scratch {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" | ||
| ) | ||
| else: | ||
| print(f" Scratch {i}: Value={ref} (Likely Semaphore)") | ||
| print("---------------------\n") | ||
| # ------------------- | ||
|
|
||
| pallas_kernel(*original_order, *extra_refs) # type: ignore[operator] | ||
|
|
||
| return reordered_kernel | ||
|
|
@@ -727,11 +749,15 @@ | |
| _jnp_dtype_map: dict[str, object] = { | ||
| "jnp.float32": jnp.float32, | ||
| "jnp.float16": jnp.float16, | ||
| "jnp.float64": jnp.float64, | ||
| "jnp.bfloat16": jnp.bfloat16, | ||
| "jnp.int32": jnp.int32, | ||
| "jnp.int16": jnp.int16, | ||
| "jnp.int8": jnp.int8, | ||
| "jnp.int64": jnp.int64, | ||
| "jnp.uint8": jnp.uint8, | ||
| "jnp.uint16": jnp.uint16, | ||
| "jnp.uint32": jnp.uint32, | ||
| "jnp.bool_": jnp.bool_, | ||
| } | ||
| scratch_shapes = [] | ||
|
|
@@ -801,6 +827,19 @@ | |
| if _pallas_interpret_flag(): | ||
| pallas_call_kwargs["interpret"] = True | ||
|
|
||
| # --- DEBUG PALLAS CALL --- | ||
| print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") | ||
| for k, v in pallas_call_kwargs.items(): | ||
| if k == "grid_spec": | ||
| print(f" {k}:") | ||
| print(f" in_specs: {v.in_specs}") | ||
| print(f" out_specs: {v.out_specs}") | ||
| print(f" scratch_shapes: {v.scratch_shapes}") | ||
| else: | ||
| print(f" {k}: {v}") | ||
| print("--------------------------\n") | ||
| # ------------------------- | ||
|
|
||
| jit_fn = pl.pallas_call( | ||
| reordered_kernel, # pyrefly: ignore[bad-argument-type] | ||
| **pallas_call_kwargs, # type: ignore[arg-type] | ||
|
|
@@ -868,11 +907,17 @@ | |
| _jnp_dtype_map: dict[str, object] = { | ||
| "jnp.float32": jnp.float32, | ||
| "jnp.float16": jnp.float16, | ||
| "jnp.float64": jnp.float64, | ||
| "jnp.bfloat16": jnp.bfloat16, | ||
| "jnp.float8_e4m3fn": getattr(jnp, "float8_e4m3fn", jnp.float32), | ||
| "jnp.float8_e5m2": getattr(jnp, "float8_e5m2", jnp.float32), | ||
| "jnp.int32": jnp.int32, | ||
| "jnp.int16": jnp.int16, | ||
| "jnp.int8": jnp.int8, | ||
| "jnp.int64": jnp.int64, | ||
| "jnp.uint8": jnp.uint8, | ||
| "jnp.uint16": jnp.uint16, | ||
| "jnp.uint32": jnp.uint32, | ||
| "jnp.bool_": jnp.bool_, | ||
| } | ||
| scratch_shapes = [] | ||
|
|
@@ -941,6 +986,19 @@ | |
| if _pallas_interpret_flag(): | ||
| pallas_call_kwargs["interpret"] = True | ||
|
|
||
| # --- DEBUG PALLAS CALL --- | ||
| print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") | ||
| for k, v in pallas_call_kwargs.items(): | ||
| if k == "grid_spec": | ||
| print(f" {k}:") | ||
| print(f" in_specs: {v.in_specs}") | ||
| print(f" out_specs: {v.out_specs}") | ||
| print(f" scratch_shapes: {v.scratch_shapes}") | ||
| else: | ||
| print(f" {k}: {v}") | ||
| print("--------------------------\n") | ||
| # ------------------------- | ||
|
|
||
| jit_fn = pl.pallas_call( | ||
| reordered_kernel, # pyrefly: ignore[bad-argument-type] | ||
| **pallas_call_kwargs, # type: ignore[arg-type] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we changig the kernel? Shouldn't we make the existing kernel work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making teh existing kernel work means we have to use fori_loop loop.
I avoided it with a dynamic gather (e.g.,
hl.load([flat_indices])) because it forces the kernel out of theemit_pipelineexecution model and do an indirect, data-dependent read from HBM based on the label indices. TPUs are heavily optimized for large, contiguous memory bursts rather than sparse, 4-byte random accesses. By loading the full, contiguous rows of logits and using a boolean mask (chunk_logits * mask) to extract the target logit, we deliberately trade extremely cheap ALU operations for predictable, dense memory access. This allows the compiler to keep the kernel inemit_pipelinemode, maximizing HBM bandwidth utilization and overlapping our memory loads with the masking compute.Cleaned up PR in #2019