Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ jobs:
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--pre \
'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict'
'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict'
# Install Bazel
if ! command -v bazel &> /dev/null; then
sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel
Expand Down
45 changes: 26 additions & 19 deletions examples/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,34 @@ def cross_entropy(
n, v = logits.shape
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)

# Flatten logits once at the beginning
logits_flat = logits.view(-1)

Comment on lines -50 to -52
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changig the kernel? Shouldn't we make the existing kernel work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making teh existing kernel work means we have to use fori_loop loop.

I avoided it with a dynamic gather (e.g., hl.load([flat_indices])) because it forces the kernel out of the emit_pipeline execution model and do an indirect, data-dependent read from HBM based on the label indices. TPUs are heavily optimized for large, contiguous memory bursts rather than sparse, 4-byte random accesses. By loading the full, contiguous rows of logits and using a boolean mask (chunk_logits * mask) to extract the target logit, we deliberately trade extremely cheap ALU operations for predictable, dense memory access. This allows the compiler to keep the kernel in emit_pipeline mode, maximizing HBM bandwidth utilization and overlapping our memory loads with the masking compute.

Cleaned up PR in #2019

for tile_n in hl.tile(n):
# Get data for this tile
labels_tile = labels[tile_n] # [tile_size]
base_indices_tile = tile_n.index * v # [tile_size]

# Compute the actual flat indices by adding the label offset
flat_indices = base_indices_tile + labels_tile
logits_at_target = hl.zeros([tile_n], dtype=logits.dtype)
max_logits_acc = hl.full([tile_n], float("-inf"), dtype=logits.dtype)

# First pass: find max and target logits
for v_chunk in hl.tile(v):
chunk_logits = logits[tile_n, v_chunk]

# Load the logits at the target indices
logits_at_target = hl.load(logits_flat, [flat_indices])
# Extract target using a chunked mask
mask = (v_chunk.index[None, :] == labels_tile[:, None]).to(logits.dtype)
logits_at_target += torch.sum(chunk_logits * mask, dim=-1)

# Compute log_softmax for numerical stability
# Load the full rows for this tile
logits_rows = logits[tile_n, :] # [tile_size, V]
# Update max
max_logits_acc = torch.maximum(
max_logits_acc, torch.amax(chunk_logits, dim=-1)
)

# Compute log-sum-exp
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
shifted = logits_rows - max_logits
exp_shifted = torch.exp(shifted)
sum_exp = torch.sum(exp_shifted, dim=-1, keepdim=True)
log_sum_exp = max_logits.squeeze(-1) + torch.log(sum_exp.squeeze(-1))
# Second pass: sum exp
sum_exp_acc = hl.zeros([tile_n], dtype=logits.dtype)
for v_chunk in hl.tile(v):
chunk_logits = logits[tile_n, v_chunk]
shifted = chunk_logits - max_logits_acc[:, None]
sum_exp_acc += torch.sum(torch.exp(shifted), dim=-1)

log_sum_exp = max_logits_acc + torch.log(sum_exp_acc)

# Cross entropy loss: log_sum_exp - logit_at_target
losses[tile_n] = log_sum_exp - logits_at_target
Expand All @@ -91,11 +95,14 @@ def main() -> None:
batch_size, seq_len, vocab_size = 8, 2048, 131072
n = batch_size * seq_len
logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32)
labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.long)
labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.int32)

def baseline_ce(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.cross_entropy(logits, labels.long())

run_example(
cross_entropy,
torch.nn.functional.cross_entropy,
baseline_ce,
(logits, labels),
kernel_name="helion",
baseline_name="torch",
Expand Down
8 changes: 3 additions & 5 deletions examples/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ def helion_flex_attention_kernel(

# iterate through full tiles (no mask needed)
if block_mask_full_kv_indices is not None:
sparse_num_blocks = (
block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation]
b_idx, h_idx, sparse_row
]
)
sparse_num_blocks = block_mask_full_kv_num_blocks[ # pyrefly: ignore[unsupported-operation]
b_idx, h_idx, sparse_row
]

for block_idx in hl.tile(sparse_num_blocks, block_size=1):
start_n = block_mask_full_kv_indices[
Expand Down
4 changes: 1 addition & 3 deletions examples/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ def grouped_gemm_jagged_persistent(
b_blk = hl.load(
B,
[k_idx, col_idx],
extra_mask=cols_valid[
None, :
], # pyrefly: ignore[bad-index]
extra_mask=cols_valid[None, :], # pyrefly: ignore[bad-index]
)

# Perform tile-level matrix multiplication and accumulate
Expand Down
4 changes: 1 addition & 3 deletions helion/_compiler/_dynamo/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,7 @@ def helion_kernel_wrapper_functional_dense(
kernel_outputs = helion_kernel_wrapper_mutation(
kernel_idx=kernel_idx,
constant_args=constant_args,
tensor_args={
k: cloned.get(k, v) for k, v in tensor_args.items()
}, # pyrefly: ignore[bad-argument-type]
tensor_args={k: cloned.get(k, v) for k, v in tensor_args.items()}, # pyrefly: ignore[bad-argument-type]
output_spec=output_spec,
)
return (kernel_outputs, cloned)
Expand Down
12 changes: 5 additions & 7 deletions helion/_compiler/_inductor/template_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,11 @@ def create(
if not any(isinstance(leaf, torch.Tensor) for leaf in flat):
return buf, ()

result = (
TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute]
buf,
structured_outputs,
direct_alias_at_leaf=direct_aliases,
on_tensor_leaf=on_tensor_leaf,
)
result = TemplateBuffer.build_multi_outputs( # pyrefly: ignore[missing-attribute]
buf,
structured_outputs,
direct_alias_at_leaf=direct_aliases,
on_tensor_leaf=on_tensor_leaf,
)
return buf, result

Expand Down
12 changes: 8 additions & 4 deletions helion/_compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,15 +881,18 @@ def tunable_fragments(self) -> dict[str, ConfigSpecFragment]:
_TORCH_TO_JAX_DTYPE: dict[str, str] = {
"torch.float16": "jnp.float16",
"torch.float32": "jnp.float32",
"torch.float64": "jnp.float64",
"torch.float64": "jnp.float32",
"torch.bfloat16": "jnp.bfloat16",
"torch.float8_e4m3fn": "jnp.float8_e4m3fn",
"torch.float8_e5m2": "jnp.float8_e5m2",
"torch.int8": "jnp.int8",
"torch.int16": "jnp.int16",
"torch.int32": "jnp.int32",
"torch.int64": "jnp.int64",
"torch.int64": "jnp.int32",
"torch.long": "jnp.int32",
"torch.uint8": "jnp.uint8",
"torch.uint32": "jnp.uint32",
"torch.uint64": "jnp.uint64",
"torch.uint64": "jnp.uint32",
"torch.bool": "jnp.bool_",
"torch.complex64": "jnp.complex64",
"torch.complex128": "jnp.complex128",
Expand Down Expand Up @@ -1177,6 +1180,7 @@ def _get_pallas_required_alignment(
tensor_ndim (int): Amount of dimensions for the tensor.
bitwidth (int): Bitwidth of tensor elements
"""
bitwidth = min(bitwidth, 32)
if dim_from_end == 0: # Last dimension
if tensor_ndim <= 1:
return 128 * (32 // bitwidth)
Expand Down Expand Up @@ -1388,7 +1392,7 @@ def _compute_block_spec_info(
# If not, fall-back to no tiling for the entire kernel
dim_size = tensor.shape[d]
dim_from_end = tensor.ndim - 1 - d
bitwidth = tensor.dtype.itemsize * 8
bitwidth = min(tensor.dtype.itemsize * 8, 32)
required_alignment = self._get_pallas_required_alignment(
dim_from_end, tensor.ndim, bitwidth
)
Expand Down
5 changes: 5 additions & 0 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ def _print_FloorDiv(self, expr: sympy.Expr) -> str:
# pyrefly: ignore [missing-attribute]
return f"({self._print(lhs)} // {self._print(rhs)})"

def _print_PythonMod(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args
# pyrefly: ignore [missing-attribute]
return f"({self._print(lhs)} % {self._print(rhs)})"


def pallas_texpr(expr: sympy.Expr) -> str:
return HelionPallasPrinter().doprint(expr)
2 changes: 1 addition & 1 deletion helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _pallas_index_str(
dim_map.setdefault(tensor_dim, block_id)
elif isinstance(idx, int):
parts.append(str(idx))
elif isinstance(idx, torch.SymInt):
elif isinstance(idx, (torch.SymInt, torch.Tensor)):
ast_subscripts = state.ast_args[1]
assert isinstance(ast_subscripts, list)
ast_idx = ast_subscripts[i]
Expand Down
58 changes: 58 additions & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,28 @@
out_ref[...] = in_ref[...] # type: ignore[index]
original_order[orig_pos] = out_ref
extra_refs = refs[n_tensor_inputs + len(_output_indices) :]

# --- DEBUG TRACE ---
print("\n[DEBUG TRACE] Executing Pallas Kernel")
print("Original Order (Tensors):")
Comment on lines +460 to +461
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Copy Markdown
Collaborator Author

@yarongmu-google yarongmu-google Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this PR was messed up when syncing with upstream and contains extra fiels. I will abandon this and create a clean one based on the current upstream main: #2018

for i, ref in enumerate(original_order):
if hasattr(ref, "shape"):
print(
f" Arg {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}"
)
else:
print(f" Arg {i}: Value={ref}")
print("Extra Refs (Scratches/VMEM):")
for i, ref in enumerate(extra_refs):
if hasattr(ref, "shape"):
print(
f" Scratch {i}: Shape={ref.shape}, Dtype={getattr(ref, 'dtype', 'unknown')}"
)
else:
print(f" Scratch {i}: Value={ref} (Likely Semaphore)")
print("---------------------\n")
# -------------------

pallas_kernel(*original_order, *extra_refs) # type: ignore[operator]

return reordered_kernel
Expand Down Expand Up @@ -727,11 +749,15 @@
_jnp_dtype_map: dict[str, object] = {
"jnp.float32": jnp.float32,
"jnp.float16": jnp.float16,
"jnp.float64": jnp.float64,
"jnp.bfloat16": jnp.bfloat16,
"jnp.int32": jnp.int32,
"jnp.int16": jnp.int16,
"jnp.int8": jnp.int8,
"jnp.int64": jnp.int64,
"jnp.uint8": jnp.uint8,
"jnp.uint16": jnp.uint16,
"jnp.uint32": jnp.uint32,
"jnp.bool_": jnp.bool_,
}
scratch_shapes = []
Expand Down Expand Up @@ -801,6 +827,19 @@
if _pallas_interpret_flag():
pallas_call_kwargs["interpret"] = True

# --- DEBUG PALLAS CALL ---
print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}")
for k, v in pallas_call_kwargs.items():
if k == "grid_spec":
print(f" {k}:")
print(f" in_specs: {v.in_specs}")

Check failure on line 835 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `in_specs`
print(f" out_specs: {v.out_specs}")

Check failure on line 836 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `out_specs`
print(f" scratch_shapes: {v.scratch_shapes}")

Check failure on line 837 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `scratch_shapes`
else:
print(f" {k}: {v}")
print("--------------------------\n")
# -------------------------

jit_fn = pl.pallas_call(
reordered_kernel, # pyrefly: ignore[bad-argument-type]
**pallas_call_kwargs, # type: ignore[arg-type]
Expand Down Expand Up @@ -868,11 +907,17 @@
_jnp_dtype_map: dict[str, object] = {
"jnp.float32": jnp.float32,
"jnp.float16": jnp.float16,
"jnp.float64": jnp.float64,
"jnp.bfloat16": jnp.bfloat16,
"jnp.float8_e4m3fn": getattr(jnp, "float8_e4m3fn", jnp.float32),
"jnp.float8_e5m2": getattr(jnp, "float8_e5m2", jnp.float32),
"jnp.int32": jnp.int32,
"jnp.int16": jnp.int16,
"jnp.int8": jnp.int8,
"jnp.int64": jnp.int64,
"jnp.uint8": jnp.uint8,
"jnp.uint16": jnp.uint16,
"jnp.uint32": jnp.uint32,
"jnp.bool_": jnp.bool_,
}
scratch_shapes = []
Expand Down Expand Up @@ -941,6 +986,19 @@
if _pallas_interpret_flag():
pallas_call_kwargs["interpret"] = True

# --- DEBUG PALLAS CALL ---
print(f"\n[DEBUG PALLAS CALL] {'Pipeline/Fori Launcher'}")
for k, v in pallas_call_kwargs.items():
if k == "grid_spec":
print(f" {k}:")
print(f" in_specs: {v.in_specs}")

Check failure on line 994 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `in_specs`
print(f" out_specs: {v.out_specs}")

Check failure on line 995 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `out_specs`
print(f" scratch_shapes: {v.scratch_shapes}")

Check failure on line 996 in helion/runtime/__init__.py

View workflow job for this annotation

GitHub Actions / lint (3.12)

Pyrefly missing-attribute

Object of class `object` has no attribute `scratch_shapes`
else:
print(f" {k}: {v}")
print("--------------------------\n")
# -------------------------

jit_fn = pl.pallas_call(
reordered_kernel, # pyrefly: ignore[bad-argument-type]
**pallas_call_kwargs, # type: ignore[arg-type]
Expand Down
Loading