|
1 | 1 | import httpx |
| 2 | +from anthropic import NotGiven |
| 3 | +from anthropic.types import ThinkingConfigEnabledParam |
| 4 | +from anthropic.types.beta import ( |
| 5 | + BetaTextBlockParam, |
| 6 | + BetaThinkingConfigParam, |
| 7 | + BetaToolChoiceParam, |
| 8 | + BetaToolUnionParam, |
| 9 | +) |
| 10 | +from pydantic import BaseModel, ConfigDict |
2 | 11 | from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential |
3 | | -from typing_extensions import override |
| 12 | +from typing_extensions import Literal, override |
4 | 13 |
|
5 | 14 | from askui.models.askui.settings import AskUiComputerAgentSettings |
6 | 15 | from askui.models.shared.computer_agent import ComputerAgent |
|
10 | 19 |
|
11 | 20 | from ...logger import logger |
12 | 21 |
|
| 22 | +NOT_GIVEN = NotGiven() |
| 23 | + |
| 24 | + |
| 25 | +class RequestBody(BaseModel): |
| 26 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 27 | + max_tokens: int |
| 28 | + messages: list[MessageParam] |
| 29 | + provider: Literal["gcp_vertex"] = "gcp_vertex" |
| 30 | + model: str |
| 31 | + tools: list[BetaToolUnionParam] |
| 32 | + betas: list[str] |
| 33 | + system: list[BetaTextBlockParam] |
| 34 | + thinking: BetaThinkingConfigParam | NotGiven = NOT_GIVEN |
| 35 | + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN |
| 36 | + |
13 | 37 |
|
14 | 38 | def is_retryable_error(exception: BaseException) -> bool: |
15 | 39 | """Check if the exception is a retryable error (status codes 429 or 529).""" |
@@ -47,20 +71,26 @@ def _create_message( |
47 | 71 | model_choice: str, # noqa: ARG002 |
48 | 72 | ) -> MessageParam: |
49 | 73 | try: |
50 | | - request_body = { |
51 | | - "max_tokens": self._settings.max_tokens, |
52 | | - "messages": [msg.model_dump(mode="json") for msg in messages], |
53 | | - "model": self._settings.model, |
54 | | - "tools": self._tool_collection.to_params(), |
55 | | - "betas": self._settings.betas, |
56 | | - "system": [self._system], |
57 | | - } |
| 74 | + request_body = RequestBody( |
| 75 | + max_tokens=self._settings.max_tokens, |
| 76 | + messages=messages, |
| 77 | + model=self._settings.model, |
| 78 | + tools=self._tool_collection.to_params(), |
| 79 | + betas=self._settings.betas, |
| 80 | + system=[self._system], |
| 81 | + tool_choice=self._settings.tool_choice, |
| 82 | + ) |
| 83 | + if self._settings.thinking: |
| 84 | + request_body.thinking = ThinkingConfigEnabledParam( |
| 85 | + budget_tokens=self._settings.thinking.budget_tokens, |
| 86 | + type="enabled", |
| 87 | + ) |
| 88 | + |
58 | 89 | response = self._client.post( |
59 | | - "/act/inference", json=request_body, timeout=300.0 |
| 90 | + "/act/inference", json=request_body.model_dump(), timeout=300.0 |
60 | 91 | ) |
61 | 92 | response.raise_for_status() |
62 | | - response_data = response.json() |
63 | | - return MessageParam.model_validate(response_data) |
| 93 | + return MessageParam.model_validate(response.json()) |
64 | 94 | except Exception as e: # noqa: BLE001 |
65 | 95 | if is_retryable_error(e): |
66 | 96 | logger.debug(e) |
|
0 commit comments