From e40e4e7eaf2e89f2bf8196405eb6cc57753efdd6 Mon Sep 17 00:00:00 2001 From: Yifei Xu <63565283+norx1991@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:45:47 -0700 Subject: [PATCH 1/5] [Pallas] Add test for Pallas OOB slice when reduction_loops doesn't divide dim (#1937) --- .github/workflows/test.yml | 2 +- test/test_examples.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3884781da..b8e5111aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -176,7 +176,7 @@ jobs: --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ --pre \ - 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + 'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' # Install Bazel if ! command -v bazel &> /dev/null; then sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel diff --git a/test/test_examples.py b/test/test_examples.py index 1ad829450..1017bc59b 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1041,6 +1041,31 @@ def test_layernorm_no_bias(self): num_stages=3, ) + @xfailIfPallas( + "Out-of-bounds slice when reduction_loops doesn't evenly divide the " + "reduction dimension (e.g. reduction_loops=32 on dim=48 generates " + "pl.ds(32, 32) which exceeds bounds)" + ) + def test_layernorm_reduction_not_divisible(self): + """Reduction loop OOB when reduction_loops doesn't divide the reduction dim.""" + batch_size = 4 + dim = 48 # not divisible by reduction_loops=32 + x = torch.randn([batch_size, dim], device=DEVICE, dtype=HALF_DTYPE) + weight = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE) + bias = torch.randn([dim], device=DEVICE, dtype=HALF_DTYPE) + + args = (x, [dim], weight, bias, 1e-5) + expected_out = torch.nn.functional.layer_norm(*args) + + check_example( + "layer_norm", + args, + (expected_out, None, None), + fn_name="layer_norm_fwd", + block_size=1, + reduction_loops=32, + ) + @xfailIfCute("CuTe LayerNorm backward example still returns incorrect results") @xfailIfPallas("InductorLoweringError") @skipIfA10G("accuracy check fails on A10G GPUs") From fdbd20b3a95ca5904ee5c344e7cb25c23395bdb4 Mon Sep 17 00:00:00 2001 From: Yifei Xu <63565283+norx1991@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:45:47 -0700 Subject: [PATCH 2/5] [Pallas] Add test for Pallas OOB slice when reduction_loops doesn't divide dim (#1937) --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3884781da..b8e5111aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -176,7 +176,7 @@ jobs: --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ --pre \ - 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + 'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' # Install Bazel if ! command -v bazel &> /dev/null; then sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel From 9878a8f11f04bd5110d4f2375a37b55feadfe6ed Mon Sep 17 00:00:00 2001 From: Yarong Mu Date: Thu, 9 Apr 2026 15:20:43 -0700 Subject: [PATCH 3/5] fix(pallas): add mapping for 64-bit dtypes to 32-bit to avoid Pallas errors and fix zero division in block size calculation --- helion/_compiler/backend.py | 12 ++++-- helion/_compiler/device_function.py | 5 +++ helion/language/memory_ops.py | 2 +- helion/runtime/__init__.py | 58 +++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 2d45f826b..71a47aabf 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -858,15 +858,18 @@ def tunable_fragments(self) -> dict[str, ConfigSpecFragment]: _TORCH_TO_JAX_DTYPE: dict[str, str] = { "torch.float16": "jnp.float16", "torch.float32": "jnp.float32", - "torch.float64": "jnp.float64", + "torch.float64": "jnp.float32", "torch.bfloat16": "jnp.bfloat16", + "torch.float8_e4m3fn": "jnp.float8_e4m3fn", + "torch.float8_e5m2": "jnp.float8_e5m2", "torch.int8": "jnp.int8", "torch.int16": "jnp.int16", "torch.int32": "jnp.int32", - "torch.int64": "jnp.int64", + "torch.int64": "jnp.int32", + "torch.long": "jnp.int32", "torch.uint8": "jnp.uint8", "torch.uint32": "jnp.uint32", - "torch.uint64": "jnp.uint64", + "torch.uint64": "jnp.uint32", "torch.bool": "jnp.bool_", "torch.complex64": "jnp.complex64", "torch.complex128": "jnp.complex128", @@ -1173,6 +1176,7 @@ def adjust_block_size_constraints( # Tiling size for 1D arrays. Mosaic lowering enforces that rank-1 # BlockSpec block shapes are a multiple of 128 * (32 // bitwidth). + min_element_bits = min(min_element_bits, 32) tiling_1d = 128 * (32 // min_element_bits) # Map block_id -> minimum dim_from_end across all tensors @@ -1352,7 +1356,7 @@ def _compute_block_spec_info( # back to no tiling for the entire kernel. dim_size = tensor.shape[d] if tensor.ndim == 1 and isinstance(dim_size, int): - bitwidth = tensor.dtype.itemsize * 8 + bitwidth = min(tensor.dtype.itemsize * 8, 32) tiling_1d = 128 * (32 // bitwidth) if bs != dim_size and bs % tiling_1d != 0: return self._no_tiling_block_spec_info(sorted_args) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index e0bd2c99a..28633e5c9 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -991,6 +991,11 @@ def _print_FloorDiv(self, expr: sympy.Expr) -> str: # pyrefly: ignore [missing-attribute] return f"({self._print(lhs)} // {self._print(rhs)})" + def _print_PythonMod(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # pyrefly: ignore [missing-attribute] + return f"({self._print(lhs)} % {self._print(rhs)})" + def pallas_texpr(expr: sympy.Expr) -> str: return HelionPallasPrinter().doprint(expr) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index cadda0e2f..b338f9224 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -237,7 +237,7 @@ def _pallas_index_str( dim_map.setdefault(tensor_dim, block_id) elif isinstance(idx, int): parts.append(str(idx)) - elif isinstance(idx, torch.SymInt): + elif isinstance(idx, (torch.SymInt, torch.Tensor)): ast_subscripts = state.ast_args[1] assert isinstance(ast_subscripts, list) ast_idx = ast_subscripts[i] diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index a25ada2fc..843d7651e 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -444,6 +444,28 @@ def reordered_kernel(*refs: object) -> None: out_ref[...] = in_ref[...] # type: ignore[index] original_order[orig_pos] = out_ref extra_refs = refs[n_tensor_inputs + len(_output_indices) :] + + # --- DEBUG TRACE --- + print("\n[DEBUG TRACE] Executing Pallas Kernel") + print("Original Order (Tensors):") + for i, ref in enumerate(original_order): + if hasattr(ref, "shape"): + print( + f" Arg {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" + ) + else: + print(f" Arg {i}: Value={ref}") + print("Extra Refs (Scratches/VMEM):") + for i, ref in enumerate(extra_refs): + if hasattr(ref, "shape"): + print( + f" Scratch {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" + ) + else: + print(f" Scratch {i}: Value={ref} (Likely Semaphore)") + print("---------------------\n") + # ------------------- + pallas_kernel(*original_order, *extra_refs) # type: ignore[operator] return reordered_kernel @@ -702,11 +724,15 @@ def default_pallas_pipeline_launcher( _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, + "jnp.float64": jnp.float64, "jnp.bfloat16": jnp.bfloat16, "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, + "jnp.int64": jnp.int64, "jnp.uint8": jnp.uint8, + "jnp.uint16": jnp.uint16, + "jnp.uint32": jnp.uint32, "jnp.bool_": jnp.bool_, } scratch_shapes = [] @@ -776,6 +802,19 @@ def default_pallas_pipeline_launcher( if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True + # --- DEBUG PALLAS CALL --- + print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") + for k, v in pallas_call_kwargs.items(): + if k == "grid_spec": + print(f" {k}:") + print(f" in_specs: {v.in_specs}") + print(f" out_specs: {v.out_specs}") + print(f" scratch_shapes: {v.scratch_shapes}") + else: + print(f" {k}: {v}") + print("--------------------------\n") + # ------------------------- + jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] @@ -843,11 +882,17 @@ def default_pallas_fori_launcher( _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, + "jnp.float64": jnp.float64, "jnp.bfloat16": jnp.bfloat16, + "jnp.float8_e4m3fn": getattr(jnp, "float8_e4m3fn", jnp.float32), + "jnp.float8_e5m2": getattr(jnp, "float8_e5m2", jnp.float32), "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, + "jnp.int64": jnp.int64, "jnp.uint8": jnp.uint8, + "jnp.uint16": jnp.uint16, + "jnp.uint32": jnp.uint32, "jnp.bool_": jnp.bool_, } scratch_shapes = [] @@ -916,6 +961,19 @@ def default_pallas_fori_launcher( if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True + # --- DEBUG PALLAS CALL --- + print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") + for k, v in pallas_call_kwargs.items(): + if k == "grid_spec": + print(f" {k}:") + print(f" in_specs: {v.in_specs}") + print(f" out_specs: {v.out_specs}") + print(f" scratch_shapes: {v.scratch_shapes}") + else: + print(f" {k}: {v}") + print("--------------------------\n") + # ------------------------- + jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] From d0157fd303d31ad05a84eaa7c119a43d8493970f Mon Sep 17 00:00:00 2001 From: Yarong Mu Date: Thu, 9 Apr 2026 17:43:33 -0700 Subject: [PATCH 4/5] perf(examples): extract target logits via boolean mask in cross_entropy.py to avoid unaligned HBM gather This optimizes the cross_entropy kernel to be hardware agnostic. By calculating the target logits via a boolean mask over the streaming dense block, it stays entirely within TensorCore/VMEM boundaries on TPU and perfectly coalesced on GPU, eliminating the unaligned 1D HBM gather which Pallas TC kernels do not natively support without SC DMA staging. --- examples/cross_entropy.py | 49 +++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index a0ec8fa53..434df03dd 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -47,30 +47,32 @@ def cross_entropy( n, v = logits.shape losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) - # Flatten logits once at the beginning - logits_flat = logits.view(-1) - for tile_n in hl.tile(n): # Get data for this tile labels_tile = labels[tile_n] # [tile_size] - base_indices_tile = tile_n.index * v # [tile_size] - - # Compute the actual flat indices by adding the label offset - flat_indices = base_indices_tile + labels_tile - - # Load the logits at the target indices - logits_at_target = hl.load(logits_flat, [flat_indices]) - # Compute log_softmax for numerical stability - # Load the full rows for this tile - logits_rows = logits[tile_n, :] # [tile_size, V] - - # Compute log-sum-exp - max_logits = torch.amax(logits_rows, dim=-1, keepdim=True) - shifted = logits_rows - max_logits - exp_shifted = torch.exp(shifted) - sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True) - log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1)) + logits_at_target = hl.zeros([tile_n], dtype=logits.dtype) + max_logits_acc = hl.full([tile_n], float("-inf"), dtype=logits.dtype) + + # First pass: find max and target logits + for v_chunk in hl.tile(v): + chunk_logits = logits[tile_n, v_chunk] + + # Extract target using a chunked mask + mask = (v_chunk.index[None, :] == labels_tile[:, None]).to(logits.dtype) + logits_at_target += torch.sum(chunk_logits * mask, dim=-1) + + # Update max + max_logits_acc = torch.maximum(max_logits_acc, torch.amax(chunk_logits, dim=-1)) + + # Second pass: sum exp + sum_exp_acc = hl.zeros([tile_n], dtype=logits.dtype) + for v_chunk in hl.tile(v): + chunk_logits = logits[tile_n, v_chunk] + shifted = chunk_logits - max_logits_acc[:, None] + sum_exp_acc += torch.sum(torch.exp(shifted), dim=-1) + + log_sum_exp = max_logits_acc + torch.log(sum_exp_acc) # Cross entropy loss: log_sum_exp - logit_at_target losses[tile_n] = log_sum_exp - logits_at_target @@ -91,11 +93,14 @@ def main() -> None: batch_size, seq_len, vocab_size = 8, 2048, 131072 n = batch_size * seq_len logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32) - labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.long) + labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.int32) + + def baseline_ce(logits, labels): + return torch.nn.functional.cross_entropy(logits, labels.long()) run_example( cross_entropy, - torch.nn.functional.cross_entropy, + baseline_ce, (logits, labels), kernel_name="helion", baseline_name="torch", From a978600cdfb9ddc4509de4d23c5c0a6003255373 Mon Sep 17 00:00:00 2001 From: Yarong Mu Date: Thu, 9 Apr 2026 17:46:10 -0700 Subject: [PATCH 5/5] style: apply ruff and pyrefly auto-formatting across project files --- examples/cross_entropy.py | 16 +++++++++------- examples/flex_attention.py | 8 +++----- examples/grouped_gemm.py | 4 +--- helion/_compiler/_dynamo/higher_order_ops.py | 4 +--- helion/_compiler/_inductor/template_buffer.py | 12 +++++------- 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index 434df03dd..0115780d5 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -53,25 +53,27 @@ def cross_entropy( logits_at_target = hl.zeros([tile_n], dtype=logits.dtype) max_logits_acc = hl.full([tile_n], float("-inf"), dtype=logits.dtype) - + # First pass: find max and target logits for v_chunk in hl.tile(v): chunk_logits = logits[tile_n, v_chunk] - + # Extract target using a chunked mask mask = (v_chunk.index[None, :] == labels_tile[:, None]).to(logits.dtype) logits_at_target += torch.sum(chunk_logits * mask, dim=-1) - + # Update max - max_logits_acc = torch.maximum(max_logits_acc, torch.amax(chunk_logits, dim=-1)) - + max_logits_acc = torch.maximum( + max_logits_acc, torch.amax(chunk_logits, dim=-1) + ) + # Second pass: sum exp sum_exp_acc = hl.zeros([tile_n], dtype=logits.dtype) for v_chunk in hl.tile(v): chunk_logits = logits[tile_n, v_chunk] shifted = chunk_logits - max_logits_acc[:, None] sum_exp_acc += torch.sum(torch.exp(shifted), dim=-1) - + log_sum_exp = max_logits_acc + torch.log(sum_exp_acc) # Cross entropy loss: log_sum_exp - logit_at_target @@ -95,7 +97,7 @@ def main() -> None: logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32) labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.int32) - def baseline_ce(logits, labels): + def baseline_ce(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return torch.nn.functional.cross_entropy(logits, labels.long()) run_example( diff --git a/examples/flex_attention.py b/examples/flex_attention.py index 61cafc31f..2d41730eb 100644 --- a/examples/flex_attention.py +++ b/examples/flex_attention.py @@ -84,11 +84,9 @@ def helion_flex_attention_kernel( # iterate through full tiles (no mask needed) if block_mask_full_kv_indices is not None: - sparse_num_blocks = ( - block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation] - b_idx, h_idx, sparse_row - ] - ) + sparse_num_blocks = block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation] + b_idx, h_idx, sparse_row + ] for block_idx in hl.tile(sparse_num_blocks, block_size=1): start_n = block_mask_full_kv_indices[ diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index 8f713dfa3..6febc399b 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -212,9 +212,7 @@ def grouped_gemm_jagged_persistent( b_blk = hl.load( B, [k_idx, col_idx], - extra_mask=cols_valid[ - None, : - ], # pyrefly: ignore[bad-index] + extra_mask=cols_valid[None, :], # pyrefly: ignore[bad-index] ) # Perform tile-level matrix multiplication and accumulate diff --git a/helion/_compiler/_dynamo/higher_order_ops.py b/helion/_compiler/_dynamo/higher_order_ops.py index 8abeb1050..a17bbcb98 100644 --- a/helion/_compiler/_dynamo/higher_order_ops.py +++ b/helion/_compiler/_dynamo/higher_order_ops.py @@ -279,9 +279,7 @@ def helion_kernel_wrapper_functional_dense( kernel_outputs = helion_kernel_wrapper_mutation( kernel_idx=kernel_idx, constant_args=constant_args, - tensor_args={ - k: cloned.get(k, v) for k, v in tensor_args.items() - }, # pyrefly: ignore[bad-argument-type] + tensor_args={k: cloned.get(k, v) for k, v in tensor_args.items()}, # pyrefly: ignore[bad-argument-type] output_spec=output_spec, ) return (kernel_outputs, cloned) diff --git a/helion/_compiler/_inductor/template_buffer.py b/helion/_compiler/_inductor/template_buffer.py index d2dd90fc9..af7ed51f2 100644 --- a/helion/_compiler/_inductor/template_buffer.py +++ b/helion/_compiler/_inductor/template_buffer.py @@ -478,13 +478,11 @@ def create( if not any(isinstance(leaf, torch.Tensor) for leaf in flat): return buf, () - result = ( - TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute] - buf, - structured_outputs, - direct_alias_at_leaf=direct_aliases, - on_tensor_leaf=on_tensor_leaf, - ) + result = TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute] + buf, + structured_outputs, + direct_alias_at_leaf=direct_aliases, + on_tensor_leaf=on_tensor_leaf, ) return buf, result