Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/models/vllm-offline.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ result = model(query)
result_non_gim = model(query, use_gim_prompt=True)
```

## Batch inference

`model.batch(...)` wraps Outlines' batch API for vLLM offline.
Each query can use its own GIM-derived structured output schema.

```python
batch_results = model.batch([query, query])
first_result = batch_results[0][0]
```

## Output types

### `output_type="cfg"` (default)
Expand Down
10 changes: 10 additions & 0 deletions docs/models/vllm-offline.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ result = model(query)
result_non_gim = model(query, use_gim_prompt=True)
```

## 批量推理

`model.batch(...)` 会包装 Outlines 的 vLLM offline batch API。
每条 query 都可以使用各自从 GIM 推导出的结构化输出 schema。

```python
batch_results = model.batch([query, query])
first_result = batch_results[0][0]
```

## 输出类型

### `output_type="cfg"`(默认)
Expand Down
31 changes: 31 additions & 0 deletions examples/vllm_offline_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vllm import LLM, SamplingParams

from gimkit import from_vllm_offline
from gimkit import guide as g


llm = LLM(model="Sculpt-AI/GIM-1.7B", max_model_len=8192)
model = from_vllm_offline(llm)

queries = [
f"Extract the person's name: Alice Zhang -> {g.person_name(name='name')}",
(
"Extract contact fields from: Bob Chen, bob@example.com, +1-212-555-0101\n"
f"Name: {g.person_name(name='name')}\n"
f"Email: {g.e_mail(name='email')}\n"
f"Phone: {g.phone_number(name='phone')}"
),
]

sampling_params = [
SamplingParams(temperature=0.0, max_tokens=256, seed=0),
SamplingParams(temperature=0.0, max_tokens=512, seed=1),
]

batch_results = model.batch(queries, output_type="cfg", sampling_params=sampling_params)

# batch_results keeps two dimensions: inputs, then completions per input.
for input_index, completions in enumerate(batch_results):
print(f"Input {input_index}")
for completion_index, result in enumerate(completions):
print(f"Completion {completion_index}: {result}")
81 changes: 78 additions & 3 deletions src/gimkit/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, overload
from collections.abc import Sequence
from typing import Literal, cast, overload

from outlines.inputs import Chat
from outlines.types.dsl import CFG, JsonSchema
Expand All @@ -14,6 +15,10 @@
from gimkit.schemas import ContextInput, MaskedTag, TagField


def _ensure_query(model_input: ContextInput | Query) -> Query:
return model_input if isinstance(model_input, Query) else Query(model_input)


def get_outlines_model_input(
model_input: ContextInput | Query,
output_type: Literal["cfg", "json"] | None,
Expand All @@ -31,7 +36,7 @@ def get_outlines_model_input(
If None, uses the Query default (["id", "desc", "content"]).
Example: ["id", "name", "desc", "content", "regex"] to expose all fields.
"""
query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input
query_obj = _ensure_query(model_input)
outlines_model_input: str | Chat = (
query_obj.to_string(fields=visible_tag_fields)
if visible_tag_fields is not None
Expand All @@ -57,11 +62,31 @@ def get_outlines_model_input(
return outlines_model_input


def get_outlines_model_inputs(
model_inputs: Sequence[ContextInput | Query],
output_type: Literal["cfg", "json"] | None,
use_gim_prompt: bool,
visible_tag_fields: list[TagField] | None = None,
) -> list[str | Chat]:
"""Transform a batch of model inputs to Outlines-compatible formats."""
if len(model_inputs) == 0:
raise ValueError("Batch input list is empty.")
return [
get_outlines_model_input(
model_input,
output_type,
use_gim_prompt,
visible_tag_fields=visible_tag_fields,
)
for model_input in model_inputs
]


def get_outlines_output_type(
model_input: ContextInput | Query, output_type: Literal["cfg", "json"] | None
) -> None | CFG | JsonSchema:
"""Transform the output type to an Outlines-compatible format."""
query_obj = Query(model_input) if not isinstance(model_input, Query) else model_input
query_obj = _ensure_query(model_input)
if output_type is None:
return None
elif output_type == "cfg":
Expand Down Expand Up @@ -157,3 +182,53 @@ def infill_responses(
raise TypeError(f"All items in the response list must be strings, got: {responses}")

return [infill_responses(query, resp, json_responses=json_responses) for resp in responses]


@overload
def infill_batch_responses(
queries: Sequence[ContextInput | Query], responses: list[str], json_responses: bool = False
) -> list[Result]: ...


@overload
def infill_batch_responses(
queries: Sequence[ContextInput | Query],
responses: list[list[str]],
json_responses: bool = False,
) -> list[list[Result]]: ...


def infill_batch_responses(
queries: Sequence[ContextInput | Query],
responses: list[str] | list[list[str]],
json_responses: bool = False,
) -> list[Result] | list[list[Result]]:
"""Infill each query in a batch with its corresponding response(s)."""
if len(queries) == 0:
raise ValueError("Batch input list is empty.")
if not isinstance(responses, list):
raise TypeError(f"Expected batch responses to be a list, got {type(responses)}")
if len(queries) != len(responses):
raise ValueError(
"Mismatched number of batch inputs and responses: "
f"{len(queries)} input(s), {len(responses)} response(s)."
)

if all(isinstance(response, str) for response in responses):
return [
infill_responses(query, cast("str", response), json_responses=json_responses)
for query, response in zip(queries, responses, strict=True)
]

if all(isinstance(response, list) for response in responses):
return [
infill_responses(query, cast("list[str]", response), json_responses=json_responses)
for query, response in zip(queries, responses, strict=True)
]

invalid_response = next(
response for response in responses if not isinstance(response, (str, list))
)
raise TypeError(
f"Each batch response must be a string or a list of strings, got {type(invalid_response)}"
)
124 changes: 117 additions & 7 deletions src/gimkit/models/vllm_offline.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
# Adapted from https://github.com/dottxt-ai/outlines/blob/main/outlines/models/vllm_offline.py


from typing import TYPE_CHECKING, Any, Literal, cast
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast

from outlines.generator import Generator
from outlines.inputs import Chat
from outlines.models.vllm_offline import VLLMOffline as OutlinesVLLMOffline
from outlines.types.dsl import CFG, JsonSchema

from gimkit.contexts import Query, Result
from gimkit.log import get_logger
from gimkit.models.utils import get_outlines_model_input, get_outlines_output_type, infill_responses
from gimkit.models.utils import (
get_outlines_model_input,
get_outlines_model_inputs,
get_outlines_output_type,
infill_batch_responses,
infill_responses,
)
from gimkit.schemas import RESPONSE_SUFFIX, ContextInput, TagField


logger = get_logger(__name__)

if TYPE_CHECKING:
from vllm import LLM
from vllm.sampling_params import SamplingParams


OutlinesModelInput: TypeAlias = str | Chat
OutlinesOutputType: TypeAlias = CFG | JsonSchema | None
VLLMFormattedInput: TypeAlias = str | list[object]


class VLLMOffline(OutlinesVLLMOffline):
Expand Down Expand Up @@ -46,6 +61,92 @@ def __call__(
json_responses=(output_type == "json"),
)

def batch(
self,
model_input: Sequence[ContextInput | Query],
output_type: Literal["cfg", "json"] | None = "cfg",
backend: str | None = None,
use_gim_prompt: bool = False,
visible_tag_fields: list[TagField] | None = None,
**inference_kwargs: Any,
) -> list[list[Result]]: # type: ignore[override]
inference_kwargs = self._ensure_response_suffix(inference_kwargs)

outlines_model_inputs = get_outlines_model_inputs(
model_input,
output_type,
use_gim_prompt,
visible_tag_fields=visible_tag_fields,
)
outlines_output_types = [
get_outlines_output_type(batch_item, output_type) for batch_item in model_input
]
raw_responses = self._generate_batch_with_output_types(
outlines_model_inputs,
outlines_output_types,
inference_kwargs,
)
logger.debug(f"Raw batch responses of {self}: {raw_responses}")
return cast(
"list[list[Result]]",
infill_batch_responses(
model_input,
raw_responses,
json_responses=(output_type == "json"),
),
)

def _generate_batch_with_output_types(
self,
model_inputs: list[OutlinesModelInput],
output_types: list[OutlinesOutputType],
inference_kwargs: dict[str, Any],
) -> list[list[str]]:
generation_kwargs = dict(inference_kwargs)
sampling_params = generation_kwargs.pop("sampling_params", None)
sampling_params_list = self._build_batch_sampling_params(sampling_params, output_types)

formatted_inputs = [
cast("VLLMFormattedInput", self.type_adapter.format_input(item))
for item in model_inputs
]
if formatted_inputs and isinstance(formatted_inputs[0], list):
chat_messages = cast("list[list[Any]]", formatted_inputs)
results = self.model.chat(
messages=chat_messages,
sampling_params=sampling_params_list,
**generation_kwargs,
)
else:
prompts = cast("list[str]", formatted_inputs)
results = self.model.generate(
prompts=prompts,
sampling_params=sampling_params_list,
**generation_kwargs,
)
return [[sample.text for sample in batch.outputs] for batch in results]

def _build_batch_sampling_params(
self,
sampling_params: "SamplingParams | list[SamplingParams] | None",
output_types: list[OutlinesOutputType],
) -> list["SamplingParams"]:
if isinstance(sampling_params, list):
if len(sampling_params) != len(output_types):
raise ValueError(
"sampling_params list must have the same length as model_input: "
f"{len(sampling_params)} sampling params for {len(output_types)} input(s)."
)
return [
self._build_generation_args({"sampling_params": params}, output_type)
for params, output_type in zip(sampling_params, output_types, strict=True)
]

return [
self._build_generation_args({"sampling_params": sampling_params}, output_type)
for output_type in output_types
]

def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str, Any]:
# Using `stop=RESPONSE_SUFFIX` is preferred for two reasons:
# 1. The model might not be trained well enough to generate EOS tokens immediately after RESPONSE_SUFFIX.
Expand All @@ -54,13 +155,22 @@ def _ensure_response_suffix(self, inference_kwargs: dict[str, Any]) -> dict[str,
from vllm import SamplingParams

inference_kwargs["sampling_params"] = SamplingParams(stop=[RESPONSE_SUFFIX])
elif (
isinstance(inference_kwargs["sampling_params"].stop, list)
and RESPONSE_SUFFIX not in inference_kwargs["sampling_params"].stop
):
inference_kwargs["sampling_params"].stop.append(RESPONSE_SUFFIX)
elif isinstance(inference_kwargs["sampling_params"], list):
for sampling_params in inference_kwargs["sampling_params"]:
self._ensure_sampling_params_response_suffix(sampling_params)
else:
self._ensure_sampling_params_response_suffix(inference_kwargs["sampling_params"])
return inference_kwargs

def _ensure_sampling_params_response_suffix(self, sampling_params: "SamplingParams") -> None:
if sampling_params.stop is None:
sampling_params.stop = [RESPONSE_SUFFIX]
elif isinstance(sampling_params.stop, str):
if sampling_params.stop != RESPONSE_SUFFIX:
sampling_params.stop = [sampling_params.stop, RESPONSE_SUFFIX]
elif RESPONSE_SUFFIX not in sampling_params.stop:
sampling_params.stop.append(RESPONSE_SUFFIX)


def from_vllm_offline(model: "LLM") -> VLLMOffline:
return VLLMOffline(model)
Loading
Loading