diff --git a/src/decoding.cc b/src/decoding.cc index 84f39ac37..4c9e314cf 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -173,6 +173,39 @@ namespace ctranslate2 { return logits; } + static std::vector build_logits_for_beamsearch(StorageView& history, + const dim_t batch, + const dim_t beam) { + if (!history) + return {}; + + const dim_t beam_size = history.dim(1); + merge_batch_beam(history); + + // get target beam logits + const dim_t flat_index = batch * beam_size + beam; + ops::Slide slide_axis0(0, flat_index, 1); + StorageView hyp_logits; + slide_axis0(history, hyp_logits); + hyp_logits.squeeze(0); + + // save token logits to std::vector + const dim_t seq_len = hyp_logits.dim(0); + std::vector logits; + logits.reserve(seq_len); + for (dim_t t = 0; t < seq_len; ++t) { + ops::Slide slide(0, t, 1); + StorageView tmp(hyp_logits.dtype(), hyp_logits.device()); + slide(hyp_logits, tmp); + logits.emplace_back(std::move(tmp.squeeze(0))); + } + + // restore original shape + split_batch_beam(history, beam_size); + + return logits; + } + static float compute_coverage_penalty(const std::vector>& attention, const float beta) { float penalty = 0; @@ -482,6 +515,7 @@ namespace ctranslate2 { } const bool use_hard_prefix = prefix_ids && !bias_towards_prefix; + StorageView alive_logits(dtype, device); StorageView logits(dtype, device); StorageView alive_seq(topk_ids.dtype()); StorageView alive_attention; @@ -525,12 +559,23 @@ namespace ctranslate2 { } disable_tokens.apply(); - std::vector logits_vec; if (return_logits_vocab) { - if (is_expanded) - logits_vec = build_logits(logits, cur_batch_size * _beam_size); - else - logits_vec = build_logits(logits, cur_batch_size); + // Accumulate logits across time steps + StorageView logits_reshaped = logits; + if (alive_logits) { + logits_reshaped.reshape({cur_batch_size, _beam_size, 1, vocabulary_size}); + const StorageView cur_alive_logits(std::move(alive_logits)); + ops::Concat(2)({&cur_alive_logits, &logits_reshaped}, alive_logits); + } + else { + if (is_expanded) { + logits_reshaped.reshape({cur_batch_size, _beam_size, 1, vocabulary_size}); + alive_logits = std::move(logits_reshaped); + } else { + logits_reshaped.reshape({cur_batch_size, 1, 1, vocabulary_size}); + ops::Tile(/*axis=*/1, _beam_size)(std::move(logits_reshaped), alive_logits); + } + } } StorageView log_probs(dtype, device); @@ -631,7 +676,9 @@ namespace ctranslate2 { if (alive_attention) result.attention.emplace_back(build_attention(alive_attention, i, k, start, end)); if (return_logits_vocab) { - result.logits_vocab.emplace_back(std::move(logits_vec[i * k])); + // map candidates to original beam + const dim_t logit_beam = gather_indices.at(i * num_candidates + k); + result.logits_vocab.emplace_back(build_logits_for_beamsearch(alive_logits, i, logit_beam)); } // Move another active beam to this position. @@ -686,6 +733,8 @@ namespace ctranslate2 { gather_beam_flat(alive_seq, active_beams, _beam_size); if (alive_attention) gather_beam_flat(alive_attention, active_beams, _beam_size); + if (alive_logits) + gather_beam_flat(alive_logits, gather_indices, _beam_size); // If some sentences finished on this step, ignore them for the next step. std::unique_ptr keep_batches; @@ -701,6 +750,9 @@ namespace ctranslate2 { gather(alive_seq, *keep_batches); if (alive_attention) gather(alive_attention, *keep_batches); + if (alive_logits) + // TODO double check whether this is correct + gather(alive_logits, *keep_batches); if (keep_batches->device() != device) *keep_batches = keep_batches->to(device); }