diff --git a/README.md b/README.md index 13b1df3..e8fb631 100644 --- a/README.md +++ b/README.md @@ -178,8 +178,20 @@ brew install android-platform-tools # macOS # 4. Connect device & verify adb devices -# 5. Set API key -export OPENAI_API_KEY="sk-..." +# 5. Configure LLM Provider (choose one) +export LLM_PROVIDER="openai" # or anthropic, gemini, bedrock + +# Set appropriate API key +export OPENAI_API_KEY="sk-..." # for OpenAI +# export ANTHROPIC_API_KEY="sk-..." # for Anthropic +# export GOOGLE_API_KEY="..." # for Gemini +# export AWS_PROFILE="default" # for Bedrock + +# Optional: Override default model +# export OPENAI_MODEL="gpt-4o" +# export ANTHROPIC_MODEL="claude-sonnet-4" +# export GEMINI_MODEL="gemini-2.0-flash-exp" +# export BEDROCK_MODEL="anthropic.claude-sonnet-4-20250514-v1:0" # 6. Run your first agent python kernel.py @@ -390,7 +402,8 @@ screen_json = get_screen_state() ### Next 2 Weeks - [ ] **PyPI package:** `pip install android-use` -- [ ] **Multi-LLM support:** Claude, Gemini, Llama +- [x] **Multi-LLM support:** OpenAI, Claude, Gemini, Bedrock +- [ ] **Llama support:** Local model integration - [ ] **WhatsApp integration:** Pre-built actions for messaging - [ ] **Error recovery:** Retry logic, fallback strategies diff --git a/action_models.py b/action_models.py new file mode 100644 index 0000000..e8404b1 --- /dev/null +++ b/action_models.py @@ -0,0 +1,32 @@ +from typing import Literal, Union, List +from pydantic import BaseModel, Field, field_validator + +class TapAction(BaseModel): + action: Literal["tap"] = "tap" + coordinates: List[int] = Field(..., description="[x, y] coordinates to tap") + reason: str = Field(..., description="Why this tap is needed") + + @field_validator("coordinates") + @classmethod + def validate_coordinates(cls, v): + if len(v) != 2: + raise ValueError("coordinates must be [x, y]") + if not all(isinstance(coord, int) and coord >= 0 for coord in v): + raise ValueError("coordinates must be positive integers") + return v + +class TypeAction(BaseModel): + action: Literal["type"] = "type" + text: str = Field(..., description="Text to type") + reason: str = Field(..., description="Why this text is needed") + +class NavigationAction(BaseModel): + action: Literal["home", "back"] = Field(..., description="Navigation action") + reason: str = Field(..., description="Why this navigation is needed") + +class ControlAction(BaseModel): + action: Literal["wait", "done"] = Field(..., description="Control action") + reason: str = Field(..., description="Why this action is needed") + +# Union type for all possible actions +AndroidAction = Union[TapAction, TypeAction, NavigationAction, ControlAction] diff --git a/examples/anthropic_example.sh b/examples/anthropic_example.sh new file mode 100755 index 0000000..14dcf0a --- /dev/null +++ b/examples/anthropic_example.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Example: Using Anthropic Claude + +export LLM_PROVIDER="anthropic" +export ANTHROPIC_API_KEY="sk-..." # Replace with your key +export ANTHROPIC_MODEL="claude-sonnet-4" # Optional + +python kernel.py diff --git a/examples/bedrock_example.sh b/examples/bedrock_example.sh new file mode 100755 index 0000000..d91e9c5 --- /dev/null +++ b/examples/bedrock_example.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Example: Using AWS Bedrock with Claude + +export LLM_PROVIDER="bedrock" +export AWS_PROFILE="default" # Or use AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY +export BEDROCK_MODEL="anthropic.claude-sonnet-4-20250514-v1:0" # Optional + +python kernel.py diff --git a/examples/gemini_example.sh b/examples/gemini_example.sh new file mode 100755 index 0000000..2682875 --- /dev/null +++ b/examples/gemini_example.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Example: Using Google Gemini (cheapest option) + +export LLM_PROVIDER="gemini" +export GOOGLE_API_KEY="..." # Replace with your key +export GEMINI_MODEL="gemini-2.0-flash-exp" # Optional + +python kernel.py diff --git a/examples/openai_example.sh b/examples/openai_example.sh new file mode 100755 index 0000000..d67b214 --- /dev/null +++ b/examples/openai_example.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Example: Using OpenAI GPT-4o + +export LLM_PROVIDER="openai" +export OPENAI_API_KEY="sk-..." # Replace with your key +export OPENAI_MODEL="gpt-4o" # Optional: override default + +python kernel.py diff --git a/kernel.py b/kernel.py index f827897..57dfda5 100644 --- a/kernel.py +++ b/kernel.py @@ -2,17 +2,25 @@ import time import subprocess import json -from typing import Dict, Any -from openai import OpenAI +import asyncio +from typing import Dict, Any, List +from llm_manager import LLMManager +from action_models import TapAction, TypeAction, NavigationAction, ControlAction import sanitizer # --- CONFIGURATION --- -ADB_PATH = "adb" # Ensure adb is in your PATH -MODEL = "gpt-4o" # Or "gpt-4-turbo" for faster/cheaper execution +ADB_PATH = "adb" SCREEN_DUMP_PATH = "/sdcard/window_dump.xml" LOCAL_DUMP_PATH = "window_dump.xml" -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +# Initialize LLM manager +llm_manager = None + +def initialize_llm(): + """Initialize the LLM manager.""" + global llm_manager + if llm_manager is None: + llm_manager = LLMManager() def run_adb_command(command: List[str]): """Executes a shell command via ADB.""" @@ -39,93 +47,68 @@ def get_screen_state() -> str: elements = sanitizer.get_interactive_elements(xml_content) return json.dumps(elements, indent=2) -def execute_action(action: Dict[str, Any]): +def execute_action(action): """Executes the action decided by the LLM.""" - act_type = action.get("action") - - if act_type == "tap": - x, y = action.get("coordinates") + if isinstance(action, TapAction): + x, y = action.coordinates print(f"👉 Tapping: ({x}, {y})") run_adb_command(["shell", "input", "tap", str(x), str(y)]) - - elif act_type == "type": - text = action.get("text").replace(" ", "%s") # ADB requires %s for spaces - print(f"⌨️ Typing: {action.get('text')}") + + elif isinstance(action, TypeAction): + text = action.text.replace(" ", "%s") # ADB requires %s for spaces + print(f"⌨️ Typing: {action.text}") run_adb_command(["shell", "input", "text", text]) - - elif act_type == "home": - print("🏠 Going Home") - run_adb_command(["shell", "input", "keyevent", "KEYWORDS_HOME"]) - - elif act_type == "back": - print("🔙 Going Back") - run_adb_command(["shell", "input", "keyevent", "KEYWORDS_BACK"]) - - elif act_type == "wait": - print("⏳ Waiting...") - time.sleep(2) - - elif act_type == "done": - print("✅ Goal Achieved.") - exit(0) -def get_llm_decision(goal: str, screen_context: str) -> Dict[str, Any]: + elif isinstance(action, NavigationAction): + if action.action == "home": + print("🏠 Going Home") + run_adb_command(["shell", "input", "keyevent", "KEYCODE_HOME"]) + elif action.action == "back": + print("🔙 Going Back") + run_adb_command(["shell", "input", "keyevent", "KEYCODE_BACK"]) + + elif isinstance(action, ControlAction): + if action.action == "wait": + print("⏳ Waiting...") + time.sleep(2) + elif action.action == "done": + print("✅ Goal Achieved.") + exit(0) + +async def get_llm_decision(goal: str, screen_context: str): """Sends screen context to LLM and asks for the next move.""" - system_prompt = """ - You are an Android Driver Agent. Your job is to achieve the user's goal by navigating the UI. - - You will receive: - 1. The User's Goal. - 2. A list of interactive UI elements (JSON) with their (x,y) center coordinates. - - You must output ONLY a valid JSON object with your next action. - - Available Actions: - - {"action": "tap", "coordinates": [x, y], "reason": "Why you are tapping"} - - {"action": "type", "text": "Hello World", "reason": "Why you are typing"} - - {"action": "home", "reason": "Go to home screen"} - - {"action": "back", "reason": "Go back"} - - {"action": "wait", "reason": "Wait for loading"} - - {"action": "done", "reason": "Task complete"} - - Example Output: - {"action": "tap", "coordinates": [540, 1200], "reason": "Clicking the 'Connect' button"} - """ - - response = client.chat.completions.create( - model=MODEL, - response_format={"type": "json_object"}, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"GOAL: {goal}\n\nSCREEN_CONTEXT:\n{screen_context}"} - ] - ) - - return json.loads(response.choices[0].message.content) + global llm_manager + if llm_manager is None: + initialize_llm() + + action = await llm_manager.get_decision(goal, screen_context) + return action + +async def run_agent(goal: str, max_steps=10): + """Main agent loop.""" + initialize_llm() + print(f"🚀 Android Use Agent Started") + print(f"📡 Provider: {llm_manager.provider} | Model: {llm_manager.model}") + print(f"🎯 Goal: {goal}\n") -def run_agent(goal: str, max_steps=10): - print(f"🚀 Android Use Agent Started. Goal: {goal}") - for step in range(max_steps): print(f"\n--- Step {step + 1} ---") - + # 1. Perception print("👀 Scanning Screen...") screen_context = get_screen_state() - + # 2. Reasoning print("🧠 Thinking...") - decision = get_llm_decision(goal, screen_context) - print(f"💡 Decision: {decision.get('reason')}") - + decision = await get_llm_decision(goal, screen_context) + print(f"💡 Decision: {decision.reason}") + # 3. Action execute_action(decision) - + # Wait for UI to update time.sleep(2) if __name__ == "__main__": - # Example Goal: "Open settings and turn on Wi-Fi" - # Or your demo goal: "Find the 'Connect' button and tap it" GOAL = input("Enter your goal: ") - run_agent(GOAL) \ No newline at end of file + asyncio.run(run_agent(GOAL)) \ No newline at end of file diff --git a/llm_manager.py b/llm_manager.py new file mode 100644 index 0000000..dfcdfaf --- /dev/null +++ b/llm_manager.py @@ -0,0 +1,115 @@ +import os +from typing import Optional +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.models.google import GoogleModel +from pydantic_ai.models.bedrock import BedrockConverseModel +from action_models import AndroidAction + +class LLMManager: + """Manages LLM provider initialization and agent creation.""" + + DEFAULT_MODELS = { + "openai": "gpt-4o", + "anthropic": "claude-sonnet-4", + "gemini": "gemini-2.0-flash-exp", + "bedrock": "anthropic.claude-sonnet-4-20250514-v1:0" + } + + def __init__(self): + self.provider = self._get_provider() + self.model = self._get_model() + self.agent = self._create_agent() + + def _get_provider(self) -> str: + """Get provider from environment.""" + provider = os.environ.get("LLM_PROVIDER") + if not provider: + raise ValueError( + "LLM_PROVIDER environment variable must be set. " + "Valid values: openai, anthropic, gemini, bedrock" + ) + if provider not in self.DEFAULT_MODELS: + raise ValueError( + f"Invalid LLM_PROVIDER '{provider}'. " + f"Valid values: {', '.join(self.DEFAULT_MODELS.keys())}" + ) + return provider + + def _get_model(self) -> str: + """Get model name from environment or use default.""" + env_var = f"{self.provider.upper()}_MODEL" + model = os.environ.get(env_var) + if not model: + model = self.DEFAULT_MODELS[self.provider] + return model + + def _validate_credentials(self): + """Validate that required credentials are present.""" + if self.provider == "openai": + if not os.environ.get("OPENAI_API_KEY"): + raise ValueError("OPENAI_API_KEY environment variable must be set") + elif self.provider == "anthropic": + if not os.environ.get("ANTHROPIC_API_KEY"): + raise ValueError("ANTHROPIC_API_KEY environment variable must be set") + elif self.provider == "gemini": + if not os.environ.get("GOOGLE_API_KEY"): + raise ValueError("GOOGLE_API_KEY environment variable must be set") + elif self.provider == "bedrock": + # Bedrock uses AWS credentials - boto3 will handle validation + pass + + def _create_agent(self) -> Agent: + """Create Pydantic AI agent with appropriate model.""" + self._validate_credentials() + + if self.provider == "openai": + model = OpenAIChatModel(self.model) + elif self.provider == "anthropic": + model = AnthropicModel(self.model) + elif self.provider == "gemini": + model = GoogleModel(self.model) + elif self.provider == "bedrock": + model = BedrockConverseModel(self.model) + + # Create agent with structured output + agent = Agent( + model=model, + output_type=AndroidAction, + system_prompt=self._get_system_prompt() + ) + + return agent + + def _get_system_prompt(self) -> str: + """Get the system prompt for the Android agent.""" + return """You are an Android Driver Agent. Your job is to achieve the user's goal by navigating the UI. + +You will receive: +1. The User's Goal. +2. A list of interactive UI elements (JSON) with their (x,y) center coordinates. + +You must decide the next action to take. + +Available Actions: +- tap: Tap at specific coordinates +- type: Type text into a field +- home: Go to home screen +- back: Go back to previous screen +- wait: Wait for loading or animation +- done: Task is complete + +Always provide a clear reason for your action.""" + + async def get_decision(self, goal: str, screen_context: str) -> AndroidAction: + """Get LLM decision for next action.""" + prompt = f"""GOAL: {goal} + +SCREEN_CONTEXT: +{screen_context} + +What action should I take next?""" + + result = await self.agent.run(prompt) + return result.data diff --git a/requirements.txt b/requirements.txt index 06018fe..7733d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ -openai>=1.12.0 \ No newline at end of file +pydantic-ai-slim[openai,anthropic,google,bedrock] + +# Dev dependencies +pytest>=8.0.0 +pytest-asyncio>=0.24.0 diff --git a/tests/test_action_models.py b/tests/test_action_models.py new file mode 100644 index 0000000..99c6f33 --- /dev/null +++ b/tests/test_action_models.py @@ -0,0 +1,25 @@ +import pytest +from action_models import AndroidAction, TapAction, TypeAction, NavigationAction, ControlAction + +def test_tap_action_valid(): + action = TapAction(coordinates=[100, 200], reason="tap button") + assert action.action == "tap" + assert action.coordinates == [100, 200] + assert action.reason == "tap button" + +def test_tap_action_invalid_coordinates(): + with pytest.raises(ValueError): + TapAction(coordinates=[100], reason="invalid") + +def test_type_action_valid(): + action = TypeAction(text="Hello", reason="enter text") + assert action.action == "type" + assert action.text == "Hello" + +def test_navigation_action_home(): + action = NavigationAction(action="home", reason="go home") + assert action.action == "home" + +def test_control_action_done(): + action = ControlAction(action="done", reason="complete") + assert action.action == "done" diff --git a/tests/test_kernel_integration.py b/tests/test_kernel_integration.py new file mode 100644 index 0000000..86fc05f --- /dev/null +++ b/tests/test_kernel_integration.py @@ -0,0 +1,22 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from kernel import get_llm_decision +from action_models import TapAction + +@pytest.mark.asyncio +async def test_get_llm_decision_returns_action(): + with patch.dict('os.environ', { + "LLM_PROVIDER": "openai", + "OPENAI_API_KEY": "test-key" + }): + with patch('kernel.LLMManager') as mock_manager_class: + mock_manager = MagicMock() + mock_manager.get_decision = AsyncMock( + return_value=TapAction(coordinates=[100, 200], reason="test") + ) + mock_manager_class.return_value = mock_manager + + action = await get_llm_decision("test goal", "test context") + + assert isinstance(action, TapAction) + assert action.coordinates == [100, 200] diff --git a/tests/test_llm_manager.py b/tests/test_llm_manager.py new file mode 100644 index 0000000..389c952 --- /dev/null +++ b/tests/test_llm_manager.py @@ -0,0 +1,42 @@ +import pytest +import os +from unittest.mock import patch +from llm_manager import LLMManager + +def test_init_openai_provider(): + with patch.dict(os.environ, { + "LLM_PROVIDER": "openai", + "OPENAI_API_KEY": "test-key", + "OPENAI_MODEL": "gpt-4o" + }): + manager = LLMManager() + assert manager.provider == "openai" + assert manager.model == "gpt-4o" + +def test_init_anthropic_provider(): + with patch.dict(os.environ, { + "LLM_PROVIDER": "anthropic", + "ANTHROPIC_API_KEY": "test-key", + "ANTHROPIC_MODEL": "claude-sonnet-4" + }): + manager = LLMManager() + assert manager.provider == "anthropic" + assert manager.model == "claude-sonnet-4" + +def test_missing_provider_raises_error(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="LLM_PROVIDER"): + LLMManager() + +def test_missing_api_key_raises_error(): + with patch.dict(os.environ, {"LLM_PROVIDER": "openai"}, clear=True): + with pytest.raises(ValueError, match="OPENAI_API_KEY"): + LLMManager() + +def test_default_models(): + with patch.dict(os.environ, { + "LLM_PROVIDER": "openai", + "OPENAI_API_KEY": "test-key" + }, clear=True): + manager = LLMManager() + assert manager.model == "gpt-4o" # default