From 9ab29d7578c8344d10c6906f3e28d0eca7914a9f Mon Sep 17 00:00:00 2001 From: Carlos Mullov Date: Fri, 6 Mar 2026 11:59:30 +0100 Subject: [PATCH 1/5] accumulate and return logits in beam search --- src/decoding.cc | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/decoding.cc b/src/decoding.cc index 84f39ac37..5be8c95ba 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -482,6 +482,7 @@ namespace ctranslate2 { } const bool use_hard_prefix = prefix_ids && !bias_towards_prefix; + StorageView alive_logits(dtype, device); StorageView logits(dtype, device); StorageView alive_seq(topk_ids.dtype()); StorageView alive_attention; @@ -531,6 +532,18 @@ namespace ctranslate2 { logits_vec = build_logits(logits, cur_batch_size * _beam_size); else logits_vec = build_logits(logits, cur_batch_size); + + // Accumulate logits across time steps + StorageView logits_reshaped = logits; + if (is_expanded) { + logits_reshaped.reshape({cur_batch_size, _beam_size, 1, vocabulary_size}); + const StorageView cur_alive_logits(std::move(alive_logits)); + ops::Concat(2)({&cur_alive_logits, &logits_reshaped}, alive_logits); + } + else { + logits_reshaped.reshape({cur_batch_size, 1, 1, vocabulary_size}); + ops::Tile(/*axis=*/1, _beam_size)(std::move(logits_reshaped), alive_logits); + } } StorageView log_probs(dtype, device); @@ -631,7 +644,17 @@ namespace ctranslate2 { if (alive_attention) result.attention.emplace_back(build_attention(alive_attention, i, k, start, end)); if (return_logits_vocab) { - result.logits_vocab.emplace_back(std::move(logits_vec[i * k])); + if (is_expanded) { + const dim_t flat_index = i * _beam_size + k; + merge_batch_beam(alive_logits); + StorageView hyp_logits; + ops::Slide slide_axis0(0, flat_index, 1); + slide_axis0(alive_logits, hyp_logits); + result.logits_vocab.emplace_back(std::vector{std::move(hyp_logits.squeeze(0))}); + split_batch_beam(alive_logits, _beam_size); + } else { + // TODO + } } // Move another active beam to this position. @@ -686,6 +709,8 @@ namespace ctranslate2 { gather_beam_flat(alive_seq, active_beams, _beam_size); if (alive_attention) gather_beam_flat(alive_attention, active_beams, _beam_size); + if (alive_logits) + gather_beam_flat(alive_logits, gather_indices, _beam_size); // If some sentences finished on this step, ignore them for the next step. std::unique_ptr keep_batches; @@ -701,6 +726,9 @@ namespace ctranslate2 { gather(alive_seq, *keep_batches); if (alive_attention) gather(alive_attention, *keep_batches); + if (alive_logits) + // TODO double check whether this is correct + gather(alive_logits, *keep_batches); if (keep_batches->device() != device) *keep_batches = keep_batches->to(device); } From f8411b78271b1f4429a7ef1e3c8ee6edc1484b67 Mon Sep 17 00:00:00 2001 From: Carlos Mullov Date: Fri, 6 Mar 2026 13:06:50 +0100 Subject: [PATCH 2/5] correctly build logit vector in beam search result --- src/decoding.cc | 48 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/src/decoding.cc b/src/decoding.cc index 5be8c95ba..da28b1b16 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -173,6 +173,39 @@ namespace ctranslate2 { return logits; } + static std::vector build_logits_for_beamsearch(StorageView& history, + const dim_t batch, + const dim_t beam) { + if (!history) + return {}; + + const dim_t beam_size = history.dim(1); + merge_batch_beam(history); + + // get target beam logits + const dim_t flat_index = batch * beam_size + beam; + ops::Slide slide_axis0(0, flat_index, 1); + StorageView hyp_logits; + slide_axis0(history, hyp_logits); + hyp_logits.squeeze(0); + + // save token logits to std::vector + const dim_t seq_len = hyp_logits.dim(0); + std::vector logits; + logits.reserve(seq_len); + for (dim_t t = 0; t < seq_len; ++t) { + ops::Slide slide(0, t, 1); + StorageView tmp(hyp_logits.dtype(), hyp_logits.device()); + slide(hyp_logits, tmp); + logits.emplace_back(std::move(tmp.squeeze(0))); + } + + // restore original shape + split_batch_beam(history, beam_size); + + return logits; + } + static float compute_coverage_penalty(const std::vector>& attention, const float beta) { float penalty = 0; @@ -643,19 +676,8 @@ namespace ctranslate2 { result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end)); if (alive_attention) result.attention.emplace_back(build_attention(alive_attention, i, k, start, end)); - if (return_logits_vocab) { - if (is_expanded) { - const dim_t flat_index = i * _beam_size + k; - merge_batch_beam(alive_logits); - StorageView hyp_logits; - ops::Slide slide_axis0(0, flat_index, 1); - slide_axis0(alive_logits, hyp_logits); - result.logits_vocab.emplace_back(std::vector{std::move(hyp_logits.squeeze(0))}); - split_batch_beam(alive_logits, _beam_size); - } else { - // TODO - } - } + if (return_logits_vocab) + result.logits_vocab.emplace_back(build_logits_for_beamsearch(alive_logits, i, k)); // Move another active beam to this position. for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) { From bf00a3ed5bd3e1e58e2dd0ab7a7fd0e2dde1c190 Mon Sep 17 00:00:00 2001 From: Carlos Mullov Date: Sun, 8 Mar 2026 00:32:06 +0100 Subject: [PATCH 3/5] fix issue with returned logits in beam search --- src/decoding.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/decoding.cc b/src/decoding.cc index da28b1b16..c3369579d 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -676,8 +676,11 @@ namespace ctranslate2 { result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end)); if (alive_attention) result.attention.emplace_back(build_attention(alive_attention, i, k, start, end)); - if (return_logits_vocab) - result.logits_vocab.emplace_back(build_logits_for_beamsearch(alive_logits, i, k)); + if (return_logits_vocab) { + // map candidates to original beam + const dim_t logit_beam = gather_indices.at(i * num_candidates + k); + result.logits_vocab.emplace_back(build_logits_for_beamsearch(alive_logits, i, logit_beam)); + } // Move another active beam to this position. for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) { From 00092eb7f19e1991e3df0177092aa42fa3f177c1 Mon Sep 17 00:00:00 2001 From: Carlos Mullov Date: Sun, 8 Mar 2026 00:33:47 +0100 Subject: [PATCH 4/5] remove obsolete code in beam search --- src/decoding.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/decoding.cc b/src/decoding.cc index c3369579d..e5cd6321d 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -559,13 +559,7 @@ namespace ctranslate2 { } disable_tokens.apply(); - std::vector logits_vec; if (return_logits_vocab) { - if (is_expanded) - logits_vec = build_logits(logits, cur_batch_size * _beam_size); - else - logits_vec = build_logits(logits, cur_batch_size); - // Accumulate logits across time steps StorageView logits_reshaped = logits; if (is_expanded) { From 933c522efadfe3bea003c1b3759de2719bd6215c Mon Sep 17 00:00:00 2001 From: Carlos Mullov Date: Sun, 8 Mar 2026 01:34:30 +0100 Subject: [PATCH 5/5] fix issue with expansion when returning logits from beam search --- src/decoding.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/decoding.cc b/src/decoding.cc index e5cd6321d..4c9e314cf 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -562,14 +562,19 @@ namespace ctranslate2 { if (return_logits_vocab) { // Accumulate logits across time steps StorageView logits_reshaped = logits; - if (is_expanded) { + if (alive_logits) { logits_reshaped.reshape({cur_batch_size, _beam_size, 1, vocabulary_size}); const StorageView cur_alive_logits(std::move(alive_logits)); ops::Concat(2)({&cur_alive_logits, &logits_reshaped}, alive_logits); } else { - logits_reshaped.reshape({cur_batch_size, 1, 1, vocabulary_size}); - ops::Tile(/*axis=*/1, _beam_size)(std::move(logits_reshaped), alive_logits); + if (is_expanded) { + logits_reshaped.reshape({cur_batch_size, _beam_size, 1, vocabulary_size}); + alive_logits = std::move(logits_reshaped); + } else { + logits_reshaped.reshape({cur_batch_size, 1, 1, vocabulary_size}); + ops::Tile(/*axis=*/1, _beam_size)(std::move(logits_reshaped), alive_logits); + } } }