From 71af9ccab42bdd63eac956d46426a35d6f0bc976 Mon Sep 17 00:00:00 2001 From: Prathamesh Jadhav <55660103+lollinng@users.noreply.github.com> Date: Sat, 6 Jun 2026 02:26:40 +0530 Subject: [PATCH] Don't disable other samplers for non-mirostat rows in a mixed batch _execute_samplers_in_order returned right after apply_mirostat whenever any request in the batch used mirostat v2, skipping temperature/top-p/top-k/ penalties for every co-batched request. apply_mirostat already restricts itself to its own (mode==2) rows, so route the remaining rows through the normal sampler order instead of returning early. Co-authored-by: Claude Opus 4.8 (1M context) --- aphrodite/v1/sample/sampler.py | 34 +++++++++++++++---- tests/v1/sample/test_aphrodite_sampler_ops.py | 30 ++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/aphrodite/v1/sample/sampler.py b/aphrodite/v1/sample/sampler.py index 57a0267d4c..d4aa57fd5b 100644 --- a/aphrodite/v1/sample/sampler.py +++ b/aphrodite/v1/sample/sampler.py @@ -171,22 +171,44 @@ def _execute_samplers_in_order( Returns: Modified logits tensor after applying samplers in priority order """ - # Check if mirostat is active - if so, disable other samplers - has_mirostat = False + # apply_mirostat only touches its own (mode==2) rows, so the remaining + # rows in a mixed batch must still go through the normal sampler order + mirostat_rows: list[int] = [] if ( sampling_metadata.mirostat_mode is not None and sampling_metadata.mirostat_tau is not None and sampling_metadata.mirostat_eta is not None ): batch_size = len(sampling_metadata.output_token_ids) - has_mirostat = any(sampling_metadata.mirostat_mode[i].item() == 2 for i in range(batch_size)) + mirostat_rows = [ + i for i in range(batch_size) if sampling_metadata.mirostat_mode[i].item() == 2 + ] - if has_mirostat: - # Mirostat is active - only apply mirostat and skip other samplers - logger.debug("Mirostat active - applying mirostat only") + if mirostat_rows: + logger.debug("Mirostat active for %d request(s)", len(mirostat_rows)) logits = self.sampling_ops.apply_mirostat(logits, sampling_metadata) + + mirostat_set = set(mirostat_rows) + non_mirostat_indices = [ + i for i in range(len(sampling_metadata.output_token_ids)) if i not in mirostat_set + ] + if non_mirostat_indices: + non_mirostat_metadata = self._subset_sampling_metadata( + sampling_metadata, non_mirostat_indices + ) + logits[non_mirostat_indices] = self._apply_normal_sampler_order( + logits[non_mirostat_indices], non_mirostat_metadata + ) return logits + return self._apply_normal_sampler_order(logits, sampling_metadata) + + def _apply_normal_sampler_order( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + """Apply the standard sampler order, honoring per-row temperature_last.""" temperature_last_flags = sampling_metadata.temperature_last if not temperature_last_flags: return self._apply_sampler_order(logits, sampling_metadata, do_temperature_last=False) diff --git a/tests/v1/sample/test_aphrodite_sampler_ops.py b/tests/v1/sample/test_aphrodite_sampler_ops.py index 4746119c4d..fd6e328df6 100644 --- a/tests/v1/sample/test_aphrodite_sampler_ops.py +++ b/tests/v1/sample/test_aphrodite_sampler_ops.py @@ -676,3 +676,33 @@ def test_input_batch_keeps_token_history_for_no_repeat_ngram_only(): assert metadata.no_repeat_ngram_size is not None assert metadata.prompt_token_ids is not None assert metadata.output_token_ids == [[]] + + +def test_mixed_mirostat_batch_still_runs_normal_samplers_on_other_rows(): + from types import SimpleNamespace + + sampler = Sampler.__new__(Sampler) + normal_calls = [] + sampler._apply_normal_sampler_order = lambda logits, md: ( + normal_calls.append(list(md.indices)) or logits + ) + sampler._subset_sampling_metadata = lambda md, indices: SimpleNamespace(indices=indices) + sampler.sampling_ops = SimpleNamespace(apply_mirostat=lambda logits, md: logits) + + def meta(modes): + return SimpleNamespace( + mirostat_mode=torch.tensor(modes), + mirostat_tau=torch.ones(len(modes)), + mirostat_eta=torch.ones(len(modes)), + output_token_ids=[[] for _ in modes], + ) + + # row 0 uses mirostat v2, row 1 does not -> row 1 must still get the normal order + normal_calls.clear() + sampler._execute_samplers_in_order(torch.randn(2, 5), meta([2, 0])) + assert normal_calls == [[1]] + + # every row uses mirostat -> nothing left for the normal order + normal_calls.clear() + sampler._execute_samplers_in_order(torch.randn(2, 5), meta([2, 2])) + assert normal_calls == []