diff --git a/gemma/gm/text/_chat_sampler.py b/gemma/gm/text/_chat_sampler.py index 12f6d04d..bbb24522 100644 --- a/gemma/gm/text/_chat_sampler.py +++ b/gemma/gm/text/_chat_sampler.py @@ -272,9 +272,9 @@ def _remove_eos_token( return dataclasses.replace( state, step=state.step - 1, - # done is True and last_token is EOS => False + # done is True and last_token is EOS => False (un-done it) # Otherwise, keep the same. - done=state.done ^ (state.last_token == tokenizer.special_tokens.EOS), + done=state.done & ~(state.last_token == tokenizer.special_tokens.EOS), last_token_pos=state.last_token_pos - 1, cache=cache_info.cache, ) diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 5b8ae163..710e321d 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -577,7 +577,7 @@ def _normalize_token(tokenizer, token: str | int) -> int: token_id = tokenizer.encode(token) if len(token_id) != 1: raise ValueError( - 'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' + f'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' ' map to single token ids in the vocab.' ) (token_id,) = token_id