Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions JAXBench/benchmark/11p_Megablox_GMM/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
and multiply with that expert's weight matrix. Core primitive for MoE layers.
From JAX experimental pallas ops (reference_gmm).

Jit-compatible: uses static-shape slicing with masking on group_sizes.
"""

import jax
Expand All @@ -18,25 +19,69 @@
'emb_dim': 4096,
'moe_mlp_dim': 1536,
'seq_len': 4096,
# Realistic top-k router with mild expert popularity bias.
# bias_scale=0.15 produces max/mean ~1.9, CV ~0.30, no empty groups.
'router_bias_scale': 0.15,
# max_expert_size is a static jit arg (compile-time upper bound on the
# largest group). 1024 = 4 * M/G is generous: covers our simulated router
# (TPU-observed max ~520 at bias_scale=0.15) with ~2x margin, and would
# still hold up to bias_scale ~0.4 if we ever wanted a more imbalanced
# workload. Matches how production capacity-factor fast-paths over-size
# the per-expert tile.
'max_expert_size': 1024,
'static_argnums': (3,), # max_expert_size must be static for jit
}


def _simulate_router(key, S, G, top_k, bias_scale):
"""Simulate top-k token-choice routing with expert popularity bias.

Returns (assignments, group_sizes) where assignments has shape (S*top_k,)
and is sorted by expert id, suitable as an Megablox GMM input contract
(lhs rows pre-grouped by expert).
"""
k_bias, k_noise = jax.random.split(key, 2)
expert_bias = jax.random.normal(k_bias, (G,)) * bias_scale
per_token_noise = jax.random.normal(k_noise, (S, G))
router_logits = expert_bias[None, :] + per_token_noise
_, topk_idx = jax.lax.top_k(router_logits, top_k) # (S, top_k)
assignments = topk_idx.reshape(-1) # (S*top_k,)
sort_perm = jnp.argsort(assignments, stable=True)
sorted_assignments = assignments[sort_perm]
group_sizes = jnp.bincount(sorted_assignments, length=G).astype(jnp.int32)
return sort_perm, group_sizes


def create_inputs(dtype=jnp.bfloat16):
key = jax.random.key(42)
k1, k2 = jax.random.split(key, 2)
k_router, k_lhs, k_rhs = jax.random.split(key, 3)
G = CONFIG['num_experts']
top_k = CONFIG['num_experts_per_tok']
K = CONFIG['emb_dim']
N = CONFIG['moe_mlp_dim']
S = CONFIG['seq_len']
M = S * top_k
limit = 1 / (M * K)
lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit)
lhs = lhs.astype(jnp.bfloat16).astype(dtype)
rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit)
rhs = rhs.astype(jnp.bfloat16).astype(dtype)
max_expert_size = M // G
group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32)
max_expert_size = CONFIG['max_expert_size']
# Small-normal weights/activations (~0.02 scale): large enough that
# matmul outputs are bf16-representable but small enough to avoid
# overflow when accumulated across K=4096. Previous version used
# `1/(M*K)` as a uniform limit, which underflowed to zero in bf16 and
# let no-op kernels trivially pass np.allclose against an all-zero
# reference.
lhs_unsorted = jax.random.normal(k_lhs, (M, K), dtype=dtype) * 0.02
rhs = jax.random.normal(k_rhs, (G, K, N), dtype=dtype) * 0.02
sort_perm, group_sizes = _simulate_router(
k_router, S, G, top_k, CONFIG['router_bias_scale']
)
# Sort lhs rows by expert id so contiguous rows belong to the same expert
# (Megablox's input contract).
lhs = lhs_unsorted[sort_perm]
# Runtime check: assert max group size fits the static upper bound.
assert int(group_sizes.max()) <= max_expert_size, (
f"Simulated max group size {int(group_sizes.max())} exceeds "
f"static max_expert_size={max_expert_size}; raise CONFIG['max_expert_size'] "
f"or reduce router_bias_scale."
)
return lhs, rhs, group_sizes, max_expert_size


Expand Down