diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3884781da..b8e5111aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -176,7 +176,7 @@ jobs: --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ --pre \ - 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + 'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' # Install Bazel if ! command -v bazel &> /dev/null; then sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index a0ec8fa53..0115780d5 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -47,30 +47,34 @@ def cross_entropy( n, v = logits.shape losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) - # Flatten logits once at the beginning - logits_flat = logits.view(-1) - for tile_n in hl.tile(n): # Get data for this tile labels_tile = labels[tile_n] # [tile_size] - base_indices_tile = tile_n.index * v # [tile_size] - # Compute the actual flat indices by adding the label offset - flat_indices = base_indices_tile + labels_tile + logits_at_target = hl.zeros([tile_n], dtype=logits.dtype) + max_logits_acc = hl.full([tile_n], float("-inf"), dtype=logits.dtype) + + # First pass: find max and target logits + for v_chunk in hl.tile(v): + chunk_logits = logits[tile_n, v_chunk] - # Load the logits at the target indices - logits_at_target = hl.load(logits_flat, [flat_indices]) + # Extract target using a chunked mask + mask = (v_chunk.index[None, :] == labels_tile[:, None]).to(logits.dtype) + logits_at_target += torch.sum(chunk_logits * mask, dim=-1) - # Compute log_softmax for numerical stability - # Load the full rows for this tile - logits_rows = logits[tile_n, :] # [tile_size, V] + # Update max + max_logits_acc = torch.maximum( + max_logits_acc, torch.amax(chunk_logits, dim=-1) + ) - # Compute log-sum-exp - max_logits = torch.amax(logits_rows, dim=-1, keepdim=True) - shifted = logits_rows - max_logits - exp_shifted = torch.exp(shifted) - sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True) - log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1)) + # Second pass: sum exp + sum_exp_acc = hl.zeros([tile_n], dtype=logits.dtype) + for v_chunk in hl.tile(v): + chunk_logits = logits[tile_n, v_chunk] + shifted = chunk_logits - max_logits_acc[:, None] + sum_exp_acc += torch.sum(torch.exp(shifted), dim=-1) + + log_sum_exp = max_logits_acc + torch.log(sum_exp_acc) # Cross entropy loss: log_sum_exp - logit_at_target losses[tile_n] = log_sum_exp - logits_at_target @@ -91,11 +95,14 @@ def main() -> None: batch_size, seq_len, vocab_size = 8, 2048, 131072 n = batch_size * seq_len logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32) - labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.long) + labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.int32) + + def baseline_ce(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.cross_entropy(logits, labels.long()) run_example( cross_entropy, - torch.nn.functional.cross_entropy, + baseline_ce, (logits, labels), kernel_name="helion", baseline_name="torch", diff --git a/examples/flex_attention.py b/examples/flex_attention.py index 61cafc31f..2d41730eb 100644 --- a/examples/flex_attention.py +++ b/examples/flex_attention.py @@ -84,11 +84,9 @@ def helion_flex_attention_kernel( # iterate through full tiles (no mask needed) if block_mask_full_kv_indices is not None: - sparse_num_blocks = ( - block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation] - b_idx, h_idx, sparse_row - ] - ) + sparse_num_blocks = block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation] + b_idx, h_idx, sparse_row + ] for block_idx in hl.tile(sparse_num_blocks, block_size=1): start_n = block_mask_full_kv_indices[ diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index 8f713dfa3..6febc399b 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -212,9 +212,7 @@ def grouped_gemm_jagged_persistent( b_blk = hl.load( B, [k_idx, col_idx], - extra_mask=cols_valid[ - None, : - ], # pyrefly: ignore[bad-index] + extra_mask=cols_valid[None, :], # pyrefly: ignore[bad-index] ) # Perform tile-level matrix multiplication and accumulate diff --git a/helion/_compiler/_dynamo/higher_order_ops.py b/helion/_compiler/_dynamo/higher_order_ops.py index 8abeb1050..a17bbcb98 100644 --- a/helion/_compiler/_dynamo/higher_order_ops.py +++ b/helion/_compiler/_dynamo/higher_order_ops.py @@ -279,9 +279,7 @@ def helion_kernel_wrapper_functional_dense( kernel_outputs = helion_kernel_wrapper_mutation( kernel_idx=kernel_idx, constant_args=constant_args, - tensor_args={ - k: cloned.get(k, v) for k, v in tensor_args.items() - }, # pyrefly: ignore[bad-argument-type] + tensor_args={k: cloned.get(k, v) for k, v in tensor_args.items()}, # pyrefly: ignore[bad-argument-type] output_spec=output_spec, ) return (kernel_outputs, cloned) diff --git a/helion/_compiler/_inductor/template_buffer.py b/helion/_compiler/_inductor/template_buffer.py index d2dd90fc9..af7ed51f2 100644 --- a/helion/_compiler/_inductor/template_buffer.py +++ b/helion/_compiler/_inductor/template_buffer.py @@ -478,13 +478,11 @@ def create( if not any(isinstance(leaf, torch.Tensor) for leaf in flat): return buf, () - result = ( - TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute] - buf, - structured_outputs, - direct_alias_at_leaf=direct_aliases, - on_tensor_leaf=on_tensor_leaf, - ) + result = TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute] + buf, + structured_outputs, + direct_alias_at_leaf=direct_aliases, + on_tensor_leaf=on_tensor_leaf, ) return buf, result diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py index 254be96a4..97b294640 100644 --- a/helion/_compiler/backend.py +++ b/helion/_compiler/backend.py @@ -881,15 +881,18 @@ def tunable_fragments(self) -> dict[str, ConfigSpecFragment]: _TORCH_TO_JAX_DTYPE: dict[str, str] = { "torch.float16": "jnp.float16", "torch.float32": "jnp.float32", - "torch.float64": "jnp.float64", + "torch.float64": "jnp.float32", "torch.bfloat16": "jnp.bfloat16", + "torch.float8_e4m3fn": "jnp.float8_e4m3fn", + "torch.float8_e5m2": "jnp.float8_e5m2", "torch.int8": "jnp.int8", "torch.int16": "jnp.int16", "torch.int32": "jnp.int32", - "torch.int64": "jnp.int64", + "torch.int64": "jnp.int32", + "torch.long": "jnp.int32", "torch.uint8": "jnp.uint8", "torch.uint32": "jnp.uint32", - "torch.uint64": "jnp.uint64", + "torch.uint64": "jnp.uint32", "torch.bool": "jnp.bool_", "torch.complex64": "jnp.complex64", "torch.complex128": "jnp.complex128", @@ -1177,6 +1180,7 @@ def _get_pallas_required_alignment( tensor_ndim (int): Amount of dimensions for the tensor. bitwidth (int): Bitwidth of tensor elements """ + bitwidth = min(bitwidth, 32) if dim_from_end == 0: # Last dimension if tensor_ndim <= 1: return 128 * (32 // bitwidth) @@ -1388,7 +1392,7 @@ def _compute_block_spec_info( # If not, fall-back to no tiling for the entire kernel dim_size = tensor.shape[d] dim_from_end = tensor.ndim - 1 - d - bitwidth = tensor.dtype.itemsize * 8 + bitwidth = min(tensor.dtype.itemsize * 8, 32) required_alignment = self._get_pallas_required_alignment( dim_from_end, tensor.ndim, bitwidth ) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 827481454..83561a415 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -1024,6 +1024,11 @@ def _print_FloorDiv(self, expr: sympy.Expr) -> str: # pyrefly: ignore [missing-attribute] return f"({self._print(lhs)} // {self._print(rhs)})" + def _print_PythonMod(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # pyrefly: ignore [missing-attribute] + return f"({self._print(lhs)} % {self._print(rhs)})" + def pallas_texpr(expr: sympy.Expr) -> str: return HelionPallasPrinter().doprint(expr) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index e3e8eccd5..55d6508a5 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -230,7 +230,7 @@ def _pallas_index_str( dim_map.setdefault(tensor_dim, block_id) elif isinstance(idx, int): parts.append(str(idx)) - elif isinstance(idx, torch.SymInt): + elif isinstance(idx, (torch.SymInt, torch.Tensor)): ast_subscripts = state.ast_args[1] assert isinstance(ast_subscripts, list) ast_idx = ast_subscripts[i] diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 7bda0d78c..e74f47225 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -455,6 +455,28 @@ def reordered_kernel(*refs: object) -> None: out_ref[...] = in_ref[...] # type: ignore[index] original_order[orig_pos] = out_ref extra_refs = refs[n_tensor_inputs + len(_output_indices) :] + + # --- DEBUG TRACE --- + print("\n[DEBUG TRACE] Executing Pallas Kernel") + print("Original Order (Tensors):") + for i, ref in enumerate(original_order): + if hasattr(ref, "shape"): + print( + f" Arg {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" + ) + else: + print(f" Arg {i}: Value={ref}") + print("Extra Refs (Scratches/VMEM):") + for i, ref in enumerate(extra_refs): + if hasattr(ref, "shape"): + print( + f" Scratch {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}" + ) + else: + print(f" Scratch {i}: Value={ref} (Likely Semaphore)") + print("---------------------\n") + # ------------------- + pallas_kernel(*original_order, *extra_refs) # type: ignore[operator] return reordered_kernel @@ -727,11 +749,15 @@ def default_pallas_pipeline_launcher( _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, + "jnp.float64": jnp.float64, "jnp.bfloat16": jnp.bfloat16, "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, + "jnp.int64": jnp.int64, "jnp.uint8": jnp.uint8, + "jnp.uint16": jnp.uint16, + "jnp.uint32": jnp.uint32, "jnp.bool_": jnp.bool_, } scratch_shapes = [] @@ -801,6 +827,19 @@ def default_pallas_pipeline_launcher( if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True + # --- DEBUG PALLAS CALL --- + print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") + for k, v in pallas_call_kwargs.items(): + if k == "grid_spec": + print(f" {k}:") + print(f" in_specs: {v.in_specs}") + print(f" out_specs: {v.out_specs}") + print(f" scratch_shapes: {v.scratch_shapes}") + else: + print(f" {k}: {v}") + print("--------------------------\n") + # ------------------------- + jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type] @@ -868,11 +907,17 @@ def default_pallas_fori_launcher( _jnp_dtype_map: dict[str, object] = { "jnp.float32": jnp.float32, "jnp.float16": jnp.float16, + "jnp.float64": jnp.float64, "jnp.bfloat16": jnp.bfloat16, + "jnp.float8_e4m3fn": getattr(jnp, "float8_e4m3fn", jnp.float32), + "jnp.float8_e5m2": getattr(jnp, "float8_e5m2", jnp.float32), "jnp.int32": jnp.int32, "jnp.int16": jnp.int16, "jnp.int8": jnp.int8, + "jnp.int64": jnp.int64, "jnp.uint8": jnp.uint8, + "jnp.uint16": jnp.uint16, + "jnp.uint32": jnp.uint32, "jnp.bool_": jnp.bool_, } scratch_shapes = [] @@ -941,6 +986,19 @@ def default_pallas_fori_launcher( if _pallas_interpret_flag(): pallas_call_kwargs["interpret"] = True + # --- DEBUG PALLAS CALL --- + print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}") + for k, v in pallas_call_kwargs.items(): + if k == "grid_spec": + print(f" {k}:") + print(f" in_specs: {v.in_specs}") + print(f" out_specs: {v.out_specs}") + print(f" scratch_shapes: {v.scratch_shapes}") + else: + print(f" {k}: {v}") + print("--------------------------\n") + # ------------------------- + jit_fn = pl.pallas_call( reordered_kernel, # pyrefly: ignore[bad-argument-type] **pallas_call_kwargs, # type: ignore[arg-type]