diff --git a/examples/classification.py b/examples/classification.py index 4ffdbc7a..882fc876 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -82,6 +82,13 @@ def _make_dataset(training: bool) -> kd.data.Pipeline: _LABEL_FIELD = "label" # pylint: disable=invalid-name tokenizer = gm.text.Gemma3Tokenizer() + yes_tokens = tokenizer.encode("Yes", add_special_tokens=False) + no_tokens = tokenizer.encode("No", add_special_tokens=False) + + if len(yes_tokens) != 1 or len(no_tokens) != 1: + raise ValueError( + "'Yes' and 'No' must map to a single token for classification." + ) return kd.data.py.Tfds( name="glue/cola", @@ -96,7 +103,7 @@ def _make_dataset(training: bool) -> kd.data.Pipeline: gm.data.FormatText( key=_INPUT_FIELD, template="""user - Please classify whether the following sentence is grammaticaly correct, please answer only with Yes or No. + Please classify whether the following sentence is grammatically correct, please answer only with Yes or No. Sentence: {text} model""", ), @@ -110,18 +117,17 @@ def _make_dataset(training: bool) -> kd.data.Pipeline: max_length=128, ), # Process the label - gm.data.MapInts( - key=_LABEL_FIELD, - # Rather than predicting the token 0 and 1, we are using the - # token 1294 and 3553 which respectivelly correspond to "No" and - # "Yes". We do this because those token already contain semantic - # information, so even zero-shot prediction without any - # finetuning has better than random performances. - old_to_new={ - 0: 1294, # Token -> "No" - 1: 3553, # Token -> "Yes" - }, - ), + gm.data.MapInts( + key=_LABEL_FIELD, + # Rather than predicting tokens 0 and 1, we map labels to the + # tokenizer-derived token IDs for "No" and "Yes". These tokens + # contain semantic information, which improves zero-shot + # performance even without finetuning. + old_to_new={ + 0: no_tokens[0], # "No" + 1: yes_tokens[0], # "Yes" + }, +), kd.data.Rearrange( key=_LABEL_FIELD, pattern="... -> ... 1", # For shape compatibility with the loss.