diff --git a/src/trio_core/api/routers/inference.py b/src/trio_core/api/routers/inference.py index 620cb15..75bf69d 100644 --- a/src/trio_core/api/routers/inference.py +++ b/src/trio_core/api/routers/inference.py @@ -194,6 +194,14 @@ class DescribeRequest(BaseModel): "swap models per-request and will log a warning, then ignore." ), ) + extra_body: dict | None = Field( + default=None, + description=( + "Backend-specific kwargs forwarded as the OpenAI SDK extra_body " + "(e.g. DashScope's enable_thinking). Honored by RemoteHTTPBackend; " + "ignored by local backends." + ), + ) class DescribeResponse(BaseModel): @@ -590,6 +598,7 @@ def _sync_describe(): max_tokens=req.max_tokens, response_format=req.response_format, model=req.model, + extra_body=req.extra_body, ) try: diff --git a/src/trio_core/backends/base.py b/src/trio_core/backends/base.py index 930a908..280b03f 100644 --- a/src/trio_core/backends/base.py +++ b/src/trio_core/backends/base.py @@ -114,6 +114,7 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: """Run inference on video frames. @@ -127,6 +128,9 @@ def generate( Honored by ``RemoteHTTPBackend`` (passed through to the upstream chat.completions call); local backends cannot swap models per-request and log a one-shot warning, then ignore. + extra_body: Backend-specific kwargs forwarded as the OpenAI + SDK ``extra_body`` (e.g. DashScope's ``enable_thinking``). + Honored by ``RemoteHTTPBackend``; ignored by local backends. Returns: GenerationResult with text and metrics. @@ -144,6 +148,7 @@ def stream_generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> Generator[StreamChunk, None, None]: """Stream inference token by token.""" ... diff --git a/src/trio_core/backends/compressed.py b/src/trio_core/backends/compressed.py index ac907f8..6ab9ad9 100644 --- a/src/trio_core/backends/compressed.py +++ b/src/trio_core/backends/compressed.py @@ -59,8 +59,9 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: - del response_format # remote-only spec; ignored by compressed local backend + del response_format, extra_body # remote-only specs; ignored by compressed local backend self._warn_model_override_once(model) tic = time.perf_counter() y, prompt_cache, prompt_token_count = self._custom_prefill( diff --git a/src/trio_core/backends/mlx.py b/src/trio_core/backends/mlx.py index 4b606a4..ea0acc5 100644 --- a/src/trio_core/backends/mlx.py +++ b/src/trio_core/backends/mlx.py @@ -499,10 +499,11 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: - # response_format is a remote-only structured-output spec; local - # MLX inference doesn't honor it — ignored. - del response_format + # response_format / extra_body are remote-only specs; local MLX + # inference doesn't honor them — ignored. + del response_format, extra_body self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) input_ids = kwargs.pop("input_ids") diff --git a/src/trio_core/backends/remote.py b/src/trio_core/backends/remote.py index 3710a4d..d954929 100644 --- a/src/trio_core/backends/remote.py +++ b/src/trio_core/backends/remote.py @@ -97,6 +97,7 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: t0 = time.monotonic() uris = _frames_to_data_uris(frames) @@ -118,6 +119,8 @@ def generate( extra_kwargs: dict[str, object] = {} if response_format is not None: extra_kwargs["response_format"] = response_format + if extra_body is not None: + extra_kwargs["extra_body"] = extra_body effective_model = model or self._remote_model response = self._client.chat.completions.create( @@ -170,6 +173,7 @@ def stream_generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> Generator[StreamChunk, None, None]: """Stream from remote VLM API. diff --git a/src/trio_core/backends/tome.py b/src/trio_core/backends/tome.py index 335ee65..fc12009 100644 --- a/src/trio_core/backends/tome.py +++ b/src/trio_core/backends/tome.py @@ -228,9 +228,10 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: """Run inference with ToMe-compressed vision tokens.""" - del response_format # remote-only spec; ignored by ToMe local backend + del response_format, extra_body # remote-only specs; ignored by ToMe local backend self._warn_model_override_once(model) formatted, kwargs = self._prepare(frames, prompt) diff --git a/src/trio_core/backends/transformers.py b/src/trio_core/backends/transformers.py index 8968588..fe2d794 100644 --- a/src/trio_core/backends/transformers.py +++ b/src/trio_core/backends/transformers.py @@ -60,10 +60,11 @@ def generate( top_p: float = 1.0, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> GenerationResult: import torch - del response_format # remote-only spec; ignored by local transformers backend + del response_format, extra_body # remote-only specs; ignored by local transformers backend self._warn_model_override_once(model) inputs = self._prepare(frames, prompt) diff --git a/src/trio_core/engine.py b/src/trio_core/engine.py index cc6ef0c..c7a2a46 100644 --- a/src/trio_core/engine.py +++ b/src/trio_core/engine.py @@ -243,6 +243,7 @@ def analyze_video( temperature: float | None = None, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> VideoResult: """Analyze a video with the loaded VLM. @@ -315,6 +316,7 @@ def analyze_video( top_p=self.config.top_p, response_format=response_format, model=model, + extra_body=extra_body, ) # ── Phase 3: Postprocess ───────────────────────────────────────── @@ -433,6 +435,7 @@ def analyze_frame( temperature: float | None = None, response_format: dict | None = None, model: str | None = None, + extra_body: dict | None = None, ) -> VideoResult: """Analyze a single frame (image). Convenience wrapper.""" if frame.ndim == 3: @@ -446,6 +449,7 @@ def analyze_frame( temperature=temperature, response_format=response_format, model=model, + extra_body=extra_body, ) def health(self) -> dict[str, Any]: