diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py index a673c64d..b40935f1 100644 --- a/src/askui/android_agent.py +++ b/src/askui/android_agent.py @@ -14,9 +14,14 @@ from askui.tools.android.ppadb_agent_os import PpadbAgentOs from askui.tools.android.tools import ( AndroidDragAndDropTool, + AndroidGetConnectedDevicesSerialNumbersTool, + AndroidGetConnectedDisplaysInfosTool, + AndroidGetCurrentConnectedDeviceInfosTool, AndroidKeyCombinationTool, AndroidKeyTapEventTool, AndroidScreenshotTool, + AndroidSelectDeviceBySerialNumberTool, + AndroidSelectDisplayByIndex, AndroidShellTool, AndroidSwipeTool, AndroidTapTool, @@ -77,6 +82,11 @@ def __init__( AndroidSwipeTool(act_agent_os_facade), AndroidKeyCombinationTool(act_agent_os_facade), AndroidShellTool(act_agent_os_facade), + AndroidSelectDeviceBySerialNumberTool(act_agent_os_facade), + AndroidSelectDisplayByIndex(act_agent_os_facade), + AndroidGetConnectedDevicesSerialNumbersTool(act_agent_os_facade), + AndroidGetConnectedDisplaysInfosTool(act_agent_os_facade), + AndroidGetCurrentConnectedDeviceInfosTool(act_agent_os_facade), ExceptionTool(), ], agent_os=self.os, diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index b21db2d0..a1b57ced 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -15,6 +15,7 @@ from askui.chat.api.mcp_clients.manager import McpServerConnectionError from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service from askui.chat.api.mcp_configs.router import router as mcp_configs_router +from askui.chat.api.mcp_servers.android import mcp as android_mcp from askui.chat.api.mcp_servers.computer import mcp as computer_mcp from askui.chat.api.mcp_servers.testing import mcp as testing_mcp from askui.chat.api.mcp_servers.utility import mcp as utility_mcp @@ -64,6 +65,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 mcp = FastMCP.from_fastapi(app=app, name="AskUI Chat MCP") mcp.mount(computer_mcp) +mcp.mount(android_mcp) mcp.mount(testing_mcp) mcp.mount(utility_mcp) mcp_app = mcp.http_app("/sse", transport="sse") diff --git a/src/askui/chat/api/assistants/seeds.py b/src/askui/chat/api/assistants/seeds.py index 2ad29e5a..2c14ea5a 100644 --- a/src/askui/chat/api/assistants/seeds.py +++ b/src/askui/chat/api/assistants/seeds.py @@ -37,6 +37,12 @@ "android_swipe_tool", "android_key_combination_tool", "android_shell_tool", + "android_connect_tool", + "android_get_connected_devices_serial_numbers_tool", + "android_get_connected_displays_infos_tool", + "android_get_current_connected_device_infos_tool", + "android_select_device_by_serial_number_tool", + "android_select_display_by_index_tool", ], ) diff --git a/src/askui/chat/api/mcp_servers/android.py b/src/askui/chat/api/mcp_servers/android.py new file mode 100644 index 00000000..ffa3255f --- /dev/null +++ b/src/askui/chat/api/mcp_servers/android.py @@ -0,0 +1,45 @@ +from fastmcp import FastMCP + +from askui.tools.android.agent_os_facade import AndroidAgentOsFacade +from askui.tools.android.ppadb_agent_os import PpadbAgentOs +from askui.tools.android.tools import ( + AndroidConnectTool, + AndroidDragAndDropTool, + AndroidGetConnectedDevicesSerialNumbersTool, + AndroidGetConnectedDisplaysInfosTool, + AndroidGetCurrentConnectedDeviceInfosTool, + AndroidKeyCombinationTool, + AndroidKeyTapEventTool, + AndroidScreenshotTool, + AndroidSelectDeviceBySerialNumberTool, + AndroidSelectDisplayByIndex, + AndroidShellTool, + AndroidSwipeTool, + AndroidTapTool, + AndroidTypeTool, +) + +mcp = FastMCP(name="AskUI Android MCP") + +# Initialize the AndroidAgentOsFacade +ANDROID_AGENT_OS = PpadbAgentOs() +ANDROID_AGENT_OS_FACADE = AndroidAgentOsFacade(ANDROID_AGENT_OS) +TOOLS = [ + AndroidSelectDeviceBySerialNumberTool(ANDROID_AGENT_OS_FACADE), + AndroidSelectDisplayByIndex(ANDROID_AGENT_OS_FACADE), + AndroidGetConnectedDevicesSerialNumbersTool(ANDROID_AGENT_OS_FACADE), + AndroidGetConnectedDisplaysInfosTool(ANDROID_AGENT_OS_FACADE), + AndroidGetCurrentConnectedDeviceInfosTool(ANDROID_AGENT_OS_FACADE), + AndroidConnectTool(ANDROID_AGENT_OS_FACADE), + AndroidScreenshotTool(ANDROID_AGENT_OS_FACADE), + AndroidTapTool(ANDROID_AGENT_OS_FACADE), + AndroidTypeTool(ANDROID_AGENT_OS_FACADE), + AndroidDragAndDropTool(ANDROID_AGENT_OS_FACADE), + AndroidKeyTapEventTool(ANDROID_AGENT_OS_FACADE), + AndroidSwipeTool(ANDROID_AGENT_OS_FACADE), + AndroidKeyCombinationTool(ANDROID_AGENT_OS_FACADE), + AndroidShellTool(ANDROID_AGENT_OS_FACADE), +] + +for tool in TOOLS: + mcp.add_tool(tool.to_mcp_tool({"android"})) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 3618a4bb..5f5b6a61 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -31,35 +31,6 @@ logger = logging.getLogger(__name__) -def _get_android_tools() -> list[Tool]: - from askui.tools.android.agent_os_facade import AndroidAgentOsFacade - from askui.tools.android.ppadb_agent_os import PpadbAgentOs - from askui.tools.android.tools import ( - AndroidDragAndDropTool, - AndroidKeyCombinationTool, - AndroidKeyTapEventTool, - AndroidScreenshotTool, - AndroidShellTool, - AndroidSwipeTool, - AndroidTapTool, - AndroidTypeTool, - ) - - agent_os = PpadbAgentOs() - agent_os.connect() - act_agent_os_facade = AndroidAgentOsFacade(agent_os) - return [ - AndroidScreenshotTool(act_agent_os_facade), - AndroidTapTool(act_agent_os_facade), - AndroidTypeTool(act_agent_os_facade), - AndroidDragAndDropTool(act_agent_os_facade), - AndroidKeyTapEventTool(act_agent_os_facade), - AndroidSwipeTool(act_agent_os_facade), - AndroidKeyCombinationTool(act_agent_os_facade), - AndroidShellTool(act_agent_os_facade), - ] - - class RunnerRunService(ABC): @abstractmethod def retrieve(self, thread_id: ThreadId, run_id: RunId) -> Run: @@ -167,9 +138,6 @@ def _run_agent_inner() -> None: include=set(self._assistant.tools), ) betas = tools.retrieve_tool_beta_flags() - # Remove this after having extracted tools into Android MCP - if self._run.assistant_id == ANDROID_AGENT.id: - tools.append_tool(*_get_android_tools()) system = self._build_system() model = str(ModelName.CLAUDE__SONNET__4__20250514) messages = syncify(self._chat_history_manager.retrieve_message_params)( diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 361ef3d8..a4bd7e1d 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -2,6 +2,7 @@ import types from abc import ABC, abstractmethod from datetime import timedelta +from functools import wraps from typing import Any, Literal, Protocol, Type import jsonref @@ -14,6 +15,8 @@ from anthropic.types.beta.beta_tool_param import InputSchema from asyncer import syncify from fastmcp.client.client import CallToolResult, ProgressHandler +from fastmcp.tools import Tool as FastMcpTool +from fastmcp.utilities.types import Image as FastMcpImage from mcp import Tool as McpTool from PIL import Image from pydantic import BaseModel, Field @@ -101,6 +104,22 @@ def _default_input_schema() -> InputSchema: return {"type": "object", "properties": {}, "required": []} +def _convert_to_mcp_content( + result: Any, +) -> Any: + if isinstance(result, tuple): + return tuple(_convert_to_mcp_content(item) for item in result) + + if isinstance(result, list): + return [_convert_to_mcp_content(item) for item in result] + + if isinstance(result, Image.Image): + src = ImageSource(result) + return FastMcpImage(data=src.to_bytes(), format="png") + + return result + + class Tool(BaseModel, ABC): name: str = Field(description="Name of the tool") description: str = Field(description="Description of what the tool does") @@ -124,6 +143,21 @@ def to_params( input_schema=self.input_schema, ) + def to_mcp_tool(self, tags: set[str]) -> FastMcpTool: + """Convert the AskUI tool to an MCP tool.""" + tool_call = self.__call__ + + @wraps(tool_call) + def wrapped_tool_call(*args: Any, **kwargs: Any) -> Any: + return _convert_to_mcp_content(tool_call(*args, **kwargs)) + + return FastMcpTool.from_function( + wrapped_tool_call, + name=self.name, + description=self.description, + tags=tags, + ) + class AgentException(Exception): """ diff --git a/src/askui/tools/android/agent_os.py b/src/askui/tools/android/agent_os.py index 7fecc2fb..05bd3a31 100644 --- a/src/askui/tools/android/agent_os.py +++ b/src/askui/tools/android/agent_os.py @@ -387,3 +387,24 @@ def get_connected_displays(self) -> list[AndroidDisplay]: Gets the connected displays for screen interactions. """ raise NotImplementedError + + @abstractmethod + def get_connected_devices_serial_numbers(self) -> list[str]: + """ + Gets the connected devices serial numbers. + """ + raise NotImplementedError + + @abstractmethod + def get_selected_device_infos(self) -> tuple[str, AndroidDisplay]: + """ + Gets the selected device infos. + """ + raise NotImplementedError + + @abstractmethod + def connect_adb_client(self) -> None: + """ + Connects the adb client to the server. + """ + raise NotImplementedError diff --git a/src/askui/tools/android/agent_os_facade.py b/src/askui/tools/android/agent_os_facade.py index f61a1a7b..c521a5be 100644 --- a/src/askui/tools/android/agent_os_facade.py +++ b/src/askui/tools/android/agent_os_facade.py @@ -143,3 +143,26 @@ def set_device_by_serial_number(self, device_sn: str) -> None: self._reporter.add_message( "AndroidAgentOS", f"Set device by serial number: {device_sn}" ) + + def get_connected_devices_serial_numbers(self) -> list[str]: + devices_sn = self._agent_os.get_connected_devices_serial_numbers() + self._reporter.add_message( + "AndroidAgentOS", + f"Retrieved connected devices serial numbers, length: {len(devices_sn)}", + ) + return devices_sn + + def get_selected_device_infos(self) -> tuple[str, AndroidDisplay]: + device_sn, selected_display = self._agent_os.get_selected_device_infos() + self._reporter.add_message( + "AndroidAgentOS", + ( + f"Selected device serial number '{device_sn}'" + f" and selected display: {str(selected_display)}" + ), + ) + return device_sn, selected_display + + def connect_adb_client(self) -> None: + self._agent_os.connect_adb_client() + self._reporter.add_message("AndroidAgentOS", "Connected to adb client") diff --git a/src/askui/tools/android/ppadb_agent_os.py b/src/askui/tools/android/ppadb_agent_os.py index 52e439a1..7d7ea15f 100644 --- a/src/askui/tools/android/ppadb_agent_os.py +++ b/src/askui/tools/android/ppadb_agent_os.py @@ -23,8 +23,24 @@ def __init__(self) -> None: self._displays: list[AndroidDisplay] = [] self._selected_display: Optional[AndroidDisplay] = None + def connect_adb_client(self) -> None: + if self._client is not None: + msg = "Adb client is already connected" + raise RuntimeError(msg) + try: + self._client = AdbClient() + except Exception as e: # noqa: BLE001 + msg = f""" Failed to connect the adb client to the server. + Make sure the adb server is running. + IF you are using a real device, make sure the device is connected. + And listed after executiing the 'adb devices' command. + If you are using an emulator, make sure the emulator is running. + The error message: {e} + """ + raise RuntimeError(msg) # noqa: B904 + def connect(self) -> None: - self._client = AdbClient() + self.connect_adb_client() self.set_device_by_index(0) assert self._device is not None self._device.wait_boot_complete() @@ -38,9 +54,8 @@ def _set_display(self, display: AndroidDisplay) -> None: self._mouse_position = (0, 0) def get_connected_displays(self) -> list[AndroidDisplay]: - if not self._device: - msg = "No device connected" - raise RuntimeError(msg) + self._check_if_device_is_selected() + assert self._device is not None displays: list[AndroidDisplay] = [] output: str = self._device.shell( "dumpsys SurfaceFlinger --display-id", @@ -122,7 +137,8 @@ def set_device_by_serial_number(self, device_sn: str) -> None: raise RuntimeError(msg) def screenshot(self) -> Image.Image: - self._check_if_device_is_connected() + self._check_if_device_is_selected() + self._check_if_display_is_selected() assert self._device is not None assert self._selected_display is not None connection_to_device = self._device.create_connection() @@ -136,16 +152,19 @@ def screenshot(self) -> Image.Image: return Image.open(io.BytesIO(response)) def shell(self, command: str) -> str: - self._check_if_device_is_connected() + self._check_if_device_is_selected() + self._check_if_display_is_selected() assert self._device is not None response: str = self._device.shell(command) return response def tap(self, x: int, y: int) -> None: - self._check_if_device_is_connected() + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index - self.shell(f"input -d {display_index} tap {x} {y}") + self._device.shell(f"input -d {display_index} tap {x} {y}") self._mouse_position = (x, y) def swipe( @@ -156,9 +175,12 @@ def swipe( y2: int, duration_in_ms: int = 1000, ) -> None: + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index - self.shell( + self._device.shell( f"input -d {display_index} swipe {x1} {y1} {x2} {y2} {duration_in_ms}" ) self._mouse_position = (x2, y2) @@ -171,9 +193,12 @@ def drag_and_drop( y2: int, duration_in_ms: int = 1000, ) -> None: + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index - self.shell( + self._device.shell( f"input -d {display_index} draganddrop {x1} {y1} {x2} {y2} {duration_in_ms}" ) self._mouse_position = (x2, y2) @@ -185,19 +210,25 @@ def type(self, text: str) -> None: + "or special characters which are not supported by the device" ) raise RuntimeError(error_msg_nonprintable) + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index escaped_text = shlex.quote(text) shell_safe_text = escaped_text.replace(" ", "%s") - self.shell(f"input -d {display_index} text {shell_safe_text}") + self._device.shell(f"input -d {display_index} text {shell_safe_text}") def key_tap(self, key: ANDROID_KEY) -> None: if key not in get_args(ANDROID_KEY): error_msg_invalid_key: str = f"Invalid key: {key}" raise RuntimeError(error_msg_invalid_key) + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index - self.shell(f"input -d {display_index} keyevent {key}") + self._device.shell(f"input -d {display_index} keyevent {key}") def key_combination( self, keys: List[ANDROID_KEY], duration_in_ms: int = 100 @@ -211,19 +242,20 @@ def key_combination( raise RuntimeError(error_msg_too_few) keys_string = " ".join(keys) + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None assert self._selected_display is not None display_index: int = self._selected_display.display_index - self.shell( + self._device.shell( f"input -d {display_index} keycombination -t {duration_in_ms} {keys_string}" ) - def _check_if_device_is_connected(self) -> None: - if not self._client or not self._device: - msg = "No device connected" - raise RuntimeError(msg) - devices: list[AndroidDevice] = self._client.devices() - if not devices: - msg = "No devices connected" + def _check_if_device_is_selected(self) -> None: + devices: list[AndroidDevice] = self._get_connected_devices() + + if not self._device: + msg = "No device is selected, did you call on of the set_device methods?" raise RuntimeError(msg) for device in devices: @@ -232,12 +264,40 @@ def _check_if_device_is_connected(self) -> None: msg = f"Device {self._device.serial} not found in connected devices" raise RuntimeError(msg) + def _check_if_display_is_selected(self) -> None: + if self._selected_display is None: + msg = "No display is selected, did you call on of the set_display methods?" + raise RuntimeError(msg) + def _get_connected_devices(self) -> list[AndroidDevice]: + """ + Get the connected devices. + """ if not self._client: - msg = "No client connected" + msg = "No adb client is connected, did you call the connect method?" raise RuntimeError(msg) devices: list[AndroidDevice] = self._client.devices() if not devices: - msg = "No devices connected" + msg = """No devices are connected, + If you are using an emulator, make sure the emulator is running. + If you are using a real device, make sure the device is connected. + """ raise RuntimeError(msg) return devices + + def get_connected_devices_serial_numbers(self) -> list[str]: + """ + Get the connected devices serial numbers. + """ + devices: list[AndroidDevice] = self._get_connected_devices() + return [device.serial for device in devices] + + def get_selected_device_infos(self) -> tuple[str, AndroidDisplay]: + """ + Get the selected device infos. + """ + self._check_if_device_is_selected() + self._check_if_display_is_selected() + assert self._device is not None + assert self._selected_display is not None + return (self._device.serial, self._selected_display) diff --git a/src/askui/tools/android/tools.py b/src/askui/tools/android/tools.py index 79bcc430..b7d51c88 100644 --- a/src/askui/tools/android/tools.py +++ b/src/askui/tools/android/tools.py @@ -344,6 +344,7 @@ def __init__(self, agent_os_facade: AndroidAgentOsFacade) -> None: run system commands, check device status, or perform administrative tasks. The command will be executed in the Android shell environment with the current user's permissions. + it adds the adb shell prefix to the provided command. """ ), input_schema={ @@ -369,3 +370,142 @@ def __init__(self, agent_os_facade: AndroidAgentOsFacade) -> None: def __call__(self, command: str) -> str: output = self._agent_os_facade.shell(command) return f"Shell command executed. Output: {output}" + + +class AndroidGetConnectedDevicesSerialNumbersTool(Tool): + """ + Get the connected devices serial numbers. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_get_connected_devices_serial_numbers_tool", + description="Can be used to get all connected devices serial numbers.", + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self) -> str: + devices_sn = self._agent_os_facade.get_connected_devices_serial_numbers() + return f"Connected devices serial numbers: [{', '.join(devices_sn)}]" + + +class AndroidGetConnectedDisplaysInfosTool(Tool): + """ + Get the connected displays infos for the current connected device. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_get_connected_device_display_infos_tool", + description="Can be used to get all connected displays infos for the " + "current selected device.", + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self) -> str: + displays = self._agent_os_facade.get_connected_displays() + display_infos = [str(display) for display in displays] + return f"Connected displays infos: [{', '.join(display_infos)}]" + + +class AndroidGetCurrentConnectedDeviceInfosTool(Tool): + """ + Get the current selected device infos. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_get_current_connected_device_infos_tool", + description=""" + Can be used to get the current selected device and selected display infos. + """, + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self) -> str: + device_serial_number, selected_display = ( + self._agent_os_facade.get_selected_device_infos() + ) + return ( + f"The device with the serial number {device_serial_number} is selected" + f" and its selected display is {str(selected_display)}." + ) + + +class AndroidSelectDeviceBySerialNumberTool(Tool): + """ + Select a device by its serial number. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_select_device_by_serial_number_tool", + description="Can be used to select a device by its serial number.", + input_schema={ + "type": "object", + "properties": { + "device_sn": { + "type": "string", + "description": "The serial number of the device to select.", + }, + }, + "required": ["device_sn"], + }, + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self, device_sn: str) -> str: + self._agent_os_facade.set_device_by_serial_number(device_sn) + return f"Device with the serial number {device_sn} was selected." + + +class AndroidSelectDisplayByIndex(Tool): + """ + Select a display by its index. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_select_display_by_index_tool", + description="Can be used to select a display by its index.", + input_schema={ + "type": "object", + "properties": { + "display_index": { + "type": "integer", + "description": "The index of the display to select.", + }, + }, + "required": ["display_index"], + }, + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self, display_index: int) -> str: + self._agent_os_facade.set_display_by_index(display_index) + return f"Display with the index {display_index} was selected." + + +class AndroidConnectTool(Tool): + """ + Connect to the Android device. + """ + + def __init__(self, agent_os_facade: AndroidAgentOsFacade): + super().__init__( + name="android_connect_tool", + description="""Can be used to connect the adb client to the server. + Needs to select a device after connecting the adb client. + """, + ) + self._agent_os_facade = agent_os_facade + + @override + def __call__(self) -> str: + self._agent_os_facade.connect_adb_client() + return "adb client is connected to the server."