diff --git a/src/askui/__init__.py b/src/askui/__init__.py index e3cc9de7..f9f884f4 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -6,6 +6,7 @@ from .models import ModelComposition, ModelDefinition from .models.router import Point from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase +from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey from .utils.image_utils import Img @@ -18,5 +19,7 @@ "Point", "ResponseSchema", "ResponseSchemaBase", + "Retry", + "ConfigurableRetry", "VisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index 490dd0c2..d0724235 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -11,11 +11,13 @@ from askui.locators.locators import Locator from askui.utils.image_utils import ImageSource, Img +from .exceptions import ElementNotFoundError from .logger import configure_logging, logger from .models import ModelComposition from .models.router import ModelRouter, Point from .models.types.response_schemas import ResponseSchema from .reporting import CompositeReporter, Reporter +from .retry import ConfigurableRetry, Retry from .tools import AgentToolbox, ModifierKey, PcKey from .tools.askui import AskUiControllerClient, AskUiControllerServer @@ -34,7 +36,7 @@ class VisionAgent: reporters (list[Reporter] | None, optional): List of reporter instances for logging and reporting. If `None`, an empty list is used. tools (AgentToolbox | None, optional): Custom toolbox instance. If `None`, a default one will be created with `AskUiControllerClient`. model (ModelComposition | str | None, optional): The default composition or name of the model(s) to be used for vision tasks. Can be overridden by the `model` parameter in the `click()`, `get()`, `act()` etc. methods. - + retry (Retry, optional): The retry instance to use for retrying failed actions. Defaults to `ConfigurableRetry` with exponential backoff. Currently only supported for `locate()` method. Example: ```python from askui import VisionAgent @@ -56,6 +58,7 @@ def __init__( reporters: list[Reporter] | None = None, tools: AgentToolbox | None = None, model: ModelComposition | str | None = None, + retry: Retry | None = None, ) -> None: load_dotenv() configure_logging(level=log_level) @@ -73,6 +76,12 @@ def __init__( else model_router ) self.model = model + self._retry = retry or ConfigurableRetry( + strategy="Exponential", + base_delay=1000, + retry_count=3, + on_exception_types=(ElementNotFoundError,), + ) @telemetry.record_call(exclude={"locator"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -126,7 +135,12 @@ def _locate( _screenshot = ImageSource( self.tools.os.screenshot() if screenshot is None else screenshot ) - point = self.model_router.locate(_screenshot.root, locator, model or self.model) + + point = self._retry.attempt( + lambda: self.model_router.locate( + _screenshot.root, locator, model or self.model + ) + ) self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point diff --git a/src/askui/retry.py b/src/askui/retry.py new file mode 100644 index 00000000..13c1e2ad --- /dev/null +++ b/src/askui/retry.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +from typing import Annotated, Callable, Literal, Tuple, Type, TypeVar + +from pydantic import ConfigDict, Field, validate_call +from tenacity import ( + RetryCallState, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, + wait_fixed, + wait_incrementing, +) + +from askui.logger import logger + +R = TypeVar("R") + + +class Retry(ABC): + """Abstract base class for implementing retry mechanisms. + + This abstract class defines the interface for retry mechanisms. Concrete + implementations should define how the retry logic works by implementing + the abstract `attempt` method. + + Example: + ```python + class MyRetry(Retry): + def attempt(self, func: Callable[..., R]) -> R: + # Custom retry implementation + return func() + + retry = MyRetry() + result = retry.attempt(some_function) + ``` + """ + + @abstractmethod + def attempt(self, func: Callable[..., R]) -> R: + """Attempt to execute a function with retry logic. + + Args: + func: The function to execute with retry logic + + Returns: + The result of the function execution + + Raises: + Exception: Any exception that occurs during execution after + all retry attempts are exhausted + """ + + +class ConfigurableRetry(Retry): + """A configurable retry implementation with different strategies. + + This class provides a flexible way to retry operations that may fail temporarily, + supporting different retry strategies (Exponential, Fixed, Linear) and configurable + parameters for delay and retry count. + + Args: + on_exception_types (Tuple[Type[Exception]]): Tuple of exception types that should trigger a retry + strategy (Literal["Exponential", "Fixed", "Linear"]): The retry strategy to use: + - `"Exponential"`: Delay increases exponentially between retries + - `"Fixed"`: Constant delay between retries + - `"Linear"`: Delay increases linearly between retries + base_delay (int, optional): Base delay in milliseconds between retries. + retry_count (int, optional): Maximum number of retry attempts. + + Example: + ```python + retry = ConfigurableRetry( + on_exception_types=(ConnectionError, TimeoutError), + strategy="Exponential", + base_delay=1000, + retry_count=3 + ) + result = retry.attempt(some_function) + ``` + """ # noqa: E501 + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + on_exception_types: Tuple[Type[Exception]], + strategy: Literal["Exponential", "Fixed", "Linear"], + base_delay: Annotated[int, Field(gt=0)] = 1000, + retry_count: Annotated[int, Field(gt=0)] = 3, + ): + self._strategy = strategy + self._base_delay = base_delay + self._retry_count = retry_count + self._on_exception_types = on_exception_types + + def _get_retry_wait_strategy( + self, + ) -> wait_fixed | wait_incrementing | wait_exponential: + """Get the appropriate wait strategy based on the configured retry strategy.""" + if self._strategy == "Fixed": + return wait_fixed(self._base_delay / 1000) + if self._strategy == "Linear": + return wait_incrementing(self._base_delay / 1000) + return wait_exponential(multiplier=self._base_delay / 1000) + + def _log_retry_attempt(self, retry_state: RetryCallState) -> None: + logger.info( + "Retrying %s: attempt %s ended with: %s", + retry_state.fn, + retry_state.attempt_number, + retry_state.outcome, + ) + + def attempt(self, func: Callable[..., R]) -> R: + retryer = Retrying( + stop=stop_after_attempt(self._retry_count), + wait=self._get_retry_wait_strategy(), + reraise=True, + after=self._log_retry_attempt, + retry=retry_if_exception_type(self._on_exception_types), + ) + return retryer(func)