diff --git a/pyproject.toml b/pyproject.toml index ac782a1aa..db0060405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ sdk = [ "jinja2>=3.1.0", "mcp>=1.15.0", "openai>=1.98", + "litellm>=1.55.0,<1.85", "prometheus-client>=0.16.0", "structlog>=23.0.0", ] diff --git a/sdk/src/openagents/config/llm_configs.py b/sdk/src/openagents/config/llm_configs.py index 429d8d91e..27c26f747 100644 --- a/sdk/src/openagents/config/llm_configs.py +++ b/sdk/src/openagents/config/llm_configs.py @@ -28,6 +28,7 @@ class LLMProviderType(str, Enum): GROQ = "groq" OPENROUTER = "openrouter" MINIMAX = "minimax" + LITELLM = "litellm" CUSTOM = "custom" # Custom OpenAI-compatible endpoint OPENAI_COMPATIBLE = "openai-compatible" # Alias for custom @@ -195,6 +196,12 @@ class LLMProviderType(str, Enum): ], "API_KEY_ENV_VAR": "MINIMAX_API_KEY", }, + # LiteLLM (unified SDK for 100+ providers) + "litellm": { + "provider": "litellm", + "models": [], # User specifies model with provider prefix (e.g., "anthropic/claude-sonnet-4-5") + "API_KEY_ENV_VAR": None, # LiteLLM reads provider-specific env vars automatically + }, # Custom OpenAI-compatible endpoint (e.g., Ollama, vLLM, local models) "custom": { "provider": "generic", @@ -439,6 +446,7 @@ def create_model_provider( AnthropicProvider, BedrockProvider, GeminiProvider, + LiteLLMProvider, MiniMaxProvider, SimpleGenericProvider, ) @@ -464,6 +472,8 @@ def create_model_provider( return MiniMaxProvider( model_name=model_name, api_base=effective_api_base, api_key=api_key, **kwargs ) + elif provider == "litellm": + return LiteLLMProvider(model_name=model_name, **kwargs) elif provider == "custom" or provider == "openai-compatible": # Custom OpenAI-compatible endpoint requires api_base if not api_base: diff --git a/sdk/src/openagents/lms/__init__.py b/sdk/src/openagents/lms/__init__.py index 68530e454..3b7c5e987 100644 --- a/sdk/src/openagents/lms/__init__.py +++ b/sdk/src/openagents/lms/__init__.py @@ -10,6 +10,7 @@ AnthropicProvider, BedrockProvider, GeminiProvider, + LiteLLMProvider, MiniMaxProvider, SimpleGenericProvider, ) @@ -30,6 +31,7 @@ "AnthropicProvider", "BedrockProvider", "GeminiProvider", + "LiteLLMProvider", "MiniMaxProvider", "SimpleGenericProvider", # LLM logging diff --git a/sdk/src/openagents/lms/providers.py b/sdk/src/openagents/lms/providers.py index 9bff301ff..850ed42b6 100644 --- a/sdk/src/openagents/lms/providers.py +++ b/sdk/src/openagents/lms/providers.py @@ -6,6 +6,8 @@ from typing import Dict, List, Any, Optional from abc import ABC, abstractmethod +import litellm + logger = logging.getLogger(__name__) @@ -658,6 +660,68 @@ def format_tools(self, tools: List[Any]) -> List[Dict[str, Any]]: return [tool.to_openai_function() for tool in tools] +class LiteLLMProvider(BaseModelProvider): + """LiteLLM provider supporting 100+ LLM providers through a unified interface. + + LiteLLM routes requests to the correct provider based on the model string. + For example, ``anthropic/claude-sonnet-4-5`` routes to Anthropic, + ``bedrock/anthropic.claude-v2`` routes to AWS Bedrock, and + ``vertex_ai/gemini-pro`` routes to Google Vertex AI. + + Authentication is handled via provider-specific environment variables + (e.g. ``ANTHROPIC_API_KEY``, ``OPENAI_API_KEY``, ``AWS_ACCESS_KEY_ID``). + LiteLLM reads these automatically based on the model prefix. + + For a full list of supported providers, see: https://docs.litellm.ai/docs/providers + """ + + def __init__(self, model_name: str, **kwargs): + self.model_name = model_name + + async def chat_completion( + self, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Generate chat completion using LiteLLM SDK.""" + + kwargs: Dict[str, Any] = {"model": self.model_name, "messages": messages} + + if tools: + kwargs["tools"] = [{"type": "function", "function": tool} for tool in tools] + kwargs["tool_choice"] = "auto" + + response = await litellm.acompletion(**kwargs) + + # Standardize response format + message = response.choices[0].message + result: Dict[str, Any] = {"content": message.content, "tool_calls": []} + + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + result["tool_calls"].append( + { + "id": tool_call.id, + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) + + # Extract token usage + if hasattr(response, "usage") and response.usage: + result["usage"] = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + + return result + + def format_tools(self, tools: List[Any]) -> List[Dict[str, Any]]: + """Format tools for LiteLLM (OpenAI-compatible format).""" + return [tool.to_openai_function() for tool in tools] + + class SimpleGenericProvider(BaseModelProvider): """Generic provider for OpenAI-compatible APIs (DeepSeek, Qwen, Grok, etc.)."""