Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,30 @@ result = model(query, output_type=None, use_gim_prompt=True)
print(result.tags["email"].content)
```

## Per-generation error collection

The default `error_mode="raise"` preserves fail-fast behavior. For multiple
candidates, use `error_mode="collect"` to parse each raw response independently:

```python
generations = model(query, n=2, error_mode="collect")

for generation in generations:
if generation.ok:
print(generation.result)
else:
print(generation.error_type, generation.error_message)
print(generation.raw_response)
```

`collect` only captures parsing and infill errors after raw text has been
generated. Network, authentication, timeout, model request, and invalid response
container errors still fail the whole call. Async models use the same parameter
through `await model(...)`.

## Advanced options

- `visible_tag_fields`: control which `MaskedTag` fields are visible to the model (e.g. `["id", "name", "desc", "content", "regex"]`). Defaults to `None` (basic fields only: `["id", "desc", "content"]`).
- `backend`: pass through to Outlines generator backend selection.
- `error_mode`: `"raise"` (default) or `"collect"`.
- `**inference_kwargs`: forwarded to the underlying OpenAI call.
21 changes: 21 additions & 0 deletions docs/models/openai.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,29 @@ result = model(query, output_type=None, use_gim_prompt=True)
print(result.tags["email"].content)
```

## 逐生成错误收集

默认的 `error_mode="raise"` 保持快速失败行为。多候选生成时,可使用
`error_mode="collect"` 让每条原始响应独立解析:

```python
generations = model(query, n=2, error_mode="collect")

for generation in generations:
if generation.ok:
print(generation.result)
else:
print(generation.error_type, generation.error_message)
print(generation.raw_response)
```

`collect` 只收集已经获得原始文本后的解析和 infill 错误。网络、认证、超时、
模型请求失败以及无效的响应容器仍会作为整个调用异常抛出。异步模型使用相同参数,
调用方式为 `await model(...)`。

## 高级参数

- `visible_tag_fields`:控制哪些 `MaskedTag` 字段对模型可见(如 `["id", "name", "desc", "content", "regex"]`)。默认为 `None`(仅基础字段:`["id", "desc", "content"]`)。
- `backend`:透传给 Outlines 生成器后端选择。
- `error_mode`:`"raise"`(默认)或 `"collect"`。
- `**inference_kwargs`:透传到底层 OpenAI 推理参数。
21 changes: 21 additions & 0 deletions docs/models/vllm-offline.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ batch_results = model.batch([query, query])
first_result = batch_results[0][0]
```

With `error_mode="collect"`, batch always returns a two-dimensional
`list[list[GenerationResult]]`: the outer list maps to queries and the inner list
maps to candidates.

```python
generation_groups = model.batch(queries, error_mode="collect")

for generation_group in generation_groups:
for generation in generation_group:
if generation.ok:
print(generation.result)
else:
print(generation.error_type, generation.error_message)
print(generation.raw_response)
```

A parsing failure for one candidate does not affect other candidates or queries.
The default `error_mode="raise"` preserves existing return types and fail-fast
behavior. Generation failures, invalid batch shapes, and invalid arguments still
fail the whole call.

## Output types

### `output_type="cfg"` (default)
Expand Down
19 changes: 19 additions & 0 deletions docs/models/vllm-offline.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ batch_results = model.batch([query, query])
first_result = batch_results[0][0]
```

使用 `error_mode="collect"` 时,batch 始终返回二维
`list[list[GenerationResult]]`:外层对应 query,内层对应候选。

```python
generation_groups = model.batch(queries, error_mode="collect")

for generation_group in generation_groups:
for generation in generation_group:
if generation.ok:
print(generation.result)
else:
print(generation.error_type, generation.error_message)
print(generation.raw_response)
```

单个候选的解析失败不会影响同一 query 的其他候选或其他 query。默认
`error_mode="raise"` 的返回类型和快速失败行为保持不变。模型生成失败、batch
形状错误和无效参数仍会作为整个调用异常抛出。

## 输出类型

### `output_type="cfg"`(默认)
Expand Down
11 changes: 11 additions & 0 deletions docs/models/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,18 @@ result = model(query, output_type="cfg")
result = model(query, output_type="json", use_gim_prompt=True)
```

## Per-generation error collection

For multiple candidates, pass `error_mode="collect"` to receive one
`GenerationResult` per raw response. Successful items expose `.result`; failed
items retain `.raw_response`, `.error_type`, and `.error_message`. The default
`error_mode="raise"` preserves existing behavior and return types. Model request,
network, and response-container failures still fail the whole call.

Async clients use the same parameter through `await model(...)`.

## Notes

- GIMKit automatically adds `stop="<|/GIM_RESPONSE|>"` for safer termination.
- `error_mode` accepts `"raise"` (default) or `"collect"`.
- You can still pass extra generation args via `**inference_kwargs`.
10 changes: 10 additions & 0 deletions docs/models/vllm.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ result = model(query, output_type="cfg")
result = model(query, output_type="json", use_gim_prompt=True)
```

## 逐生成错误收集

多候选生成时,可传入 `error_mode="collect"`,逐条获得
`GenerationResult`。成功项通过 `.result` 访问,失败项保留 `.raw_response`、
`.error_type` 和 `.error_message`。默认 `error_mode="raise"` 的行为和返回类型
保持不变。模型请求、网络和响应容器错误仍会作为整个调用异常抛出。

异步客户端使用相同参数,调用方式为 `await model(...)`。

## 说明

- GIMKit 会自动添加 `stop="<|/GIM_RESPONSE|>"`,确保更稳定停止。
- `error_mode` 可设为 `"raise"`(默认)或 `"collect"`。
- 可通过 `**inference_kwargs` 继续传递生成参数。
6 changes: 5 additions & 1 deletion src/gimkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from importlib.metadata import PackageNotFoundError, version

from gimkit.contexts import Query, Response
from gimkit.guides import guide
from gimkit.models import from_openai, from_vllm, from_vllm_offline
from gimkit.models import GenerationResult, from_openai, from_vllm, from_vllm_offline


try:
Expand All @@ -11,6 +12,9 @@


__all__ = [
"GenerationResult",
"Query",
"Response",
"from_openai",
"from_vllm",
"from_vllm_offline",
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .openai import from_openai
from .types import GenerationResult
from .vllm import from_vllm
from .vllm_offline import from_vllm_offline


__all__ = ["from_openai", "from_vllm", "from_vllm_offline"]
__all__ = ["GenerationResult", "from_openai", "from_vllm", "from_vllm_offline"]
35 changes: 28 additions & 7 deletions src/gimkit/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

from gimkit.contexts import Query, Result
from gimkit.log import get_logger
from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses
from gimkit.models.types import ErrorMode, GenerationResult
from gimkit.models.utils import (
get_outlines_model_input,
get_outlines_output_type,
parse_generation_responses,
)
from gimkit.schemas import ContextInput, TagField


Expand All @@ -19,8 +24,10 @@ def _call(
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: ErrorMode = "raise",
**inference_kwargs: Any,
) -> Result | list[Result]:
) -> Result | list[Result] | GenerationResult | list[GenerationResult]:
outlines_model_input = get_outlines_model_input(
model_input, output_type, use_gim_prompt, visible_tag_fields
)
Expand All @@ -29,8 +36,14 @@ def _call(
generator = Generator(self, outlines_output_type, backend)
raw_responses = generator(outlines_model_input, **inference_kwargs)
logger.debug(f"Raw responses of {self}: {raw_responses}")
return infill_responses(
model_input, cast("str | list[str]", raw_responses), json_responses=(output_type == "json")
return cast(
"Result | list[Result] | GenerationResult | list[GenerationResult]",
cast("Any", parse_generation_responses)(
model_input,
cast("str | list[str]", raw_responses),
json_responses=(output_type == "json"),
error_mode=error_mode,
),
)


Expand All @@ -41,8 +54,10 @@ async def _acall(
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: ErrorMode = "raise",
**inference_kwargs: Any,
) -> Result | list[Result]:
) -> Result | list[Result] | GenerationResult | list[GenerationResult]:
outlines_model_input = get_outlines_model_input(
model_input, output_type, use_gim_prompt, visible_tag_fields
)
Expand All @@ -51,6 +66,12 @@ async def _acall(
generator = Generator(self, outlines_output_type, backend)
raw_responses = await generator(outlines_model_input, **inference_kwargs)
logger.debug(f"Raw responses of {self}: {raw_responses}")
return infill_responses(
model_input, cast("str | list[str]", raw_responses), json_responses=(output_type == "json")
return cast(
"Result | list[Result] | GenerationResult | list[GenerationResult]",
cast("Any", parse_generation_responses)(
model_input,
cast("str | list[str]", raw_responses),
json_responses=(output_type == "json"),
error_mode=error_mode,
),
)
63 changes: 61 additions & 2 deletions src/gimkit/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,98 @@

from gimkit.contexts import Query, Result
from gimkit.models.base import _acall, _call
from gimkit.models.types import ErrorMode, GenerationResult
from gimkit.schemas import ContextInput, TagField


class OpenAI(OutlinesOpenAI):
@overload
def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: Literal["raise"] = "raise",
**inference_kwargs: Any,
) -> Result | list[Result]:
) -> Result | list[Result]: ...

@overload
def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: Literal["collect"],
**inference_kwargs: Any,
) -> GenerationResult | list[GenerationResult]: ...

def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: ErrorMode = "raise",
**inference_kwargs: Any,
) -> Result | list[Result] | GenerationResult | list[GenerationResult]:
return _call(
self,
model_input,
output_type,
backend,
use_gim_prompt,
visible_tag_fields,
error_mode=error_mode,
**inference_kwargs,
)


class AsyncOpenAI(OutlinesAsyncOpenAI):
@overload
async def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: Literal["raise"] = "raise",
**inference_kwargs: Any,
) -> Result | list[Result]: ...

@overload
async def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: Literal["collect"],
**inference_kwargs: Any,
) -> GenerationResult | list[GenerationResult]: ...

async def __call__(
self,
model_input: ContextInput | Query,
output_type: Literal["json"] | None = None,
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
*,
error_mode: ErrorMode = "raise",
**inference_kwargs: Any,
) -> Result | list[Result]:
) -> Result | list[Result] | GenerationResult | list[GenerationResult]:
return await _acall(
self,
model_input,
Expand All @@ -53,6 +111,7 @@ async def __call__(
use_gim_prompt,
visible_tag_fields,
**inference_kwargs,
error_mode=error_mode,
)


Expand Down
22 changes: 22 additions & 0 deletions src/gimkit/models/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass
from typing import Literal, TypeAlias

from gimkit.contexts import Result


ErrorMode: TypeAlias = Literal["raise", "collect"]


@dataclass(frozen=True, slots=True)
class GenerationResult:
"""The parsed result or parsing error for one raw model generation."""

raw_response: str
result: Result | None = None
error_type: str | None = None
error_message: str | None = None

@property
def ok(self) -> bool:
"""Whether the raw generation was parsed and infilled successfully."""
return self.result is not None
Loading
Loading