Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 83 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,23 +640,97 @@ with VisionAgent() as agent:
For structured data extraction, use Pydantic models extending `ResponseSchemaBase`:

```python
from askui import ResponseSchemaBase
from askui import ResponseSchemaBase, VisionAgent
from PIL import Image
import json

class UserInfo(ResponseSchemaBase):
username: str
is_online: bool

# Get structured data
user_info = agent.get(
"What is the username and online status?",
response_schema=UserInfo
)
print(f"User {user_info.username} is {'online' if user_info.is_online else 'offline'}")
class UrlResponse(ResponseSchemaBase):
url: str

class NestedResponse(ResponseSchemaBase):
nested: UrlResponse

class LinkedListNode(ResponseSchemaBase):
value: str
next: "LinkedListNode | None"

with VisionAgent() as agent:
# Get structured data
user_info = agent.get(
"What is the username and online status?",
response_schema=UserInfo
)
print(f"User {user_info.username} is {'online' if user_info.is_online else 'offline'}")

# Get URL as string
url = agent.get("What is the current url shown in the url bar?")
print(url) # e.g., "github.com/login"

# Get URL as Pydantic model from image at (relative) path
response = agent.get(
"What is the current url shown in the url bar?",
response_schema=UrlResponse,
image="screenshot.png",
)

# Dump whole model
print(response.model_dump_json(indent=2))
# or
response_json_dict = response.model_dump(mode="json")
print(json.dumps(response_json_dict, indent=2))
# or for regular dict
response_dict = response.model_dump()
print(response_dict["url"])

# Get boolean response from PIL Image
is_login_page = agent.get(
"Is this a login page?",
response_schema=bool,
image=Image.open("screenshot.png"),
)
print(is_login_page)

# Get integer response
input_count = agent.get(
"How many input fields are visible on this page?",
response_schema=int,
)
print(input_count)

# Get float response
design_rating = agent.get(
"Rate the page design quality from 0 to 1",
response_schema=float,
)
print(design_rating)

# Get nested response
nested = agent.get(
"Extract the URL and its metadata from the page",
response_schema=NestedResponse,
)
print(nested.nested.url)

# Get recursive response
linked_list = agent.get(
"Extract the breadcrumb navigation as a linked list",
response_schema=LinkedListNode,
)
current = linked_list
while current:
print(current.value)
current = current.next
```

**⚠️ Limitations:**
- Nested Pydantic schemas are not currently supported
- Response schema is currently only supported by "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set)
- Not all models support response schemas or all kinds of properties that a response schema can have at the moment
- Default values are not supported, e.g., `url: str = "github.com"` or `url: str | None = None`. This includes `default_factory`
and `default` args of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or
`url: str = Field(default_factory=lambda: "github.com")`.

## What is AskUI Vision Agent?

Expand Down
41 changes: 36 additions & 5 deletions src/askui/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,22 @@ def get(
Returns:
ResponseSchema | str: The extracted information, `str` if no `response_schema` is provided.

Limitations:
- Nested Pydantic schemas are not currently supported
- Schema support is only available with "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set) at the moment

Example:
```python
from askui import ResponseSchemaBase, VisionAgent
from PIL import Image
import json

class UrlResponse(ResponseSchemaBase):
url: str

class NestedResponse(ResponseSchemaBase):
nested: UrlResponse

class LinkedListNode(ResponseSchemaBase):
value: str
next: "LinkedListNode | None"

with VisionAgent() as agent:
# Get URL as string
url = agent.get("What is the current url shown in the url bar?")
Expand All @@ -369,26 +373,53 @@ class UrlResponse(ResponseSchemaBase):
response_schema=UrlResponse,
image="screenshot.png",
)
print(response.url)
# Dump whole model
print(response.model_dump_json(indent=2))
# or
response_json_dict = response.model_dump(mode="json")
print(json.dumps(response_json_dict, indent=2))
# or for regular dict
response_dict = response.model_dump()
print(response_dict["url"])

# Get boolean response from PIL Image
is_login_page = agent.get(
"Is this a login page?",
response_schema=bool,
image=Image.open("screenshot.png"),
)
print(is_login_page)

# Get integer response
input_count = agent.get(
"How many input fields are visible on this page?",
response_schema=int,
)
print(input_count)

# Get float response
design_rating = agent.get(
"Rate the page design quality from 0 to 1",
response_schema=float,
)
print(design_rating)

# Get nested response
nested = agent.get(
"Extract the URL and its metadata from the page",
response_schema=NestedResponse,
)
print(nested.nested.url)

# Get recursive response
linked_list = agent.get(
"Extract the breadcrumb navigation as a linked list",
response_schema=LinkedListNode,
)
current = linked_list
while current:
print(current.value)
current = current.next
```
"""
logger.debug("VisionAgent received instruction to get '%s'", query)
Expand Down
8 changes: 3 additions & 5 deletions src/askui/models/askui/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Type

import requests
from pydantic import RootModel
from typing_extensions import override

from askui.locators.locators import Locator
Expand Down Expand Up @@ -92,11 +91,10 @@ def get(
"prompt": query,
}
_response_schema = to_response_schema(response_schema)
json["config"] = {"json_schema": _response_schema.model_json_schema()}
json_schema = _response_schema.model_json_schema()
json["config"] = {"json_schema": json_schema}
logger.debug(f"json_schema:\n{json_lib.dumps(json['config']['json_schema'])}")
content = self._request(endpoint="vqa/inference", json=json)
response = content["data"]["response"]
validated_response = _response_schema.model_validate(response)
if isinstance(validated_response, RootModel):
return validated_response.root
return validated_response
return validated_response.root
70 changes: 24 additions & 46 deletions src/askui/models/types/response_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class ResponseSchemaBase(BaseModel):
on top so that it can be used with models to define the schema (type) of
the data to be extracted.

**Important**: Default values are not supported, e.g., `url: str = "github.com"` or
`url: str | None = None`. This includes `default_factory` and `default` args
of `pydantic.Field` as well, e.g., `url: str = Field(default="github.com")` or
`url: str = Field(default_factory=lambda: "github.com")`.

Example:
```python
class UrlResponse(ResponseSchemaBase):
Expand All @@ -27,19 +32,22 @@ class UrlResponse(ResponseSchemaBase):
description="The URL of the response. Should used `\"https\"` scheme.",
examples=["https://www.example.com"]
)

# To define recursive response schemas, you can use quotation marks around the
# type of the field, e.g., `next: "LinkedListNode | None"`.
class LinkedListNode(ResponseSchemaBase):
value: str
next: "LinkedListNode | None"
```
"""

model_config = ConfigDict(extra="forbid")


String = RootModel[str]
Boolean = RootModel[bool]
Integer = RootModel[int]
Float = RootModel[float]


ResponseSchema = TypeVar("ResponseSchema", ResponseSchemaBase, str, bool, int, float)
ResponseSchema = TypeVar(
"ResponseSchema",
bound=ResponseSchemaBase | str | bool | int | float,
)
"""Type of the responses of data extracted, e.g., using `askui.VisionAgent.get()`.

The following types are allowed:
Expand All @@ -49,51 +57,21 @@ class UrlResponse(ResponseSchemaBase):
- `int`: Integer responses
- `float`: Floating point responses

Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be
passed to model(s). Also used for validating the responses of the model(s) used for
Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be
passed to model(s). Also used for validating the responses of the model(s) used for
data extraction.
"""


@overload
def to_response_schema(response_schema: None) -> Type[String]: ...
@overload
def to_response_schema(response_schema: Type[str]) -> Type[String]: ...
@overload
def to_response_schema(response_schema: Type[bool]) -> Type[Boolean]: ...
@overload
def to_response_schema(response_schema: Type[int]) -> Type[Integer]: ...
@overload
def to_response_schema(response_schema: Type[float]) -> Type[Float]: ...
def to_response_schema(response_schema: None) -> Type[RootModel[str]]: ...
@overload
def to_response_schema(
response_schema: Type[ResponseSchemaBase],
) -> Type[ResponseSchemaBase]: ...
response_schema: Type[ResponseSchema],
) -> Type[RootModel[ResponseSchema]]: ...
def to_response_schema(
response_schema: Type[ResponseSchemaBase]
| Type[str]
| Type[bool]
| Type[int]
| Type[float]
| None = None,
) -> (
Type[ResponseSchemaBase]
| Type[String]
| Type[Boolean]
| Type[Integer]
| Type[Float]
):
response_schema: Type[ResponseSchema] | None,
) -> Type[RootModel[str]] | Type[RootModel[ResponseSchema]]:
if response_schema is None:
return String
if response_schema is str:
return String
if response_schema is bool:
return Boolean
if response_schema is int:
return Integer
if response_schema is float:
return Float
if issubclass(response_schema, ResponseSchemaBase):
return response_schema
error_msg = f"Invalid response schema type: {response_schema}"
raise ValueError(error_msg)
return RootModel[str]
return RootModel[response_schema] # type: ignore[valid-type]
Loading