diff --git a/gemma/gm/nn/_modules.py b/gemma/gm/nn/_modules.py index b1bbc789..4c7648ce 100644 --- a/gemma/gm/nn/_modules.py +++ b/gemma/gm/nn/_modules.py @@ -269,7 +269,7 @@ def __call__( attn_mask *= sliding_mask elif self.attn_type != AttentionType.GLOBAL: raise ValueError( - 'Attn_type must be either AttentionType.GLOBAL or' + 'Attn_type must be either AttentionType.LOCAL_SLIDING or' f' AttentionType.GLOBAL not {self.attn_type}' )