Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/trio_core/api/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ class DescribeRequest(BaseModel):
"swap models per-request and will log a warning, then ignore."
),
)
extra_body: dict | None = Field(
default=None,
description=(
"Backend-specific kwargs forwarded as the OpenAI SDK extra_body "
"(e.g. DashScope's enable_thinking). Honored by RemoteHTTPBackend; "
"ignored by local backends."
),
)


class DescribeResponse(BaseModel):
Expand Down Expand Up @@ -590,6 +598,7 @@ def _sync_describe():
max_tokens=req.max_tokens,
response_format=req.response_format,
model=req.model,
extra_body=req.extra_body,
)

try:
Expand Down
5 changes: 5 additions & 0 deletions src/trio_core/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
"""Run inference on video frames.

Expand All @@ -127,6 +128,9 @@ def generate(
Honored by ``RemoteHTTPBackend`` (passed through to the
upstream chat.completions call); local backends cannot swap
models per-request and log a one-shot warning, then ignore.
extra_body: Backend-specific kwargs forwarded as the OpenAI
SDK ``extra_body`` (e.g. DashScope's ``enable_thinking``).
Honored by ``RemoteHTTPBackend``; ignored by local backends.

Returns:
GenerationResult with text and metrics.
Expand All @@ -144,6 +148,7 @@ def stream_generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> Generator[StreamChunk, None, None]:
"""Stream inference token by token."""
...
Expand Down
3 changes: 2 additions & 1 deletion src/trio_core/backends/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
del response_format # remote-only spec; ignored by compressed local backend
del response_format, extra_body # remote-only specs; ignored by compressed local backend
self._warn_model_override_once(model)
tic = time.perf_counter()
y, prompt_cache, prompt_token_count = self._custom_prefill(
Expand Down
7 changes: 4 additions & 3 deletions src/trio_core/backends/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,11 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
# response_format is a remote-only structured-output spec; local
# MLX inference doesn't honor it — ignored.
del response_format
# response_format / extra_body are remote-only specs; local MLX
# inference doesn't honor them — ignored.
del response_format, extra_body
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)
input_ids = kwargs.pop("input_ids")
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/backends/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
t0 = time.monotonic()
uris = _frames_to_data_uris(frames)
Expand All @@ -118,6 +119,8 @@ def generate(
extra_kwargs: dict[str, object] = {}
if response_format is not None:
extra_kwargs["response_format"] = response_format
if extra_body is not None:
extra_kwargs["extra_body"] = extra_body

effective_model = model or self._remote_model
response = self._client.chat.completions.create(
Expand Down Expand Up @@ -170,6 +173,7 @@ def stream_generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> Generator[StreamChunk, None, None]:
"""Stream from remote VLM API.

Expand Down
3 changes: 2 additions & 1 deletion src/trio_core/backends/tome.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
"""Run inference with ToMe-compressed vision tokens."""
del response_format # remote-only spec; ignored by ToMe local backend
del response_format, extra_body # remote-only specs; ignored by ToMe local backend
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)

Expand Down
3 changes: 2 additions & 1 deletion src/trio_core/backends/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def generate(
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> GenerationResult:
import torch

del response_format # remote-only spec; ignored by local transformers backend
del response_format, extra_body # remote-only specs; ignored by local transformers backend
self._warn_model_override_once(model)
inputs = self._prepare(frames, prompt)

Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def analyze_video(
temperature: float | None = None,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> VideoResult:
"""Analyze a video with the loaded VLM.

Expand Down Expand Up @@ -315,6 +316,7 @@ def analyze_video(
top_p=self.config.top_p,
response_format=response_format,
model=model,
extra_body=extra_body,
)

# ── Phase 3: Postprocess ─────────────────────────────────────────
Expand Down Expand Up @@ -433,6 +435,7 @@ def analyze_frame(
temperature: float | None = None,
response_format: dict | None = None,
model: str | None = None,
extra_body: dict | None = None,
) -> VideoResult:
"""Analyze a single frame (image). Convenience wrapper."""
if frame.ndim == 3:
Expand All @@ -446,6 +449,7 @@ def analyze_frame(
temperature=temperature,
response_format=response_format,
model=model,
extra_body=extra_body,
)

def health(self) -> dict[str, Any]:
Expand Down
Loading