From da56ea06d8e3138ca9c7bb67d46b2018ec61b085 Mon Sep 17 00:00:00 2001 From: tomohiro86 Date: Mon, 4 May 2026 09:44:46 +0900 Subject: [PATCH] fix: use integer division (//) instead of float division (/) in GQA reshape Replaces `int(kg / self.num_kv_heads)` with `kg // self.num_kv_heads` in all GQA reshape operations across 4 modules. Using float division with int() silently truncates when kg is not evenly divisible by num_kv_heads, producing an incorrect tensor shape with no error. Fixes #641. --- gemma/gm/nn/_modules.py | 4 ++-- gemma/gm/nn/gemma3n/_modules.py | 4 ++-- gemma/gm/nn/gemma4/_modules.py | 4 ++-- gemma/research/t5gemma/modules.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gemma/gm/nn/_modules.py b/gemma/gm/nn/_modules.py index b1bbc789..0afc1190 100644 --- a/gemma/gm/nn/_modules.py +++ b/gemma/gm/nn/_modules.py @@ -241,7 +241,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -285,7 +285,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma3n/_modules.py b/gemma/gm/nn/gemma3n/_modules.py index c6c3f562..b7cc7ea7 100644 --- a/gemma/gm/nn/gemma3n/_modules.py +++ b/gemma/gm/nn/gemma3n/_modules.py @@ -339,7 +339,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -384,7 +384,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma4/_modules.py b/gemma/gm/nn/gemma4/_modules.py index 087f4acb..b7d6525f 100644 --- a/gemma/gm/nn/gemma4/_modules.py +++ b/gemma/gm/nn/gemma4/_modules.py @@ -321,7 +321,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_proj.shape query_proj = query_proj.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_proj, key_proj) b, t, k, g, s = logits.shape @@ -360,7 +360,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/research/t5gemma/modules.py b/gemma/research/t5gemma/modules.py index c2d34588..32596122 100644 --- a/gemma/research/t5gemma/modules.py +++ b/gemma/research/t5gemma/modules.py @@ -237,7 +237,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -269,7 +269,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape