From c1f62e9a251010f6b85da6834439a3adcda3a076 Mon Sep 17 00:00:00 2001 From: Charles Hong Date: Wed, 10 Jun 2026 00:04:21 -0700 Subject: [PATCH] Fix no-op exploit in 19 KernelBench baselines: zero-init weights -> random 18k-50k KernelBench-derived baselines initialized their weight/bias tensors to jnp.zeros in create_inputs. With zero weights, `x @ W (+ b)` is identically zero and independent of the input, so the reference output is a trivial constant (all-zero, or a fixed activation thereof). Any kernel returning that constant -- including a no-op that skips the matmul/conv entirely -- passes np.allclose, so these benchmarks could report large meaningless speedups without computing the operator. This replaces the zero-init weights/biases with small-normal random values (~0.02 scale: input-dependent, bf16-representable, no overflow). Only create_inputs is changed; the workload/op is untouched. After the fix a no-op (all-zero output) fails correctness on all 19. Scope: 19 of the affected baselines are fully fixed by non-zero weights. Five others whose *output* is intrinsically small regardless of weights -- the softmax-terminated 38k/43k/50k (row outputs ~1/N) and the structurally degenerate 25k (GroupNorm->Mean) and 42k (Max-Subtract-GELU) -- are NOT addressed here; they need a tolerance or operator change and are left to a follow-up. Megablox (11p) has a distinct input-underflow variant fixed separately. Co-Authored-By: Claude Opus 4.8 --- JAXBench/benchmark/18k_Conv2D_ReLU_BiasAdd/baseline.py | 8 +++++--- .../19k_Matmul_Subtract_Multiply_ReLU/baseline.py | 6 ++++-- .../benchmark/20k_Gemm_Multiply_LeakyReLU/baseline.py | 6 ++++-- .../benchmark/22k_Conv2d_InstanceNorm_Divide/baseline.py | 6 ++++-- .../baseline.py | 6 ++++-- JAXBench/benchmark/27k_Matmul_Mish_Mish/baseline.py | 6 ++++-- .../benchmark/29k_Matmul_Swish_Sum_GroupNorm/baseline.py | 6 ++++-- .../benchmark/30k_Matmul_Scaling_ResidualAdd/baseline.py | 6 ++++-- JAXBench/benchmark/33k_Conv3d_Mish_Tanh/baseline.py | 6 ++++-- .../benchmark/35k_Gemm_Scaling_Hardtanh_GELU/baseline.py | 6 ++++-- JAXBench/benchmark/37k_Matmul_Swish_Scaling/baseline.py | 6 ++++-- .../benchmark/39k_Conv2d_GELU_GlobalAvgPool/baseline.py | 6 ++++-- .../benchmark/40k_Gemm_GroupNorm_Min_BiasAdd/baseline.py | 8 +++++--- JAXBench/benchmark/41k_Gemm_Add_ReLU/baseline.py | 6 ++++-- JAXBench/benchmark/44k_Matmul_Divide_GELU/baseline.py | 6 ++++-- .../45k_Gemm_GroupNorm_Swish_Multiply_Swish/baseline.py | 8 +++++--- .../47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/baseline.py | 8 +++++--- .../48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/baseline.py | 8 +++++--- .../49k_Matmul_AvgPool_GELU_Scale_Max/baseline.py | 6 ++++-- 19 files changed, 81 insertions(+), 43 deletions(-) diff --git a/JAXBench/benchmark/18k_Conv2D_ReLU_BiasAdd/baseline.py b/JAXBench/benchmark/18k_Conv2D_ReLU_BiasAdd/baseline.py index 6a6ce25..383d5d0 100644 --- a/JAXBench/benchmark/18k_Conv2D_ReLU_BiasAdd/baseline.py +++ b/JAXBench/benchmark/18k_Conv2D_ReLU_BiasAdd/baseline.py @@ -14,13 +14,15 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb, kc = jax.random.split(rand_key, 3) k1, k2 = jax.random.split(key) batch_size, in_channels, out_channels, kernel_size = 128, 64, 128, 3 height = width = 128 x = jax.random.uniform(k1, (batch_size, in_channels, height, width), dtype=dtype) - weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) - conv_bias = jnp.zeros(out_channels, dtype=dtype) - bias = jnp.zeros((out_channels, 1, 1), dtype=dtype) + weight = jax.random.normal(ka, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) * 0.02 + conv_bias = jax.random.normal(kb, out_channels, dtype=dtype) * 0.02 + bias = jax.random.normal(kc, (out_channels, 1, 1), dtype=dtype) * 0.02 return x, weight, conv_bias, bias diff --git a/JAXBench/benchmark/19k_Matmul_Subtract_Multiply_ReLU/baseline.py b/JAXBench/benchmark/19k_Matmul_Subtract_Multiply_ReLU/baseline.py index e7185b6..fbcd348 100644 --- a/JAXBench/benchmark/19k_Matmul_Subtract_Multiply_ReLU/baseline.py +++ b/JAXBench/benchmark/19k_Matmul_Subtract_Multiply_ReLU/baseline.py @@ -15,9 +15,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/20k_Gemm_Multiply_LeakyReLU/baseline.py b/JAXBench/benchmark/20k_Gemm_Multiply_LeakyReLU/baseline.py index d9b4474..a004262 100644 --- a/JAXBench/benchmark/20k_Gemm_Multiply_LeakyReLU/baseline.py +++ b/JAXBench/benchmark/20k_Gemm_Multiply_LeakyReLU/baseline.py @@ -15,9 +15,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/22k_Conv2d_InstanceNorm_Divide/baseline.py b/JAXBench/benchmark/22k_Conv2d_InstanceNorm_Divide/baseline.py index 0774bb5..4fca411 100644 --- a/JAXBench/benchmark/22k_Conv2d_InstanceNorm_Divide/baseline.py +++ b/JAXBench/benchmark/22k_Conv2d_InstanceNorm_Divide/baseline.py @@ -15,11 +15,13 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) batch_size, in_channels, out_channels, kernel_size = 128, 64, 128, 3 height = width = 128 x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) - weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) - conv_bias = jnp.zeros(out_channels, dtype=dtype) + weight = jax.random.normal(ka, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) * 0.02 + conv_bias = jax.random.normal(kb, out_channels, dtype=dtype) * 0.02 in_weight = jnp.ones(out_channels, dtype=dtype) in_bias = jnp.zeros(out_channels, dtype=dtype) return x, weight, conv_bias, in_weight, in_bias diff --git a/JAXBench/benchmark/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/baseline.py b/JAXBench/benchmark/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/baseline.py index c1ccdfe..16ed3ce 100644 --- a/JAXBench/benchmark/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/baseline.py +++ b/JAXBench/benchmark/23k_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp/baseline.py @@ -13,9 +13,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/27k_Matmul_Mish_Mish/baseline.py b/JAXBench/benchmark/27k_Matmul_Mish_Mish/baseline.py index bfbbf75..17dc573 100644 --- a/JAXBench/benchmark/27k_Matmul_Mish_Mish/baseline.py +++ b/JAXBench/benchmark/27k_Matmul_Mish_Mish/baseline.py @@ -13,9 +13,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/29k_Matmul_Swish_Sum_GroupNorm/baseline.py b/JAXBench/benchmark/29k_Matmul_Swish_Sum_GroupNorm/baseline.py index 49ae776..bf9cee8 100644 --- a/JAXBench/benchmark/29k_Matmul_Swish_Sum_GroupNorm/baseline.py +++ b/JAXBench/benchmark/29k_Matmul_Swish_Sum_GroupNorm/baseline.py @@ -14,10 +14,12 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) batch_size, in_features, out_features, num_groups = 8192, 4096, 4096, 64 x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) - weight = jnp.zeros((in_features, out_features), dtype=dtype) - bias = jnp.zeros(out_features, dtype=dtype) + weight = jax.random.normal(ka, (in_features, out_features), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, out_features, dtype=dtype) * 0.02 gn_weight = jnp.ones(out_features, dtype=dtype) gn_bias = jnp.zeros(out_features, dtype=dtype) return x, weight, bias, gn_weight, gn_bias diff --git a/JAXBench/benchmark/30k_Matmul_Scaling_ResidualAdd/baseline.py b/JAXBench/benchmark/30k_Matmul_Scaling_ResidualAdd/baseline.py index 4241ce8..495e6e4 100644 --- a/JAXBench/benchmark/30k_Matmul_Scaling_ResidualAdd/baseline.py +++ b/JAXBench/benchmark/30k_Matmul_Scaling_ResidualAdd/baseline.py @@ -14,9 +14,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (16384, 4096), dtype=dtype) - weight = jnp.zeros((4096, 4096), dtype=dtype) - bias = jnp.zeros(4096, dtype=dtype) + weight = jax.random.normal(ka, (4096, 4096), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 4096, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/33k_Conv3d_Mish_Tanh/baseline.py b/JAXBench/benchmark/33k_Conv3d_Mish_Tanh/baseline.py index 7b3e12f..ff64b25 100644 --- a/JAXBench/benchmark/33k_Conv3d_Mish_Tanh/baseline.py +++ b/JAXBench/benchmark/33k_Conv3d_Mish_Tanh/baseline.py @@ -14,11 +14,13 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) batch_size, in_channels, out_channels, kernel_size = 16, 32, 64, 3 D, H, W = 32, 64, 64 x = jax.random.uniform(key, (batch_size, in_channels, D, H, W), dtype=dtype) - weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) - bias = jnp.zeros(out_channels, dtype=dtype) + weight = jax.random.normal(ka, (out_channels, in_channels, kernel_size, kernel_size, kernel_size), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, out_channels, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/35k_Gemm_Scaling_Hardtanh_GELU/baseline.py b/JAXBench/benchmark/35k_Gemm_Scaling_Hardtanh_GELU/baseline.py index 31a591e..014b471 100644 --- a/JAXBench/benchmark/35k_Gemm_Scaling_Hardtanh_GELU/baseline.py +++ b/JAXBench/benchmark/35k_Gemm_Scaling_Hardtanh_GELU/baseline.py @@ -16,9 +16,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/37k_Matmul_Swish_Scaling/baseline.py b/JAXBench/benchmark/37k_Matmul_Swish_Scaling/baseline.py index a8e49df..57d7a1c 100644 --- a/JAXBench/benchmark/37k_Matmul_Swish_Scaling/baseline.py +++ b/JAXBench/benchmark/37k_Matmul_Swish_Scaling/baseline.py @@ -14,9 +14,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/39k_Conv2d_GELU_GlobalAvgPool/baseline.py b/JAXBench/benchmark/39k_Conv2d_GELU_GlobalAvgPool/baseline.py index 45465ec..e44643f 100644 --- a/JAXBench/benchmark/39k_Conv2d_GELU_GlobalAvgPool/baseline.py +++ b/JAXBench/benchmark/39k_Conv2d_GELU_GlobalAvgPool/baseline.py @@ -14,11 +14,13 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) batch_size, in_channels, out_channels, kernel_size = 128, 8, 64, 3 height, width = 256, 256 x = jax.random.uniform(key, (batch_size, in_channels, height, width), dtype=dtype) - weight = jnp.zeros((out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) - bias = jnp.zeros(out_channels, dtype=dtype) + weight = jax.random.normal(ka, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, out_channels, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/40k_Gemm_GroupNorm_Min_BiasAdd/baseline.py b/JAXBench/benchmark/40k_Gemm_GroupNorm_Min_BiasAdd/baseline.py index 1c01586..8593fcf 100644 --- a/JAXBench/benchmark/40k_Gemm_GroupNorm_Min_BiasAdd/baseline.py +++ b/JAXBench/benchmark/40k_Gemm_GroupNorm_Min_BiasAdd/baseline.py @@ -14,13 +14,15 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb, kc = jax.random.split(rand_key, 3) batch_size, in_features, out_features, num_groups = 4096, 8192, 8192, 512 x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) - weight = jnp.zeros((out_features, in_features), dtype=dtype) - linear_bias = jnp.zeros(out_features, dtype=dtype) + weight = jax.random.normal(ka, (out_features, in_features), dtype=dtype) * 0.02 + linear_bias = jax.random.normal(kb, out_features, dtype=dtype) * 0.02 gn_weight = jnp.ones(out_features, dtype=dtype) gn_bias = jnp.zeros(out_features, dtype=dtype) - bias = jnp.zeros((1, out_features, 1, 1), dtype=dtype) + bias = jax.random.normal(kc, (1, out_features, 1, 1), dtype=dtype) * 0.02 return x, weight, linear_bias, gn_weight, gn_bias, bias diff --git a/JAXBench/benchmark/41k_Gemm_Add_ReLU/baseline.py b/JAXBench/benchmark/41k_Gemm_Add_ReLU/baseline.py index 8fc33fa..a7a8311 100644 --- a/JAXBench/benchmark/41k_Gemm_Add_ReLU/baseline.py +++ b/JAXBench/benchmark/41k_Gemm_Add_ReLU/baseline.py @@ -13,9 +13,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/44k_Matmul_Divide_GELU/baseline.py b/JAXBench/benchmark/44k_Matmul_Divide_GELU/baseline.py index d87be67..0928aef 100644 --- a/JAXBench/benchmark/44k_Matmul_Divide_GELU/baseline.py +++ b/JAXBench/benchmark/44k_Matmul_Divide_GELU/baseline.py @@ -14,9 +14,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias diff --git a/JAXBench/benchmark/45k_Gemm_GroupNorm_Swish_Multiply_Swish/baseline.py b/JAXBench/benchmark/45k_Gemm_GroupNorm_Swish_Multiply_Swish/baseline.py index 8259d85..20d131c 100644 --- a/JAXBench/benchmark/45k_Gemm_GroupNorm_Swish_Multiply_Swish/baseline.py +++ b/JAXBench/benchmark/45k_Gemm_GroupNorm_Swish_Multiply_Swish/baseline.py @@ -14,13 +14,15 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb, kc = jax.random.split(rand_key, 3) batch_size, in_features, out_features, num_groups = 4096, 8192, 8192, 256 x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) - gemm_weight = jnp.zeros((out_features, in_features), dtype=dtype) - gemm_bias = jnp.zeros(out_features, dtype=dtype) + gemm_weight = jax.random.normal(ka, (out_features, in_features), dtype=dtype) * 0.02 + gemm_bias = jax.random.normal(kb, out_features, dtype=dtype) * 0.02 gn_weight = jnp.ones(out_features, dtype=dtype) gn_bias = jnp.zeros(out_features, dtype=dtype) - multiply_weight = jnp.zeros(out_features, dtype=dtype) + multiply_weight = jax.random.normal(kc, out_features, dtype=dtype) * 0.02 return x, gemm_weight, gemm_bias, gn_weight, gn_bias, multiply_weight diff --git a/JAXBench/benchmark/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/baseline.py b/JAXBench/benchmark/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/baseline.py index 159113e..e45c570 100644 --- a/JAXBench/benchmark/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/baseline.py +++ b/JAXBench/benchmark/47k_Matmul_Add_Swish_Tanh_GELU_Hardtanh/baseline.py @@ -13,10 +13,12 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb, kc = jax.random.split(rand_key, 3) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) - add_value = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 + add_value = jax.random.normal(kc, 8192, dtype=dtype) * 0.02 return x, weight, bias, add_value diff --git a/JAXBench/benchmark/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/baseline.py b/JAXBench/benchmark/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/baseline.py index 7d60f67..0c053c4 100644 --- a/JAXBench/benchmark/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/baseline.py +++ b/JAXBench/benchmark/48k_Matmul_BatchNorm_BiasAdd_Divide_Swish/baseline.py @@ -16,15 +16,17 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb, kc = jax.random.split(rand_key, 3) batch_size, in_features, out_features = 4096, 8192, 8192 x = jax.random.uniform(key, (batch_size, in_features), dtype=dtype) - weight = jnp.zeros((in_features, out_features), dtype=dtype) - linear_bias = jnp.zeros(out_features, dtype=dtype) + weight = jax.random.normal(ka, (in_features, out_features), dtype=dtype) * 0.02 + linear_bias = jax.random.normal(kb, out_features, dtype=dtype) * 0.02 bn_scale = jnp.ones(out_features, dtype=dtype) bn_bias = jnp.zeros(out_features, dtype=dtype) bn_mean = jnp.zeros(out_features, dtype=dtype) bn_var = jnp.ones(out_features, dtype=dtype) - bias = jnp.zeros((1,), dtype=dtype) + bias = jax.random.normal(kc, (1,), dtype=dtype) * 0.02 return x, weight, linear_bias, bn_scale, bn_bias, bn_mean, bn_var, bias diff --git a/JAXBench/benchmark/49k_Matmul_AvgPool_GELU_Scale_Max/baseline.py b/JAXBench/benchmark/49k_Matmul_AvgPool_GELU_Scale_Max/baseline.py index 0ad64e3..a6c17a3 100644 --- a/JAXBench/benchmark/49k_Matmul_AvgPool_GELU_Scale_Max/baseline.py +++ b/JAXBench/benchmark/49k_Matmul_AvgPool_GELU_Scale_Max/baseline.py @@ -15,9 +15,11 @@ def create_inputs(dtype=jnp.float32): """Create all inputs including weights.""" key = jax.random.key(0) + rand_key = jax.random.key(0xBADC0DE) + ka, kb = jax.random.split(rand_key, 2) x = jax.random.uniform(key, (4096, 8192), dtype=dtype) - weight = jnp.zeros((8192, 8192), dtype=dtype) - bias = jnp.zeros(8192, dtype=dtype) + weight = jax.random.normal(ka, (8192, 8192), dtype=dtype) * 0.02 + bias = jax.random.normal(kb, 8192, dtype=dtype) * 0.02 return x, weight, bias