From 27e150cfaedb1620cf3b74e89af19b2bb8ef3da0 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 11 Jun 2025 13:18:46 +0200 Subject: [PATCH 1/3] docs(agent): update docs + tests with regards to nested/recursive response schemas supported now for AskUI model --- README.md | 83 +++++++++++++++++++--- src/askui/agent.py | 31 ++++++-- src/askui/models/types/response_schemas.py | 15 +++- tests/e2e/agent/test_get.py | 57 ++++++++++++--- 4 files changed, 163 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 34b28c15..5054fef6 100644 --- a/README.md +++ b/README.md @@ -640,23 +640,88 @@ with VisionAgent() as agent: For structured data extraction, use Pydantic models extending `ResponseSchemaBase`: ```python -from askui import ResponseSchemaBase +from askui import ResponseSchemaBase, VisionAgent +from PIL import Image class UserInfo(ResponseSchemaBase): username: str is_online: bool -# Get structured data -user_info = agent.get( - "What is the username and online status?", - response_schema=UserInfo -) -print(f"User {user_info.username} is {'online' if user_info.is_online else 'offline'}") +class UrlResponse(ResponseSchemaBase): + url: str + +class NestedResponse(ResponseSchemaBase): + nested: UrlResponse + +class LinkedListNode(ResponseSchemaBase): + value: str + next: "LinkedListNode | None" + +with VisionAgent() as agent: + # Get structured data + user_info = agent.get( + "What is the username and online status?", + response_schema=UserInfo + ) + print(f"User {user_info.username} is {'online' if user_info.is_online else 'offline'}") + + # Get URL as string + url = agent.get("What is the current url shown in the url bar?") + print(url) # e.g., "github.com/login" + + # Get URL as Pydantic model from image at (relative) path + response = agent.get( + "What is the current url shown in the url bar?", + response_schema=UrlResponse, + image="screenshot.png", + ) + print(response.url) + + # Get boolean response from PIL Image + is_login_page = agent.get( + "Is this a login page?", + response_schema=bool, + image=Image.open("screenshot.png"), + ) + print(is_login_page) + + # Get integer response + input_count = agent.get( + "How many input fields are visible on this page?", + response_schema=int, + ) + print(input_count) + + # Get float response + design_rating = agent.get( + "Rate the page design quality from 0 to 1", + response_schema=float, + ) + print(design_rating) + + # Get nested response + nested = agent.get( + "Extract the URL and its metadata from the page", + response_schema=NestedResponse, + ) + print(nested.nested.url) + + # Get recursive response + linked_list = agent.get( + "Extract the breadcrumb navigation as a linked list", + response_schema=LinkedListNode, + ) + current = linked_list + while current: + print(current.value) + current = current.next ``` **⚠️ Limitations:** -- Nested Pydantic schemas are not currently supported -- Response schema is currently only supported by "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set) +- Not all models support response schemas or all kinds of properties that a response schema can have at the moment +- Default values are not supported, e.g., `url: str = "github.com"` or `url: str | None = None`. This includes `default_factory` + and `default` args of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or + `url: str = Field(default_factory=lambda: "github.com")`. ## What is AskUI Vision Agent? diff --git a/src/askui/agent.py b/src/askui/agent.py index b59711e0..4217b26d 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -347,10 +347,6 @@ def get( Returns: ResponseSchema | str: The extracted information, `str` if no `response_schema` is provided. - Limitations: - - Nested Pydantic schemas are not currently supported - - Schema support is only available with "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set) at the moment - Example: ```python from askui import ResponseSchemaBase, VisionAgent @@ -359,6 +355,13 @@ def get( class UrlResponse(ResponseSchemaBase): url: str + class NestedResponse(ResponseSchemaBase): + nested: UrlResponse + + class LinkedListNode(ResponseSchemaBase): + value: str + next: "LinkedListNode | None" + with VisionAgent() as agent: # Get URL as string url = agent.get("What is the current url shown in the url bar?") @@ -377,18 +380,38 @@ class UrlResponse(ResponseSchemaBase): response_schema=bool, image=Image.open("screenshot.png"), ) + print(is_login_page) # Get integer response input_count = agent.get( "How many input fields are visible on this page?", response_schema=int, ) + print(input_count) # Get float response design_rating = agent.get( "Rate the page design quality from 0 to 1", response_schema=float, ) + print(design_rating) + + # Get nested response + nested = agent.get( + "Extract the URL and its metadata from the page", + response_schema=NestedResponse, + ) + print(nested.nested.url) + + # Get recursive response + linked_list = agent.get( + "Extract the breadcrumb navigation as a linked list", + response_schema=LinkedListNode, + ) + current = linked_list + while current: + print(current.value) + current = current.next ``` """ logger.debug("VisionAgent received instruction to get '%s'", query) diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py index 607770ae..56b363c8 100644 --- a/src/askui/models/types/response_schemas.py +++ b/src/askui/models/types/response_schemas.py @@ -11,6 +11,11 @@ class ResponseSchemaBase(BaseModel): on top so that it can be used with models to define the schema (type) of the data to be extracted. + **Important**: Default values are not supported, e.g., `url: str = "github.com"` or + `url: str | None = None`. This includes `default_factory` and `default` args + of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or + `url: str = Field(default_factory=lambda: "github.com")`. + Example: ```python class UrlResponse(ResponseSchemaBase): @@ -27,6 +32,12 @@ class UrlResponse(ResponseSchemaBase): description="The URL of the response. Should used `\"https\"` scheme.", examples=["https://www.example.com"] ) + + # To define recursive response schemas, you can use quotation marks around the + # type of the field, e.g., `next: "LinkedListNode | None"`. + class LinkedListNode(ResponseSchemaBase): + value: str + next: "LinkedListNode | None" ``` """ @@ -49,8 +60,8 @@ class UrlResponse(ResponseSchemaBase): - `int`: Integer responses - `float`: Floating point responses -Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be -passed to model(s). Also used for validating the responses of the model(s) used for +Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be +passed to model(s). Also used for validating the responses of the model(s) used for data extraction. """ diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index be3fbb26..6d41f3d8 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -2,6 +2,7 @@ import pytest from PIL import Image as PILImage +from pydantic import BaseModel from askui import ResponseSchemaBase, VisionAgent from askui.models import ModelName @@ -72,18 +73,36 @@ def test_get_with_model_composition_should_use_default_model( assert url in ["github.com/login", "https://github.com/login"] -@pytest.mark.skip( - "Skip for now as this pops up in our observability systems as a false positive" -) +class UrlResponseBaseModel(BaseModel): + url: str + + def test_get_with_response_schema_without_additional_properties_with_askui_model_raises( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, +) -> None: + with pytest.raises(Exception): # noqa: B017 + vision_agent.get( # type: ignore[type-var] + "What is the current url shown in the url bar?", + image=github_login_screenshot, + response_schema=UrlResponseBaseModel, + model=ModelName.ASKUI, + ) + + +class OptionalUrlResponse(ResponseSchemaBase): + url: str = "github.com" + + +def test_get_with_response_schema_with_default_value_with_askui_model_raises( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, ) -> None: with pytest.raises(Exception): # noqa: B017 vision_agent.get( "What is the current url shown in the url bar?", image=github_login_screenshot, - response_schema=UrlResponse, + response_schema=OptionalUrlResponse, model=ModelName.ASKUI, ) @@ -134,9 +153,6 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( @pytest.mark.parametrize("model", [ModelName.ASKUI]) -@pytest.mark.skip( - "Skip as there is currently a bug on the api side not supporting definitions used for nested schemas" -) def test_get_with_nested_and_inherited_response_schema( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, @@ -150,10 +166,35 @@ def test_get_with_nested_and_inherited_response_schema( ) assert isinstance(response, BrowserContextResponse) assert response.page_context.url in ["https://github.com/login", "github.com/login"] - assert "Github" in response.page_context.title + assert "GitHub" in response.page_context.title assert response.browser_type in ["chrome", "firefox", "edge", "safari"] +class LinkedListNode(ResponseSchemaBase): + value: str + next: "LinkedListNode | None" + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_recursive_response_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "Can you extract all segments (domain, path etc.) from the url as a linked list, " + "e.g. 'https://google.com/test' -> 'google.com->test->None'?", + image=github_login_screenshot, + response_schema=LinkedListNode, + model=model, + ) + assert isinstance(response, LinkedListNode) + assert response.value == "github.com" + assert response.next is not None + assert response.next.value == "login" + assert response.next.next is None + + @pytest.mark.parametrize("model", [ModelName.ASKUI]) def test_get_with_string_schema( vision_agent: VisionAgent, From 9b33b2e09a1fe576cdccb4abfe8ee55dbe6e16c6 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 12 Jun 2025 09:33:39 +0200 Subject: [PATCH 2/3] fix(agent): fix response type of `VisionAgent.get()` - was `ResponseSchemaBase` instead of class extending `ResponseSchemaBase` if class extending `ResponseSchemaBase` was passed --- src/askui/models/askui/inference_api.py | 8 ++-- src/askui/models/types/response_schemas.py | 55 +++++----------------- tests/e2e/agent/test_get.py | 54 ++++++++++++--------- tests/integration/test_custom_models.py | 8 +--- 4 files changed, 48 insertions(+), 77 deletions(-) diff --git a/src/askui/models/askui/inference_api.py b/src/askui/models/askui/inference_api.py index 6221a8f6..3ec933c0 100644 --- a/src/askui/models/askui/inference_api.py +++ b/src/askui/models/askui/inference_api.py @@ -2,7 +2,6 @@ from typing import Any, Type import requests -from pydantic import RootModel from typing_extensions import override from askui.locators.locators import Locator @@ -92,11 +91,10 @@ def get( "prompt": query, } _response_schema = to_response_schema(response_schema) - json["config"] = {"json_schema": _response_schema.model_json_schema()} + json_schema = _response_schema.model_json_schema() + json["config"] = {"json_schema": json_schema} logger.debug(f"json_schema:\n{json_lib.dumps(json['config']['json_schema'])}") content = self._request(endpoint="vqa/inference", json=json) response = content["data"]["response"] validated_response = _response_schema.model_validate(response) - if isinstance(validated_response, RootModel): - return validated_response.root - return validated_response + return validated_response.root diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py index 56b363c8..60651722 100644 --- a/src/askui/models/types/response_schemas.py +++ b/src/askui/models/types/response_schemas.py @@ -44,13 +44,10 @@ class LinkedListNode(ResponseSchemaBase): model_config = ConfigDict(extra="forbid") -String = RootModel[str] -Boolean = RootModel[bool] -Integer = RootModel[int] -Float = RootModel[float] - - -ResponseSchema = TypeVar("ResponseSchema", ResponseSchemaBase, str, bool, int, float) +ResponseSchema = TypeVar( + "ResponseSchema", + bound=ResponseSchemaBase | str | bool | int | float, +) """Type of the responses of data extracted, e.g., using `askui.VisionAgent.get()`. The following types are allowed: @@ -67,44 +64,14 @@ class LinkedListNode(ResponseSchemaBase): @overload -def to_response_schema(response_schema: None) -> Type[String]: ... -@overload -def to_response_schema(response_schema: Type[str]) -> Type[String]: ... -@overload -def to_response_schema(response_schema: Type[bool]) -> Type[Boolean]: ... -@overload -def to_response_schema(response_schema: Type[int]) -> Type[Integer]: ... -@overload -def to_response_schema(response_schema: Type[float]) -> Type[Float]: ... +def to_response_schema(response_schema: None) -> Type[RootModel[str]]: ... @overload def to_response_schema( - response_schema: Type[ResponseSchemaBase], -) -> Type[ResponseSchemaBase]: ... + response_schema: Type[ResponseSchema], +) -> Type[RootModel[ResponseSchema]]: ... def to_response_schema( - response_schema: Type[ResponseSchemaBase] - | Type[str] - | Type[bool] - | Type[int] - | Type[float] - | None = None, -) -> ( - Type[ResponseSchemaBase] - | Type[String] - | Type[Boolean] - | Type[Integer] - | Type[Float] -): + response_schema: Type[ResponseSchema] | None, +) -> Type[RootModel[str]] | Type[RootModel[ResponseSchema]]: if response_schema is None: - return String - if response_schema is str: - return String - if response_schema is bool: - return Boolean - if response_schema is int: - return Integer - if response_schema is float: - return Float - if issubclass(response_schema, ResponseSchemaBase): - return response_schema - error_msg = f"Invalid response schema type: {response_schema}" - raise ValueError(error_msg) + return RootModel[str] + return RootModel[response_schema] # type: ignore[valid-type] diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index 6d41f3d8..9a8d974e 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -2,7 +2,7 @@ import pytest from PIL import Image as PILImage -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from askui import ResponseSchemaBase, VisionAgent from askui.models import ModelName @@ -82,10 +82,10 @@ def test_get_with_response_schema_without_additional_properties_with_askui_model github_login_screenshot: PILImage.Image, ) -> None: with pytest.raises(Exception): # noqa: B017 - vision_agent.get( # type: ignore[type-var] + vision_agent.get( "What is the current url shown in the url bar?", image=github_login_screenshot, - response_schema=UrlResponseBaseModel, + response_schema=UrlResponseBaseModel, # type: ignore[type-var] model=ModelName.ASKUI, ) @@ -107,22 +107,6 @@ def test_get_with_response_schema_with_default_value_with_askui_model_raises( ) -@pytest.mark.skip( - "Skip for now as this pops up in our observability systems as a false positive" -) -def test_get_with_response_schema_without_required_with_askui_model_raises( - vision_agent: VisionAgent, - github_login_screenshot: PILImage.Image, -) -> None: - with pytest.raises(Exception): # noqa: B017 - vision_agent.get( - "What is the current url shown in the url bar?", - image=github_login_screenshot, - response_schema=UrlResponse, - model=ModelName.ASKUI, - ) - - @pytest.mark.parametrize("model", [None, ModelName.ASKUI]) def test_get_with_response_schema( vision_agent: VisionAgent, @@ -192,7 +176,11 @@ def test_get_with_recursive_response_schema( assert response.value == "github.com" assert response.next is not None assert response.next.value == "login" - assert response.next.next is None + assert ( + response.next.next is None + or response.next.next.value == "" + and response.next.next.next is None + ) @pytest.mark.parametrize("model", [ModelName.ASKUI]) @@ -289,4 +277,28 @@ def test_get_with_basis_schema( model=model, ) assert isinstance(response, Basis) - assert response.answer != '"What is the display showing?"' + assert isinstance(response.answer, str) + + +class Answer(ResponseSchemaBase): + answer: str + + +class BasisWithNestedRootModel(ResponseSchemaBase): + answer: RootModel[Answer] + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_nested_root_model( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "What is the display showing?", + image=github_login_screenshot, + response_schema=BasisWithNestedRootModel, + model=model, + ) + assert isinstance(response, BasisWithNestedRootModel) + assert isinstance(response.answer.root.answer, str) diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 01db14fd..a64525ba 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -47,13 +47,7 @@ class SimpleGetModel(GetModel): def __init__(self, response: str | ResponseSchemaBase = "test response") -> None: self.queries: list[str] = [] self.images: list[ImageSource] = [] - self.schemas: list[ - Optional[type[ResponseSchemaBase]] - | Optional[type[str]] - | Optional[type[bool]] - | Optional[type[int]] - | Optional[type[float]] - ] = [] + self.schemas: list[Any] = [] self.model_choices: list[str] = [] self.response = response From 15065c189f65b6ca00874fb56bd2f3d0a046d6f1 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 12 Jun 2025 09:40:22 +0200 Subject: [PATCH 3/3] docs(agent): add examples of dumping response models --- README.md | 11 ++++++++++- src/askui/agent.py | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5054fef6..c27d5408 100644 --- a/README.md +++ b/README.md @@ -642,6 +642,7 @@ For structured data extraction, use Pydantic models extending `ResponseSchemaBas ```python from askui import ResponseSchemaBase, VisionAgent from PIL import Image +import json class UserInfo(ResponseSchemaBase): username: str @@ -675,7 +676,15 @@ with VisionAgent() as agent: response_schema=UrlResponse, image="screenshot.png", ) - print(response.url) + + # Dump whole model + print(response.model_dump_json(indent=2)) + # or + response_json_dict = response.model_dump(mode="json") + print(json.dumps(response_json_dict, indent=2)) + # or for regular dict + response_dict = response.model_dump() + print(response_dict["url"]) # Get boolean response from PIL Image is_login_page = agent.get( diff --git a/src/askui/agent.py b/src/askui/agent.py index 4217b26d..981839df 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -351,6 +351,7 @@ def get( ```python from askui import ResponseSchemaBase, VisionAgent from PIL import Image + import json class UrlResponse(ResponseSchemaBase): url: str @@ -372,7 +373,14 @@ class LinkedListNode(ResponseSchemaBase): response_schema=UrlResponse, image="screenshot.png", ) - print(response.url) + # Dump whole model + print(response.model_dump_json(indent=2)) + # or + response_json_dict = response.model_dump(mode="json") + print(json.dumps(response_json_dict, indent=2)) + # or for regular dict + response_dict = response.model_dump() + print(response_dict["url"]) # Get boolean response from PIL Image is_login_page = agent.get(