diff --git a/src/trio_core/api/routers/inference.py b/src/trio_core/api/routers/inference.py index 33d641b..620cb15 100644 --- a/src/trio_core/api/routers/inference.py +++ b/src/trio_core/api/routers/inference.py @@ -174,6 +174,26 @@ class DescribeRequest(BaseModel): "VLM backends. Ignored by local backends." ), ) + max_tokens: int | None = Field( + default=None, + ge=1, + le=16384, + description=( + "Maximum tokens to generate. None falls back to engine default. " + "Structured-output prompts (e.g. JSON-schema) need a higher value " + "than free-form description." + ), + ) + model: str | None = Field( + default=None, + max_length=128, + description=( + "Override the server-default VLM model for this request. Forwarded " + "to remote backends (the model name is passed straight to the " + "OpenAI-compatible chat.completions call). Local backends cannot " + "swap models per-request and will log a warning, then ignore." + ), + ) class DescribeResponse(BaseModel): @@ -209,6 +229,14 @@ class CropDescribeRequest(BaseModel): "VLM backends. Ignored by local backends." ), ) + model: str | None = Field( + default=None, + max_length=128, + description=( + "Override the server-default VLM model for this request. Forwarded " + "to remote backends; ignored (with a warning) by local backends." + ), + ) class CropDescribeResponse(BaseModel): @@ -556,7 +584,13 @@ def _sync_describe(): frame = _decode_image(req.image_b64) frame_chw = _frame_to_chw(frame) engine = _get_vlm() - return engine.analyze_frame(frame_chw, req.prompt, response_format=req.response_format) + return engine.analyze_frame( + frame_chw, + req.prompt, + max_tokens=req.max_tokens, + response_format=req.response_format, + model=req.model, + ) try: result = await loop.run_in_executor(None, _sync_describe) @@ -638,6 +672,7 @@ async def _crop_describe_inner(req: CropDescribeRequest): frame_chw, scene_prompt, response_format=req.response_format, + model=req.model, ), ) text = _strip_thinking(result.text or "") diff --git a/src/trio_core/backends/base.py b/src/trio_core/backends/base.py index d48a2d1..930a908 100644 --- a/src/trio_core/backends/base.py +++ b/src/trio_core/backends/base.py @@ -64,6 +64,29 @@ def __init__( # remote backends override with contextlib.nullcontext() so multiple # HTTP calls run in parallel. self._lock: AbstractContextManager[bool] = threading.Lock() + self._model_override_warned = False + + def _warn_model_override_once(self, requested: str | None) -> None: + """Local backends call this when a per-request ``model`` is set. + + Local GPU backends load one model at startup and cannot swap per + request, so the override is ignored. We log a single warning per + backend instance to avoid log spam in production paths that pass a + model name on every call. + """ + if not requested or requested == self.model_name: + return + if self._model_override_warned: + return + self._model_override_warned = True + logger.warning( + "%s backend cannot swap model per request; ignoring model=%r and " + "using loaded model %r. (Per-request override is honored only by " + "RemoteHTTPBackend.)", + self.backend_name, + requested, + self.model_name, + ) @property def loaded(self) -> bool: @@ -90,6 +113,7 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: """Run inference on video frames. @@ -99,6 +123,10 @@ def generate( response_format: OpenAI-compatible structured-output spec (e.g. ``{"type": "json_schema", "json_schema": {...}}``). Honored by remote backends; ignored by local backends. + model: Per-request override of the backend's configured model. + Honored by ``RemoteHTTPBackend`` (passed through to the + upstream chat.completions call); local backends cannot swap + models per-request and log a one-shot warning, then ignore. Returns: GenerationResult with text and metrics. @@ -115,6 +143,7 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator[StreamChunk, None, None]: """Stream inference token by token.""" ... diff --git a/src/trio_core/backends/compressed.py b/src/trio_core/backends/compressed.py index e280155..ac907f8 100644 --- a/src/trio_core/backends/compressed.py +++ b/src/trio_core/backends/compressed.py @@ -58,8 +58,10 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: del response_format # remote-only spec; ignored by compressed local backend + self._warn_model_override_once(model) tic = time.perf_counter() y, prompt_cache, prompt_token_count = self._custom_prefill( frames, @@ -215,9 +217,11 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator[StreamChunk, None, None]: """Real token-by-token streaming with compressed prefill.""" del response_format + self._warn_model_override_once(model) y, prompt_cache, prompt_token_count = self._custom_prefill( frames, prompt, diff --git a/src/trio_core/backends/mlx.py b/src/trio_core/backends/mlx.py index a6ad05e..4b606a4 100644 --- a/src/trio_core/backends/mlx.py +++ b/src/trio_core/backends/mlx.py @@ -498,10 +498,12 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: # response_format is a remote-only structured-output spec; local # MLX inference doesn't honor it — ignored. del response_format + self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) input_ids = kwargs.pop("input_ids") pixel_values = kwargs.pop("pixel_values") @@ -525,8 +527,10 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator[StreamChunk, None, None]: del response_format + self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) input_ids = kwargs.pop("input_ids") pixel_values = kwargs.pop("pixel_values") diff --git a/src/trio_core/backends/remote.py b/src/trio_core/backends/remote.py index 53a5695..3710a4d 100644 --- a/src/trio_core/backends/remote.py +++ b/src/trio_core/backends/remote.py @@ -96,6 +96,7 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: t0 = time.monotonic() uris = _frames_to_data_uris(frames) @@ -118,8 +119,9 @@ def generate( if response_format is not None: extra_kwargs["response_format"] = response_format + effective_model = model or self._remote_model response = self._client.chat.completions.create( - model=self._remote_model, + model=effective_model, messages=messages, max_tokens=max_tokens, temperature=temperature, @@ -140,12 +142,13 @@ def generate( gen_tps = completion_tokens / max(elapsed, 1e-9) logger.info( - "[Remote] generate: %d frames, %d+%d tokens, %.1f tps, %.0fms", + "[Remote] generate: %d frames, %d+%d tokens, %.1f tps, %.0fms (model=%s)", frames.shape[0], prompt_tokens, completion_tokens, gen_tps, elapsed * 1000, + effective_model, ) return GenerationResult( @@ -166,6 +169,7 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator[StreamChunk, None, None]: """Stream from remote VLM API. @@ -180,6 +184,7 @@ def stream_generate( temperature=temperature, top_p=top_p, response_format=response_format, + model=model, ) yield StreamChunk( text=result.text, diff --git a/src/trio_core/backends/tome.py b/src/trio_core/backends/tome.py index 9b4a836..335ee65 100644 --- a/src/trio_core/backends/tome.py +++ b/src/trio_core/backends/tome.py @@ -227,9 +227,11 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: """Run inference with ToMe-compressed vision tokens.""" del response_format # remote-only spec; ignored by ToMe local backend + self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) input_ids = kwargs.pop("input_ids") @@ -263,8 +265,10 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator: del response_format + self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) input_ids = kwargs.pop("input_ids") diff --git a/src/trio_core/backends/transformers.py b/src/trio_core/backends/transformers.py index 9acefc9..8968588 100644 --- a/src/trio_core/backends/transformers.py +++ b/src/trio_core/backends/transformers.py @@ -59,10 +59,12 @@ def generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> GenerationResult: import torch del response_format # remote-only spec; ignored by local transformers backend + self._warn_model_override_once(model) inputs = self._prepare(frames, prompt) t0 = time.monotonic() @@ -99,12 +101,14 @@ def stream_generate( temperature: float = 0.0, top_p: float = 1.0, response_format: dict | None = None, + model: str | None = None, ) -> Generator[StreamChunk, None, None]: import threading from transformers import TextIteratorStreamer del response_format + self._warn_model_override_once(model) inputs = self._prepare(frames, prompt) streamer = TextIteratorStreamer(self._processor, skip_prompt=True, skip_special_tokens=True) diff --git a/src/trio_core/engine.py b/src/trio_core/engine.py index 36726d7..cc6ef0c 100644 --- a/src/trio_core/engine.py +++ b/src/trio_core/engine.py @@ -242,6 +242,7 @@ def analyze_video( max_tokens: int | None = None, temperature: float | None = None, response_format: dict | None = None, + model: str | None = None, ) -> VideoResult: """Analyze a video with the loaded VLM. @@ -313,6 +314,7 @@ def analyze_video( temperature=temperature, top_p=self.config.top_p, response_format=response_format, + model=model, ) # ── Phase 3: Postprocess ───────────────────────────────────────── @@ -430,6 +432,7 @@ def analyze_frame( max_tokens: int | None = None, temperature: float | None = None, response_format: dict | None = None, + model: str | None = None, ) -> VideoResult: """Analyze a single frame (image). Convenience wrapper.""" if frame.ndim == 3: @@ -442,6 +445,7 @@ def analyze_frame( max_tokens=max_tokens, temperature=temperature, response_format=response_format, + model=model, ) def health(self) -> dict[str, Any]: diff --git a/tests/test_backends.py b/tests/test_backends.py index ed646bd..1525296 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -269,6 +269,50 @@ def test_generate_omits_response_format_when_none(self): call_kwargs = mock_client.chat.completions.create.call_args.kwargs assert "response_format" not in call_kwargs + def test_generate_uses_per_request_model_override(self): + """When ``model`` is passed to generate(), the OpenAI call uses it. + + This is what lets the cortex client route VLM and segmentation + requests to different upstream models without a server reload. + """ + from trio_core.backends.remote import RemoteHTTPBackend + + b = RemoteHTTPBackend(url="https://api.example.com/v1", model="qwen-vl-plus") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + mock_response.usage = None + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + b._client = mock_client + b._loaded = True + + frames = np.random.rand(1, 3, 64, 64).astype(np.float32) + b.generate(frames, "test", model="qwen3.6-plus") + + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert call_kwargs["model"] == "qwen3.6-plus" + + def test_generate_falls_back_to_configured_model(self): + """When ``model`` is omitted, the OpenAI call uses the configured one.""" + from trio_core.backends.remote import RemoteHTTPBackend + + b = RemoteHTTPBackend(url="https://api.example.com/v1", model="qwen-vl-plus") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + mock_response.usage = None + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + b._client = mock_client + b._loaded = True + + frames = np.random.rand(1, 3, 64, 64).astype(np.float32) + b.generate(frames, "test") # no model override + + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert call_kwargs["model"] == "qwen-vl-plus" + def test_generate_handles_empty_content(self): from trio_core.backends.remote import RemoteHTTPBackend