Skip to content

In the length_normed_log_probs function, the log_prob of the output token was obtained incorrectly #27

@TraceIvan

Description

@TraceIvan

In the length_normed_log_probs function, the original code discards the last logit as well as the first sequence, and then, based on the logits, obtains the corresponding probabilities of each output token in the sequences:

logits_tensor = logits_tensor[..., :-1, :].contiguous()
sequence_ids = sequence_ids[..., 1:].contiguous()
attention_mask = attention_mask[..., 1:].contiguous() if attention_mask is not None else None
log_probs = F.log_softmax(logits_tensor, dim=-1)
selected_log_probs = log_probs.gather(2, sequence_ids.unsqueeze(-1)).squeeze(-1)

Among them, the logits_tensor and sequence_ids are generated through model.generate():

outputs = accelerator.unwrap_model(model).generate(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    max_new_tokens=GENERATE_MAX_NEW_TOKENS,
    do_sample=True,
    num_return_sequences=num_candidates,
    return_dict_in_generate=True,
    output_scores=True,
    temperature=1.5,
    output_logits=True,
    stop_strings=policy_head_stopping_criteria,
    tokenizer=tokenizer,
)

generated_sequences = outputs.sequences[:, inputs['input_ids'].size(1):]
generated_sequences_mask = torch.zeros_like(generated_sequences)
generated_sequences_mask[generated_sequences != tokenizer.pad_token_id] = 1
generated_texts = tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)

logits = torch.stack(outputs.logits, dim=1)

However, after slicing outputs.sequences to remove the input tokens, the length of generated_sequences is consistent with that of the logits, and the two have already been aligned internally in transformers without any misalignment. The probability distribution corresponding to the first token ID in generated_sequences comes from the first one in the logits. This can be verified by checking generated_sequences and logits.argmax(dim=-1) after setting do_sample=False:

>>>generated_sequences
tensor([[  5331, 235303, 235256,   2409,    573,   1758,    576,  33593,   1960,
            692,   8989,    697, 235251,   3306, 256008]], device='cuda:0')
>>>logits.argmax(dim=-1)
tensor([[  5331, 235303, 235256,   2409,    573,   1758,    576,  33593,   1960,
            692,   8989,    697, 235251,   3306, 256008]], device='cuda:0')

Therefore, in the length_normed_log_probs function, the correct way of writing should be:

logits_tensor = logits_tensor.contiguous()
sequence_ids = sequence_ids.contiguous()
attention_mask = attention_mask.contiguous() if attention_mask is not None else None

Note:If the logits are obtained by directly passing the input tokens through model.forward(), it can be found that the output logits of the last input token are consistent with the first logits obtained through model.generate(). This also indicates that the logits obtained by model.generate() exactly correspond to each output token:

>>>model(**inputs, return_dict=True).logits[:,-1,:]
tensor([[ 1.7969, 17.7500, 13.4375,  ..., 11.6250, 12.4375, 11.2500]],
       device='cuda:0', dtype=torch.bfloat16)
>>>torch.stack(outputs.logits, dim=1)[:,0,:]
tensor([[ 1.7969, 17.7500, 13.4375,  ..., 11.6250, 12.4375, 11.2500]],
       device='cuda:0')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions