From 31121c00575399c9adb3cf57ca81f2e4a790a8e0 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 19 Jun 2026 15:53:04 +0800 Subject: [PATCH 1/5] collect per-generation parsing errors --- docs/models/openai.md | 22 ++++ docs/models/openai.zh.md | 21 ++++ docs/models/vllm-offline.md | 21 ++++ docs/models/vllm-offline.zh.md | 19 +++ docs/models/vllm.md | 11 ++ docs/models/vllm.zh.md | 10 ++ src/gimkit/__init__.py | 3 +- src/gimkit/models/__init__.py | 3 +- src/gimkit/models/base.py | 38 ++++-- src/gimkit/models/openai.py | 63 +++++++++- src/gimkit/models/types.py | 22 ++++ src/gimkit/models/utils.py | 186 +++++++++++++++++++++++++++++- src/gimkit/models/vllm.py | 63 +++++++++- src/gimkit/models/vllm_offline.py | 94 ++++++++++++--- tests/models/test_openai.py | 90 +++++++++++++++ tests/models/test_utils.py | 106 +++++++++++++++++ tests/models/test_vllm_offline.py | 56 +++++++++ 17 files changed, 798 insertions(+), 30 deletions(-) create mode 100644 src/gimkit/models/types.py diff --git a/docs/models/openai.md b/docs/models/openai.md index 442e06f..aa3deb1 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -53,8 +53,30 @@ result = model(query, output_type=None, use_gim_prompt=True) print(result.tags["email"].content) ``` +## Per-generation error collection + +The default `error_mode="raise"` preserves fail-fast behavior. For multiple +candidates, use `error_mode="collect"` to parse each raw response independently: + +```python +generations = model(query, n=2, error_mode="collect") + +for generation in generations: + if generation.ok: + print(generation.result) + else: + print(generation.error_type, generation.error_message) + print(generation.raw_response) +``` + +`collect` only captures parsing and infill errors after raw text has been +generated. Network, authentication, timeout, model request, and invalid response +container errors still fail the whole call. Async models use the same parameter +through `await model(...)`. + ## Advanced options - `visible_tag_fields`: control which `MaskedTag` fields are visible to the model (e.g. `["id", "name", "desc", "content", "regex"]`). Defaults to `None` (basic fields only: `["id", "desc", "content"]`). - `backend`: pass through to Outlines generator backend selection. +- `error_mode`: `"raise"` (default) or `"collect"`. - `**inference_kwargs`: forwarded to the underlying OpenAI call. diff --git a/docs/models/openai.zh.md b/docs/models/openai.zh.md index adb34c4..f249b81 100644 --- a/docs/models/openai.zh.md +++ b/docs/models/openai.zh.md @@ -53,8 +53,29 @@ result = model(query, output_type=None, use_gim_prompt=True) print(result.tags["email"].content) ``` +## 逐生成错误收集 + +默认的 `error_mode="raise"` 保持快速失败行为。多候选生成时,可使用 +`error_mode="collect"` 让每条原始响应独立解析: + +```python +generations = model(query, n=2, error_mode="collect") + +for generation in generations: + if generation.ok: + print(generation.result) + else: + print(generation.error_type, generation.error_message) + print(generation.raw_response) +``` + +`collect` 只收集已经获得原始文本后的解析和 infill 错误。网络、认证、超时、 +模型请求失败以及无效的响应容器仍会作为整个调用异常抛出。异步模型使用相同参数, +调用方式为 `await model(...)`。 + ## 高级参数 - `visible_tag_fields`:控制哪些 `MaskedTag` 字段对模型可见(如 `["id", "name", "desc", "content", "regex"]`)。默认为 `None`(仅基础字段:`["id", "desc", "content"]`)。 - `backend`:透传给 Outlines 生成器后端选择。 +- `error_mode`:`"raise"`(默认)或 `"collect"`。 - `**inference_kwargs`:透传到底层 OpenAI 推理参数。 diff --git a/docs/models/vllm-offline.md b/docs/models/vllm-offline.md index dbc1118..57f1d0c 100644 --- a/docs/models/vllm-offline.md +++ b/docs/models/vllm-offline.md @@ -47,6 +47,27 @@ batch_results = model.batch([query, query]) first_result = batch_results[0][0] ``` +With `error_mode="collect"`, batch always returns a two-dimensional +`list[list[GenerationResult]]`: the outer list maps to queries and the inner list +maps to candidates. + +```python +generation_groups = model.batch(queries, error_mode="collect") + +for generation_group in generation_groups: + for generation in generation_group: + if generation.ok: + print(generation.result) + else: + print(generation.error_type, generation.error_message) + print(generation.raw_response) +``` + +A parsing failure for one candidate does not affect other candidates or queries. +The default `error_mode="raise"` preserves existing return types and fail-fast +behavior. Generation failures, invalid batch shapes, and invalid arguments still +fail the whole call. + ## Output types ### `output_type="cfg"` (default) diff --git a/docs/models/vllm-offline.zh.md b/docs/models/vllm-offline.zh.md index d29987b..d95b96f 100644 --- a/docs/models/vllm-offline.zh.md +++ b/docs/models/vllm-offline.zh.md @@ -47,6 +47,25 @@ batch_results = model.batch([query, query]) first_result = batch_results[0][0] ``` +使用 `error_mode="collect"` 时,batch 始终返回二维 +`list[list[GenerationResult]]`:外层对应 query,内层对应候选。 + +```python +generation_groups = model.batch(queries, error_mode="collect") + +for generation_group in generation_groups: + for generation in generation_group: + if generation.ok: + print(generation.result) + else: + print(generation.error_type, generation.error_message) + print(generation.raw_response) +``` + +单个候选的解析失败不会影响同一 query 的其他候选或其他 query。默认 +`error_mode="raise"` 的返回类型和快速失败行为保持不变。模型生成失败、batch +形状错误和无效参数仍会作为整个调用异常抛出。 + ## 输出类型 ### `output_type="cfg"`(默认) diff --git a/docs/models/vllm.md b/docs/models/vllm.md index 503125a..b75bc67 100644 --- a/docs/models/vllm.md +++ b/docs/models/vllm.md @@ -50,7 +50,18 @@ result = model(query, output_type="cfg") result = model(query, output_type="json", use_gim_prompt=True) ``` +## Per-generation error collection + +For multiple candidates, pass `error_mode="collect"` to receive one +`GenerationResult` per raw response. Successful items expose `.result`; failed +items retain `.raw_response`, `.error_type`, and `.error_message`. The default +`error_mode="raise"` preserves existing behavior and return types. Model request, +network, and response-container failures still fail the whole call. + +Async clients use the same parameter through `await model(...)`. + ## Notes - GIMKit automatically adds `stop="<|/GIM_RESPONSE|>"` for safer termination. +- `error_mode` accepts `"raise"` (default) or `"collect"`. - You can still pass extra generation args via `**inference_kwargs`. diff --git a/docs/models/vllm.zh.md b/docs/models/vllm.zh.md index 6a1cce7..0dc0512 100644 --- a/docs/models/vllm.zh.md +++ b/docs/models/vllm.zh.md @@ -50,7 +50,17 @@ result = model(query, output_type="cfg") result = model(query, output_type="json", use_gim_prompt=True) ``` +## 逐生成错误收集 + +多候选生成时,可传入 `error_mode="collect"`,逐条获得 +`GenerationResult`。成功项通过 `.result` 访问,失败项保留 `.raw_response`、 +`.error_type` 和 `.error_message`。默认 `error_mode="raise"` 的行为和返回类型 +保持不变。模型请求、网络和响应容器错误仍会作为整个调用异常抛出。 + +异步客户端使用相同参数,调用方式为 `await model(...)`。 + ## 说明 - GIMKit 会自动添加 `stop="<|/GIM_RESPONSE|>"`,确保更稳定停止。 +- `error_mode` 可设为 `"raise"`(默认)或 `"collect"`。 - 可通过 `**inference_kwargs` 继续传递生成参数。 diff --git a/src/gimkit/__init__.py b/src/gimkit/__init__.py index 37e53fe..facaca7 100644 --- a/src/gimkit/__init__.py +++ b/src/gimkit/__init__.py @@ -1,7 +1,7 @@ from importlib.metadata import PackageNotFoundError, version from gimkit.guides import guide -from gimkit.models import from_openai, from_vllm, from_vllm_offline +from gimkit.models import GenerationResult, from_openai, from_vllm, from_vllm_offline try: @@ -11,6 +11,7 @@ __all__ = [ + "GenerationResult", "from_openai", "from_vllm", "from_vllm_offline", diff --git a/src/gimkit/models/__init__.py b/src/gimkit/models/__init__.py index 26d35ce..2376d58 100644 --- a/src/gimkit/models/__init__.py +++ b/src/gimkit/models/__init__.py @@ -1,6 +1,7 @@ from .openai import from_openai +from .types import GenerationResult from .vllm import from_vllm from .vllm_offline import from_vllm_offline -__all__ = ["from_openai", "from_vllm", "from_vllm_offline"] +__all__ = ["GenerationResult", "from_openai", "from_vllm", "from_vllm_offline"] diff --git a/src/gimkit/models/base.py b/src/gimkit/models/base.py index 495df5c..d4d8b10 100644 --- a/src/gimkit/models/base.py +++ b/src/gimkit/models/base.py @@ -5,7 +5,13 @@ from gimkit.contexts import Query, Result from gimkit.log import get_logger -from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses +from gimkit.models.types import ErrorMode, GenerationResult +from gimkit.models.utils import ( + get_outlines_model_input, + get_outlines_output_type, + parse_generation_responses, + validate_error_mode, +) from gimkit.schemas import ContextInput, TagField @@ -19,8 +25,11 @@ def _call( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", **inference_kwargs: Any, -) -> Result | list[Result]: +) -> Result | list[Result] | GenerationResult | list[GenerationResult]: + validate_error_mode(error_mode) outlines_model_input = get_outlines_model_input( model_input, output_type, use_gim_prompt, visible_tag_fields ) @@ -29,8 +38,14 @@ def _call( generator = Generator(self, outlines_output_type, backend) raw_responses = generator(outlines_model_input, **inference_kwargs) logger.debug(f"Raw responses of {self}: {raw_responses}") - return infill_responses( - model_input, cast("str | list[str]", raw_responses), json_responses=(output_type == "json") + return cast( + "Result | list[Result] | GenerationResult | list[GenerationResult]", + cast("Any", parse_generation_responses)( + model_input, + cast("str | list[str]", raw_responses), + json_responses=(output_type == "json"), + error_mode=error_mode, + ), ) @@ -41,8 +56,11 @@ async def _acall( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", **inference_kwargs: Any, -) -> Result | list[Result]: +) -> Result | list[Result] | GenerationResult | list[GenerationResult]: + validate_error_mode(error_mode) outlines_model_input = get_outlines_model_input( model_input, output_type, use_gim_prompt, visible_tag_fields ) @@ -51,6 +69,12 @@ async def _acall( generator = Generator(self, outlines_output_type, backend) raw_responses = await generator(outlines_model_input, **inference_kwargs) logger.debug(f"Raw responses of {self}: {raw_responses}") - return infill_responses( - model_input, cast("str | list[str]", raw_responses), json_responses=(output_type == "json") + return cast( + "Result | list[Result] | GenerationResult | list[GenerationResult]", + cast("Any", parse_generation_responses)( + model_input, + cast("str | list[str]", raw_responses), + json_responses=(output_type == "json"), + error_mode=error_mode, + ), ) diff --git a/src/gimkit/models/openai.py b/src/gimkit/models/openai.py index d2ca51a..1a2330f 100644 --- a/src/gimkit/models/openai.py +++ b/src/gimkit/models/openai.py @@ -11,10 +11,12 @@ from gimkit.contexts import Query, Result from gimkit.models.base import _acall, _call +from gimkit.models.types import ErrorMode, GenerationResult from gimkit.schemas import ContextInput, TagField class OpenAI(OutlinesOpenAI): + @overload def __call__( self, model_input: ContextInput | Query, @@ -22,8 +24,35 @@ def __call__( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", **inference_kwargs: Any, - ) -> Result | list[Result]: + ) -> Result | list[Result]: ... + + @overload + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["json"] | None = None, + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> GenerationResult | list[GenerationResult]: ... + + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["json"] | None = None, + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", + **inference_kwargs: Any, + ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: return _call( self, model_input, @@ -31,11 +60,38 @@ def __call__( backend, use_gim_prompt, visible_tag_fields, + error_mode=error_mode, **inference_kwargs, ) class AsyncOpenAI(OutlinesAsyncOpenAI): + @overload + async def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["json"] | None = None, + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", + **inference_kwargs: Any, + ) -> Result | list[Result]: ... + + @overload + async def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["json"] | None = None, + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> GenerationResult | list[GenerationResult]: ... + async def __call__( self, model_input: ContextInput | Query, @@ -43,8 +99,10 @@ async def __call__( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", **inference_kwargs: Any, - ) -> Result | list[Result]: + ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: return await _acall( self, model_input, @@ -53,6 +111,7 @@ async def __call__( use_gim_prompt, visible_tag_fields, **inference_kwargs, + error_mode=error_mode, ) diff --git a/src/gimkit/models/types.py b/src/gimkit/models/types.py new file mode 100644 index 0000000..944f277 --- /dev/null +++ b/src/gimkit/models/types.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Literal, TypeAlias + +from gimkit.contexts import Result + + +ErrorMode: TypeAlias = Literal["raise", "collect"] + + +@dataclass(frozen=True, slots=True) +class GenerationResult: + """The parsed result or parsing error for one raw model generation.""" + + raw_response: str + result: Result | None = None + error_type: str | None = None + error_message: str | None = None + + @property + def ok(self) -> bool: + """Whether the raw generation was parsed and infilled successfully.""" + return self.result is not None diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 8a2b165..cb2f55c 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -1,11 +1,12 @@ from collections.abc import Sequence -from typing import Literal, cast, overload +from typing import Any, Literal, cast, overload from outlines.inputs import Chat from outlines.types.dsl import CFG, JsonSchema from gimkit.contexts import Query, Response, Result, infill from gimkit.dsls import build_cfg, build_json_schema +from gimkit.models.types import ErrorMode, GenerationResult from gimkit.prompts import ( DEMO_CONVERSATION_MSGS, DEMO_CONVERSATION_MSGS_JSON, @@ -149,6 +150,189 @@ def json_responses_to_gim_response(json_response: str) -> str: ) +def validate_error_mode(error_mode: ErrorMode) -> None: + if error_mode not in ("raise", "collect"): + raise ValueError(f"Invalid error mode: {error_mode}. Expected 'raise' or 'collect'.") + + +@overload +def parse_generation_response( + query: ContextInput | Query, + raw_response: str, + *, + json_response: bool = False, + error_mode: Literal["raise"] = "raise", +) -> Result: ... + + +@overload +def parse_generation_response( + query: ContextInput | Query, + raw_response: str, + *, + json_response: bool = False, + error_mode: Literal["collect"], +) -> GenerationResult: ... + + +def parse_generation_response( + query: ContextInput | Query, + raw_response: str, + *, + json_response: bool = False, + error_mode: ErrorMode = "raise", +) -> Result | GenerationResult: + """Parse and infill one raw model generation. + + ``collect`` only isolates errors raised while parsing and infilling an + already generated string. Model invocation and response-container errors + remain whole-call failures. + """ + validate_error_mode(error_mode) + if not isinstance(raw_response, str): + raise TypeError(f"Expected raw response to be str, got {type(raw_response)}") + + try: + result = infill_responses(query, raw_response, json_responses=json_response) + except Exception as exc: + if error_mode == "raise": + raise + return GenerationResult( + raw_response=raw_response, + error_type=type(exc).__name__, + error_message=str(exc), + ) + + if error_mode == "raise": + return result + return GenerationResult(raw_response=raw_response, result=result) + + +@overload +def parse_generation_responses( + query: ContextInput | Query, + raw_responses: str | list[str], + *, + json_responses: bool = False, + error_mode: Literal["raise"] = "raise", +) -> Result | list[Result]: ... + + +@overload +def parse_generation_responses( + query: ContextInput | Query, + raw_responses: str | list[str], + *, + json_responses: bool = False, + error_mode: Literal["collect"], +) -> GenerationResult | list[GenerationResult]: ... + + +def parse_generation_responses( + query: ContextInput | Query, + raw_responses: str | list[str], + *, + json_responses: bool = False, + error_mode: ErrorMode = "raise", +) -> Result | list[Result] | GenerationResult | list[GenerationResult]: + """Parse one or more raw generations while preserving their container shape.""" + validate_error_mode(error_mode) + if isinstance(raw_responses, str): + return parse_generation_response( + query, + raw_responses, + json_response=json_responses, + error_mode=error_mode, + ) + if not isinstance(raw_responses, list): + raise TypeError(f"Expected responses to be str or list of str, got {type(raw_responses)}") + if len(raw_responses) == 0: + raise ValueError("Response list is empty.") + if not all(isinstance(response, str) for response in raw_responses): + raise TypeError(f"All items in the response list must be strings, got: {raw_responses}") + + parsed = [ + cast("Any", parse_generation_response)( + query, + raw_response, + json_response=json_responses, + error_mode=error_mode, + ) + for raw_response in raw_responses + ] + return cast("list[Result] | list[GenerationResult]", parsed) + + +@overload +def parse_batch_generation_responses( + queries: Sequence[ContextInput | Query], + raw_responses: list[list[str]], + *, + json_responses: bool = False, + error_mode: Literal["raise"] = "raise", +) -> list[list[Result]]: ... + + +@overload +def parse_batch_generation_responses( + queries: Sequence[ContextInput | Query], + raw_responses: list[list[str]], + *, + json_responses: bool = False, + error_mode: Literal["collect"], +) -> list[list[GenerationResult]]: ... + + +def parse_batch_generation_responses( + queries: Sequence[ContextInput | Query], + raw_responses: list[list[str]], + *, + json_responses: bool = False, + error_mode: ErrorMode = "raise", +) -> list[list[Result]] | list[list[GenerationResult]]: + """Parse batch generations, preserving query and candidate dimensions.""" + validate_error_mode(error_mode) + if len(queries) == 0: + raise ValueError("Batch input list is empty.") + if not isinstance(raw_responses, list): + raise TypeError(f"Expected batch responses to be a list, got {type(raw_responses)}") + if len(queries) != len(raw_responses): + raise ValueError( + "Mismatched number of batch inputs and responses: " + f"{len(queries)} input(s), {len(raw_responses)} response group(s)." + ) + if not all(isinstance(response_group, list) for response_group in raw_responses): + invalid_group = next( + response_group + for response_group in raw_responses + if not isinstance(response_group, list) + ) + raise TypeError( + f"Each batch response group must be a list of strings, got {type(invalid_group)}" + ) + for response_group in raw_responses: + if len(response_group) == 0: + raise ValueError("Response list is empty.") + if not all(isinstance(response, str) for response in response_group): + raise TypeError( + f"All items in the response list must be strings, got: {response_group}" + ) + + parsed = [ + cast( + "list[Result] | list[GenerationResult]", + cast("Any", parse_generation_responses)( + query, + response_group, + json_responses=json_responses, + error_mode=error_mode, + ), + ) + for query, response_group in zip(queries, raw_responses, strict=True) + ] + return cast("list[list[Result]] | list[list[GenerationResult]]", parsed) + + @overload def infill_responses( query: ContextInput | Query, responses: str, json_responses: bool = False diff --git a/src/gimkit/models/vllm.py b/src/gimkit/models/vllm.py index a77af2b..9398e5d 100644 --- a/src/gimkit/models/vllm.py +++ b/src/gimkit/models/vllm.py @@ -10,10 +10,12 @@ from gimkit.contexts import Query, Result from gimkit.models.base import _acall, _call +from gimkit.models.types import ErrorMode, GenerationResult from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField class VLLM(OutlinesVLLM): + @overload def __call__( self, model_input: ContextInput | Query, @@ -21,8 +23,35 @@ def __call__( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", **inference_kwargs: Any, - ) -> Result | list[Result]: + ) -> Result | list[Result]: ... + + @overload + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> GenerationResult | list[GenerationResult]: ... + + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", + **inference_kwargs: Any, + ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: # Using `stop=RESPONSE_SUFFIX` is preferred for two reasons: # 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX. # 2. Even with CFG, inference engines like vLLM do not guarantee termination when the CFG is satisfied (See https://github.com/vllm-project/vllm/issues/29632). @@ -33,12 +62,39 @@ def __call__( backend, use_gim_prompt, visible_tag_fields, + error_mode=error_mode, stop=RESPONSE_SUFFIX, **inference_kwargs, ) class AsyncVLLM(OutlinesAsyncVLLM): + @overload + async def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", + **inference_kwargs: Any, + ) -> Result | list[Result]: ... + + @overload + async def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> GenerationResult | list[GenerationResult]: ... + async def __call__( self, model_input: ContextInput | Query, @@ -46,8 +102,10 @@ async def __call__( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", **inference_kwargs: Any, - ) -> Result | list[Result]: + ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: return await _acall( self, model_input, @@ -56,6 +114,7 @@ async def __call__( use_gim_prompt, visible_tag_fields, stop=RESPONSE_SUFFIX, + error_mode=error_mode, **inference_kwargs, ) diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 8e6d8c6..8387b1e 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -2,7 +2,7 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload from outlines.generator import Generator from outlines.inputs import Chat @@ -11,12 +11,14 @@ from gimkit.contexts import Query, Result from gimkit.log import get_logger +from gimkit.models.types import ErrorMode, GenerationResult from gimkit.models.utils import ( get_outlines_model_input, get_outlines_model_inputs, get_outlines_output_type, - infill_batch_responses, - infill_responses, + parse_batch_generation_responses, + parse_generation_responses, + validate_error_mode, ) from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField @@ -34,6 +36,7 @@ class VLLMOffline(OutlinesVLLMOffline): + @overload def __call__( self, model_input: ContextInput | Query, @@ -41,8 +44,36 @@ def __call__( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", **inference_kwargs: Any, - ) -> Result | list[Result]: + ) -> Result | list[Result]: ... + + @overload + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> GenerationResult | list[GenerationResult]: ... + + def __call__( + self, + model_input: ContextInput | Query, + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", + **inference_kwargs: Any, + ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: + validate_error_mode(error_mode) inference_kwargs = self._ensure_response_suffix(inference_kwargs) outlines_model_input = get_outlines_model_input( @@ -55,12 +86,17 @@ def __call__( generator = Generator(self, outlines_output_type, backend) raw_responses = generator(outlines_model_input, **inference_kwargs) logger.debug(f"Raw responses of {self}: {raw_responses}") - return infill_responses( - model_input, - cast("str | list[str]", raw_responses), - json_responses=(output_type == "json"), + return cast( + "Result | list[Result] | GenerationResult | list[GenerationResult]", + cast("Any", parse_generation_responses)( + model_input, + cast("str | list[str]", raw_responses), + json_responses=(output_type == "json"), + error_mode=error_mode, + ), ) + @overload def batch( self, model_input: Sequence[ContextInput | Query], @@ -68,8 +104,36 @@ def batch( backend: str | None = None, use_gim_prompt: bool = False, visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["raise"] = "raise", **inference_kwargs: Any, - ) -> list[list[Result]]: # type: ignore[override] + ) -> list[list[Result]]: ... + + @overload + def batch( + self, + model_input: Sequence[ContextInput | Query], + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: Literal["collect"], + **inference_kwargs: Any, + ) -> list[list[GenerationResult]]: ... + + def batch( + self, + model_input: Sequence[ContextInput | Query], + output_type: Literal["cfg", "json"] | None = "cfg", + backend: str | None = None, + use_gim_prompt: bool = False, + visible_tag_fields: list[TagField] | None = None, + *, + error_mode: ErrorMode = "raise", + **inference_kwargs: Any, + ) -> list[list[Result]] | list[list[GenerationResult]]: # type: ignore[override] + validate_error_mode(error_mode) inference_kwargs = self._ensure_response_suffix(inference_kwargs) outlines_model_inputs = get_outlines_model_inputs( @@ -87,13 +151,11 @@ def batch( inference_kwargs, ) logger.debug(f"Raw batch responses of {self}: {raw_responses}") - return cast( - "list[list[Result]]", - infill_batch_responses( - model_input, - raw_responses, - json_responses=(output_type == "json"), - ), + return parse_batch_generation_responses( + model_input, + raw_responses, + json_responses=(output_type == "json"), + error_mode=error_mode, ) def _generate_batch_with_output_types( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 8bd73a7..ab7f6f1 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -11,6 +11,7 @@ from gimkit.models.openai import AsyncOpenAI as GIMAsyncOpenAI from gimkit.models.openai import OpenAI as GIMOpenAI from gimkit.models.openai import from_openai +from gimkit.models.types import GenerationResult from gimkit.schemas import MaskedTag @@ -87,3 +88,92 @@ async def test_async_call(): assert isinstance(result, Result) assert result.tags[0] == MaskedTag(id=0, content="world") mock_create.assert_awaited_once() + + +def test_sync_call_collects_candidate_errors(): + client = OpenAI(api_key="test", timeout=0, max_retries=0) + valid = '<|MASKED id="m_0"|>world<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + mock_response = MagicMock() + mock_response.choices = [MagicMock(), MagicMock()] + for choice, content in zip(mock_response.choices, [valid, invalid], strict=True): + choice.message.content = content + choice.message.refusal = None + + with patch.object(client.chat.completions, "create", return_value=mock_response): + model = from_openai(client, model_name="gpt-4o") + generations = model( + "Hello, " + guide(), + output_type=None, + error_mode="collect", + n=2, + ) + + assert isinstance(generations, list) + assert all(isinstance(item, GenerationResult) for item in generations) + assert [item.ok for item in generations] == [True, False] + assert str(generations[0].result) == "Hello, world" + assert generations[1].raw_response == invalid + + +@pytest.mark.asyncio +async def test_async_call_collects_candidate_errors(): + client = AsyncOpenAI(api_key="test", timeout=0, max_retries=0) + valid = '<|MASKED id="m_0"|>world<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + mock_response = MagicMock() + mock_response.choices = [MagicMock(), MagicMock()] + for choice, content in zip(mock_response.choices, [valid, invalid], strict=True): + choice.message.content = content + choice.message.refusal = None + + with patch.object( + client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=mock_response, + ): + model = from_openai(client, model_name="gpt-4o") + generations = await model( + "Hello, " + guide(), + output_type=None, + error_mode="collect", + n=2, + ) + + assert isinstance(generations, list) + assert all(isinstance(item, GenerationResult) for item in generations) + assert [item.ok for item in generations] == [True, False] + assert str(generations[0].result) == "Hello, world" + assert generations[1].raw_response == invalid + + +@pytest.mark.asyncio +async def test_async_request_error_is_not_collected(): + client = AsyncOpenAI(api_key="test", timeout=0, max_retries=0) + request_error = RuntimeError("request failed") + + with patch.object( + client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=request_error, + ): + model = from_openai(client, model_name="gpt-4o") + with pytest.raises(RuntimeError, match="request failed"): + await model("Hello, " + guide(), output_type=None, error_mode="collect") + + +def test_invalid_error_mode_fails_before_model_request(): + client = OpenAI(api_key="test", timeout=0, max_retries=0) + + with patch.object(client.chat.completions, "create") as mock_create: + model = from_openai(client, model_name="gpt-4o") + with pytest.raises(ValueError, match="Invalid error mode"): + model( + "Hello, " + guide(), + output_type=None, + error_mode="invalid", + ) + + mock_create.assert_not_called() diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index a296b2f..08c1657 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -4,6 +4,7 @@ from outlines.types.dsl import CFG, JsonSchema from gimkit.contexts import Query, Result +from gimkit.models.types import GenerationResult from gimkit.models.utils import ( get_outlines_model_input, get_outlines_model_inputs, @@ -11,6 +12,9 @@ infill_batch_responses, infill_responses, json_responses_to_gim_response, + parse_batch_generation_responses, + parse_generation_response, + parse_generation_responses, ) from gimkit.prompts import SYSTEM_PROMPT_MSG, SYSTEM_PROMPT_MSG_JSON from gimkit.schemas import MaskedTag @@ -207,3 +211,105 @@ def test_infill_batch_responses(): with pytest.raises(TypeError, match="Each batch response must be a string or a list"): infill_batch_responses(queries, [responses[0], object()]) + + +def test_parse_generation_response_error_modes(): + query = Query("Hello, ", MaskedTag(id=0)) + valid = '<|MASKED id="m_0"|>world<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + + result = parse_generation_response(query, valid) + assert isinstance(result, Result) + assert str(result) == "Hello, world" + + with pytest.raises(Exception, match="Mismatched or nested masked tags"): + parse_generation_response(query, invalid) + + collected = parse_generation_response(query, valid, error_mode="collect") + assert isinstance(collected, GenerationResult) + assert collected.ok + assert collected.raw_response == valid + assert str(collected.result) == "Hello, world" + assert collected.error_type is None + assert collected.error_message is None + + failed = parse_generation_response(query, invalid, error_mode="collect") + assert isinstance(failed, GenerationResult) + assert not failed.ok + assert failed.raw_response == invalid + assert failed.result is None + assert failed.error_type == "InvalidFormatError" + assert "Mismatched or nested masked tags" in failed.error_message + + +def test_parse_generation_responses_isolates_candidates(): + query = Query("Hello, ", MaskedTag(id=0)) + valid = '<|MASKED id="m_0"|>world<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + + collected = parse_generation_responses( + query, + [valid, invalid], + error_mode="collect", + ) + + assert isinstance(collected, list) + assert [item.ok for item in collected] == [True, False] + assert str(collected[0].result) == "Hello, world" + assert collected[1].raw_response == invalid + + with pytest.raises(Exception, match="Mismatched or nested masked tags"): + parse_generation_responses(query, [valid, invalid]) + + +def test_parse_batch_generation_responses_isolates_queries_and_candidates(): + queries = [ + Query("Hello, ", MaskedTag(id=0)), + Query("Goodbye, ", MaskedTag(id=0)), + ] + valid_world = '<|MASKED id="m_0"|>world<|/MASKED|>' + valid_friend = '<|MASKED id="m_0"|>friend<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + + collected = parse_batch_generation_responses( + queries, + [[valid_world, invalid], [valid_friend]], + error_mode="collect", + ) + + assert [[item.ok for item in group] for group in collected] == [[True, False], [True]] + assert str(collected[0][0].result) == "Hello, world" + assert str(collected[1][0].result) == "Goodbye, friend" + + +def test_parse_generation_response_json_preserves_original_text(): + query = Query("Hello, ", MaskedTag(id=0)) + raw_response = '{\n "invalid": "world"\n}' + + failed = parse_generation_response( + query, + raw_response, + json_response=True, + error_mode="collect", + ) + + assert not failed.ok + assert failed.raw_response == raw_response + assert failed.error_type == "ValueError" + assert "Invalid field name" in failed.error_message + + +def test_parse_generation_response_rejects_invalid_configuration_and_container(): + query = Query(MaskedTag(id=0)) + + with pytest.raises(ValueError, match="Invalid error mode"): + parse_generation_response(query, "response", error_mode="invalid") + + with pytest.raises(TypeError, match="All items in the response list must be strings"): + parse_generation_responses(query, ["response", object()], error_mode="collect") + + with pytest.raises(ValueError, match="Response list is empty"): + parse_batch_generation_responses([query], [[]], error_mode="collect") + + with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): + parse_batch_generation_responses([query], [], error_mode="collect") diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index 383a46b..7e8ce15 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -8,6 +8,7 @@ from outlines.models.vllm_offline import VLLMOffline as OutlinesVLLMOffline from gimkit.contexts import Result +from gimkit.models.types import GenerationResult from gimkit.models.vllm_offline import VLLMOffline as GIMVLLMOffline from gimkit.models.vllm_offline import from_vllm_offline from gimkit.schemas import RESPONSE_SUFFIX, MaskedTag @@ -207,3 +208,58 @@ def test_vllm_offline_call_invalid_response(): mock_generator.return_value = generator_instance with pytest.raises(ValueError, match="Response list is empty"): model(MaskedTag()) + + +def test_vllm_offline_call_collects_candidate_errors(): + model = from_vllm_offline(_mock_vllm_client()) + valid = '<|MASKED id="m_0"|>hi<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + + with patch("gimkit.models.vllm_offline.Generator") as mock_generator: + generator_instance = MagicMock(return_value=[valid, invalid]) + mock_generator.return_value = generator_instance + + returned = model(MaskedTag(), error_mode="collect") + + assert isinstance(returned, list) + assert all(isinstance(item, GenerationResult) for item in returned) + assert [item.ok for item in returned] == [True, False] + assert str(returned[0].result) == "hi" + assert returned[1].raw_response == invalid + + +def test_vllm_offline_batch_collects_query_and_candidate_errors(): + mock_client = _mock_vllm_client() + valid_world = '<|MASKED id="m_0"|>world<|/MASKED|>' + valid_friend = '<|MASKED id="m_0"|>friend<|/MASKED|>' + invalid = '<|MASKED id="m_0"|><|MASKED id="m_1"|>nested<|/MASKED|><|/MASKED|>' + mock_client.generate.return_value = [ + _request_output(valid_world, invalid), + _request_output(valid_friend), + ] + model = from_vllm_offline(mock_client) + + returned = model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ], + error_mode="collect", + ) + + assert [[item.ok for item in group] for group in returned] == [[True, False], [True]] + assert str(returned[0][0].result) == "Hello, world" + assert returned[0][1].raw_response == invalid + assert str(returned[1][0].result) == "Goodbye, friend" + + mock_client.generate.return_value = [ + _request_output(valid_world, invalid), + _request_output(valid_friend), + ] + with pytest.raises(Exception, match="Mismatched or nested masked tags"): + model.batch( + [ + ["Hello, ", MaskedTag()], + ["Goodbye, ", MaskedTag()], + ] + ) From 29676121ae572bd2d507962206cc0815f1e16c60 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Fri, 19 Jun 2026 18:06:11 +0800 Subject: [PATCH 2/5] cover generation error validation --- tests/models/test_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 08c1657..2801817 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -313,3 +313,19 @@ def test_parse_generation_response_rejects_invalid_configuration_and_container() with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): parse_batch_generation_responses([query], [], error_mode="collect") + + with pytest.raises(TypeError, match="Expected raw response to be str"): + parse_generation_response(query, object(), error_mode="collect") + + with pytest.raises(ValueError, match="Batch input list is empty"): + parse_batch_generation_responses([], [], error_mode="collect") + + with pytest.raises(TypeError, match="Expected batch responses to be a list"): + parse_batch_generation_responses([query], "response", error_mode="collect") + + with pytest.raises(TypeError, match="Each batch response group must be a list of strings"): + parse_batch_generation_responses( + [query], + ["response"], + error_mode="collect", + ) From c8c6dd7cc32f081ff6ffaf20930a52feef923320 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 22 Jun 2026 16:27:33 +0800 Subject: [PATCH 3/5] feat: add Query and Response to public API --- src/gimkit/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gimkit/__init__.py b/src/gimkit/__init__.py index facaca7..ddda735 100644 --- a/src/gimkit/__init__.py +++ b/src/gimkit/__init__.py @@ -1,5 +1,6 @@ from importlib.metadata import PackageNotFoundError, version +from gimkit.contexts import Query, Response from gimkit.guides import guide from gimkit.models import GenerationResult, from_openai, from_vllm, from_vllm_offline @@ -12,6 +13,8 @@ __all__ = [ "GenerationResult", + "Query", + "Response", "from_openai", "from_vllm", "from_vllm_offline", From 1b1cfc00ae44a17c254ed8d669235f9d66c13d5e Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:29:56 +0800 Subject: [PATCH 4/5] refactor: remove validate_error_mode function and related checks --- src/gimkit/models/base.py | 3 -- src/gimkit/models/utils.py | 72 ++----------------------------- src/gimkit/models/vllm_offline.py | 28 ++++++------ tests/models/test_openai.py | 15 ------- tests/models/test_utils.py | 46 -------------------- tests/models/test_vllm_offline.py | 68 +++++++---------------------- 6 files changed, 33 insertions(+), 199 deletions(-) diff --git a/src/gimkit/models/base.py b/src/gimkit/models/base.py index d4d8b10..19d4b5c 100644 --- a/src/gimkit/models/base.py +++ b/src/gimkit/models/base.py @@ -10,7 +10,6 @@ get_outlines_model_input, get_outlines_output_type, parse_generation_responses, - validate_error_mode, ) from gimkit.schemas import ContextInput, TagField @@ -29,7 +28,6 @@ def _call( error_mode: ErrorMode = "raise", **inference_kwargs: Any, ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: - validate_error_mode(error_mode) outlines_model_input = get_outlines_model_input( model_input, output_type, use_gim_prompt, visible_tag_fields ) @@ -60,7 +58,6 @@ async def _acall( error_mode: ErrorMode = "raise", **inference_kwargs: Any, ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: - validate_error_mode(error_mode) outlines_model_input = get_outlines_model_input( model_input, output_type, use_gim_prompt, visible_tag_fields ) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index cb2f55c..9b10fc2 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -150,11 +150,6 @@ def json_responses_to_gim_response(json_response: str) -> str: ) -def validate_error_mode(error_mode: ErrorMode) -> None: - if error_mode not in ("raise", "collect"): - raise ValueError(f"Invalid error mode: {error_mode}. Expected 'raise' or 'collect'.") - - @overload def parse_generation_response( query: ContextInput | Query, @@ -188,9 +183,6 @@ def parse_generation_response( already generated string. Model invocation and response-container errors remain whole-call failures. """ - validate_error_mode(error_mode) - if not isinstance(raw_response, str): - raise TypeError(f"Expected raw response to be str, got {type(raw_response)}") try: result = infill_responses(query, raw_response, json_responses=json_response) @@ -236,7 +228,6 @@ def parse_generation_responses( error_mode: ErrorMode = "raise", ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: """Parse one or more raw generations while preserving their container shape.""" - validate_error_mode(error_mode) if isinstance(raw_responses, str): return parse_generation_response( query, @@ -244,12 +235,6 @@ def parse_generation_responses( json_response=json_responses, error_mode=error_mode, ) - if not isinstance(raw_responses, list): - raise TypeError(f"Expected responses to be str or list of str, got {type(raw_responses)}") - if len(raw_responses) == 0: - raise ValueError("Response list is empty.") - if not all(isinstance(response, str) for response in raw_responses): - raise TypeError(f"All items in the response list must be strings, got: {raw_responses}") parsed = [ cast("Any", parse_generation_response)( @@ -291,32 +276,11 @@ def parse_batch_generation_responses( error_mode: ErrorMode = "raise", ) -> list[list[Result]] | list[list[GenerationResult]]: """Parse batch generations, preserving query and candidate dimensions.""" - validate_error_mode(error_mode) - if len(queries) == 0: - raise ValueError("Batch input list is empty.") - if not isinstance(raw_responses, list): - raise TypeError(f"Expected batch responses to be a list, got {type(raw_responses)}") if len(queries) != len(raw_responses): raise ValueError( "Mismatched number of batch inputs and responses: " f"{len(queries)} input(s), {len(raw_responses)} response group(s)." ) - if not all(isinstance(response_group, list) for response_group in raw_responses): - invalid_group = next( - response_group - for response_group in raw_responses - if not isinstance(response_group, list) - ) - raise TypeError( - f"Each batch response group must be a list of strings, got {type(invalid_group)}" - ) - for response_group in raw_responses: - if len(response_group) == 0: - raise ValueError("Response list is empty.") - if not all(isinstance(response, str) for response in response_group): - raise TypeError( - f"All items in the response list must be strings, got: {response_group}" - ) parsed = [ cast( @@ -355,16 +319,6 @@ def infill_responses( responses = json_responses_to_gim_response(responses) return infill(query, responses) - # Handle list of responses - if not isinstance(responses, list): - raise TypeError(f"Expected responses to be str or list of str, got {type(responses)}") - - if len(responses) == 0: - raise ValueError("Response list is empty.") - - if not all(isinstance(resp, str) for resp in responses): - raise TypeError(f"All items in the response list must be strings, got: {responses}") - return [infill_responses(query, resp, json_responses=json_responses) for resp in responses] @@ -388,31 +342,13 @@ def infill_batch_responses( json_responses: bool = False, ) -> list[Result] | list[list[Result]]: """Infill each query in a batch with its corresponding response(s).""" - if len(queries) == 0: - raise ValueError("Batch input list is empty.") - if not isinstance(responses, list): - raise TypeError(f"Expected batch responses to be a list, got {type(responses)}") if len(queries) != len(responses): raise ValueError( "Mismatched number of batch inputs and responses: " f"{len(queries)} input(s), {len(responses)} response(s)." ) - if all(isinstance(response, str) for response in responses): - return [ - infill_responses(query, cast("str", response), json_responses=json_responses) - for query, response in zip(queries, responses, strict=True) - ] - - if all(isinstance(response, list) for response in responses): - return [ - infill_responses(query, cast("list[str]", response), json_responses=json_responses) - for query, response in zip(queries, responses, strict=True) - ] - - invalid_response = next( - response for response in responses if not isinstance(response, (str, list)) - ) - raise TypeError( - f"Each batch response must be a string or a list of strings, got {type(invalid_response)}" - ) + return [ + infill_responses(query, response, json_responses=json_responses) + for query, response in zip(queries, responses, strict=True) + ] diff --git a/src/gimkit/models/vllm_offline.py b/src/gimkit/models/vllm_offline.py index 8387b1e..54cc576 100644 --- a/src/gimkit/models/vllm_offline.py +++ b/src/gimkit/models/vllm_offline.py @@ -18,7 +18,6 @@ get_outlines_output_type, parse_batch_generation_responses, parse_generation_responses, - validate_error_mode, ) from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField @@ -73,7 +72,6 @@ def __call__( error_mode: ErrorMode = "raise", **inference_kwargs: Any, ) -> Result | list[Result] | GenerationResult | list[GenerationResult]: - validate_error_mode(error_mode) inference_kwargs = self._ensure_response_suffix(inference_kwargs) outlines_model_input = get_outlines_model_input( @@ -133,7 +131,6 @@ def batch( error_mode: ErrorMode = "raise", **inference_kwargs: Any, ) -> list[list[Result]] | list[list[GenerationResult]]: # type: ignore[override] - validate_error_mode(error_mode) inference_kwargs = self._ensure_response_suffix(inference_kwargs) outlines_model_inputs = get_outlines_model_inputs( @@ -213,26 +210,27 @@ def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, # Using `stop=RESPONSE_SUFFIX` is preferred for two reasons: # 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX. # 2. Even with CFG, inference engines like vLLM do not guarantee termination when the CFG is satisfied (See https://github.com/vllm-project/vllm/issues/29632). + + def _ensure_sampling_params_response_suffix(sampling_params: "SamplingParams") -> None: + if sampling_params.stop is None: + sampling_params.stop = [RESPONSE_SUFFIX] + elif isinstance(sampling_params.stop, str): + if sampling_params.stop != RESPONSE_SUFFIX: + sampling_params.stop = [sampling_params.stop, RESPONSE_SUFFIX] + elif RESPONSE_SUFFIX not in sampling_params.stop: + sampling_params.stop.append(RESPONSE_SUFFIX) + if "sampling_params" not in inference_kwargs: from vllm import SamplingParams inference_kwargs["sampling_params"] = SamplingParams(stop=[RESPONSE_SUFFIX]) - elif isinstance(inference_kwargs["sampling_params"], list): + elif isinstance(inference_kwargs["sampling_params"], list): # For batch inference for sampling_params in inference_kwargs["sampling_params"]: - self._ensure_sampling_params_response_suffix(sampling_params) + _ensure_sampling_params_response_suffix(sampling_params) else: - self._ensure_sampling_params_response_suffix(inference_kwargs["sampling_params"]) + _ensure_sampling_params_response_suffix(inference_kwargs["sampling_params"]) return inference_kwargs - def _ensure_sampling_params_response_suffix(self, sampling_params: "SamplingParams") -> None: - if sampling_params.stop is None: - sampling_params.stop = [RESPONSE_SUFFIX] - elif isinstance(sampling_params.stop, str): - if sampling_params.stop != RESPONSE_SUFFIX: - sampling_params.stop = [sampling_params.stop, RESPONSE_SUFFIX] - elif RESPONSE_SUFFIX not in sampling_params.stop: - sampling_params.stop.append(RESPONSE_SUFFIX) - def from_vllm_offline(model: "LLM") -> VLLMOffline: return VLLMOffline(model) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index ab7f6f1..2812de0 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -162,18 +162,3 @@ async def test_async_request_error_is_not_collected(): model = from_openai(client, model_name="gpt-4o") with pytest.raises(RuntimeError, match="request failed"): await model("Hello, " + guide(), output_type=None, error_mode="collect") - - -def test_invalid_error_mode_fails_before_model_request(): - client = OpenAI(api_key="test", timeout=0, max_retries=0) - - with patch.object(client.chat.completions, "create") as mock_create: - model = from_openai(client, model_name="gpt-4o") - with pytest.raises(ValueError, match="Invalid error mode"): - model( - "Hello, " + guide(), - output_type=None, - error_mode="invalid", - ) - - mock_create.assert_not_called() diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 2801817..3cc12a6 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -160,18 +160,6 @@ def test_infill_responses(): assert isinstance(result_from_json, Result) assert str(result_from_json) == "Hello, world and friend" - # Test invalid response type - with pytest.raises(TypeError, match="Expected responses to be str or list of str, got"): - infill_responses(query, 123) - - # Test empty list - with pytest.raises(ValueError, match="Response list is empty"): - infill_responses(query, []) - - # Test list with non-string items - with pytest.raises(TypeError, match="All items in the response list must be strings, got"): - infill_responses(query, ["a", 1]) - def test_infill_batch_responses(): queries = [ @@ -200,18 +188,9 @@ def test_infill_batch_responses(): ) assert [str(result) for result in json_results] == ["Hello, world", "Goodbye, friend"] - with pytest.raises(ValueError, match="Batch input list is empty"): - infill_batch_responses([], []) - with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): infill_batch_responses(queries, [responses[0]]) - with pytest.raises(TypeError, match="Expected batch responses to be a list"): - infill_batch_responses(queries, "response") - - with pytest.raises(TypeError, match="Each batch response must be a string or a list"): - infill_batch_responses(queries, [responses[0], object()]) - def test_parse_generation_response_error_modes(): query = Query("Hello, ", MaskedTag(id=0)) @@ -302,30 +281,5 @@ def test_parse_generation_response_json_preserves_original_text(): def test_parse_generation_response_rejects_invalid_configuration_and_container(): query = Query(MaskedTag(id=0)) - with pytest.raises(ValueError, match="Invalid error mode"): - parse_generation_response(query, "response", error_mode="invalid") - - with pytest.raises(TypeError, match="All items in the response list must be strings"): - parse_generation_responses(query, ["response", object()], error_mode="collect") - - with pytest.raises(ValueError, match="Response list is empty"): - parse_batch_generation_responses([query], [[]], error_mode="collect") - with pytest.raises(ValueError, match="Mismatched number of batch inputs and responses"): parse_batch_generation_responses([query], [], error_mode="collect") - - with pytest.raises(TypeError, match="Expected raw response to be str"): - parse_generation_response(query, object(), error_mode="collect") - - with pytest.raises(ValueError, match="Batch input list is empty"): - parse_batch_generation_responses([], [], error_mode="collect") - - with pytest.raises(TypeError, match="Expected batch responses to be a list"): - parse_batch_generation_responses([query], "response", error_mode="collect") - - with pytest.raises(TypeError, match="Each batch response group must be a list of strings"): - parse_batch_generation_responses( - [query], - ["response"], - error_mode="collect", - ) diff --git a/tests/models/test_vllm_offline.py b/tests/models/test_vllm_offline.py index 7e8ce15..5b49820 100644 --- a/tests/models/test_vllm_offline.py +++ b/tests/models/test_vllm_offline.py @@ -115,20 +115,26 @@ def test_vllm_offline_batch_sampling_params_list(): assert RESPONSE_SUFFIX in sampling_params[1].stop -def test_vllm_offline_ensure_sampling_params_response_suffix(): +def test_vllm_offline_ensure_response_suffix(): model = from_vllm_offline(_mock_vllm_client()) - sampling_params = SimpleNamespace(stop=None) - model._ensure_sampling_params_response_suffix(sampling_params) - assert sampling_params.stop == [RESPONSE_SUFFIX] + sampling_params1 = SimpleNamespace(stop=None) + model._ensure_response_suffix({"sampling_params": sampling_params1}) + assert sampling_params1.stop == [RESPONSE_SUFFIX] - sampling_params = SimpleNamespace(stop="") - model._ensure_sampling_params_response_suffix(sampling_params) - assert sampling_params.stop == ["", RESPONSE_SUFFIX] + sampling_params2 = SimpleNamespace(stop="") + model._ensure_response_suffix({"sampling_params": sampling_params2}) + assert sampling_params2.stop == ["", RESPONSE_SUFFIX] - sampling_params = SimpleNamespace(stop=RESPONSE_SUFFIX) - model._ensure_sampling_params_response_suffix(sampling_params) - assert sampling_params.stop == RESPONSE_SUFFIX + sampling_params3 = SimpleNamespace(stop=RESPONSE_SUFFIX) + model._ensure_response_suffix({"sampling_params": sampling_params3}) + assert sampling_params3.stop == RESPONSE_SUFFIX + + sp1 = SimpleNamespace(stop=None) + sp2 = SimpleNamespace(stop="") + model._ensure_response_suffix({"sampling_params": [sp1, sp2]}) + assert sp1.stop == [RESPONSE_SUFFIX] + assert sp2.stop == ["", RESPONSE_SUFFIX] def test_vllm_offline_batch_invalid_sampling_params_list_length(): @@ -146,21 +152,6 @@ def test_vllm_offline_batch_invalid_sampling_params_list_length(): ) -def test_vllm_offline_batch_invalid_response(): - mock_client = _mock_vllm_client() - mock_client.generate.return_value = [_request_output()] - model = from_vllm_offline(mock_client) - - with pytest.raises(ValueError, match="Response list is empty"): - model.batch([["Hello, ", MaskedTag()]]) - - mock_client.generate.return_value = [ - MagicMock(outputs=[MagicMock(text=object())]), - ] - with pytest.raises(TypeError, match="All items in the response list must be strings"): - model.batch([["Hello, ", MaskedTag()]]) - - def test_vllm_offline_batch_chat(): mock_client = _mock_vllm_client() mock_client.chat.return_value = [ @@ -183,33 +174,6 @@ def test_vllm_offline_batch_chat(): mock_client.chat.assert_called_once() -def test_vllm_offline_call_invalid_response(): - from vllm import SamplingParams - - model = from_vllm_offline(_mock_vllm_client()) - - with patch("gimkit.models.vllm_offline.Generator") as mock_generator: - generator_instance = MagicMock() - generator_instance.return_value = set() - mock_generator.return_value = generator_instance - with pytest.raises(TypeError, match="Expected responses to be str or list of str, got"): - model(MaskedTag()) - - with patch("gimkit.models.vllm_offline.Generator") as mock_generator: - generator_instance = MagicMock() - generator_instance.return_value = [object, "response2"] - mock_generator.return_value = generator_instance - with pytest.raises(TypeError, match="All items in the response list must be strings, got"): - model(MaskedTag(), sampling_params=SamplingParams(n=2)) - - with patch("gimkit.models.vllm_offline.Generator") as mock_generator: - generator_instance = MagicMock() - generator_instance.return_value = [] - mock_generator.return_value = generator_instance - with pytest.raises(ValueError, match="Response list is empty"): - model(MaskedTag()) - - def test_vllm_offline_call_collects_candidate_errors(): model = from_vllm_offline(_mock_vllm_client()) valid = '<|MASKED id="m_0"|>hi<|/MASKED|>' From 8b17e075205b82a3c86afc742e5985518b1abd86 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:51:11 +0800 Subject: [PATCH 5/5] fix: suppress type checking error in infill_batch_responses --- src/gimkit/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gimkit/models/utils.py b/src/gimkit/models/utils.py index 9b10fc2..fb36bbd 100644 --- a/src/gimkit/models/utils.py +++ b/src/gimkit/models/utils.py @@ -349,6 +349,6 @@ def infill_batch_responses( ) return [ - infill_responses(query, response, json_responses=json_responses) + infill_responses(query, response, json_responses=json_responses) # type: ignore[call-overload] for query, response in zip(queries, responses, strict=True) ]