diff --git a/docs/models/vllm-offline.md b/docs/models/vllm-offline.md index 68aa63e..dbc1118 100644 --- a/docs/models/vllm-offline.md +++ b/docs/models/vllm-offline.md @@ -37,6 +37,16 @@ result = model(query) result_non_gim = model(query, use_gim_prompt=True) ``` +## Batch inference + +`model.batch(...)` wraps Outlines' batch API for vLLM offline. +Each query can use its own GIM-derived structured output schema. + +```python +batch_results = model.batch([query, query]) +first_result = batch_results[0][0] +``` + ## Output types ### `output_type="cfg"` (default) diff --git a/docs/models/vllm-offline.zh.md b/docs/models/vllm-offline.zh.md index a3edff0..d29987b 100644 --- a/docs/models/vllm-offline.zh.md +++ b/docs/models/vllm-offline.zh.md @@ -37,6 +37,16 @@ result = model(query) result_non_gim = model(query, use_gim_prompt=True) ``` +## 批量推理 + +`model.batch(...)` 会包装 Outlines 的 vLLM offline batch API。 +每条 query 都可以使用各自从 GIM 推导出的结构化输出 schema。 + +```python +batch_results = model.batch([query, query]) +first_result = batch_results[0][0] +``` + ## 输出类型 ### `output_type="cfg"`(默认) diff --git a/examples/vllm_offline_batch.py b/examples/vllm_offline_batch.py new file mode 100644 index 0000000..4b959bb --- /dev/null +++ b/examples/vllm_offline_batch.py @@ -0,0 +1,31 @@ +from vllm import LLM, SamplingParams + +from gimkit import from_vllm_offline +from gimkit import guide as g + + +llm = LLM(model="Sculpt-AI/GIM-1.7B", max_model_len=8192) +model = from_vllm_offline(llm) + +queries = [ + f"Extract the person's name: Alice Zhang -> {g.person_name(name='name')}", + ( + "Extract contact fields from: Bob Chen, bob@example.com, +1-212-555-0101\n" + f"Name: {g.person_name(name='name')}\n" + f"Email: {g.e_mail(name='email')}\n" + f"Phone: {g.phone_number(name='phone')}" + ), +] + +sampling_params = [ + SamplingParams(temperature=0.0, max_tokens=256, seed=0), + SamplingParams(temperature=0.0, max_tokens=512, seed=1), +] + +batch_results = model.batch(queries, output_type="cfg", sampling_params=sampling_params) + +# batch_results keeps two dimensions: inputs, then completions per input. +for input_index, completions in enumerate(batch_results): + print(f"Input {input_index}") + for completion_index, result in enumerate(completions): + print(f"Completion {completion_index}: {result}") diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 878f7bc..8a2b165 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -1,4 +1,5 @@ -from typing import Literal, overload +from collections.abc import Sequence +from typing import Literal, cast, overload from outlines.inputs import Chat from outlines.types.dsl import CFG, JsonSchema @@ -14,6 +15,10 @@ from gimkit.schemas import ContextInput, MaskedTag, TagField +def _ensure_query(model_input: ContextInput | Query) -> Query: + return model_input if isinstance(model_input, Query) else Query(model_input) + + def get_outlines_model_input( model_input: ContextInput | Query, output_type: Literal["cfg", "json"] | None, @@ -31,7 +36,7 @@ def get_outlines_model_input( If None, uses the Query default (["id", "desc", "content"]). Example: ["id", "name", "desc", "content", "regex"] to expose all fields. """ - query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input + query_obj = _ensure_query(model_input) outlines_model_input: str | Chat = ( query_obj.to_string(fields=visible_tag_fields) if visible_tag_fields is not None @@ -57,11 +62,31 @@ def get_outlines_model_input( return outlines_model_input +def get_outlines_model_inputs( + model_inputs: Sequence[ContextInput | Query], + output_type: Literal["cfg", "json"] | None, + use_gim_prompt: bool, + visible_tag_fields: list[TagField] | None = None, +) -> list[str | Chat]: + """Transform a batch of model inputs to Outlines-compatible formats.""" + if len(model_inputs) == 0: + raise ValueError("Batch input list is empty.") + return [ + get_outlines_model_input( + model_input, + output_type, + use_gim_prompt, + visible_tag_fields=visible_tag_fields, + ) + for model_input in model_inputs + ] + + def get_outlines_output_type( model_input: ContextInput | Query, output_type: Literal["cfg", "json"] | None ) -> None | CFG | JsonSchema: """Transform the output type to an Outlines-compatible format.""" - query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input + query_obj = _ensure_query(model_input) if output_type is None: return None elif output_type == "cfg": @@ -157,3 +182,53 @@ def infill_responses( raise TypeError(f"All items in the response list must be strings, got: {responses}") return [infill_responses(query, resp, json_responses=json_responses) for resp in responses] + + +@overload +def infill_batch_responses( + queries: Sequence[ContextInput | Query], responses: list[str], json_responses: bool = False +) -> list[Result]: ... + + +@overload +def infill_batch_responses( + queries: Sequence[ContextInput | Query], + responses: list[list[str]], + json_responses: bool = False, +) -> list[list[Result]]: ... + + +def infill_batch_responses( + queries: Sequence[ContextInput | Query], + responses: list[str] | list[list[str]], + json_responses: bool = False, +) -> list[Result] | list[list[Result]]: + """Infill each query in a batch with its corresponding response(s).""" + if len(queries) == 0: + raise ValueError("Batch input list is empty.") + if not isinstance(responses, list): + raise TypeError(f"Expected batch responses to be a list, got {type(responses)}") + if len(queries) != len(responses): + raise ValueError( + "Mismatched number of batch inputs and responses: " + f"{len(queries)} input(s), {len(responses)} response(s)." + ) + + if all(isinstance(response, str) for response in responses): + return [ + infill_responses(query, cast("str", response), json_responses=json_responses) + for query, response in zip(queries, responses, strict=True) + ] + + if all(isinstance(response, list) for response in responses): + return [ + infill_responses(query, cast("list[str]", response), json_responses=json_responses) + for query, response in zip(queries, responses, strict=True) + ] + + invalid_response = next( + response for response in responses if not isinstance(response, (str, list)) + ) + raise TypeError( + f"Each batch response must be a string or a list of strings, got {type(invalid_response)}" + ) diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 9a7f147..8e6d8c6 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -1,14 +1,23 @@ # Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/vllm_offline.py -from typing import TYPE_CHECKING, Any, Literal, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast from outlines.generator import Generator +from outlines.inputs import Chat from outlines.models.vllm_offline import VLLMOffline as OutlinesVLLMOffline +from outlines.types.dsl import CFG, JsonSchema from gimkit.contexts import Query, Result from gimkit.log import get_logger -from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses +from gimkit.models.utils import ( + get_outlines_model_input, + get_outlines_model_inputs, + get_outlines_output_type, + infill_batch_responses, + infill_responses, +) from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField @@ -16,6 +25,12 @@ if TYPE_CHECKING: from vllm import LLM + from vllm.sampling_params import SamplingParams + + +OutlinesModelInput: TypeAlias = str | Chat +OutlinesOutputType: TypeAlias = CFG | JsonSchema | None +VLLMFormattedInput: TypeAlias = str | list[object] class VLLMOffline(OutlinesVLLMOffline): @@ -46,6 +61,92 @@ def __call__( json_responses=(output_type == "json"), ) + def batch( + self, + model_input: Sequence[ContextInput | Query], + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + **inference_kwargs: Any, + ) -> list[list[Result]]: # type: ignore[override] + inference_kwargs = self._ensure_response_suffix(inference_kwargs) + + outlines_model_inputs = get_outlines_model_inputs( + model_input, + output_type, + use_gim_prompt, + visible_tag_fields=visible_tag_fields, + ) + outlines_output_types = [ + get_outlines_output_type(batch_item, output_type) for batch_item in model_input + ] + raw_responses = self._generate_batch_with_output_types( + outlines_model_inputs, + outlines_output_types, + inference_kwargs, + ) + logger.debug(f"Raw batch responses of {self}: {raw_responses}") + return cast( + "list[list[Result]]", + infill_batch_responses( + model_input, + raw_responses, + json_responses=(output_type == "json"), + ), + ) + + def _generate_batch_with_output_types( + self, + model_inputs: list[OutlinesModelInput], + output_types: list[OutlinesOutputType], + inference_kwargs: dict[str, Any], + ) -> list[list[str]]: + generation_kwargs = dict(inference_kwargs) + sampling_params = generation_kwargs.pop("sampling_params", None) + sampling_params_list = self._build_batch_sampling_params(sampling_params, output_types) + + formatted_inputs = [ + cast("VLLMFormattedInput", self.type_adapter.format_input(item)) + for item in model_inputs + ] + if formatted_inputs and isinstance(formatted_inputs[0], list): + chat_messages = cast("list[list[Any]]", formatted_inputs) + results = self.model.chat( + messages=chat_messages, + sampling_params=sampling_params_list, + **generation_kwargs, + ) + else: + prompts = cast("list[str]", formatted_inputs) + results = self.model.generate( + prompts=prompts, + sampling_params=sampling_params_list, + **generation_kwargs, + ) + return [[sample.text for sample in batch.outputs] for batch in results] + + def _build_batch_sampling_params( + self, + sampling_params: "SamplingParams | list[SamplingParams] | None", + output_types: list[OutlinesOutputType], + ) -> list["SamplingParams"]: + if isinstance(sampling_params, list): + if len(sampling_params) != len(output_types): + raise ValueError( + "sampling_params list must have the same length as model_input: " + f"{len(sampling_params)} sampling params for {len(output_types)} input(s)." + ) + return [ + self._build_generation_args({"sampling_params": params}, output_type) + for params, output_type in zip(sampling_params, output_types, strict=True) + ] + + return [ + self._build_generation_args({"sampling_params": sampling_params}, output_type) + for output_type in output_types + ] + def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, Any]: # Using `stop=RESPONSE_SUFFIX` is preferred for two reasons: # 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX. @@ -54,13 +155,22 @@ def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, from vllm import SamplingParams inference_kwargs["sampling_params"] = SamplingParams(stop=[RESPONSE_SUFFIX]) - elif ( - isinstance(inference_kwargs["sampling_params"].stop, list) - and RESPONSE_SUFFIX not in inference_kwargs["sampling_params"].stop - ): - inference_kwargs["sampling_params"].stop.append(RESPONSE_SUFFIX) + elif isinstance(inference_kwargs["sampling_params"], list): + for sampling_params in inference_kwargs["sampling_params"]: + self._ensure_sampling_params_response_suffix(sampling_params) + else: + self._ensure_sampling_params_response_suffix(inference_kwargs["sampling_params"]) return inference_kwargs + def _ensure_sampling_params_response_suffix(self, sampling_params: "SamplingParams") -> None: + if sampling_params.stop is None: + sampling_params.stop = [RESPONSE_SUFFIX] + elif isinstance(sampling_params.stop, str): + if sampling_params.stop != RESPONSE_SUFFIX: + sampling_params.stop = [sampling_params.stop, RESPONSE_SUFFIX] + elif RESPONSE_SUFFIX not in sampling_params.stop: + sampling_params.stop.append(RESPONSE_SUFFIX) + def from_vllm_offline(model: "LLM") -> VLLMOffline: return VLLMOffline(model) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index cfd05e7..a296b2f 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -6,7 +6,9 @@ from gimkit.contexts import Query, Result from gimkit.models.utils import ( get_outlines_model_input, + get_outlines_model_inputs, get_outlines_output_type, + infill_batch_responses, infill_responses, json_responses_to_gim_response, ) @@ -77,6 +79,18 @@ def test_get_outlines_output_type(): get_outlines_output_type(query, "xxx") +def test_get_outlines_model_inputs(): + queries = [Query("Hello, ", MaskedTag()), Query("Goodbye, ", MaskedTag())] + model_inputs = get_outlines_model_inputs(queries, output_type=None, use_gim_prompt=False) + assert model_inputs == [ + '<|GIM_QUERY|>Hello, <|MASKED id="m_0"|><|/MASKED|><|/GIM_QUERY|>', + '<|GIM_QUERY|>Goodbye, <|MASKED id="m_0"|><|/MASKED|><|/GIM_QUERY|>', + ] + + with pytest.raises(ValueError, match="Batch input list is empty"): + get_outlines_model_inputs([], output_type=None, use_gim_prompt=False) + + def test_json_responses_to_gim_response(): json_str = '{"m_0": "John", "m_1": "Doe"}' expected_gim_str = '<|GIM_RESPONSE|><|MASKED id="m_0"|>John<|/MASKED|><|MASKED id="m_1"|>Doe<|/MASKED|><|/GIM_RESPONSE|>' @@ -153,3 +167,43 @@ def test_infill_responses(): # Test list with non-string items with pytest.raises(TypeError, match="All items in the response list must be strings, got"): infill_responses(query, ["a", 1]) + + +def test_infill_batch_responses(): + queries = [ + Query("Hello, ", MaskedTag(id=0)), + Query("Goodbye, ", MaskedTag(id=0)), + ] + responses = [ + '<|GIM_RESPONSE|><|MASKED id="m_0"|>world<|/MASKED|><|/GIM_RESPONSE|>', + '<|GIM_RESPONSE|><|MASKED id="m_0"|>friend<|/MASKED|><|/GIM_RESPONSE|>', + ] + results = infill_batch_responses(queries, responses) + assert isinstance(results[0], Result) + assert [str(result) for result in results] == ["Hello, world", "Goodbye, friend"] + + nested_results = infill_batch_responses(queries, [[responses[0]], [responses[1]]]) + assert isinstance(nested_results[0], list) + assert [[str(result) for result in group] for group in nested_results] == [ + ["Hello, world"], + ["Goodbye, friend"], + ] + + json_results = infill_batch_responses( + queries, + ['{"m_0": "world"}', '{"m_0": "friend"}'], + json_responses=True, + ) + assert [str(result) for result in json_results] == ["Hello, world", "Goodbye, friend"] + + with pytest.raises(ValueError, match="Batch input list is empty"): + infill_batch_responses([], []) + + with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): + infill_batch_responses(queries, [responses[0]]) + + with pytest.raises(TypeError, match="Expected batch responses to be a list"): + infill_batch_responses(queries, "response") + + with pytest.raises(TypeError, match="Each batch response must be a string or a list"): + infill_batch_responses(queries, [responses[0], object()]) diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index c429b40..383a46b 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -1,5 +1,6 @@ import sys +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -9,7 +10,7 @@ from gimkit.contexts import Result from gimkit.models.vllm_offline import VLLMOffline as GIMVLLMOffline from gimkit.models.vllm_offline import from_vllm_offline -from gimkit.schemas import MaskedTag +from gimkit.schemas import RESPONSE_SUFFIX, MaskedTag pytestmark = pytest.mark.skipif( @@ -17,19 +18,31 @@ ) +def _mock_vllm_client(): + from vllm import LLM + + mock_client = MagicMock(spec=LLM) + mock_client.get_tokenizer.return_value = object() + return mock_client + + +def _request_output(*texts: str): + request_output = MagicMock() + request_output.outputs = [MagicMock(text=text) for text in texts] + return request_output + + def test_from_vllm_offline(): from vllm import LLM - model = from_vllm_offline(MagicMock(spec=LLM)) + model = from_vllm_offline(_mock_vllm_client()) assert type(model) is GIMVLLMOffline assert type(model) is not OutlinesVLLMOffline assert type(model) is not LLM def test_vllm_offline_call(): - from vllm import LLM - - mock_client = MagicMock(spec=LLM) + mock_client = _mock_vllm_client() model = from_vllm_offline(mock_client) with patch("gimkit.models.vllm_offline.Generator") as mock_generator: @@ -44,10 +57,135 @@ def test_vllm_offline_call(): model(MaskedTag(), visible_tag_fields=["id", "desc", "content", "regex"]) +def test_vllm_offline_batch(): + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output('<|MASKED id="m_0"|>dear<|/MASKED|><|MASKED id="m_1"|>friend<|/MASKED|>'), + ] + model = from_vllm_offline(mock_client) + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag(), " ", MaskedTag()], + ] + ) + + assert len(returned) == 2 + assert isinstance(returned[0], list) + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, dear friend" + + mock_client.generate.assert_called_once() + sampling_params = mock_client.generate.call_args.kwargs["sampling_params"] + assert len(sampling_params) == 2 + assert ( + sampling_params[0].structured_outputs.grammar + != sampling_params[1].structured_outputs.grammar + ) + assert RESPONSE_SUFFIX in sampling_params[0].stop + assert RESPONSE_SUFFIX in sampling_params[1].stop + + +def test_vllm_offline_batch_sampling_params_list(): + from vllm import SamplingParams + + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output('<|MASKED id="m_0"|>friend<|/MASKED|>'), + ] + model = from_vllm_offline(mock_client) + sampling_params = [SamplingParams(stop=[""]), SamplingParams()] + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + sampling_params=sampling_params, + ) + + assert len(returned) == 2 + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, friend" + assert RESPONSE_SUFFIX in sampling_params[0].stop + assert RESPONSE_SUFFIX in sampling_params[1].stop + + +def test_vllm_offline_ensure_sampling_params_response_suffix(): + model = from_vllm_offline(_mock_vllm_client()) + + sampling_params = SimpleNamespace(stop=None) + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == [RESPONSE_SUFFIX] + + sampling_params = SimpleNamespace(stop="") + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == ["", RESPONSE_SUFFIX] + + sampling_params = SimpleNamespace(stop=RESPONSE_SUFFIX) + model._ensure_sampling_params_response_suffix(sampling_params) + assert sampling_params.stop == RESPONSE_SUFFIX + + +def test_vllm_offline_batch_invalid_sampling_params_list_length(): + from vllm import SamplingParams + + model = from_vllm_offline(_mock_vllm_client()) + + with pytest.raises(ValueError, match="sampling_params list must have the same length"): + model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + sampling_params=[SamplingParams()], + ) + + +def test_vllm_offline_batch_invalid_response(): + mock_client = _mock_vllm_client() + mock_client.generate.return_value = [_request_output()] + model = from_vllm_offline(mock_client) + + with pytest.raises(ValueError, match="Response list is empty"): + model.batch([["Hello, ", MaskedTag()]]) + + mock_client.generate.return_value = [ + MagicMock(outputs=[MagicMock(text=object())]), + ] + with pytest.raises(TypeError, match="All items in the response list must be strings"): + model.batch([["Hello, ", MaskedTag()]]) + + +def test_vllm_offline_batch_chat(): + mock_client = _mock_vllm_client() + mock_client.chat.return_value = [ + _request_output('<|MASKED id="m_0"|>world<|/MASKED|>'), + _request_output('<|MASKED id="m_0"|>friend<|/MASKED|>'), + ] + model = from_vllm_offline(mock_client) + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + use_gim_prompt=True, + ) + + assert len(returned) == 2 + assert str(returned[0][0]) == "Hello, world" + assert str(returned[1][0]) == "Goodbye, friend" + mock_client.chat.assert_called_once() + + def test_vllm_offline_call_invalid_response(): - from vllm import LLM, SamplingParams + from vllm import SamplingParams - model = from_vllm_offline(MagicMock(spec=LLM)) + model = from_vllm_offline(_mock_vllm_client()) with patch("gimkit.models.vllm_offline.Generator") as mock_generator: generator_instance = MagicMock()