diff --git a/README.md b/README.md index 8fcf51f6..297f0e0a 100644 --- a/README.md +++ b/README.md @@ -731,10 +731,8 @@ with VisionAgent() as agent: ``` **⚠️ Limitations:** -- Not all models support response schemas or all kinds of properties that a response schema can have at the moment -- Default values are not supported, e.g., `url: str = "github.com"` or `url: str | None = None`. This includes `default_factory` - and `default` args of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or - `url: str = Field(default_factory=lambda: "github.com")`. +- The support for response schemas varies among models. Currently, the `askui` model provides best support for response schemas + as we try different models under the hood with your schema to see which one works best. ## What is AskUI Vision Agent? diff --git a/pdm.lock b/pdm.lock index 6e92459a..94ccb76e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "android", "chat", "dev", "mcp", "pynput", "test", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1798fa81b5af10196977271587aae133fd763fe780218adcece9662c41e65e32" +content_hash = "sha256:02ec481499e740ea23f85f2490af2f9670ce24ac16614a1dc29c2f5db68fdba0" [[metadata.targets]] requires_python = ">=3.10" @@ -144,6 +144,17 @@ files = [ {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, ] +[[package]] +name = "cachetools" +version = "5.5.2" +requires_python = ">=3.7" +summary = "Extensible memoizing collections and decorators" +groups = ["default"] +files = [ + {file = "cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a"}, + {file = "cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4"}, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -689,6 +700,42 @@ files = [ {file = "genson-1.3.0.tar.gz", hash = "sha256:e02db9ac2e3fd29e65b5286f7135762e2cd8a986537c075b06fc5f1517308e37"}, ] +[[package]] +name = "google-auth" +version = "2.40.3" +requires_python = ">=3.7" +summary = "Google Authentication Library" +groups = ["default"] +dependencies = [ + "cachetools<6.0,>=2.0.0", + "pyasn1-modules>=0.2.1", + "rsa<5,>=3.1.4", +] +files = [ + {file = "google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca"}, + {file = "google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77"}, +] + +[[package]] +name = "google-genai" +version = "1.20.0" +requires_python = ">=3.9" +summary = "GenAI Python SDK" +groups = ["default"] +dependencies = [ + "anyio<5.0.0,>=4.8.0", + "google-auth<3.0.0,>=2.14.1", + "httpx<1.0.0,>=0.28.1", + "pydantic<3.0.0,>=2.0.0", + "requests<3.0.0,>=2.28.1", + "typing-extensions<5.0.0,>=4.11.0", + "websockets<15.1.0,>=13.0.0", +] +files = [ + {file = "google_genai-1.20.0-py3-none-any.whl", hash = "sha256:ccd61d6ebcb14f5c778b817b8010e3955ae4f6ddfeaabf65f42f6d5e3e5a8125"}, + {file = "google_genai-1.20.0.tar.gz", hash = "sha256:dccca78f765233844b1bd4f1f7a2237d9a76fe6038cf9aa72c0cd991e3c107b5"}, +] + [[package]] name = "gradio-client" version = "1.8.0" @@ -1512,6 +1559,31 @@ files = [ {file = "py_machineid-0.7.0-py3-none-any.whl", hash = "sha256:3dacc322b0511383d79f1e817a2710b19bcfb820a4c7cea34aaa329775fef468"}, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +requires_python = ">=3.8" +summary = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +groups = ["default"] +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +requires_python = ">=3.8" +summary = "A collection of ASN.1-based protocols modules" +groups = ["default"] +dependencies = [ + "pyasn1<0.7.0,>=0.6.1", +] +files = [ + {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"}, + {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -2212,6 +2284,20 @@ files = [ {file = "rpds_py-0.26.0.tar.gz", hash = "sha256:20dae58a859b0906f0685642e591056f1e787f3a8b39c8e8749a45dc7d26bdb0"}, ] +[[package]] +name = "rsa" +version = "4.9.1" +requires_python = "<4,>=3.6" +summary = "Pure-Python RSA implementation" +groups = ["default"] +dependencies = [ + "pyasn1>=0.1.3", +] +files = [ + {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"}, + {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"}, +] + [[package]] name = "ruff" version = "0.11.3" diff --git a/pyproject.toml b/pyproject.toml index c699f5f1..f93a3d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tenacity>=9.1.2", "jsonref>=1.1.0", "protobuf>=6.31.1", + "google-genai>=1.20.0", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/askui/models/askui/get_model.py b/src/askui/models/askui/get_model.py new file mode 100644 index 00000000..ef753caf --- /dev/null +++ b/src/askui/models/askui/get_model.py @@ -0,0 +1,72 @@ +from typing import Type + +from google.genai.errors import ClientError +from typing_extensions import override + +from askui.logger import logger +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi +from askui.models.askui.inference_api import AskUiInferenceApi +from askui.models.exceptions import QueryNoResponseError, QueryUnexpectedResponseError +from askui.models.models import GetModel +from askui.models.types.response_schemas import ResponseSchema +from askui.utils.image_utils import ImageSource + + +class AskUiGetModel(GetModel): + """A GetModel implementation that is supposed to be as comprehensive and + powerful as possible using the available AskUi models. + + This model first attempts to use the Google GenAI API for information extraction. + If the Google GenAI API fails (e.g., no response, unexpected response, or other + errors), it falls back to using the AskUI Inference API. + + Args: + google_genai_api (AskUiGoogleGenAiApi): The Google GenAI API instance to use + as primary. + inference_api (AskUiInferenceApi): The Inference API instance to use as + fallback. + """ + + def __init__( + self, + google_genai_api: AskUiGoogleGenAiApi, + inference_api: AskUiInferenceApi, + ) -> None: + self._google_genai_api = google_genai_api + self._inference_api = inference_api + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + try: + logger.debug("Attempting to use Google GenAI API") + return self._google_genai_api.get( + query=query, + image=image, + response_schema=response_schema, + model_choice=model_choice, + ) + except ( + ClientError, + QueryNoResponseError, + QueryUnexpectedResponseError, + NotImplementedError, + ) as e: + if isinstance(e, ClientError) and e.code != 400: + raise + logger.debug( + f"Google GenAI API failed with error that may not occur with other " + f"models/apis: {e}" + ". Falling back to Inference API..." + ) + return self._inference_api.get( + query=query, + image=image, + response_schema=response_schema, + model_choice=model_choice, + ) diff --git a/src/askui/models/askui/google_genai_api.py b/src/askui/models/askui/google_genai_api.py new file mode 100644 index 00000000..27af00df --- /dev/null +++ b/src/askui/models/askui/google_genai_api.py @@ -0,0 +1,92 @@ +import json as json_lib +from typing import Type + +import google.genai as genai +from google.genai import types as genai_types +from pydantic import ValidationError +from typing_extensions import override + +from askui.logger import logger +from askui.models.askui.inference_api import AskUiInferenceApiSettings +from askui.models.exceptions import QueryNoResponseError, QueryUnexpectedResponseError +from askui.models.models import GetModel, ModelName +from askui.models.shared.prompts import SYSTEM_PROMPT_GET +from askui.models.types.response_schemas import ResponseSchema, to_response_schema +from askui.utils.image_utils import ImageSource + +ASKUI_MODEL_CHOICE_PREFIX = "askui/" +ASKUI_MODEL_CHOICE_PREFIX_LEN = len(ASKUI_MODEL_CHOICE_PREFIX) + + +def _extract_model_id(model_choice: str) -> str: + if model_choice == ModelName.ASKUI: + return ModelName.GEMINI__2_5__FLASH + if model_choice.startswith(ASKUI_MODEL_CHOICE_PREFIX): + return model_choice[ASKUI_MODEL_CHOICE_PREFIX_LEN:] + return model_choice + + +class AskUiGoogleGenAiApi(GetModel): + def __init__(self, settings: AskUiInferenceApiSettings | None = None) -> None: + self._settings = settings or AskUiInferenceApiSettings() + self._client = genai.Client( + vertexai=True, + api_key="DummyValueRequiredByGenaiClient", + http_options=genai_types.HttpOptions( + base_url=f"{self._settings.base_url}/proxy/vertexai", + headers={ + "Authorization": self._settings.authorization_header, + }, + ), + ) + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + try: + _response_schema = to_response_schema(response_schema) + json_schema = _response_schema.model_json_schema() + logger.debug(f"json_schema:\n{json_lib.dumps(json_schema)}") + content = genai_types.Content( + parts=[ + genai_types.Part.from_bytes( + data=image.to_bytes(), + mime_type="image/png", + ), + genai_types.Part.from_text(text=query), + ], + role="user", + ) + generate_content_response = self._client.models.generate_content( + model=f"models/{_extract_model_id(model_choice)}", + contents=content, + config={ + "response_mime_type": "application/json", + "response_schema": _response_schema, + "system_instruction": SYSTEM_PROMPT_GET, + }, + ) + json_str = generate_content_response.text + if json_str is None: + raise QueryNoResponseError( + message="No response from the model", query=query + ) + try: + return _response_schema.model_validate_json(json_str).root + except ValidationError as e: + error_message = str(e.errors()) + raise QueryUnexpectedResponseError( + message=f"Unexpected response from the model: {error_message}", + query=query, + response=json_str, + ) from e + except RecursionError as e: + error_message = ( + "Recursive response schemas are not supported by AskUiGoogleGenAiApi" + ) + raise NotImplementedError(error_message) from e diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 3429ef0a..889ae0fa 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -7,6 +7,8 @@ from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer from askui.models.anthropic.messages_api import AnthropicMessagesApi from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.askui.get_model import AskUiGetModel +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi from askui.models.askui.model_router import AskUiModelRouter from askui.models.exceptions import ModelNotFoundError, ModelTypeMismatchError from askui.models.huggingface.spaces_api import HFSpacesHandler @@ -57,6 +59,10 @@ def anthropic_facade() -> ModelFacade: locate_model=messages_api, ) + @functools.cache + def askui_google_genai_api() -> AskUiGoogleGenAiApi: + return AskUiGoogleGenAiApi() + @functools.cache def askui_inference_api() -> AskUiInferenceApi: return AskUiInferenceApi( @@ -72,6 +78,13 @@ def askui_model_router() -> AskUiModelRouter: inference_api=askui_inference_api(), ) + @functools.cache + def askui_get_model() -> AskUiGetModel: + return AskUiGetModel( + google_genai_api=askui_google_genai_api(), + inference_api=askui_inference_api(), + ) + @functools.cache def askui_facade() -> ModelFacade: computer_agent = Agent( @@ -80,7 +93,7 @@ def askui_facade() -> ModelFacade: ) return ModelFacade( act_model=computer_agent, - get_model=askui_inference_api(), + get_model=askui_get_model(), locate_model=askui_model_router(), ) @@ -93,6 +106,8 @@ def hf_spaces_handler() -> HFSpacesHandler: return { ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: anthropic_facade, ModelName.ASKUI: askui_facade, + ModelName.ASKUI__GEMINI__2_5__FLASH: askui_google_genai_api, + ModelName.ASKUI__GEMINI__2_5__PRO: askui_google_genai_api, ModelName.ASKUI__AI_ELEMENT: askui_model_router, ModelName.ASKUI__COMBO: askui_model_router, ModelName.ASKUI__OCR: askui_model_router, diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 5f1861f1..5bd65ddc 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -27,11 +27,15 @@ class ModelName(str, Enum): ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022" ASKUI = "askui" + ASKUI__GEMINI__2_5__FLASH = "askui/gemini-2.5-flash" + ASKUI__GEMINI__2_5__PRO = "askui/gemini-2.5-pro" ASKUI__AI_ELEMENT = "askui-ai-element" ASKUI__COMBO = "askui-combo" ASKUI__OCR = "askui-ocr" ASKUI__PTA = "askui-pta" CLAUDE__SONNET__4__20250514 = "claude-sonnet-4-20250514" + GEMINI__2_5__FLASH = "gemini-2.5-flash" + GEMINI__2_5__PRO = "gemini-2.5-pro" HF__SPACES__ASKUI__PTA_1 = "AskUI/PTA-1" HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B = "OS-Copilot/OS-Atlas-Base-7B" HF__SPACES__QWEN__QWEN2_VL_2B_INSTRUCT = "Qwen/Qwen2-VL-2B-Instruct" diff --git a/src/askui/utils/image_utils.py b/src/askui/utils/image_utils.py index d2ca53ea..bfefc4ac 100644 --- a/src/askui/utils/image_utils.py +++ b/src/askui/utils/image_utils.py @@ -410,6 +410,16 @@ def to_base64(self) -> str: """ return image_to_base64(image=self.root) + def to_bytes(self) -> bytes: + """Convert the image to bytes. + + Returns: + bytes: The image as bytes. + """ + img_byte_arr = io.BytesIO() + self.root.save(img_byte_arr, format="PNG") + return img_byte_arr.getvalue() + __all__ = [ "load_image", diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index fbda3329..bcc27d37 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -10,6 +10,8 @@ from askui.agent import VisionAgent from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.askui.get_model import AskUiGetModel +from askui.models.askui.google_genai_api import AskUiGoogleGenAiApi from askui.models.askui.inference_api import ( AskUiInferenceApi, AskUiInferenceApiSettings, @@ -72,7 +74,10 @@ def askui_facade( ) return ModelFacade( act_model=agent, - get_model=askui_inference_api, + get_model=AskUiGetModel( + google_genai_api=AskUiGoogleGenAiApi(), + inference_api=askui_inference_api, + ), locate_model=AskUiModelRouter(inference_api=askui_inference_api), ) diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index e174a610..679eb5a2 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -30,6 +30,8 @@ class BrowserContextResponse(ResponseSchemaBase): [ None, ModelName.ASKUI, + ModelName.ASKUI__GEMINI__2_5__FLASH, + ModelName.ASKUI__GEMINI__2_5__PRO, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ModelName.CLAUDE__SONNET__4__20250514, ], @@ -100,17 +102,18 @@ class OptionalUrlResponse(ResponseSchemaBase): url: str = "github.com" -def test_get_with_response_schema_with_default_value_with_askui_model_raises( +def test_get_with_response_schema_with_default_value( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, ) -> None: - with pytest.raises(Exception): # noqa: B017 - vision_agent.get( - "What is the current url shown in the url bar?", - image=github_login_screenshot, - response_schema=OptionalUrlResponse, - model=ModelName.ASKUI, - ) + response = vision_agent.get( + "What is the current url shown in the url bar?", + image=github_login_screenshot, + response_schema=OptionalUrlResponse, + model=ModelName.ASKUI, + ) + assert isinstance(response, OptionalUrlResponse) + assert "github.com" in response.url @pytest.mark.parametrize("model", [None, ModelName.ASKUI])