From ebc4c973bb00eeda8835c2a3142b70ce650d5170 Mon Sep 17 00:00:00 2001 From: ttlequals0 Date: Fri, 30 Jan 2026 19:52:37 -0500 Subject: [PATCH] feat: add JSON response format support and dynamic model fetching - Add response_format parameter for OpenAI-compatible JSON mode - Add ModelService for dynamic model fetching from Anthropic API - Add claude-opus-4-5-20251101 model to supported models - Add JSON extraction and enforcement methods to MessageAdapter - Update docker-compose.yml to use published image - Bump version to 2.3.0 --- docker-compose.yml | 26 ++- pyproject.toml | 2 +- src/__init__.py | 2 +- src/constants.py | 3 +- src/main.py | 125 ++++++++++--- src/message_adapter.py | 121 ++++++++++++ src/model_service.py | 141 ++++++++++++++ src/models.py | 13 ++ src/parameter_validator.py | 27 ++- tests/test_json_format_unit.py | 305 +++++++++++++++++++++++++++++++ tests/test_model_service_unit.py | 255 ++++++++++++++++++++++++++ 11 files changed, 987 insertions(+), 33 deletions(-) create mode 100644 src/model_service.py create mode 100644 tests/test_json_format_unit.py create mode 100644 tests/test_model_service_unit.py diff --git a/docker-compose.yml b/docker-compose.yml index 6d0d141..95d993d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,16 +1,34 @@ -version: '3' +version: '3.8' services: claude-wrapper: - build: . + image: ttlequals0/claude-code-openai-wrapper:latest + container_name: claude-wrapper ports: - "8000:8000" volumes: + # Mount Claude CLI credentials - ~/.claude:/root/.claude # Optional: Mount a specific workspace directory - # Uncomment and modify the line below to use a custom workspace # - ./workspace:/workspace environment: - PORT=8000 + - MAX_TIMEOUT=600000 + # Authentication (choose one method): + # Option 1: Direct API key (recommended) + # - ANTHROPIC_API_KEY=your-api-key + # Option 2: Explicit auth method selection + # - CLAUDE_AUTH_METHOD=cli # Options: cli, api_key, bedrock, vertex # Optional: Set Claude's working directory (defaults to isolated temp dir) - # Uncomment and modify the line below to set a custom working directory # - CLAUDE_CWD=/workspace + # Optional: Enable debug logging + # - DEBUG_MODE=true + # Optional: Rate limiting configuration + # - RATE_LIMIT_ENABLED=true + # - RATE_LIMIT_CHAT_PER_MINUTE=10 + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s diff --git a/pyproject.toml b/pyproject.toml index e0cc381..dcc6fe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "claude-code-openai-wrapper" -version = "2.2.0" +version = "2.3.0" description = "OpenAI API-compatible wrapper for Claude Code" authors = ["Richard Atkinson "] readme = "README.md" diff --git a/src/__init__.py b/src/__init__.py index ca47b3b..4642a13 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +1,3 @@ """Claude Code OpenAI Wrapper - A FastAPI-based OpenAI-compatible API for Claude Code.""" -__version__ = "2.2.0" +__version__ = "2.3.0" diff --git a/src/constants.py b/src/constants.py index 5fb452b..5eb4149 100644 --- a/src/constants.py +++ b/src/constants.py @@ -70,7 +70,8 @@ async def chat_endpoint(): ... # NOTE: Claude Agent SDK only supports Claude 4+ models, not Claude 3.x CLAUDE_MODELS = [ # Claude 4.5 Family (Latest - Fall 2025) - RECOMMENDED - "claude-opus-4-5-20250929", # Latest Opus 4.5 - Most capable + "claude-opus-4-5-20251101", # Latest Opus 4.5 - Most capable (November 2025) + "claude-opus-4-5-20250929", # Opus 4.5 - September version "claude-sonnet-4-5-20250929", # Recommended - best coding model "claude-haiku-4-5-20251001", # Fast & cheap # Claude 4.1 diff --git a/src/main.py b/src/main.py index 4a74aa4..eb1b286 100644 --- a/src/main.py +++ b/src/main.py @@ -52,6 +52,7 @@ rate_limit_endpoint, ) from src.constants import CLAUDE_MODELS, CLAUDE_TOOLS, DEFAULT_ALLOWED_TOOLS +from src.model_service import model_service # Load environment variables load_dotenv() @@ -133,6 +134,9 @@ async def lifespan(app: FastAPI): """Verify Claude Code authentication and CLI on startup.""" logger.info("Verifying Claude Code authentication and CLI...") + # Initialize model service (fetch models from API or use fallback) + await model_service.initialize() + # Validate authentication first auth_valid, auth_info = validate_claude_code_auth() @@ -197,6 +201,9 @@ async def lifespan(app: FastAPI): logger.info("Shutting down session manager...") session_manager.shutdown() + # Shutdown model service + await model_service.shutdown() + # Create FastAPI app app = FastAPI( @@ -410,6 +417,16 @@ async def generate_streaming_response( system_prompt = sampling_instructions logger.debug(f"Added sampling instructions: {sampling_instructions}") + # Check for JSON mode + json_mode = request.response_format and request.response_format.type == "json_object" + if json_mode: + # Prepend JSON instruction to system prompt + if system_prompt: + system_prompt = f"{MessageAdapter.JSON_MODE_INSTRUCTION}\n\n{system_prompt}" + else: + system_prompt = MessageAdapter.JSON_MODE_INSTRUCTION + logger.info("JSON mode enabled (streaming) - response will be accumulated and formatted") + # Filter content for unsupported features prompt = MessageAdapter.filter_content(prompt) if system_prompt: @@ -443,6 +460,7 @@ async def generate_streaming_response( chunks_buffer = [] role_sent = False # Track if we've sent the initial role chunk content_sent = False # Track if we've sent any content + json_mode_buffer = [] # Buffer for JSON mode - accumulate all content async for chunk in claude_cli.run_completion( prompt=prompt, @@ -501,15 +519,42 @@ async def generate_streaming_response( filtered_text = MessageAdapter.filter_content(raw_text) if filtered_text and not filtered_text.isspace(): + if json_mode: + # In JSON mode, buffer content for later processing + json_mode_buffer.append(filtered_text) + else: + # Create streaming chunk + stream_chunk = ChatCompletionStreamResponse( + id=request_id, + model=request.model, + choices=[ + StreamChoice( + index=0, + delta={"content": filtered_text}, + finish_reason=None, + ) + ], + ) + + yield f"data: {stream_chunk.model_dump_json()}\n\n" + content_sent = True + + elif isinstance(content, str): + # Filter out tool usage and thinking blocks + filtered_content = MessageAdapter.filter_content(content) + + if filtered_content and not filtered_content.isspace(): + if json_mode: + # In JSON mode, buffer content for later processing + json_mode_buffer.append(filtered_content) + else: # Create streaming chunk stream_chunk = ChatCompletionStreamResponse( id=request_id, model=request.model, choices=[ StreamChoice( - index=0, - delta={"content": filtered_text}, - finish_reason=None, + index=0, delta={"content": filtered_content}, finish_reason=None ) ], ) @@ -517,24 +562,38 @@ async def generate_streaming_response( yield f"data: {stream_chunk.model_dump_json()}\n\n" content_sent = True - elif isinstance(content, str): - # Filter out tool usage and thinking blocks - filtered_content = MessageAdapter.filter_content(content) - - if filtered_content and not filtered_content.isspace(): - # Create streaming chunk - stream_chunk = ChatCompletionStreamResponse( - id=request_id, - model=request.model, - choices=[ - StreamChoice( - index=0, delta={"content": filtered_content}, finish_reason=None - ) - ], + # Handle JSON mode: emit accumulated content as single JSON-formatted chunk + if json_mode and json_mode_buffer: + # Send role chunk first if not sent + if not role_sent: + initial_chunk = ChatCompletionStreamResponse( + id=request_id, + model=request.model, + choices=[ + StreamChoice( + index=0, delta={"role": "assistant", "content": ""}, finish_reason=None ) + ], + ) + yield f"data: {initial_chunk.model_dump_json()}\n\n" + role_sent = True - yield f"data: {stream_chunk.model_dump_json()}\n\n" - content_sent = True + # Combine buffered content and enforce JSON format + combined_content = "".join(json_mode_buffer) + json_content = MessageAdapter.enforce_json_format(combined_content, strict=True) + + # Emit as single chunk + json_chunk = ChatCompletionStreamResponse( + id=request_id, + model=request.model, + choices=[ + StreamChoice( + index=0, delta={"content": json_content}, finish_reason=None + ) + ], + ) + yield f"data: {json_chunk.model_dump_json()}\n\n" + content_sent = True # Handle case where no role was sent (send at least role chunk) if not role_sent: @@ -553,13 +612,16 @@ async def generate_streaming_response( # If we sent role but no content, send a minimal response if role_sent and not content_sent: + fallback_content = ( + "[]" if json_mode else "I'm unable to provide a response at the moment." + ) fallback_chunk = ChatCompletionStreamResponse( id=request_id, model=request.model, choices=[ StreamChoice( index=0, - delta={"content": "I'm unable to provide a response at the moment."}, + delta={"content": fallback_content}, finish_reason=None, ) ], @@ -672,6 +734,19 @@ async def chat_completions( system_prompt = sampling_instructions logger.debug(f"Added sampling instructions: {sampling_instructions}") + # Check for JSON mode + json_mode = ( + request_body.response_format + and request_body.response_format.type == "json_object" + ) + if json_mode: + # Prepend JSON instruction to system prompt + if system_prompt: + system_prompt = f"{MessageAdapter.JSON_MODE_INSTRUCTION}\n\n{system_prompt}" + else: + system_prompt = MessageAdapter.JSON_MODE_INSTRUCTION + logger.info("JSON mode enabled - response will be enforced as valid JSON") + # Filter content prompt = MessageAdapter.filter_content(prompt) if system_prompt: @@ -724,6 +799,12 @@ async def chat_completions( # Filter out tool usage and thinking blocks assistant_content = MessageAdapter.filter_content(raw_assistant_content) + # Enforce JSON format if JSON mode is enabled + if json_mode: + assistant_content = MessageAdapter.enforce_json_format( + assistant_content, strict=True + ) + # Add assistant response to session if using session mode if actual_session_id: assistant_message = Message(role="assistant", content=assistant_content) @@ -864,12 +945,12 @@ async def list_models( # Check FastAPI API key if configured await verify_api_key(request, credentials) - # Use constants for single source of truth + # Use dynamic models from model_service (fetched from API or fallback to constants) return { "object": "list", "data": [ {"id": model_id, "object": "model", "owned_by": "anthropic"} - for model_id in CLAUDE_MODELS + for model_id in model_service.get_models() ], } diff --git a/src/message_adapter.py b/src/message_adapter.py index 1c9d732..3f26661 100644 --- a/src/message_adapter.py +++ b/src/message_adapter.py @@ -1,11 +1,132 @@ from typing import List, Optional, Dict, Any from src.models import Message import re +import json class MessageAdapter: """Converts between OpenAI message format and Claude Code prompts.""" + # Instruction to prepend to system prompt for JSON mode + JSON_MODE_INSTRUCTION = ( + "CRITICAL: Respond with ONLY valid JSON. " + "No explanations, no markdown, no code blocks. " + "Start with [ or { and end with ] or }." + ) + + @staticmethod + def extract_json(content: str) -> Optional[str]: + """ + Extract JSON from content. + + Handles: + 1. Pure JSON (content is already valid JSON) + 2. Markdown code blocks (```json ... ```) + 3. Embedded JSON (JSON within other text) + + Args: + content: The content to extract JSON from + + Returns: + Extracted JSON string, or None if no valid JSON found + """ + if not content: + return None + + content = content.strip() + + # Case 1: Try parsing as pure JSON first + try: + json.loads(content) + return content + except json.JSONDecodeError: + pass + + # Case 2: Extract from markdown code blocks + # Match ```json ... ``` or ``` ... ``` + code_block_patterns = [ + r"```json\s*([\s\S]*?)\s*```", # ```json block + r"```\s*([\s\S]*?)\s*```", # generic ``` block + ] + + for pattern in code_block_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + for match in matches: + match = match.strip() + try: + json.loads(match) + return match + except json.JSONDecodeError: + continue + + # Case 3: Find embedded JSON (objects or arrays) + # Look for JSON objects {...} + object_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}" + for match in re.finditer(object_pattern, content): + candidate = match.group() + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + continue + + # Look for JSON arrays [...] + array_pattern = r"\[[^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*\]" + for match in re.finditer(array_pattern, content): + candidate = match.group() + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + continue + + # Try more aggressive nested JSON extraction for complex objects + # Find the first { and match to the last } + first_brace = content.find("{") + last_brace = content.rfind("}") + if first_brace != -1 and last_brace > first_brace: + candidate = content[first_brace : last_brace + 1] + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + pass + + # Try for arrays + first_bracket = content.find("[") + last_bracket = content.rfind("]") + if first_bracket != -1 and last_bracket > first_bracket: + candidate = content[first_bracket : last_bracket + 1] + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + pass + + return None + + @staticmethod + def enforce_json_format(content: str, strict: bool = False) -> str: + """ + Enforce JSON format on content. + + Args: + content: The content to enforce JSON format on + strict: If True, return '[]' on failure. If False, return original content. + + Returns: + Valid JSON string, or fallback value based on strict mode + """ + extracted = MessageAdapter.extract_json(content) + + if extracted: + return extracted + + if strict: + return "[]" + + return content + @staticmethod def messages_to_prompt(messages: List[Message]) -> tuple[str, Optional[str]]: """ diff --git a/src/model_service.py b/src/model_service.py new file mode 100644 index 0000000..7254937 --- /dev/null +++ b/src/model_service.py @@ -0,0 +1,141 @@ +""" +Model service for dynamically fetching available models from Anthropic API. + +This service provides: +- Dynamic model discovery from Anthropic API on startup +- Graceful fallback to static CLAUDE_MODELS when API is unavailable +- Caching of fetched models for the session lifetime +""" + +import os +import logging +from typing import List, Optional + +import httpx + +from src.constants import CLAUDE_MODELS + +logger = logging.getLogger(__name__) + +# Anthropic API configuration +ANTHROPIC_API_BASE = "https://api.anthropic.com" +ANTHROPIC_API_VERSION = "2023-06-01" +MODEL_FETCH_TIMEOUT = 10.0 # seconds + + +class ModelService: + """Fetches models from Anthropic API with fallback to constants.""" + + def __init__(self): + self._cached_models: Optional[List[str]] = None + self._http_client: Optional[httpx.AsyncClient] = None + self._initialized: bool = False + + async def initialize(self) -> None: + """Called during app startup - fetch models from API.""" + if self._initialized: + return + + self._http_client = httpx.AsyncClient(timeout=MODEL_FETCH_TIMEOUT) + + # Attempt to fetch models from API + fetched_models = await self.fetch_models_from_api() + + if fetched_models: + self._cached_models = fetched_models + logger.info(f"Successfully fetched {len(fetched_models)} models from Anthropic API") + else: + self._cached_models = None + logger.info("Using fallback static model list from constants") + + self._initialized = True + + async def shutdown(self) -> None: + """Close HTTP client on app shutdown.""" + if self._http_client: + await self._http_client.aclose() + self._http_client = None + self._cached_models = None + self._initialized = False + + async def fetch_models_from_api(self) -> Optional[List[str]]: + """ + Fetch models from Anthropic API. + + GET https://api.anthropic.com/v1/models + Headers: + - x-api-key: {ANTHROPIC_API_KEY} + - anthropic-version: 2023-06-01 + + Returns list of model IDs on success, None on failure. + """ + api_key = os.getenv("ANTHROPIC_API_KEY") + + if not api_key: + logger.debug("ANTHROPIC_API_KEY not set, skipping API model fetch") + return None + + if not self._http_client: + self._http_client = httpx.AsyncClient(timeout=MODEL_FETCH_TIMEOUT) + + try: + response = await self._http_client.get( + f"{ANTHROPIC_API_BASE}/v1/models", + headers={ + "x-api-key": api_key, + "anthropic-version": ANTHROPIC_API_VERSION, + }, + ) + + if response.status_code == 200: + data = response.json() + # Extract model IDs from the response + # API returns {"data": [{"id": "claude-...", ...}, ...]} + models = [] + for model_data in data.get("data", []): + model_id = model_data.get("id") + if model_id: + models.append(model_id) + + if models: + logger.debug(f"Fetched models from API: {models}") + return models + else: + logger.warning("API returned empty model list") + return None + + elif response.status_code == 401: + logger.warning("Anthropic API authentication failed (401). Check ANTHROPIC_API_KEY.") + return None + elif response.status_code == 429: + logger.warning("Anthropic API rate limited (429). Using fallback models.") + return None + else: + logger.warning( + f"Anthropic API returned status {response.status_code}. Using fallback models." + ) + return None + + except httpx.TimeoutException: + logger.warning(f"Anthropic API request timed out after {MODEL_FETCH_TIMEOUT}s") + return None + except httpx.RequestError as e: + logger.warning(f"Network error fetching models from Anthropic API: {e}") + return None + except Exception as e: + logger.warning(f"Unexpected error fetching models: {e}") + return None + + def get_models(self) -> List[str]: + """Return cached models or CLAUDE_MODELS fallback.""" + if self._cached_models: + return self._cached_models + return list(CLAUDE_MODELS) + + def is_initialized(self) -> bool: + """Check if service has been initialized.""" + return self._initialized + + +# Global singleton instance +model_service = ModelService() diff --git a/src/models.py b/src/models.py index 82e85f4..b513f2e 100644 --- a/src/models.py +++ b/src/models.py @@ -53,6 +53,15 @@ class StreamOptions(BaseModel): ) +class ResponseFormat(BaseModel): + """OpenAI-compatible response format specification.""" + + type: Literal["text", "json_object"] = Field( + default="text", + description="Response format type - 'text' for regular text, 'json_object' for JSON mode", + ) + + class ChatCompletionRequest(BaseModel): model: str = Field(default_factory=get_default_model) messages: List[Message] @@ -79,6 +88,10 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = Field( default=None, description="Options for streaming responses" ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description="Response format - use {'type': 'json_object'} for JSON mode", + ) @field_validator("n") @classmethod diff --git a/src/parameter_validator.py b/src/parameter_validator.py index e45452f..2bf1b70 100644 --- a/src/parameter_validator.py +++ b/src/parameter_validator.py @@ -3,17 +3,33 @@ """ import logging -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Set from src.models import ChatCompletionRequest from src.constants import CLAUDE_MODELS logger = logging.getLogger(__name__) +def get_supported_models() -> Set[str]: + """Get supported models from model_service or fallback to constants.""" + try: + from src.model_service import model_service + + return set(model_service.get_models()) + except ImportError: + return set(CLAUDE_MODELS) + + class ParameterValidator: """Validates and maps OpenAI Chat Completions parameters to Claude Code SDK options.""" - # Use models from constants (single source of truth) + @classmethod + def get_supported_models(cls) -> Set[str]: + """Get currently supported models (dynamic or fallback).""" + return get_supported_models() + + # Legacy class attribute for backwards compatibility + # Use get_supported_models() method for dynamic models SUPPORTED_MODELS = set(CLAUDE_MODELS) # Valid permission modes for Claude Code SDK @@ -22,9 +38,10 @@ class ParameterValidator: @classmethod def validate_model(cls, model: str) -> bool: """Validate that the model is supported by Claude Code SDK.""" - if model not in cls.SUPPORTED_MODELS: + supported = cls.get_supported_models() + if model not in supported: logger.warning( - f"Model '{model}' is not in the known supported models list. It will still be attempted but may fail. Supported models: {sorted(cls.SUPPORTED_MODELS)}" + f"Model '{model}' is not in the known supported models list. It will still be attempted but may fail. Supported models: {sorted(supported)}" ) # Return True anyway to allow graceful degradation return True @@ -164,6 +181,8 @@ def generate_compatibility_report(cls, request: ChatCompletionRequest) -> Dict[s report["supported_parameters"].append("stream") if request.user: report["supported_parameters"].append("user (for logging)") + if request.response_format: + report["supported_parameters"].append("response_format") # Check unsupported parameters with suggestions if request.temperature != 1.0: diff --git a/tests/test_json_format_unit.py b/tests/test_json_format_unit.py new file mode 100644 index 0000000..102db4d --- /dev/null +++ b/tests/test_json_format_unit.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +""" +Unit tests for JSON format functionality. + +Tests the JSON extraction and enforcement methods in MessageAdapter, +as well as the ResponseFormat model. +""" + +import pytest + +from src.message_adapter import MessageAdapter +from src.models import ResponseFormat, ChatCompletionRequest, Message + + +class TestExtractJson: + """Test MessageAdapter.extract_json() method.""" + + def test_extract_json_pure(self): + """Pure JSON content is returned as-is.""" + content = '{"name": "test", "value": 123}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_extract_json_pure_array(self): + """Pure JSON array is returned as-is.""" + content = '[1, 2, 3, 4, 5]' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_extract_json_pure_with_whitespace(self): + """Pure JSON with surrounding whitespace is extracted.""" + content = ' \n{"key": "value"}\n ' + result = MessageAdapter.extract_json(content) + assert result == '{"key": "value"}' + + def test_extract_json_markdown_block(self): + """Extracts JSON from ```json code block.""" + content = '''Here is the data: +```json +{"items": [1, 2, 3]} +``` +That's all!''' + result = MessageAdapter.extract_json(content) + assert result == '{"items": [1, 2, 3]}' + + def test_extract_json_generic_code_block(self): + """Extracts JSON from generic ``` code block.""" + content = '''Response: +``` +{"status": "ok"} +```''' + result = MessageAdapter.extract_json(content) + assert result == '{"status": "ok"}' + + def test_extract_json_embedded_object(self): + """Finds JSON object embedded in text.""" + content = 'The result is {"success": true, "count": 42} as expected.' + result = MessageAdapter.extract_json(content) + assert result == '{"success": true, "count": 42}' + + def test_extract_json_embedded_array(self): + """Finds JSON array embedded in text.""" + content = 'Available items: [1, 2, 3] are ready.' + result = MessageAdapter.extract_json(content) + assert result == '[1, 2, 3]' + + def test_extract_json_nested_object(self): + """Extracts nested JSON objects.""" + content = '''Result: {"outer": {"inner": {"deep": "value"}}}''' + result = MessageAdapter.extract_json(content) + assert result is not None + assert '"deep": "value"' in result + + def test_extract_json_complex_array(self): + """Extracts complex JSON arrays.""" + content = '''Data: [{"id": 1}, {"id": 2}]''' + result = MessageAdapter.extract_json(content) + assert result is not None + assert '"id": 1' in result + + def test_extract_json_no_json(self): + """Returns None when no valid JSON found.""" + content = 'This is just plain text with no JSON.' + result = MessageAdapter.extract_json(content) + assert result is None + + def test_extract_json_invalid_json(self): + """Returns None for malformed JSON.""" + content = '{"broken: json' + result = MessageAdapter.extract_json(content) + assert result is None + + def test_extract_json_empty_string(self): + """Returns None for empty string.""" + result = MessageAdapter.extract_json('') + assert result is None + + def test_extract_json_none_input(self): + """Returns None for None input.""" + result = MessageAdapter.extract_json(None) + assert result is None + + def test_extract_json_prefers_code_block(self): + """Prefers code block JSON over embedded JSON.""" + content = '''Text {"wrong": "json"} +```json +{"correct": "json"} +```''' + result = MessageAdapter.extract_json(content) + assert result == '{"correct": "json"}' + + def test_extract_json_multiline(self): + """Extracts multiline JSON from code block.""" + content = '''```json +{ + "name": "test", + "items": [ + 1, + 2, + 3 + ] +} +```''' + result = MessageAdapter.extract_json(content) + assert result is not None + assert '"name": "test"' in result + assert '"items"' in result + + +class TestEnforceJsonFormat: + """Test MessageAdapter.enforce_json_format() method.""" + + def test_enforce_json_valid_object(self): + """Valid JSON object passes through.""" + content = '{"key": "value"}' + result = MessageAdapter.enforce_json_format(content) + assert result == content + + def test_enforce_json_valid_array(self): + """Valid JSON array passes through.""" + content = '[1, 2, 3]' + result = MessageAdapter.enforce_json_format(content) + assert result == content + + def test_enforce_json_extracts_from_text(self): + """Extracts JSON from surrounding text.""" + content = 'Here is the result: {"data": 123}' + result = MessageAdapter.enforce_json_format(content) + assert result == '{"data": 123}' + + def test_enforce_json_strict_fallback(self): + """Returns '[]' on failure in strict mode.""" + content = 'No JSON here at all!' + result = MessageAdapter.enforce_json_format(content, strict=True) + assert result == '[]' + + def test_enforce_json_non_strict_returns_original(self): + """Returns original content on failure in non-strict mode.""" + content = 'No JSON here at all!' + result = MessageAdapter.enforce_json_format(content, strict=False) + assert result == content + + def test_enforce_json_from_markdown(self): + """Extracts JSON from markdown code block.""" + content = '''```json +{"extracted": true} +```''' + result = MessageAdapter.enforce_json_format(content) + assert result == '{"extracted": true}' + + def test_enforce_json_empty_strict(self): + """Empty input returns '[]' in strict mode.""" + result = MessageAdapter.enforce_json_format('', strict=True) + assert result == '[]' + + +class TestResponseFormatModel: + """Test ResponseFormat Pydantic model.""" + + def test_response_format_default_text(self): + """Default type is 'text'.""" + rf = ResponseFormat() + assert rf.type == "text" + + def test_response_format_text_explicit(self): + """Can explicitly set type to 'text'.""" + rf = ResponseFormat(type="text") + assert rf.type == "text" + + def test_response_format_json_object(self): + """Can set type to 'json_object'.""" + rf = ResponseFormat(type="json_object") + assert rf.type == "json_object" + + def test_response_format_invalid_type(self): + """Invalid type raises validation error.""" + with pytest.raises(ValueError): + ResponseFormat(type="invalid") + + def test_response_format_in_request(self): + """ResponseFormat can be used in ChatCompletionRequest.""" + request = ChatCompletionRequest( + messages=[Message(role="user", content="Return JSON")], + response_format=ResponseFormat(type="json_object"), + ) + assert request.response_format is not None + assert request.response_format.type == "json_object" + + def test_response_format_none_in_request(self): + """ResponseFormat can be None in ChatCompletionRequest.""" + request = ChatCompletionRequest( + messages=[Message(role="user", content="Hello")], + ) + assert request.response_format is None + + def test_response_format_dict_input(self): + """ResponseFormat accepts dict input (OpenAI client style).""" + request = ChatCompletionRequest( + messages=[Message(role="user", content="Return JSON")], + response_format={"type": "json_object"}, + ) + assert request.response_format.type == "json_object" + + +class TestJsonModeInstruction: + """Test JSON_MODE_INSTRUCTION constant.""" + + def test_json_mode_instruction_exists(self): + """JSON_MODE_INSTRUCTION constant exists.""" + assert hasattr(MessageAdapter, "JSON_MODE_INSTRUCTION") + + def test_json_mode_instruction_not_empty(self): + """JSON_MODE_INSTRUCTION is not empty.""" + assert len(MessageAdapter.JSON_MODE_INSTRUCTION) > 0 + + def test_json_mode_instruction_mentions_json(self): + """JSON_MODE_INSTRUCTION mentions JSON.""" + assert "JSON" in MessageAdapter.JSON_MODE_INSTRUCTION.upper() + + def test_json_mode_instruction_is_string(self): + """JSON_MODE_INSTRUCTION is a string.""" + assert isinstance(MessageAdapter.JSON_MODE_INSTRUCTION, str) + + +class TestJsonExtractionEdgeCases: + """Test edge cases for JSON extraction.""" + + def test_json_with_escaped_quotes(self): + """Handles JSON with escaped quotes.""" + content = '{"message": "He said \\"hello\\""}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_json_with_unicode(self): + """Handles JSON with unicode characters.""" + content = '{"emoji": "\\u2764", "text": "hello"}' + result = MessageAdapter.extract_json(content) + assert result is not None + + def test_json_boolean_values(self): + """Handles JSON boolean values.""" + content = '{"active": true, "deleted": false}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_json_null_value(self): + """Handles JSON null value.""" + content = '{"data": null}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_json_number_types(self): + """Handles various JSON number types.""" + content = '{"int": 42, "float": 3.14, "negative": -10, "exp": 1e5}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_deeply_nested_json(self): + """Handles deeply nested JSON.""" + content = '{"a": {"b": {"c": {"d": {"e": 1}}}}}' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_json_array_of_objects(self): + """Handles array of objects.""" + content = '[{"id": 1}, {"id": 2}, {"id": 3}]' + result = MessageAdapter.extract_json(content) + assert result == content + + def test_multiple_json_blocks_returns_first_valid(self): + """When multiple code blocks exist, returns valid JSON from first.""" + content = '''```json +{"first": true} +``` +```json +{"second": true} +```''' + result = MessageAdapter.extract_json(content) + assert result == '{"first": true}' + + def test_json_with_newlines(self): + """Handles JSON with embedded newlines.""" + content = '{"text": "line1\\nline2"}' + result = MessageAdapter.extract_json(content) + assert result == content diff --git a/tests/test_model_service_unit.py b/tests/test_model_service_unit.py new file mode 100644 index 0000000..54ee3b7 --- /dev/null +++ b/tests/test_model_service_unit.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" +Unit tests for src/model_service.py + +Tests the ModelService class that fetches models from Anthropic API +with graceful fallback to static constants. +""" + +import pytest +from unittest.mock import patch, AsyncMock, MagicMock +import httpx + +from src.model_service import ModelService, MODEL_FETCH_TIMEOUT +from src.constants import CLAUDE_MODELS + + +class TestModelService: + """Test ModelService class.""" + + @pytest.fixture + def model_service(self): + """Create a fresh ModelService instance for each test.""" + return ModelService() + + @pytest.mark.asyncio + async def test_fetch_models_success(self, model_service): + """Successfully fetches models from API.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet"}, + {"id": "claude-haiku-4-5-20251001", "name": "Claude Haiku"}, + ] + } + + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock(return_value=mock_response) + + result = await model_service.fetch_models_from_api() + + assert result is not None + assert len(result) == 2 + assert "claude-sonnet-4-5-20250929" in result + assert "claude-haiku-4-5-20251001" in result + + @pytest.mark.asyncio + async def test_fetch_models_timeout(self, model_service): + """Returns None on timeout, allowing fallback to constants.""" + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock(side_effect=httpx.TimeoutException("timeout")) + + result = await model_service.fetch_models_from_api() + + assert result is None + + @pytest.mark.asyncio + async def test_fetch_models_auth_error(self, model_service): + """Returns None on 401 auth error, allowing fallback.""" + mock_response = MagicMock() + mock_response.status_code = 401 + + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "invalid-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock(return_value=mock_response) + + result = await model_service.fetch_models_from_api() + + assert result is None + + @pytest.mark.asyncio + async def test_fetch_models_rate_limited(self, model_service): + """Returns None on 429 rate limit, allowing fallback.""" + mock_response = MagicMock() + mock_response.status_code = 429 + + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock(return_value=mock_response) + + result = await model_service.fetch_models_from_api() + + assert result is None + + @pytest.mark.asyncio + async def test_fetch_models_network_error(self, model_service): + """Returns None on network error, allowing fallback.""" + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock( + side_effect=httpx.RequestError("connection failed") + ) + + result = await model_service.fetch_models_from_api() + + assert result is None + + @pytest.mark.asyncio + async def test_fetch_models_no_api_key(self, model_service): + """Returns None when no API key is set.""" + with patch.dict("os.environ", {}, clear=True): + # Ensure ANTHROPIC_API_KEY is not set + import os + if "ANTHROPIC_API_KEY" in os.environ: + del os.environ["ANTHROPIC_API_KEY"] + + result = await model_service.fetch_models_from_api() + + assert result is None + + @pytest.mark.asyncio + async def test_fetch_models_empty_response(self, model_service): + """Returns None when API returns empty model list.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": []} + + with patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test-key"}): + with patch.object(model_service, "_http_client") as mock_client: + mock_client.get = AsyncMock(return_value=mock_response) + + result = await model_service.fetch_models_from_api() + + assert result is None + + def test_get_models_returns_cached(self, model_service): + """Returns cached models when available.""" + model_service._cached_models = ["model-a", "model-b", "model-c"] + + result = model_service.get_models() + + assert result == ["model-a", "model-b", "model-c"] + + def test_get_models_returns_fallback(self, model_service): + """Returns CLAUDE_MODELS fallback when no cached models.""" + model_service._cached_models = None + + result = model_service.get_models() + + assert result == list(CLAUDE_MODELS) + + def test_get_models_returns_fallback_empty_cache(self, model_service): + """Returns CLAUDE_MODELS fallback when cache is empty list.""" + # Empty list is falsy, so should fall back + model_service._cached_models = [] + + result = model_service.get_models() + + # Empty list is falsy, so fallback is used + assert result == list(CLAUDE_MODELS) + + def test_is_initialized_false_by_default(self, model_service): + """Service is not initialized by default.""" + assert model_service.is_initialized() is False + + @pytest.mark.asyncio + async def test_initialize_sets_initialized(self, model_service): + """Initialize sets initialized flag.""" + with patch.object(model_service, "fetch_models_from_api", new_callable=AsyncMock) as mock: + mock.return_value = None + + await model_service.initialize() + + assert model_service.is_initialized() is True + + @pytest.mark.asyncio + async def test_initialize_caches_fetched_models(self, model_service): + """Initialize caches successfully fetched models.""" + fetched = ["claude-3-opus", "claude-3-sonnet"] + + with patch.object(model_service, "fetch_models_from_api", new_callable=AsyncMock) as mock: + mock.return_value = fetched + + await model_service.initialize() + + assert model_service._cached_models == fetched + + @pytest.mark.asyncio + async def test_initialize_only_once(self, model_service): + """Initialize only fetches models once.""" + with patch.object(model_service, "fetch_models_from_api", new_callable=AsyncMock) as mock: + mock.return_value = ["model-1"] + + await model_service.initialize() + await model_service.initialize() # Second call should be no-op + + mock.assert_called_once() + + @pytest.mark.asyncio + async def test_shutdown_closes_client(self, model_service): + """Shutdown closes the HTTP client.""" + mock_client = AsyncMock() + model_service._http_client = mock_client + model_service._initialized = True + + await model_service.shutdown() + + mock_client.aclose.assert_called_once() + assert model_service._http_client is None + assert model_service._initialized is False + + @pytest.mark.asyncio + async def test_shutdown_safe_when_not_initialized(self, model_service): + """Shutdown is safe when called before initialization.""" + # Should not raise + await model_service.shutdown() + + assert model_service._http_client is None + + +class TestModelServiceIntegration: + """Integration-style tests for ModelService.""" + + @pytest.mark.asyncio + async def test_full_lifecycle(self): + """Test full initialize-use-shutdown lifecycle.""" + service = ModelService() + + # Mock the API call + with patch.object(service, "fetch_models_from_api", new_callable=AsyncMock) as mock: + mock.return_value = ["test-model-1", "test-model-2"] + + # Initialize + await service.initialize() + assert service.is_initialized() + + # Use + models = service.get_models() + assert models == ["test-model-1", "test-model-2"] + + # Shutdown + await service.shutdown() + assert not service.is_initialized() + + # After shutdown, should return fallback + models = service.get_models() + assert models == list(CLAUDE_MODELS) + + @pytest.mark.asyncio + async def test_fallback_on_api_failure(self): + """Test that API failure results in fallback models.""" + service = ModelService() + + # Mock API failure + with patch.object(service, "fetch_models_from_api", new_callable=AsyncMock) as mock: + mock.return_value = None # API failed + + await service.initialize() + + models = service.get_models() + assert models == list(CLAUDE_MODELS) + + await service.shutdown()