From 15a81f8fd285080a1811e763f1908db82e770354 Mon Sep 17 00:00:00 2001 From: Charles Hong Date: Tue, 9 Jun 2026 23:03:35 -0700 Subject: [PATCH] Fix Megablox GMM no-op exploit: use realistic non-underflowing inputs The 11p_Megablox_GMM baseline scaled inputs by limit=1/(M*K) ~= 7.45e-9 (uniform), so each output element was ~2e-13 and underflowed toward zero in bf16. The reference output was therefore ~0 everywhere, which let a no-op kernel (returning zeros, or skipping the grouped matmul entirely) trivially pass np.allclose(atol=1e-2) and report an enormous, meaningless speedup (observed in practice: a degenerate kernel timing ~0.1 ms / ~27x). This replaces the inputs with small-normal weights/activations (~0.02 scale, output max ~0.14 / mean ~0.02 -- bf16-representable, no K=4096 overflow) and a simulated top-k router with mild popularity bias (router_bias_scale=0.15) so group sizes are realistic and non-uniform. max_expert_size is a static jit arg (added to CONFIG so harnesses that read static_argnums lower the reference correctly). With these inputs a no-op kernel now fails correctness. Scope: this fixes Megablox's input-underflow mechanism specifically. A separate and distinct no-op class -- KernelBench-derived baselines (18k-50k) that initialize weights/bias to jnp.zeros, yielding an all-zero reference -- is not addressed here and is left to a follow-up. Co-Authored-By: Claude Opus 4.8 --- .../benchmark/11p_Megablox_GMM/baseline.py | 61 ++++++++++++++++--- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/JAXBench/benchmark/11p_Megablox_GMM/baseline.py b/JAXBench/benchmark/11p_Megablox_GMM/baseline.py index f0649b3..56e8b57 100644 --- a/JAXBench/benchmark/11p_Megablox_GMM/baseline.py +++ b/JAXBench/benchmark/11p_Megablox_GMM/baseline.py @@ -4,6 +4,7 @@ and multiply with that expert's weight matrix. Core primitive for MoE layers. From JAX experimental pallas ops (reference_gmm). +Jit-compatible: uses static-shape slicing with masking on group_sizes. """ import jax @@ -18,25 +19,69 @@ 'emb_dim': 4096, 'moe_mlp_dim': 1536, 'seq_len': 4096, + # Realistic top-k router with mild expert popularity bias. + # bias_scale=0.15 produces max/mean ~1.9, CV ~0.30, no empty groups. + 'router_bias_scale': 0.15, + # max_expert_size is a static jit arg (compile-time upper bound on the + # largest group). 1024 = 4 * M/G is generous: covers our simulated router + # (TPU-observed max ~520 at bias_scale=0.15) with ~2x margin, and would + # still hold up to bias_scale ~0.4 if we ever wanted a more imbalanced + # workload. Matches how production capacity-factor fast-paths over-size + # the per-expert tile. + 'max_expert_size': 1024, + 'static_argnums': (3,), # max_expert_size must be static for jit } +def _simulate_router(key, S, G, top_k, bias_scale): + """Simulate top-k token-choice routing with expert popularity bias. + + Returns (assignments, group_sizes) where assignments has shape (S*top_k,) + and is sorted by expert id, suitable as an Megablox GMM input contract + (lhs rows pre-grouped by expert). + """ + k_bias, k_noise = jax.random.split(key, 2) + expert_bias = jax.random.normal(k_bias, (G,)) * bias_scale + per_token_noise = jax.random.normal(k_noise, (S, G)) + router_logits = expert_bias[None, :] + per_token_noise + _, topk_idx = jax.lax.top_k(router_logits, top_k) # (S, top_k) + assignments = topk_idx.reshape(-1) # (S*top_k,) + sort_perm = jnp.argsort(assignments, stable=True) + sorted_assignments = assignments[sort_perm] + group_sizes = jnp.bincount(sorted_assignments, length=G).astype(jnp.int32) + return sort_perm, group_sizes + + def create_inputs(dtype=jnp.bfloat16): key = jax.random.key(42) - k1, k2 = jax.random.split(key, 2) + k_router, k_lhs, k_rhs = jax.random.split(key, 3) G = CONFIG['num_experts'] top_k = CONFIG['num_experts_per_tok'] K = CONFIG['emb_dim'] N = CONFIG['moe_mlp_dim'] S = CONFIG['seq_len'] M = S * top_k - limit = 1 / (M * K) - lhs = jax.random.uniform(k1, (M, K), dtype=dtype, minval=-limit, maxval=limit) - lhs = lhs.astype(jnp.bfloat16).astype(dtype) - rhs = jax.random.uniform(k2, (G, K, N), dtype=dtype, minval=-limit, maxval=limit) - rhs = rhs.astype(jnp.bfloat16).astype(dtype) - max_expert_size = M // G - group_sizes = jnp.full((G,), max_expert_size, dtype=jnp.int32) + max_expert_size = CONFIG['max_expert_size'] + # Small-normal weights/activations (~0.02 scale): large enough that + # matmul outputs are bf16-representable but small enough to avoid + # overflow when accumulated across K=4096. Previous version used + # `1/(M*K)` as a uniform limit, which underflowed to zero in bf16 and + # let no-op kernels trivially pass np.allclose against an all-zero + # reference. + lhs_unsorted = jax.random.normal(k_lhs, (M, K), dtype=dtype) * 0.02 + rhs = jax.random.normal(k_rhs, (G, K, N), dtype=dtype) * 0.02 + sort_perm, group_sizes = _simulate_router( + k_router, S, G, top_k, CONFIG['router_bias_scale'] + ) + # Sort lhs rows by expert id so contiguous rows belong to the same expert + # (Megablox's input contract). + lhs = lhs_unsorted[sort_perm] + # Runtime check: assert max group size fits the static upper bound. + assert int(group_sizes.max()) <= max_expert_size, ( + f"Simulated max group size {int(group_sizes.max())} exceeds " + f"static max_expert_size={max_expert_size}; raise CONFIG['max_expert_size'] " + f"or reduce router_bias_scale." + ) return lhs, rhs, group_sizes, max_expert_size