diff --git a/gemma/gm/nn/_modules.py b/gemma/gm/nn/_modules.py index b1bbc789..0afc1190 100644 --- a/gemma/gm/nn/_modules.py +++ b/gemma/gm/nn/_modules.py @@ -241,7 +241,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -285,7 +285,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma3n/_modules.py b/gemma/gm/nn/gemma3n/_modules.py index c6c3f562..b7cc7ea7 100644 --- a/gemma/gm/nn/gemma3n/_modules.py +++ b/gemma/gm/nn/gemma3n/_modules.py @@ -339,7 +339,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -384,7 +384,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/gm/nn/gemma4/_modules.py b/gemma/gm/nn/gemma4/_modules.py index cb8cf375..95df7061 100644 --- a/gemma/gm/nn/gemma4/_modules.py +++ b/gemma/gm/nn/gemma4/_modules.py @@ -323,7 +323,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_proj.shape query_proj = query_proj.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_proj, key_proj) b, t, k, g, s = logits.shape @@ -362,7 +362,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape diff --git a/gemma/research/t5gemma/modules.py b/gemma/research/t5gemma/modules.py index c2d34588..32596122 100644 --- a/gemma/research/t5gemma/modules.py +++ b/gemma/research/t5gemma/modules.py @@ -237,7 +237,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = query_scaled.shape query_scaled = query_scaled.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) b, t, k, g, s = logits.shape @@ -269,7 +269,7 @@ def __call__( # Reshape matrices to enable einsums over groups. b, t, kg, h = probs.shape probs = probs.reshape( - (b, t, self.num_kv_heads, int(kg / self.num_kv_heads), h) + (b, t, self.num_kv_heads, kg // self.num_kv_heads, h) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) b, t, k, g, h = encoded.shape