From 5898e8970b1ac6263b336a3b5f7cda4020e14a7d Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 30 Jul 2025 12:25:01 +0200 Subject: [PATCH 1/2] feat(models)!: based get model on google genai api - `askui` model now uses `gemini-2.5-flash` as default model falling back to original `askui` model (Inference API's VQA endpoint) if the Google GenAI API fails, e.g., because of missing support of schema or for unknown reason. For example, Google GenAI API does not support recursive schemas at the moment. - `askui/gemini-2.5-flash` and `askui/gemini-2.5-pro` are now supported as model choices. - We are using an AskUI hosted VertexAI proxy for the Google GenAI API to ensure compliance, e.g., only EU hosting. BREAKING CHANGE: - The `askui`/default model for `AgentBase.get()` changed. --- README.md | 6 +- pdm.lock | 88 ++++++++++++++++++++- pyproject.toml | 1 + src/askui/models/askui/get_model.py | 72 +++++++++++++++++ src/askui/models/askui/google_genai_api.py | 92 ++++++++++++++++++++++ src/askui/models/model_router.py | 17 +++- src/askui/models/models.py | 4 + src/askui/utils/image_utils.py | 10 +++ tests/e2e/agent/conftest.py | 7 +- tests/e2e/agent/test_get.py | 19 +++-- 10 files changed, 301 insertions(+), 15 deletions(-) create mode 100644 src/askui/models/askui/get_model.py create mode 100644 src/askui/models/askui/google_genai_api.py diff --git a/README.md b/README.md index 8fcf51f6..297f0e0a 100644 --- a/README.md +++ b/README.md @@ -731,10 +731,8 @@ with VisionAgent() as agent: ``` **⚠️ Limitations:** -- Not all models support response schemas or all kinds of properties that a response schema can have at the moment -- Default values are not supported, e.g., `url: str = "github.com"` or `url: str | None = None`. This includes `default_factory` - and `default` args of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or - `url: str = Field(default_factory=lambda: "github.com")`. +- The support for response schemas varies among models. Currently, the `askui` model provides best support for response schemas + as we try different models under the hood with your schema to see which one works best. ## What is AskUI Vision Agent? diff --git a/pdm.lock b/pdm.lock index 6e92459a..94ccb76e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "chat", "dev", "mcp", "pynput", "test", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1798fa81b5af10196977271587aae133fd763fe780218adcece9662c41e65e32" +content_hash = "sha256:02ec481499e740ea23f85f2490af2f9670ce24ac16614a1dc29c2f5db68fdba0" [[metadata.targets]] requires_python = ">=3.10" @@ -144,6 +144,17 @@ files = [ {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, ] +[[package]] +name = "cachetools" +version = "5.5.2" +requires_python = ">=3.7" +summary = "Extensible memoizing collections and decorators" +groups = ["default"] +files = [ + {file = "cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a"}, + {file = "cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4"}, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -689,6 +700,42 @@ files = [ {file = "genson-1.3.0.tar.gz", hash = "sha256:e02db9ac2e3fd29e65b5286f7135762e2cd8a986537c075b06fc5f1517308e37"}, ] +[[package]] +name = "google-auth" +version = "2.40.3" +requires_python = ">=3.7" +summary = "Google Authentication Library" +groups = ["default"] +dependencies = [ + "cachetools<6.0,>=2.0.0", + "pyasn1-modules>=0.2.1", + "rsa<5,>=3.1.4", +] +files = [ + {file = "google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca"}, + {file = "google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77"}, +] + +[[package]] +name = "google-genai" +version = "1.20.0" +requires_python = ">=3.9" +summary = "GenAI Python SDK" +groups = ["default"] +dependencies = [ + "anyio<5.0.0,>=4.8.0", + "google-auth<3.0.0,>=2.14.1", + "httpx<1.0.0,>=0.28.1", + "pydantic<3.0.0,>=2.0.0", + "requests<3.0.0,>=2.28.1", + "typing-extensions<5.0.0,>=4.11.0", + "websockets<15.1.0,>=13.0.0", +] +files = [ + {file = "google_genai-1.20.0-py3-none-any.whl", hash = "sha256:ccd61d6ebcb14f5c778b817b8010e3955ae4f6ddfeaabf65f42f6d5e3e5a8125"}, + {file = "google_genai-1.20.0.tar.gz", hash = "sha256:dccca78f765233844b1bd4f1f7a2237d9a76fe6038cf9aa72c0cd991e3c107b5"}, +] + [[package]] name = "gradio-client" version = "1.8.0" @@ -1512,6 +1559,31 @@ files = [ {file = "py_machineid-0.7.0-py3-none-any.whl", hash = "sha256:3dacc322b0511383d79f1e817a2710b19bcfb820a4c7cea34aaa329775fef468"}, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +requires_python = ">=3.8" +summary = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +groups = ["default"] +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +requires_python = ">=3.8" +summary = "A collection of ASN.1-based protocols modules" +groups = ["default"] +dependencies = [ + "pyasn1<0.7.0,>=0.6.1", +] +files = [ + {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"}, + {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -2212,6 +2284,20 @@ files = [ {file = "rpds_py-0.26.0.tar.gz", hash = "sha256:20dae58a859b0906f0685642e591056f1e787f3a8b39c8e8749a45dc7d26bdb0"}, ] +[[package]] +name = "rsa" +version = "4.9.1" +requires_python = "<4,>=3.6" +summary = "Pure-Python RSA implementation" +groups = ["default"] +dependencies = [ + "pyasn1>=0.1.3", +] +files = [ + {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"}, + {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"}, +] + [[package]] name = "ruff" version = "0.11.3" diff --git a/pyproject.toml b/pyproject.toml index c699f5f1..f93a3d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tenacity>=9.1.2", "jsonref>=1.1.0", "protobuf>=6.31.1", + "google-genai>=1.20.0", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/askui/models/askui/get_model.py b/src/askui/models/askui/get_model.py new file mode 100644 index 00000000..ef753caf --- /dev/null +++ b/src/askui/models/askui/get_model.py @@ -0,0 +1,72 @@ +from typing import Type + +from google.genai.errors import ClientError +from typing_extensions import override + +from askui.logger import logger +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi +from askui.models.askui.inference_api import AskUiInferenceApi +from askui.models.exceptions import QueryNoResponseError, QueryUnexpectedResponseError +from askui.models.models import GetModel +from askui.models.types.response_schemas import ResponseSchema +from askui.utils.image_utils import ImageSource + + +class AskUiGetModel(GetModel): + """A GetModel implementation that is supposed to be as comprehensive and + powerful as possible using the available AskUi models. + + This model first attempts to use the Google GenAI API for information extraction. + If the Google GenAI API fails (e.g., no response, unexpected response, or other + errors), it falls back to using the AskUI Inference API. + + Args: + google_genai_api (AskUiGoogleGenAiApi): The Google GenAI API instance to use + as primary. + inference_api (AskUiInferenceApi): The Inference API instance to use as + fallback. + """ + + def __init__( + self, + google_genai_api: AskUiGoogleGenAiApi, + inference_api: AskUiInferenceApi, + ) -> None: + self._google_genai_api = google_genai_api + self._inference_api = inference_api + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + try: + logger.debug("Attempting to use Google GenAI API") + return self._google_genai_api.get( + query=query, + image=image, + response_schema=response_schema, + model_choice=model_choice, + ) + except ( + ClientError, + QueryNoResponseError, + QueryUnexpectedResponseError, + NotImplementedError, + ) as e: + if isinstance(e, ClientError) and e.code != 400: + raise + logger.debug( + f"Google GenAI API failed with error that may not occur with other " + f"models/apis: {e}" + ". Falling back to Inference API..." + ) + return self._inference_api.get( + query=query, + image=image, + response_schema=response_schema, + model_choice=model_choice, + ) diff --git a/src/askui/models/askui/google_genai_api.py b/src/askui/models/askui/google_genai_api.py new file mode 100644 index 00000000..f3538da5 --- /dev/null +++ b/src/askui/models/askui/google_genai_api.py @@ -0,0 +1,92 @@ +import json as json_lib +from typing import Type + +import google.genai as genai +from google.genai import types as genai_types +from pydantic import ValidationError +from typing_extensions import override + +from askui.logger import logger +from askui.models.askui.inference_api import AskUiInferenceApiSettings +from askui.models.exceptions import QueryNoResponseError, QueryUnexpectedResponseError +from askui.models.models import GetModel, ModelName +from askui.models.shared.prompts import SYSTEM_PROMPT_GET +from askui.models.types.response_schemas import ResponseSchema, to_response_schema +from askui.utils.image_utils import ImageSource + +ASKUI_MODEL_CHOICE_PREFIX = "askui/" +ASKUI_MODEL_CHOICE_PREFIX_LEN = len(ASKUI_MODEL_CHOICE_PREFIX) + + +def _extract_model_id(model_choice: str) -> str: + if model_choice == ModelName.ASKUI: + return ModelName.GEMINI__2_5__FLASH + if model_choice.startswith(ASKUI_MODEL_CHOICE_PREFIX): + return model_choice[ASKUI_MODEL_CHOICE_PREFIX_LEN:] + return model_choice + + +class AskUiGoogleGenAiApi(GetModel): + def __init__(self, settings: AskUiInferenceApiSettings | None = None) -> None: + self._settings = settings or AskUiInferenceApiSettings() + self._client = genai.Client( + vertexai=True, + api_key="Necessary", + http_options=genai_types.HttpOptions( + base_url=f"{self._settings.base_url}/proxy/vertexai", + headers={ + "Authorization": self._settings.authorization_header, + }, + ), + ) + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + try: + _response_schema = to_response_schema(response_schema) + json_schema = _response_schema.model_json_schema() + logger.debug(f"json_schema:\n{json_lib.dumps(json_schema)}") + content = genai_types.Content( + parts=[ + genai_types.Part.from_bytes( + data=image.to_bytes(), + mime_type="image/png", + ), + genai_types.Part.from_text(text=query), + ], + role="user", + ) + generate_content_response = self._client.models.generate_content( + model=f"models/{_extract_model_id(model_choice)}", + contents=content, + config={ + "response_mime_type": "application/json", + "response_schema": _response_schema, + "system_instruction": SYSTEM_PROMPT_GET, + }, + ) + json_str = generate_content_response.text + if json_str is None: + raise QueryNoResponseError( + message="No response from the model", query=query + ) + try: + return _response_schema.model_validate_json(json_str).root + except ValidationError as e: + error_message = str(e.errors()) + raise QueryUnexpectedResponseError( + message=f"Unexpected response from the model: {error_message}", + query=query, + response=json_str, + ) from e + except RecursionError as e: + error_message = ( + "Recursive response schemas are not supported by AskUiGoogleGenAiApi" + ) + raise NotImplementedError(error_message) from e diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 3429ef0a..889ae0fa 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -7,6 +7,8 @@ from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer from askui.models.anthropic.messages_api import AnthropicMessagesApi from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.askui.get_model import AskUiGetModel +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi from askui.models.askui.model_router import AskUiModelRouter from askui.models.exceptions import ModelNotFoundError, ModelTypeMismatchError from askui.models.huggingface.spaces_api import HFSpacesHandler @@ -57,6 +59,10 @@ def anthropic_facade() -> ModelFacade: locate_model=messages_api, ) + @functools.cache + def askui_google_genai_api() -> AskUiGoogleGenAiApi: + return AskUiGoogleGenAiApi() + @functools.cache def askui_inference_api() -> AskUiInferenceApi: return AskUiInferenceApi( @@ -72,6 +78,13 @@ def askui_model_router() -> AskUiModelRouter: inference_api=askui_inference_api(), ) + @functools.cache + def askui_get_model() -> AskUiGetModel: + return AskUiGetModel( + google_genai_api=askui_google_genai_api(), + inference_api=askui_inference_api(), + ) + @functools.cache def askui_facade() -> ModelFacade: computer_agent = Agent( @@ -80,7 +93,7 @@ def askui_facade() -> ModelFacade: ) return ModelFacade( act_model=computer_agent, - get_model=askui_inference_api(), + get_model=askui_get_model(), locate_model=askui_model_router(), ) @@ -93,6 +106,8 @@ def hf_spaces_handler() -> HFSpacesHandler: return { ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: anthropic_facade, ModelName.ASKUI: askui_facade, + ModelName.ASKUI__GEMINI__2_5__FLASH: askui_google_genai_api, + ModelName.ASKUI__GEMINI__2_5__PRO: askui_google_genai_api, ModelName.ASKUI__AI_ELEMENT: askui_model_router, ModelName.ASKUI__COMBO: askui_model_router, ModelName.ASKUI__OCR: askui_model_router, diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 5f1861f1..5bd65ddc 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -27,11 +27,15 @@ class ModelName(str, Enum): ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022" ASKUI = "askui" + ASKUI__GEMINI__2_5__FLASH = "askui/gemini-2.5-flash" + ASKUI__GEMINI__2_5__PRO = "askui/gemini-2.5-pro" ASKUI__AI_ELEMENT = "askui-ai-element" ASKUI__COMBO = "askui-combo" ASKUI__OCR = "askui-ocr" ASKUI__PTA = "askui-pta" CLAUDE__SONNET__4__20250514 = "claude-sonnet-4-20250514" + GEMINI__2_5__FLASH = "gemini-2.5-flash" + GEMINI__2_5__PRO = "gemini-2.5-pro" HF__SPACES__ASKUI__PTA_1 = "AskUI/PTA-1" HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B = "OS-Copilot/OS-Atlas-Base-7B" HF__SPACES__QWEN__QWEN2_VL_2B_INSTRUCT = "Qwen/Qwen2-VL-2B-Instruct" diff --git a/src/askui/utils/image_utils.py b/src/askui/utils/image_utils.py index d2ca53ea..bfefc4ac 100644 --- a/src/askui/utils/image_utils.py +++ b/src/askui/utils/image_utils.py @@ -410,6 +410,16 @@ def to_base64(self) -> str: """ return image_to_base64(image=self.root) + def to_bytes(self) -> bytes: + """Convert the image to bytes. + + Returns: + bytes: The image as bytes. + """ + img_byte_arr = io.BytesIO() + self.root.save(img_byte_arr, format="PNG") + return img_byte_arr.getvalue() + __all__ = [ "load_image", diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index fbda3329..bcc27d37 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -10,6 +10,8 @@ from askui.agent import VisionAgent from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.askui.get_model import AskUiGetModel +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi from askui.models.askui.inference_api import ( AskUiInferenceApi, AskUiInferenceApiSettings, @@ -72,7 +74,10 @@ def askui_facade( ) return ModelFacade( act_model=agent, - get_model=askui_inference_api, + get_model=AskUiGetModel( + google_genai_api=AskUiGoogleGenAiApi(), + inference_api=askui_inference_api, + ), locate_model=AskUiModelRouter(inference_api=askui_inference_api), ) diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index e174a610..679eb5a2 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -30,6 +30,8 @@ class BrowserContextResponse(ResponseSchemaBase): [ None, ModelName.ASKUI, + ModelName.ASKUI__GEMINI__2_5__FLASH, + ModelName.ASKUI__GEMINI__2_5__PRO, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ModelName.CLAUDE__SONNET__4__20250514, ], @@ -100,17 +102,18 @@ class OptionalUrlResponse(ResponseSchemaBase): url: str = "github.com" -def test_get_with_response_schema_with_default_value_with_askui_model_raises( +def test_get_with_response_schema_with_default_value( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, ) -> None: - with pytest.raises(Exception): # noqa: B017 - vision_agent.get( - "What is the current url shown in the url bar?", - image=github_login_screenshot, - response_schema=OptionalUrlResponse, - model=ModelName.ASKUI, - ) + response = vision_agent.get( + "What is the current url shown in the url bar?", + image=github_login_screenshot, + response_schema=OptionalUrlResponse, + model=ModelName.ASKUI, + ) + assert isinstance(response, OptionalUrlResponse) + assert "github.com" in response.url @pytest.mark.parametrize("model", [None, ModelName.ASKUI]) From 2dd9f545445fdbccfb9afaec865088183742bae7 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 30 Jul 2025 16:44:03 +0200 Subject: [PATCH 2/2] refactor(models): make it more obvious that we are using a dummy api key to google genai api --- src/askui/models/askui/google_genai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/askui/models/askui/google_genai_api.py b/src/askui/models/askui/google_genai_api.py index f3538da5..27af00df 100644 --- a/src/askui/models/askui/google_genai_api.py +++ b/src/askui/models/askui/google_genai_api.py @@ -31,7 +31,7 @@ def __init__(self, settings: AskUiInferenceApiSettings | None = None) -> None: self._settings = settings or AskUiInferenceApiSettings() self._client = genai.Client( vertexai=True, - api_key="Necessary", + api_key="DummyValueRequiredByGenaiClient", http_options=genai_types.HttpOptions( base_url=f"{self._settings.base_url}/proxy/vertexai", headers={