In the length_normed_log_probs function, the original code discards the last logit as well as the first sequence, and then, based on the logits, obtains the corresponding probabilities of each output token in the sequences:
logits_tensor = logits_tensor[..., :-1, :].contiguous()
sequence_ids = sequence_ids[..., 1:].contiguous()
attention_mask = attention_mask[..., 1:].contiguous() if attention_mask is not None else None
log_probs = F.log_softmax(logits_tensor, dim=-1)
selected_log_probs = log_probs.gather(2, sequence_ids.unsqueeze(-1)).squeeze(-1)
Among them, the logits_tensor and sequence_ids are generated through model.generate():
outputs = accelerator.unwrap_model(model).generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=GENERATE_MAX_NEW_TOKENS,
do_sample=True,
num_return_sequences=num_candidates,
return_dict_in_generate=True,
output_scores=True,
temperature=1.5,
output_logits=True,
stop_strings=policy_head_stopping_criteria,
tokenizer=tokenizer,
)
generated_sequences = outputs.sequences[:, inputs['input_ids'].size(1):]
generated_sequences_mask = torch.zeros_like(generated_sequences)
generated_sequences_mask[generated_sequences != tokenizer.pad_token_id] = 1
generated_texts = tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)
logits = torch.stack(outputs.logits, dim=1)
However, after slicing outputs.sequences to remove the input tokens, the length of generated_sequences is consistent with that of the logits, and the two have already been aligned internally in transformers without any misalignment. The probability distribution corresponding to the first token ID in generated_sequences comes from the first one in the logits. This can be verified by checking generated_sequences and logits.argmax(dim=-1) after setting do_sample=False:
>>>generated_sequences
tensor([[ 5331, 235303, 235256, 2409, 573, 1758, 576, 33593, 1960,
692, 8989, 697, 235251, 3306, 256008]], device='cuda:0')
>>>logits.argmax(dim=-1)
tensor([[ 5331, 235303, 235256, 2409, 573, 1758, 576, 33593, 1960,
692, 8989, 697, 235251, 3306, 256008]], device='cuda:0')
Therefore, in the length_normed_log_probs function, the correct way of writing should be:
logits_tensor = logits_tensor.contiguous()
sequence_ids = sequence_ids.contiguous()
attention_mask = attention_mask.contiguous() if attention_mask is not None else None
Note:If the logits are obtained by directly passing the input tokens through model.forward(), it can be found that the output logits of the last input token are consistent with the first logits obtained through model.generate(). This also indicates that the logits obtained by model.generate() exactly correspond to each output token:
>>>model(**inputs, return_dict=True).logits[:,-1,:]
tensor([[ 1.7969, 17.7500, 13.4375, ..., 11.6250, 12.4375, 11.2500]],
device='cuda:0', dtype=torch.bfloat16)
>>>torch.stack(outputs.logits, dim=1)[:,0,:]
tensor([[ 1.7969, 17.7500, 13.4375, ..., 11.6250, 12.4375, 11.2500]],
device='cuda:0')
In the
length_normed_log_probsfunction, the original code discards the last logit as well as the first sequence, and then, based on the logits, obtains the corresponding probabilities of each output token in the sequences:Among them, the
logits_tensorandsequence_idsare generated through model.generate():However, after slicing
outputs.sequencesto remove the input tokens, the length ofgenerated_sequencesis consistent with that of thelogits, and the two have already been aligned internally intransformerswithout any misalignment. The probability distribution corresponding to the first token ID ingenerated_sequencescomes from the first one in thelogits. This can be verified by checkinggenerated_sequencesandlogits.argmax(dim=-1)after settingdo_sample=False:Therefore, in the
length_normed_log_probsfunction, the correct way of writing should be:Note:If the
logitsare obtained by directly passing the input tokens throughmodel.forward(), it can be found that the output logits of the last input token are consistent with the first logits obtained throughmodel.generate(). This also indicates that the logits obtained bymodel.generate()exactly correspond to each output token: