From 00281b4a72b90de12b26ed49b845030751d310d5 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:49:25 +0800 Subject: [PATCH 1/7] feat: support vllm offline batching --- docs/models/vllm-offline.md | 10 ++ docs/models/vllm-offline.zh.md | 10 ++ src/gimkit/models/utils.py | 74 +++++++++++++- src/gimkit/models/vllm_offline.py | 108 ++++++++++++++++++-- tests/models/test_utils.py | 54 ++++++++++ tests/models/test_vllm_offline.py | 158 ++++++++++++++++++++++++++++-- 6 files changed, 398 insertions(+), 16 deletions(-) diff --git a/docs/models/vllm-offline.md b/docs/models/vllm-offline.md index 68aa63e..dbc1118 100644 --- a/docs/models/vllm-offline.md +++ b/docs/models/vllm-offline.md @@ -37,6 +37,16 @@ result = model(query) result_non_gim = model(query, use_gim_prompt=True) ``` +## Batch inference + +`model.batch(...)` wraps Outlines' batch API for vLLM offline. +Each query can use its own GIM-derived structured output schema. + +```python +batch_results = model.batch([query, query]) +first_result = batch_results[0][0] +``` + ## Output types ### `output_type="cfg"` (default) diff --git a/docs/models/vllm-offline.zh.md b/docs/models/vllm-offline.zh.md index a3edff0..d29987b 100644 --- a/docs/models/vllm-offline.zh.md +++ b/docs/models/vllm-offline.zh.md @@ -37,6 +37,16 @@ result = model(query) result_non_gim = model(query, use_gim_prompt=True) ``` +## 批量推理 + +`model.batch(...)` 会包装 Outlines 的 vLLM offline batch API。 +每条 query 都可以使用各自从 GIM 推导出的结构化输出 schema。 + +```python +batch_results = model.batch([query, query]) +first_result = batch_results[0][0] +``` + ## 输出类型 ### `output_type="cfg"`(默认) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 878f7bc..0a05813 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -1,4 +1,5 @@ -from typing import Literal, overload +from collections.abc import Sequence +from typing import Any, Literal, overload from outlines.inputs import Chat from outlines.types.dsl import CFG, JsonSchema @@ -14,6 +15,10 @@ from gimkit.schemas import ContextInput, MaskedTag, TagField +def _ensure_query(model_input: ContextInput | Query) -> Query: + return model_input if isinstance(model_input, Query) else Query(model_input) + + def get_outlines_model_input( model_input: ContextInput | Query, output_type: Literal["cfg", "json"] | None, @@ -31,7 +36,7 @@ def get_outlines_model_input( If None, uses the Query default (["id", "desc", "content"]). Example: ["id", "name", "desc", "content", "regex"] to expose all fields. """ - query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input + query_obj = _ensure_query(model_input) outlines_model_input: str | Chat = ( query_obj.to_string(fields=visible_tag_fields) if visible_tag_fields is not None @@ -57,11 +62,31 @@ def get_outlines_model_input( return outlines_model_input +def get_outlines_model_inputs( + model_inputs: Sequence[ContextInput | Query], + output_type: Literal["cfg", "json"] | None, + use_gim_prompt: bool, + visible_tag_fields: list[TagField] | None = None, +) -> list[str | Chat]: + """Transform a batch of model inputs to Outlines-compatible formats.""" + if len(model_inputs) == 0: + raise ValueError("Batch input list is empty.") + return [ + get_outlines_model_input( + model_input, + output_type, + use_gim_prompt, + visible_tag_fields=visible_tag_fields, + ) + for model_input in model_inputs + ] + + def get_outlines_output_type( model_input: ContextInput | Query, output_type: Literal["cfg", "json"] | None ) -> None | CFG | JsonSchema: """Transform the output type to an Outlines-compatible format.""" - query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input + query_obj = _ensure_query(model_input) if output_type is None: return None elif output_type == "cfg": @@ -157,3 +182,46 @@ def infill_responses( raise TypeError(f"All items in the response list must be strings, got: {responses}") return [infill_responses(query, resp, json_responses=json_responses) for resp in responses] + + +@overload +def infill_batch_responses( + queries: Sequence[ContextInput | Query], responses: list[str], json_responses: bool = False +) -> list[Result]: ... + + +@overload +def infill_batch_responses( + queries: Sequence[ContextInput | Query], + responses: list[list[str]], + json_responses: bool = False, +) -> list[list[Result]]: ... + + +def infill_batch_responses( + queries: Sequence[ContextInput | Query], + responses: list[str] | list[list[str]], + json_responses: bool = False, +) -> list[Any]: + """Infill each query in a batch with its corresponding response(s).""" + if len(queries) == 0: + raise ValueError("Batch input list is empty.") + if not isinstance(responses, list): + raise TypeError(f"Expected batch responses to be a list, got {type(responses)}") + if len(queries) != len(responses): + raise ValueError( + "Mismatched number of batch inputs and responses: " + f"{len(queries)} input(s), {len(responses)} response(s)." + ) + + results = [] + for query, response in zip(queries, responses, strict=True): + if isinstance(response, (str, list)): + results.append(infill_responses(query, response, json_responses=json_responses)) + else: + raise TypeError( + "Each batch response must be a string or a list of strings, " + f"got {type(response)}" + ) + + return results diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 9a7f147..589abd0 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -8,7 +8,13 @@ from gimkit.contexts import Query, Result from gimkit.log import get_logger -from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses +from gimkit.models.utils import ( + get_outlines_model_input, + get_outlines_model_inputs, + get_outlines_output_type, + infill_batch_responses, + infill_responses, +) from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField @@ -16,6 +22,7 @@ if TYPE_CHECKING: from vllm import LLM + from vllm.sampling_params import SamplingParams class VLLMOffline(OutlinesVLLMOffline): @@ -46,6 +53,86 @@ def __call__( json_responses=(output_type == "json"), ) + def batch( + self, + model_input: list[Any], + output_type: Any | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + **inference_kwargs: Any, + ) -> list[Any]: + inference_kwargs = self._ensure_response_suffix(inference_kwargs) + model_inputs = cast("list[ContextInput | Query]", model_input) + gim_output_type = cast("Literal['cfg', 'json'] | None", output_type) + + outlines_model_inputs = get_outlines_model_inputs( + model_inputs, + gim_output_type, + use_gim_prompt, + visible_tag_fields=visible_tag_fields, + ) + outlines_output_types = [ + get_outlines_output_type(model_input, gim_output_type) for model_input in model_inputs + ] + raw_responses = self._generate_batch_with_output_types( + outlines_model_inputs, + outlines_output_types, + inference_kwargs, + ) + logger.debug(f"Raw batch responses of {self}: {raw_responses}") + return infill_batch_responses( + model_inputs, + cast("list[str] | list[list[str]]", raw_responses), + json_responses=(gim_output_type == "json"), + ) + + def _generate_batch_with_output_types( + self, + model_inputs: list[Any], + output_types: list[Any], + inference_kwargs: dict[str, Any], + ) -> list[list[str]]: + generation_kwargs = dict(inference_kwargs) + sampling_params = generation_kwargs.pop("sampling_params", None) + sampling_params_list = self._build_batch_sampling_params(sampling_params, output_types) + + formatted_inputs = [self.type_adapter.format_input(item) for item in model_inputs] + if formatted_inputs and isinstance(formatted_inputs[0], list): + results = self.model.chat( + messages=formatted_inputs, + sampling_params=sampling_params_list, + **generation_kwargs, + ) + else: + results = self.model.generate( + prompts=formatted_inputs, + sampling_params=sampling_params_list, + **generation_kwargs, + ) + return [[sample.text for sample in batch.outputs] for batch in results] + + def _build_batch_sampling_params( + self, + sampling_params: Any, + output_types: list[Any], + ) -> list["SamplingParams"]: + if isinstance(sampling_params, list): + if len(sampling_params) != len(output_types): + raise ValueError( + "sampling_params list must have the same length as model_input: " + f"{len(sampling_params)} sampling params for {len(output_types)} input(s)." + ) + return [ + self._build_generation_args({"sampling_params": params}, output_type) + for params, output_type in zip(sampling_params, output_types, strict=True) + ] + + return [ + self._build_generation_args({"sampling_params": sampling_params}, output_type) + for output_type in output_types + ] + def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, Any]: # Using `stop=RESPONSE_SUFFIX` is preferred for two reasons: # 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX. @@ -54,13 +141,22 @@ def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, from vllm import SamplingParams inference_kwargs["sampling_params"] = SamplingParams(stop=[RESPONSE_SUFFIX]) - elif ( - isinstance(inference_kwargs["sampling_params"].stop, list) - and RESPONSE_SUFFIX not in inference_kwargs["sampling_params"].stop - ): - inference_kwargs["sampling_params"].stop.append(RESPONSE_SUFFIX) + elif isinstance(inference_kwargs["sampling_params"], list): + for sampling_params in inference_kwargs["sampling_params"]: + self._ensure_sampling_params_response_suffix(sampling_params) + else: + self._ensure_sampling_params_response_suffix(inference_kwargs["sampling_params"]) return inference_kwargs + def _ensure_sampling_params_response_suffix(self, sampling_params: "SamplingParams") -> None: + if sampling_params.stop is None: + sampling_params.stop = [RESPONSE_SUFFIX] + elif isinstance(sampling_params.stop, str): + if sampling_params.stop != RESPONSE_SUFFIX: + sampling_params.stop = [sampling_params.stop, RESPONSE_SUFFIX] + elif RESPONSE_SUFFIX not in sampling_params.stop: + sampling_params.stop.append(RESPONSE_SUFFIX) + def from_vllm_offline(model: "LLM") -> VLLMOffline: return VLLMOffline(model) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index cfd05e7..a296b2f 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -6,7 +6,9 @@ from gimkit.contexts import Query, Result from gimkit.models.utils import ( get_outlines_model_input, + get_outlines_model_inputs, get_outlines_output_type, + infill_batch_responses, infill_responses, json_responses_to_gim_response, ) @@ -77,6 +79,18 @@ def test_get_outlines_output_type(): get_outlines_output_type(query, "xxx") +def test_get_outlines_model_inputs(): + queries = [Query("Hello, ", MaskedTag()), Query("Goodbye, ", MaskedTag())] + model_inputs = get_outlines_model_inputs(queries, output_type=None, use_gim_prompt=False) + assert model_inputs == [ + '<|GIM_QUERY|>Hello, <|MASKED id="m_0"|><|/MASKED|><|/GIM_QUERY|>', + '<|GIM_QUERY|>Goodbye, <|MASKED id="m_0"|><|/MASKED|><|/GIM_QUERY|>', + ] + + with pytest.raises(ValueError, match="Batch input list is empty"): + get_outlines_model_inputs([], output_type=None, use_gim_prompt=False) + + def test_json_responses_to_gim_response(): json_str = '{"m_0": "John", "m_1": "Doe"}' expected_gim_str = '<|GIM_RESPONSE|><|MASKED id="m_0"|>John<|/MASKED|><|MASKED id="m_1"|>Doe<|/MASKED|><|/GIM_RESPONSE|>' @@ -153,3 +167,43 @@ def test_infill_responses(): # Test list with non-string items with pytest.raises(TypeError, match="All items in the response list must be strings, got"): infill_responses(query, ["a", 1]) + + +def test_infill_batch_responses(): + queries = [ + Query("Hello, ", MaskedTag(id=0)), + Query("Goodbye, ", MaskedTag(id=0)), + ] + responses = [ + '<|GIM_RESPONSE|><|MASKED id="m_0"|>world<|/MASKED|><|/GIM_RESPONSE|>', + '<|GIM_RESPONSE|><|MASKED id="m_0"|>friend<|/MASKED|><|/GIM_RESPONSE|>', + ] + results = infill_batch_responses(queries, responses) + assert isinstance(results[0], Result) + assert [str(result) for result in results] == ["Hello, world", "Goodbye, friend"] + + nested_results = infill_batch_responses(queries, [[responses[0]], [responses[1]]]) + assert isinstance(nested_results[0], list) + assert [[str(result) for result in group] for group in nested_results] == [ + ["Hello, world"], + ["Goodbye, friend"], + ] + + json_results = infill_batch_responses( + queries, + ['{"m_0": "world"}', '{"m_0": "friend"}'], + json_responses=True, + ) + assert [str(result) for result in json_results] == ["Hello, world", "Goodbye, friend"] + + with pytest.raises(ValueError, match="Batch input list is empty"): + infill_batch_responses([], []) + + with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): + infill_batch_responses(queries, [responses[0]]) + + with pytest.raises(TypeError, match="Expected batch responses to be a list"): + infill_batch_responses(queries, "response") + + with pytest.raises(TypeError, match="Each batch response must be a string or a list"): + infill_batch_responses(queries, [responses[0], object()]) diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index c429b40..cc16b87 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -9,7 +9,7 @@ from gimkit.contexts import Result from gimkit.models.vllm_offline import VLLMOffline as GIMVLLMOffline from gimkit.models.vllm_offline import from_vllm_offline -from gimkit.schemas import MaskedTag +from gimkit.schemas import RESPONSE_SUFFIX, MaskedTag pytestmark = pytest.mark.skipif( @@ -17,19 +17,31 @@ ) +def _mock_vllm_client(): + from vllm import LLM + + mock_client = MagicMock(spec=LLM) + mock_client.get_tokenizer.return_value = object() + return mock_client + + +def _request_output(*texts: str): + request_output = MagicMock() + request_output.outputs = [MagicMock(text=text) for text in texts] + return request_output + + def test_from_vllm_offline(): from vllm import LLM - model = from_vllm_offline(MagicMock(spec=LLM)) + model = from_vllm_offline(_mock_vllm_client()) assert type(model) is GIMVLLMOffline assert type(model) is not OutlinesVLLMOffline assert type(model) is not LLM def test_vllm_offline_call(): - from vllm import LLM - - mock_client = MagicMock(spec=LLM) + mock_client = _mock_vllm_client() model = from_vllm_offline(mock_client) with patch("gimkit.models.vllm_offline.Generator") as mock_generator: @@ -44,10 +56,142 @@ def test_vllm_offline_call(): model(MaskedTag(), visible_tag_fields=["id", "desc", "content", "regex"]) +def test_vllm_offline_batch(): + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output( + '<|MASKED id="m_0"|>dear<|/MASKED|><|MASKED id="m_1"|>friend<|/MASKED|>' + ), + ] + model = from_vllm_offline(mock_client) + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag(), " ", MaskedTag()], + ] + ) + + assert len(returned) == 2 + assert isinstance(returned[0], list) + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, dear friend" + + mock_client.generate.assert_called_once() + sampling_params = mock_client.generate.call_args.kwargs["sampling_params"] + assert len(sampling_params) == 2 + assert sampling_params[0].structured_outputs.grammar != sampling_params[1].structured_outputs.grammar + assert RESPONSE_SUFFIX in sampling_params[0].stop + assert RESPONSE_SUFFIX in sampling_params[1].stop + + +def test_vllm_offline_batch_sampling_params_list(): + from vllm import SamplingParams + + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output('<|MASKED id="m_0"|>friend<|/MASKED|>'), + ] + model = from_vllm_offline(mock_client) + sampling_params = [SamplingParams(stop=[""]), SamplingParams()] + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + sampling_params=sampling_params, + ) + + assert len(returned) == 2 + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, friend" + assert RESPONSE_SUFFIX in sampling_params[0].stop + assert RESPONSE_SUFFIX in sampling_params[1].stop + + +def test_vllm_offline_batch_invalid_sampling_params_list_length(): + from vllm import SamplingParams + + model = from_vllm_offline(_mock_vllm_client()) + + with pytest.raises(ValueError, match="sampling_params list must have the same length"): + model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + sampling_params=[SamplingParams()], + ) + + +def test_vllm_offline_batch_invalid_response(): + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [_request_output()] + model = from_vllm_offline(mock_client) + + with pytest.raises(ValueError, match="Response list is empty"): + model.batch([["Hello, ", MaskedTag()]]) + + mock_client.generate.return_value = [ + MagicMock(outputs=[MagicMock(text=object())]), + ] + with pytest.raises(TypeError, match="All items in the response list must be strings"): + model.batch([["Hello, ", MaskedTag()]]) + + +def test_vllm_offline_batch_chat(): + mock_client = _mock_vllm_client() + mock_client.chat.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output('<|MASKED id="m_0"|>friend<|/MASKED|>'), + ] + model = from_vllm_offline(mock_client) + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + use_gim_prompt=True, + ) + + assert len(returned) == 2 + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, friend" + mock_client.chat.assert_called_once() + + +def test_vllm_offline_batch_flat_responses(): + model = from_vllm_offline(_mock_vllm_client()) + + with patch.object( + model, + "_generate_batch_with_output_types", + return_value=[ + '<|MASKED id="m_0"|>world<|/MASKED|>', + '<|MASKED id="m_0"|>friend<|/MASKED|>', + ] + ): + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ] + ) + + assert len(returned) == 2 + assert isinstance(returned[0], Result) + assert str(returned[0]) == "Hello, world" + assert str(returned[1]) == "Goodbye, friend" + + def test_vllm_offline_call_invalid_response(): - from vllm import LLM, SamplingParams + from vllm import SamplingParams - model = from_vllm_offline(MagicMock(spec=LLM)) + model = from_vllm_offline(_mock_vllm_client()) with patch("gimkit.models.vllm_offline.Generator") as mock_generator: generator_instance = MagicMock() From 8c904ddca02a578dcae34bf1aa04c6a9e31dd78d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jun 2026 06:50:21 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/gimkit/models/utils.py | 3 +-- tests/models/test_vllm_offline.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 0a05813..355b30b 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -220,8 +220,7 @@ def infill_batch_responses( results.append(infill_responses(query, response, json_responses=json_responses)) else: raise TypeError( - "Each batch response must be a string or a list of strings, " - f"got {type(response)}" + f"Each batch response must be a string or a list of strings, got {type(response)}" ) return results diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index cc16b87..1b7c9a5 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -60,9 +60,7 @@ def test_vllm_offline_batch(): mock_client = _mock_vllm_client() mock_client.generate.return_value = [ _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), - _request_output( - '<|MASKED id="m_0"|>dear<|/MASKED|><|MASKED id="m_1"|>friend<|/MASKED|>' - ), + _request_output('<|MASKED id="m_0"|>dear<|/MASKED|><|MASKED id="m_1"|>friend<|/MASKED|>'), ] model = from_vllm_offline(mock_client) @@ -81,7 +79,10 @@ def test_vllm_offline_batch(): mock_client.generate.assert_called_once() sampling_params = mock_client.generate.call_args.kwargs["sampling_params"] assert len(sampling_params) == 2 - assert sampling_params[0].structured_outputs.grammar != sampling_params[1].structured_outputs.grammar + assert ( + sampling_params[0].structured_outputs.grammar + != sampling_params[1].structured_outputs.grammar + ) assert RESPONSE_SUFFIX in sampling_params[0].stop assert RESPONSE_SUFFIX in sampling_params[1].stop @@ -173,7 +174,7 @@ def test_vllm_offline_batch_flat_responses(): return_value=[ '<|MASKED id="m_0"|>world<|/MASKED|>', '<|MASKED id="m_0"|>friend<|/MASKED|>', - ] + ], ): returned = model.batch( [ From 306f01397f073702ceea22ddd57e3c4dfd184ddf Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:19:49 +0800 Subject: [PATCH 3/7] test: cover vllm offline stop handling --- tests/models/test_vllm_offline.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index 1b7c9a5..79881a2 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -1,5 +1,6 @@ import sys +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -113,6 +114,22 @@ def test_vllm_offline_batch_sampling_params_list(): assert RESPONSE_SUFFIX in sampling_params[1].stop +def test_vllm_offline_ensure_sampling_params_response_suffix(): + model = from_vllm_offline(_mock_vllm_client()) + + sampling_params = SimpleNamespace(stop=None) + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == [RESPONSE_SUFFIX] + + sampling_params = SimpleNamespace(stop="") + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == ["", RESPONSE_SUFFIX] + + sampling_params = SimpleNamespace(stop=RESPONSE_SUFFIX) + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == RESPONSE_SUFFIX + + def test_vllm_offline_batch_invalid_sampling_params_list_length(): from vllm import SamplingParams From f54400793fad775446c504a6915477eabe2a6b67 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:43:17 +0800 Subject: [PATCH 4/7] refactor: tighten vllm offline batch types --- src/gimkit/models/utils.py | 31 +++++++++++------- src/gimkit/models/vllm_offline.py | 53 +++++++++++++++++++------------ tests/models/test_vllm_offline.py | 24 -------------- 3 files changed, 53 insertions(+), 55 deletions(-) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 355b30b..bef7971 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Literal, overload +from typing import Literal, cast, overload from outlines.inputs import Chat from outlines.types.dsl import CFG, JsonSchema @@ -202,7 +202,7 @@ def infill_batch_responses( queries: Sequence[ContextInput | Query], responses: list[str] | list[list[str]], json_responses: bool = False, -) -> list[Any]: +) -> list[Result] | list[list[Result]]: """Infill each query in a batch with its corresponding response(s).""" if len(queries) == 0: raise ValueError("Batch input list is empty.") @@ -214,13 +214,22 @@ def infill_batch_responses( f"{len(queries)} input(s), {len(responses)} response(s)." ) - results = [] - for query, response in zip(queries, responses, strict=True): - if isinstance(response, (str, list)): - results.append(infill_responses(query, response, json_responses=json_responses)) - else: - raise TypeError( - f"Each batch response must be a string or a list of strings, got {type(response)}" - ) + if all(isinstance(response, str) for response in responses): + return [ + infill_responses(query, cast("str", response), json_responses=json_responses) + for query, response in zip(queries, responses, strict=True) + ] + + if all(isinstance(response, list) for response in responses): + return [ + infill_responses(query, cast("list[str]", response), json_responses=json_responses) + for query, response in zip(queries, responses, strict=True) + ] - return results + invalid_response = next( + response for response in responses if not isinstance(response, (str, list)) + ) + raise TypeError( + "Each batch response must be a string or a list of strings, " + f"got {type(invalid_response)}" + ) diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 589abd0..0a61d66 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -1,10 +1,12 @@ # Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/vllm_offline.py -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast from outlines.generator import Generator +from outlines.inputs import Chat from outlines.models.vllm_offline import VLLMOffline as OutlinesVLLMOffline +from outlines.types.dsl import CFG, JsonSchema from gimkit.contexts import Query, Result from gimkit.log import get_logger @@ -25,6 +27,11 @@ from vllm.sampling_params import SamplingParams +OutlinesModelInput: TypeAlias = str | Chat +OutlinesOutputType: TypeAlias = CFG | JsonSchema | None +VLLMFormattedInput: TypeAlias = str | list[object] + + class VLLMOffline(OutlinesVLLMOffline): def __call__( self, @@ -55,25 +62,23 @@ def __call__( def batch( self, - model_input: list[Any], - output_type: Any | None = "cfg", + model_input: list[ContextInput | Query], + output_type: Literal["cfg", "json"] | None = "cfg", backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, **inference_kwargs: Any, - ) -> list[Any]: + ) -> list[list[Result]]: # type: ignore[override] inference_kwargs = self._ensure_response_suffix(inference_kwargs) - model_inputs = cast("list[ContextInput | Query]", model_input) - gim_output_type = cast("Literal['cfg', 'json'] | None", output_type) outlines_model_inputs = get_outlines_model_inputs( - model_inputs, - gim_output_type, + model_input, + output_type, use_gim_prompt, visible_tag_fields=visible_tag_fields, ) outlines_output_types = [ - get_outlines_output_type(model_input, gim_output_type) for model_input in model_inputs + get_outlines_output_type(batch_item, output_type) for batch_item in model_input ] raw_responses = self._generate_batch_with_output_types( outlines_model_inputs, @@ -81,32 +86,40 @@ def batch( inference_kwargs, ) logger.debug(f"Raw batch responses of {self}: {raw_responses}") - return infill_batch_responses( - model_inputs, - cast("list[str] | list[list[str]]", raw_responses), - json_responses=(gim_output_type == "json"), + return cast( + "list[list[Result]]", + infill_batch_responses( + model_input, + raw_responses, + json_responses=(output_type == "json"), + ), ) def _generate_batch_with_output_types( self, - model_inputs: list[Any], - output_types: list[Any], + model_inputs: list[OutlinesModelInput], + output_types: list[OutlinesOutputType], inference_kwargs: dict[str, Any], ) -> list[list[str]]: generation_kwargs = dict(inference_kwargs) sampling_params = generation_kwargs.pop("sampling_params", None) sampling_params_list = self._build_batch_sampling_params(sampling_params, output_types) - formatted_inputs = [self.type_adapter.format_input(item) for item in model_inputs] + formatted_inputs = [ + cast("VLLMFormattedInput", self.type_adapter.format_input(item)) + for item in model_inputs + ] if formatted_inputs and isinstance(formatted_inputs[0], list): + chat_messages = cast("list[list[Any]]", formatted_inputs) results = self.model.chat( - messages=formatted_inputs, + messages=chat_messages, sampling_params=sampling_params_list, **generation_kwargs, ) else: + prompts = cast("list[str]", formatted_inputs) results = self.model.generate( - prompts=formatted_inputs, + prompts=prompts, sampling_params=sampling_params_list, **generation_kwargs, ) @@ -114,8 +127,8 @@ def _generate_batch_with_output_types( def _build_batch_sampling_params( self, - sampling_params: Any, - output_types: list[Any], + sampling_params: "SamplingParams | list[SamplingParams] | None", + output_types: list[OutlinesOutputType], ) -> list["SamplingParams"]: if isinstance(sampling_params, list): if len(sampling_params) != len(output_types): diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index 79881a2..383a46b 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -182,30 +182,6 @@ def test_vllm_offline_batch_chat(): mock_client.chat.assert_called_once() -def test_vllm_offline_batch_flat_responses(): - model = from_vllm_offline(_mock_vllm_client()) - - with patch.object( - model, - "_generate_batch_with_output_types", - return_value=[ - '<|MASKED id="m_0"|>world<|/MASKED|>', - '<|MASKED id="m_0"|>friend<|/MASKED|>', - ], - ): - returned = model.batch( - [ - ["Hello, ", MaskedTag()], - ["Goodbye, ", MaskedTag()], - ] - ) - - assert len(returned) == 2 - assert isinstance(returned[0], Result) - assert str(returned[0]) == "Hello, world" - assert str(returned[1]) == "Goodbye, friend" - - def test_vllm_offline_call_invalid_response(): from vllm import SamplingParams From d4c1966d27f7cc9d831d6562b024be32c49029b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:43:37 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/gimkit/models/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index bef7971..8a2b165 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -230,6 +230,5 @@ def infill_batch_responses( response for response in responses if not isinstance(response, (str, list)) ) raise TypeError( - "Each batch response must be a string or a list of strings, " - f"got {type(invalid_response)}" + f"Each batch response must be a string or a list of strings, got {type(invalid_response)}" ) From ac0cb1dc39054fa0b2f5b6a55f7b9a84ad43ca36 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 12 Jun 2026 19:45:53 +0800 Subject: [PATCH 6/7] docs: add vllm offline batch example --- examples/vllm_offline_batch.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/vllm_offline_batch.py diff --git a/examples/vllm_offline_batch.py b/examples/vllm_offline_batch.py new file mode 100644 index 0000000..4b959bb --- /dev/null +++ b/examples/vllm_offline_batch.py @@ -0,0 +1,31 @@ +from vllm import LLM, SamplingParams + +from gimkit import from_vllm_offline +from gimkit import guide as g + + +llm = LLM(model="Sculpt-AI/GIM-1.7B", max_model_len=8192) +model = from_vllm_offline(llm) + +queries = [ + f"Extract the person's name: Alice Zhang -> {g.person_name(name='name')}", + ( + "Extract contact fields from: Bob Chen, bob@example.com, +1-212-555-0101\n" + f"Name: {g.person_name(name='name')}\n" + f"Email: {g.e_mail(name='email')}\n" + f"Phone: {g.phone_number(name='phone')}" + ), +] + +sampling_params = [ + SamplingParams(temperature=0.0, max_tokens=256, seed=0), + SamplingParams(temperature=0.0, max_tokens=512, seed=1), +] + +batch_results = model.batch(queries, output_type="cfg", sampling_params=sampling_params) + +# batch_results keeps two dimensions: inputs, then completions per input. +for input_index, completions in enumerate(batch_results): + print(f"Input {input_index}") + for completion_index, result in enumerate(completions): + print(f"Completion {completion_index}: {result}") From c93030146a850d32a3878f0ff76f9330a0e7c316 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 12 Jun 2026 20:33:42 +0800 Subject: [PATCH 7/7] refactor: update model_input type to Sequence in VLLMOffline.batch method --- src/gimkit/models/vllm_offline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 0a61d66..8e6d8c6 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/vllm_offline.py +from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast from outlines.generator import Generator @@ -62,7 +63,7 @@ def __call__( def batch( self, - model_input: list[ContextInput | Query], + model_input: Sequence[ContextInput | Query], output_type: Literal["cfg", "json"] | None = "cfg", backend: str | None = None, use_gim_prompt: bool = False,