Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/trio_core/api/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ class DescribeRequest(BaseModel):
"VLM backends. Ignored by local backends."
),
)
max_tokens: int | None = Field(
default=None,
ge=1,
le=16384,
description=(
"Maximum tokens to generate. None falls back to engine default. "
"Structured-output prompts (e.g. JSON-schema) need a higher value "
"than free-form description."
),
)
model: str | None = Field(
default=None,
max_length=128,
description=(
"Override the server-default VLM model for this request. Forwarded "
"to remote backends (the model name is passed straight to the "
"OpenAI-compatible chat.completions call). Local backends cannot "
"swap models per-request and will log a warning, then ignore."
),
)


class DescribeResponse(BaseModel):
Expand Down Expand Up @@ -209,6 +229,14 @@ class CropDescribeRequest(BaseModel):
"VLM backends. Ignored by local backends."
),
)
model: str | None = Field(
default=None,
max_length=128,
description=(
"Override the server-default VLM model for this request. Forwarded "
"to remote backends; ignored (with a warning) by local backends."
),
)


class CropDescribeResponse(BaseModel):
Expand Down Expand Up @@ -556,7 +584,13 @@ def _sync_describe():
frame = _decode_image(req.image_b64)
frame_chw = _frame_to_chw(frame)
engine = _get_vlm()
return engine.analyze_frame(frame_chw, req.prompt, response_format=req.response_format)
return engine.analyze_frame(
frame_chw,
req.prompt,
max_tokens=req.max_tokens,
response_format=req.response_format,
model=req.model,
)

try:
result = await loop.run_in_executor(None, _sync_describe)
Expand Down Expand Up @@ -638,6 +672,7 @@ async def _crop_describe_inner(req: CropDescribeRequest):
frame_chw,
scene_prompt,
response_format=req.response_format,
model=req.model,
),
)
text = _strip_thinking(result.text or "")
Expand Down
29 changes: 29 additions & 0 deletions src/trio_core/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ def __init__(
# remote backends override with contextlib.nullcontext() so multiple
# HTTP calls run in parallel.
self._lock: AbstractContextManager[bool] = threading.Lock()
self._model_override_warned = False

def _warn_model_override_once(self, requested: str | None) -> None:
"""Local backends call this when a per-request ``model`` is set.

Local GPU backends load one model at startup and cannot swap per
request, so the override is ignored. We log a single warning per
backend instance to avoid log spam in production paths that pass a
model name on every call.
"""
if not requested or requested == self.model_name:
return
if self._model_override_warned:
return
self._model_override_warned = True
logger.warning(
"%s backend cannot swap model per request; ignoring model=%r and "
"using loaded model %r. (Per-request override is honored only by "
"RemoteHTTPBackend.)",
self.backend_name,
requested,
self.model_name,
)

@property
def loaded(self) -> bool:
Expand All @@ -90,6 +113,7 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
"""Run inference on video frames.

Expand All @@ -99,6 +123,10 @@ def generate(
response_format: OpenAI-compatible structured-output spec
(e.g. ``{"type": "json_schema", "json_schema": {...}}``).
Honored by remote backends; ignored by local backends.
model: Per-request override of the backend's configured model.
Honored by ``RemoteHTTPBackend`` (passed through to the
upstream chat.completions call); local backends cannot swap
models per-request and log a one-shot warning, then ignore.

Returns:
GenerationResult with text and metrics.
Expand All @@ -115,6 +143,7 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator[StreamChunk, None, None]:
"""Stream inference token by token."""
...
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/backends/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
del response_format # remote-only spec; ignored by compressed local backend
self._warn_model_override_once(model)
tic = time.perf_counter()
y, prompt_cache, prompt_token_count = self._custom_prefill(
frames,
Expand Down Expand Up @@ -215,9 +217,11 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator[StreamChunk, None, None]:
"""Real token-by-token streaming with compressed prefill."""
del response_format
self._warn_model_override_once(model)
y, prompt_cache, prompt_token_count = self._custom_prefill(
frames,
prompt,
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/backends/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,12 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
# response_format is a remote-only structured-output spec; local
# MLX inference doesn't honor it — ignored.
del response_format
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)
input_ids = kwargs.pop("input_ids")
pixel_values = kwargs.pop("pixel_values")
Expand All @@ -525,8 +527,10 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator[StreamChunk, None, None]:
del response_format
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)
input_ids = kwargs.pop("input_ids")
pixel_values = kwargs.pop("pixel_values")
Expand Down
9 changes: 7 additions & 2 deletions src/trio_core/backends/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
t0 = time.monotonic()
uris = _frames_to_data_uris(frames)
Expand All @@ -118,8 +119,9 @@ def generate(
if response_format is not None:
extra_kwargs["response_format"] = response_format

effective_model = model or self._remote_model
response = self._client.chat.completions.create(
model=self._remote_model,
model=effective_model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
Expand All @@ -140,12 +142,13 @@ def generate(
gen_tps = completion_tokens / max(elapsed, 1e-9)

logger.info(
"[Remote] generate: %d frames, %d+%d tokens, %.1f tps, %.0fms",
"[Remote] generate: %d frames, %d+%d tokens, %.1f tps, %.0fms (model=%s)",
frames.shape[0],
prompt_tokens,
completion_tokens,
gen_tps,
elapsed * 1000,
effective_model,
)

return GenerationResult(
Expand All @@ -166,6 +169,7 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator[StreamChunk, None, None]:
"""Stream from remote VLM API.

Expand All @@ -180,6 +184,7 @@ def stream_generate(
temperature=temperature,
top_p=top_p,
response_format=response_format,
model=model,
)
yield StreamChunk(
text=result.text,
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/backends/tome.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,11 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
"""Run inference with ToMe-compressed vision tokens."""
del response_format # remote-only spec; ignored by ToMe local backend
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)

input_ids = kwargs.pop("input_ids")
Expand Down Expand Up @@ -263,8 +265,10 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator:
del response_format
self._warn_model_override_once(model)
formatted, kwargs = self._prepare(frames, prompt)

input_ids = kwargs.pop("input_ids")
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/backends/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ def generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> GenerationResult:
import torch

del response_format # remote-only spec; ignored by local transformers backend
self._warn_model_override_once(model)
inputs = self._prepare(frames, prompt)

t0 = time.monotonic()
Expand Down Expand Up @@ -99,12 +101,14 @@ def stream_generate(
temperature: float = 0.0,
top_p: float = 1.0,
response_format: dict | None = None,
model: str | None = None,
) -> Generator[StreamChunk, None, None]:
import threading

from transformers import TextIteratorStreamer

del response_format
self._warn_model_override_once(model)
inputs = self._prepare(frames, prompt)

streamer = TextIteratorStreamer(self._processor, skip_prompt=True, skip_special_tokens=True)
Expand Down
4 changes: 4 additions & 0 deletions src/trio_core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def analyze_video(
max_tokens: int | None = None,
temperature: float | None = None,
response_format: dict | None = None,
model: str | None = None,
) -> VideoResult:
"""Analyze a video with the loaded VLM.

Expand Down Expand Up @@ -313,6 +314,7 @@ def analyze_video(
temperature=temperature,
top_p=self.config.top_p,
response_format=response_format,
model=model,
)

# ── Phase 3: Postprocess ─────────────────────────────────────────
Expand Down Expand Up @@ -430,6 +432,7 @@ def analyze_frame(
max_tokens: int | None = None,
temperature: float | None = None,
response_format: dict | None = None,
model: str | None = None,
) -> VideoResult:
"""Analyze a single frame (image). Convenience wrapper."""
if frame.ndim == 3:
Expand All @@ -442,6 +445,7 @@ def analyze_frame(
max_tokens=max_tokens,
temperature=temperature,
response_format=response_format,
model=model,
)

def health(self) -> dict[str, Any]:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,50 @@ def test_generate_omits_response_format_when_none(self):
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "response_format" not in call_kwargs

def test_generate_uses_per_request_model_override(self):
"""When ``model`` is passed to generate(), the OpenAI call uses it.

This is what lets the cortex client route VLM and segmentation
requests to different upstream models without a server reload.
"""
from trio_core.backends.remote import RemoteHTTPBackend

b = RemoteHTTPBackend(url="https://api.example.com/v1", model="qwen-vl-plus")
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "ok"
mock_response.usage = None
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
b._client = mock_client
b._loaded = True

frames = np.random.rand(1, 3, 64, 64).astype(np.float32)
b.generate(frames, "test", model="qwen3.6-plus")

call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert call_kwargs["model"] == "qwen3.6-plus"

def test_generate_falls_back_to_configured_model(self):
"""When ``model`` is omitted, the OpenAI call uses the configured one."""
from trio_core.backends.remote import RemoteHTTPBackend

b = RemoteHTTPBackend(url="https://api.example.com/v1", model="qwen-vl-plus")
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "ok"
mock_response.usage = None
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
b._client = mock_client
b._loaded = True

frames = np.random.rand(1, 3, 64, 64).astype(np.float32)
b.generate(frames, "test") # no model override

call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert call_kwargs["model"] == "qwen-vl-plus"

def test_generate_handles_empty_content(self):
from trio_core.backends.remote import RemoteHTTPBackend

Expand Down
Loading