From ec74c67093040d17de8e55fcfbfacdb38029bfd4 Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Fri, 29 May 2026 15:40:48 +0100 Subject: [PATCH] fix(gqa): replace float division with integer division in GQA reshape In grouped query attention (GQA), the number of query heads per KV head was computed as int(kg / self.num_kv_heads) using float division + int(). When kg is not exactly divisible by num_kv_heads, float division produces a non-integer result that int() silently truncates. This yields an incorrect reshape dimension with no error, causing silent shape corruption or an unexpected crash in non-standard head configurations. Replaced all 8 occurrences across 4 files with integer division (//), which is semantically correct for integer tensor dimensions and makes the intent explicit. Affected files: - gemma/gm/nn/_modules.py (lines 244, 288) - gemma/gm/nn/gemma4/_modules.py (lines 326, 365) - gemma/gm/nn/gemma3n/_modules.py (lines 342, 387) - gemma/research/t5gemma/modules.py (lines 240, 272) Fixes: https://github.com/google-deepmind/gemma/issues/641 --- gemma/gm/nn/_modules.py | 4 ++-- gemma/gm/nn/gemma3n/_modules.py | 4 ++-- gemma/gm/nn/gemma4/_modules.py | 4 ++-- gemma/research/t5gemma/modules.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gemma/gm/nn/_modules.py b/gemma/gm/nn/_modules.py index b1bbc789..0afc1190 100644 --- a/gemma/gm/nn/_modules.py +++ b/gemma/gm/nn/_modules.py @@ -241,7 +241,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -285,7 +285,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma3n/_modules.py b/gemma/gm/nn/gemma3n/_modules.py index c6c3f562..b7cc7ea7 100644 --- a/gemma/gm/nn/gemma3n/_modules.py +++ b/gemma/gm/nn/gemma3n/_modules.py @@ -339,7 +339,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -384,7 +384,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma4/_modules.py b/gemma/gm/nn/gemma4/_modules.py index cb8cf375..95df7061 100644 --- a/gemma/gm/nn/gemma4/_modules.py +++ b/gemma/gm/nn/gemma4/_modules.py @@ -323,7 +323,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_proj.shape query_proj = query_proj.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_proj, key_proj) b, t, k, g, s = logits.shape @@ -362,7 +362,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/research/t5gemma/modules.py b/gemma/research/t5gemma/modules.py index c2d34588..32596122 100644 --- a/gemma/research/t5gemma/modules.py +++ b/gemma/research/t5gemma/modules.py @@ -237,7 +237,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -269,7 +269,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape