Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions aphrodite/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,44 @@ def _execute_samplers_in_order(
Returns:
Modified logits tensor after applying samplers in priority order
"""
# Check if mirostat is active - if so, disable other samplers
has_mirostat = False
# apply_mirostat only touches its own (mode==2) rows, so the remaining
# rows in a mixed batch must still go through the normal sampler order
mirostat_rows: list[int] = []
if (
sampling_metadata.mirostat_mode is not None
and sampling_metadata.mirostat_tau is not None
and sampling_metadata.mirostat_eta is not None
):
batch_size = len(sampling_metadata.output_token_ids)
has_mirostat = any(sampling_metadata.mirostat_mode[i].item() == 2 for i in range(batch_size))
mirostat_rows = [
i for i in range(batch_size) if sampling_metadata.mirostat_mode[i].item() == 2
]

if has_mirostat:
# Mirostat is active - only apply mirostat and skip other samplers
logger.debug("Mirostat active - applying mirostat only")
if mirostat_rows:
logger.debug("Mirostat active for %d request(s)", len(mirostat_rows))
logits = self.sampling_ops.apply_mirostat(logits, sampling_metadata)

mirostat_set = set(mirostat_rows)
non_mirostat_indices = [
i for i in range(len(sampling_metadata.output_token_ids)) if i not in mirostat_set
]
if non_mirostat_indices:
non_mirostat_metadata = self._subset_sampling_metadata(
sampling_metadata, non_mirostat_indices
)
logits[non_mirostat_indices] = self._apply_normal_sampler_order(
logits[non_mirostat_indices], non_mirostat_metadata
)
return logits

return self._apply_normal_sampler_order(logits, sampling_metadata)

def _apply_normal_sampler_order(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Apply the standard sampler order, honoring per-row temperature_last."""
temperature_last_flags = sampling_metadata.temperature_last
if not temperature_last_flags:
return self._apply_sampler_order(logits, sampling_metadata, do_temperature_last=False)
Expand Down
30 changes: 30 additions & 0 deletions tests/v1/sample/test_aphrodite_sampler_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,33 @@ def test_input_batch_keeps_token_history_for_no_repeat_ngram_only():
assert metadata.no_repeat_ngram_size is not None
assert metadata.prompt_token_ids is not None
assert metadata.output_token_ids == [[]]


def test_mixed_mirostat_batch_still_runs_normal_samplers_on_other_rows():
from types import SimpleNamespace

sampler = Sampler.__new__(Sampler)
normal_calls = []
sampler._apply_normal_sampler_order = lambda logits, md: (
normal_calls.append(list(md.indices)) or logits
)
sampler._subset_sampling_metadata = lambda md, indices: SimpleNamespace(indices=indices)
sampler.sampling_ops = SimpleNamespace(apply_mirostat=lambda logits, md: logits)

def meta(modes):
return SimpleNamespace(
mirostat_mode=torch.tensor(modes),
mirostat_tau=torch.ones(len(modes)),
mirostat_eta=torch.ones(len(modes)),
output_token_ids=[[] for _ in modes],
)

# row 0 uses mirostat v2, row 1 does not -> row 1 must still get the normal order
normal_calls.clear()
sampler._execute_samplers_in_order(torch.randn(2, 5), meta([2, 0]))
assert normal_calls == [[1]]

# every row uses mirostat -> nothing left for the normal order
normal_calls.clear()
sampler._execute_samplers_in_order(torch.randn(2, 5), meta([2, 2]))
assert normal_calls == []