From 35ebe47c8361e91242fdce8fbc95513b4f665efb Mon Sep 17 00:00:00 2001 From: Akshay Ram Date: Thu, 5 Mar 2026 21:02:34 +0700 Subject: [PATCH] feat: add Swagger 2.0 support and structured HTTP error details --- CLAUDE.md | 13 +- README.md | 19 ++- api_agent/agent/graphql_agent.py | 10 +- api_agent/agent/rest_agent.py | 4 +- api_agent/graphql/client.py | 4 +- api_agent/recipe/common.py | 4 +- api_agent/rest/client.py | 12 +- api_agent/rest/schema_loader.py | 231 ++++++++++++++++++++++++++++++- api_agent/utils/http_errors.py | 47 +++++++ tests/test_graphql_client.py | 130 +++++++++++++++++ tests/test_http_errors.py | 36 +++++ tests/test_rest_client.py | 34 +++++ tests/test_rest_schema.py | 173 +++++++++++++++++++++++ 13 files changed, 696 insertions(+), 21 deletions(-) create mode 100644 api_agent/utils/http_errors.py create mode 100644 tests/test_graphql_client.py create mode 100644 tests/test_http_errors.py diff --git a/CLAUDE.md b/CLAUDE.md index e6476dd..661635d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -85,9 +85,10 @@ docker run -p 3000:3000 -e OPENAI_API_KEY="..." api-agent - **api_agent/utils/**: Shared utilities - **csv.py**: CSV conversion via DuckDB (for recipe `return_directly` output) + - **http_errors.py**: HTTP error response extraction (used by both clients) - **api_agent/graphql/**: GraphQL client (httpx) -- **api_agent/rest/**: REST client (httpx) + OpenAPI loader +- **api_agent/rest/**: REST client (httpx) + OpenAPI loader (supports OpenAPI 3.x and Swagger 2.0) - **api_agent/executor.py**: DuckDB SQL execution, table extraction, context truncation ### Context Management @@ -131,6 +132,16 @@ Query → Agent executes → Extractor LLM → Recipe stored → MCP tool `r_{na - **Templating**: GraphQL `{{param}}`, REST `{"$param": "name"}`, SQL `{{param}}` - **Config**: `ENABLE_RECIPES` (default: True), `RECIPE_CACHE_SIZE` (default: 64) +## After Code Changes + +Always run before marking task complete: +```bash +uv run ruff check --fix api_agent/ # Lint + auto-fix +uv run ruff format api_agent/ # Format +uv run ty check # Type check +uv run pytest tests/ -v # Tests +``` + ## Testing Notes Tests use pytest-asyncio. Mock httpx for HTTP calls. See `tests/test_*.py` for patterns. diff --git a/README.md b/README.md index b84de63..b1b97a2 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ That's it. Agent introspects schema, generates queries, runs SQL post-processing ## More Examples -**REST API (Petstore):** +**REST API (Petstore — OpenAPI 3.x):** ```json { "mcpServers": { @@ -71,6 +71,21 @@ That's it. Agent introspects schema, generates queries, runs SQL post-processing } ``` +**REST API (Petstore — Swagger 2.0):** +```json +{ + "mcpServers": { + "petstore": { + "url": "http://localhost:3000/mcp", + "headers": { + "X-Target-URL": "https://petstore.swagger.io/v2/swagger.json", + "X-API-Type": "rest" + } + } + } +} +``` + **Your own API with auth:** ```json { @@ -95,7 +110,7 @@ That's it. Agent introspects schema, generates queries, runs SQL post-processing | Header | Required | Description | | ---------------------- | -------- | ---------------------------------------------------------- | -| `X-Target-URL` | Yes | GraphQL endpoint OR OpenAPI spec URL | +| `X-Target-URL` | Yes | GraphQL endpoint OR OpenAPI/Swagger spec URL (3.x and 2.0) | | `X-API-Type` | Yes | `graphql` or `rest` | | `X-Target-Headers` | No | JSON auth headers, e.g. `{"Authorization": "Bearer xxx"}` | | `X-API-Name` | No | Override tool name prefix (default: auto-generated) | diff --git a/api_agent/agent/graphql_agent.py b/api_agent/agent/graphql_agent.py index 6b217cc..0f440a9 100644 --- a/api_agent/agent/graphql_agent.py +++ b/api_agent/agent/graphql_agent.py @@ -391,6 +391,11 @@ async def graphql_query(query: str, name: str = "data", return_directly: bool = indent=2, ) + if not result.get("success"): + result["hint"] = ( + "Use search_schema to find valid field names, enum values, or required args" + ) + return json.dumps(result, indent=2) return graphql_query @@ -463,7 +468,10 @@ def _create_individual_recipe_tools( tool_name = deduplicate_tool_name(s.get("tool_name", "unknown_recipe"), seen_names) params_spec = recipe.get("params", {}) docstring = build_recipe_docstring( - s["question"], recipe.get("steps", []), recipe.get("sql_steps", []), "graphql", + s["question"], + recipe.get("steps", []), + recipe.get("sql_steps", []), + "graphql", params_spec=params_spec, ) diff --git a/api_agent/agent/rest_agent.py b/api_agent/agent/rest_agent.py index 36d9e4d..c7fe400 100644 --- a/api_agent/agent/rest_agent.py +++ b/api_agent/agent/rest_agent.py @@ -538,7 +538,9 @@ def _create_individual_recipe_tools( tool_name = deduplicate_tool_name(s.get("tool_name", "unknown_recipe"), seen_names) params_spec = recipe.get("params", {}) docstring = build_recipe_docstring( - s["question"], recipe.get("steps", []), recipe.get("sql_steps", []), + s["question"], + recipe.get("steps", []), + recipe.get("sql_steps", []), params_spec=params_spec, ) diff --git a/api_agent/graphql/client.py b/api_agent/graphql/client.py index 27f9a68..27871af 100644 --- a/api_agent/graphql/client.py +++ b/api_agent/graphql/client.py @@ -6,6 +6,8 @@ import httpx +from ..utils.http_errors import build_http_error_response + logger = logging.getLogger(__name__) # Block mutations (read-only mode) @@ -57,7 +59,7 @@ async def execute_query( return {"success": False, "error": result["errors"]} return {"success": True, "data": result.get("data", {})} except httpx.HTTPStatusError as e: - return {"success": False, "error": f"HTTP {e.response.status_code}"} + return build_http_error_response(e) except Exception as e: logger.exception("GraphQL error") return {"success": False, "error": str(e)} diff --git a/api_agent/recipe/common.py b/api_agent/recipe/common.py index d03604a..d45d785 100644 --- a/api_agent/recipe/common.py +++ b/api_agent/recipe/common.py @@ -202,7 +202,9 @@ def build_recipe_docstring( if params_spec: param_lines = [] for pname, spec in params_spec.items(): - ptype = _JSON_TYPE_NAMES.get(spec.get("type", "str") if isinstance(spec, dict) else "str", "string") + ptype = _JSON_TYPE_NAMES.get( + spec.get("type", "str") if isinstance(spec, dict) else "str", "string" + ) example = spec.get("default") if isinstance(spec, dict) else None hint = f" (e.g. {example})" if example is not None else "" param_lines.append(f" {pname}: {ptype} REQUIRED{hint}") diff --git a/api_agent/rest/client.py b/api_agent/rest/client.py index c742144..a84995a 100644 --- a/api_agent/rest/client.py +++ b/api_agent/rest/client.py @@ -7,6 +7,8 @@ import httpx +from ..utils.http_errors import build_http_error_response + logger = logging.getLogger(__name__) # Unsafe HTTP methods (blocked by default) @@ -138,15 +140,7 @@ async def execute_request( return {"success": True, "data": data} except httpx.HTTPStatusError as e: - # Try to get error body - try: - error_body = e.response.json() - except Exception: - error_body = e.response.text[:500] - return { - "success": False, - "error": f"HTTP {e.response.status_code}: {error_body}", - } + return build_http_error_response(e) except Exception as e: logger.exception("REST API error") return {"success": False, "error": str(e)} diff --git a/api_agent/rest/schema_loader.py b/api_agent/rest/schema_loader.py index f76ead3..571b3fb 100644 --- a/api_agent/rest/schema_loader.py +++ b/api_agent/rest/schema_loader.py @@ -12,6 +12,220 @@ logger = logging.getLogger(__name__) +def _rewrite_swagger_ref(ref: Any) -> Any: + """Rewrite Swagger 2 refs to OpenAPI 3 refs.""" + if not isinstance(ref, str): + return ref + return ref.replace("#/definitions/", "#/components/schemas/") + + +def _rewrite_refs(value: Any) -> Any: + """Recursively rewrite refs from Swagger 2 to OpenAPI 3 paths.""" + if isinstance(value, dict): + out: dict[str, Any] = {} + for k, v in value.items(): + if k == "$ref": + out[k] = _rewrite_swagger_ref(v) + else: + out[k] = _rewrite_refs(v) + return out + if isinstance(value, list): + return [_rewrite_refs(v) for v in value] + return value + + +def _swagger_param_to_oas3(param: Any) -> dict[str, Any] | None: + """Convert non-body Swagger 2 parameter to OpenAPI 3 parameter.""" + if not isinstance(param, dict): + return None + if param.get("in") == "body": + return None + + converted = _rewrite_refs(param) + if "schema" in converted and isinstance(converted["schema"], dict): + return converted + + # Swagger 2 allows schema fields directly on parameter. + schema: dict[str, Any] = {} + for key in ["type", "format", "items", "enum", "default", "minimum", "maximum"]: + if key in converted: + schema[key] = converted[key] # already rewritten + if schema: + converted["schema"] = schema + return converted + + +def _swagger_request_body_to_oas3(parameters: list[Any]) -> tuple[dict[str, Any] | None, list[Any]]: + """Extract Swagger 2 body parameter and return OAS3 requestBody + remaining params.""" + request_body: dict[str, Any] | None = None + remaining: list[Any] = [] + for p in parameters: + if isinstance(p, dict) and p.get("in") == "body" and request_body is None: + schema = p.get("schema", {}) + if isinstance(schema, dict): + request_body = { + "required": bool(p.get("required", False)), + "content": {"application/json": {"schema": _rewrite_refs(schema)}}, + } + continue + remaining.append(p) + return request_body, remaining + + +def _swagger_responses_to_oas3(responses: Any) -> dict[str, Any]: + """Convert Swagger 2 responses shape to OpenAPI 3 responses shape.""" + if not isinstance(responses, dict): + return {} + + out: dict[str, Any] = {} + for code, resp in responses.items(): + if not isinstance(resp, dict): + continue + converted = _rewrite_refs(resp) + schema = converted.pop("schema", None) + if isinstance(schema, dict): + converted["content"] = {"application/json": {"schema": schema}} + out[str(code)] = converted + return out + + +def _swagger_security_to_oas3(security_definitions: Any) -> dict[str, Any]: + """Convert Swagger 2 securityDefinitions to OAS3 securitySchemes.""" + if not isinstance(security_definitions, dict): + return {} + + out: dict[str, Any] = {} + for name, scheme in security_definitions.items(): + if not isinstance(scheme, dict): + continue + scheme_type = scheme.get("type", "") + converted = _rewrite_refs(scheme) + if scheme_type == "basic": + converted = {"type": "http", "scheme": "basic"} + elif scheme_type == "oauth2": + flow = scheme.get("flow", "") + flows: dict[str, Any] = {} + scopes = scheme.get("scopes", {}) + if flow == "accessCode": + flows["authorizationCode"] = { + "authorizationUrl": scheme.get("authorizationUrl", ""), + "tokenUrl": scheme.get("tokenUrl", ""), + "scopes": scopes if isinstance(scopes, dict) else {}, + } + elif flow == "application": + flows["clientCredentials"] = { + "tokenUrl": scheme.get("tokenUrl", ""), + "scopes": scopes if isinstance(scopes, dict) else {}, + } + elif flow == "password": + flows["password"] = { + "tokenUrl": scheme.get("tokenUrl", ""), + "scopes": scopes if isinstance(scopes, dict) else {}, + } + else: + flows["implicit"] = { + "authorizationUrl": scheme.get("authorizationUrl", ""), + "scopes": scopes if isinstance(scopes, dict) else {}, + } + converted = {"type": "oauth2", "flows": flows} + out[name] = converted + return out + + +def _swagger_servers_from_spec(swagger_spec: dict[str, Any]) -> list[dict[str, str]]: + """Build OAS3 servers from Swagger 2 host/basePath/schemes fields.""" + host = swagger_spec.get("host", "") + base_path = swagger_spec.get("basePath", "") + if not isinstance(host, str) or not host: + return [] + if not isinstance(base_path, str): + base_path = "" + if base_path and not base_path.startswith("/"): + base_path = f"/{base_path}" + + schemes = swagger_spec.get("schemes", []) + if not isinstance(schemes, list): + schemes = [] + scheme_list = [s for s in schemes if isinstance(s, str) and s] + if not scheme_list: + scheme_list = ["https"] + return [{"url": f"{s}://{host}{base_path}"} for s in scheme_list] + + +def normalize_swagger2_to_oas3(swagger_spec: dict[str, Any]) -> dict[str, Any]: + """Normalize Swagger 2.0 spec into minimal OpenAPI 3.x structure.""" + out: dict[str, Any] = { + "openapi": "3.0.3", + "info": _rewrite_refs(swagger_spec.get("info", {})) + if isinstance(swagger_spec.get("info"), dict) + else {}, + "paths": {}, + "components": { + "schemas": _rewrite_refs(swagger_spec.get("definitions", {})) + if isinstance(swagger_spec.get("definitions"), dict) + else {}, + "securitySchemes": _swagger_security_to_oas3( + swagger_spec.get("securityDefinitions", {}) + ), + }, + } + + servers = _swagger_servers_from_spec(swagger_spec) + if servers: + out["servers"] = servers + + paths = swagger_spec.get("paths", {}) + if not isinstance(paths, dict): + paths = {} + + for path, path_item in paths.items(): + if not isinstance(path, str) or not isinstance(path_item, dict): + continue + out_path_item: dict[str, Any] = {} + + path_level_params_raw = path_item.get("parameters", []) + if not isinstance(path_level_params_raw, list): + path_level_params_raw = [] + path_level_params = [] + for p in path_level_params_raw: + converted_param = _swagger_param_to_oas3(p) + if converted_param: + path_level_params.append(converted_param) + if path_level_params: + out_path_item["parameters"] = path_level_params + + for method in ["get", "post", "put", "delete", "patch", "options", "head"]: + op = path_item.get(method) + if not isinstance(op, dict): + continue + + new_op = _rewrite_refs(op) + raw_params = op.get("parameters", []) + if not isinstance(raw_params, list): + raw_params = [] + request_body, non_body_params = _swagger_request_body_to_oas3(raw_params) + + converted_params = [] + for p in non_body_params: + converted_param = _swagger_param_to_oas3(p) + if converted_param: + converted_params.append(converted_param) + if converted_params: + new_op["parameters"] = converted_params + else: + new_op.pop("parameters", None) + + if request_body: + new_op["requestBody"] = request_body + + new_op["responses"] = _swagger_responses_to_oas3(op.get("responses", {})) + out_path_item[method] = new_op + + out["paths"][path] = out_path_item + + return out + + async def load_openapi_spec( spec_url: str, headers: dict[str, str] | None = None, @@ -46,13 +260,20 @@ async def load_openapi_spec( logger.warning("OpenAPI spec root is not an object") return {} - # Validate OpenAPI 3.x + # Validate/normalize API schema shape openapi_version = spec.get("openapi", "") - if not isinstance(openapi_version, str) or not openapi_version.startswith("3."): - logger.warning(f"Unsupported OpenAPI version: {openapi_version}, expected 3.x") - return {} + if isinstance(openapi_version, str) and openapi_version.startswith("3."): + return spec + + swagger_version = spec.get("swagger", "") + if isinstance(swagger_version, str) and swagger_version.startswith("2."): + logger.info("Detected Swagger 2.0 spec, normalizing to OpenAPI 3.0 shape") + return normalize_swagger2_to_oas3(spec) - return spec + logger.warning( + f"Unsupported API schema version. openapi={openapi_version!r}, swagger={swagger_version!r}" + ) + return {} except Exception as e: logger.exception(f"Failed to load OpenAPI spec: {e}") diff --git a/api_agent/utils/http_errors.py b/api_agent/utils/http_errors.py new file mode 100644 index 0000000..4b3d4ed --- /dev/null +++ b/api_agent/utils/http_errors.py @@ -0,0 +1,47 @@ +"""HTTP error detail extraction utilities.""" + +from typing import Any + +import httpx + + +def build_http_error_response(e: httpx.HTTPStatusError) -> dict[str, Any]: + """Build a consistent error payload from HTTPStatusError.""" + status_code = e.response.status_code if e.response is not None else 0 + out: dict[str, Any] = { + "success": False, + "error": f"HTTP {status_code}", + "status_code": status_code, + } + details = extract_http_error_details(e.response) + if details is not None: + out["details"] = details + return out + + +def extract_http_error_details(response: httpx.Response | None) -> Any | None: + """Extract useful error payload from non-2xx HTTP responses.""" + if response is None: + return None + + try: + payload = response.json() + except Exception: + payload = None + + if payload is not None: + if isinstance(payload, dict): + if "errors" in payload: + return payload["errors"] + if "error" in payload: + return payload["error"] + if "message" in payload: + return payload["message"] + return payload + + # Bound fallback text extraction to avoid loading huge payloads. + raw = response.content[:1500] if response.content else b"" + text = raw.decode("utf-8", errors="replace").strip() if raw else "" + if text: + return text[:1000] + return None diff --git a/tests/test_graphql_client.py b/tests/test_graphql_client.py new file mode 100644 index 0000000..1534c82 --- /dev/null +++ b/tests/test_graphql_client.py @@ -0,0 +1,130 @@ +"""Tests for GraphQL client behavior.""" + +import httpx +import pytest + +from api_agent.graphql.client import execute_query + + +def _mock_response(json_data=None, raise_status=None): + """Create mock response with optional error.""" + + class _Response: + def raise_for_status(self): + if raise_status: + raise raise_status + return None + + def json(self): + return json_data + + return _Response() + + +def _mock_client(response_or_fn): + """Create mock async client that returns response or calls fn.""" + + class _Client: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, endpoint, json, headers): + if callable(response_or_fn): + return response_or_fn(endpoint, json, headers) + return response_or_fn + + return _Client() + + +@pytest.mark.asyncio +async def test_execute_query_requires_endpoint(): + result = await execute_query("query { users { id } }", endpoint="") + assert result["success"] is False + assert result["error"] == "No endpoint provided" + + +@pytest.mark.asyncio +async def test_execute_query_blocks_mutations(): + result = await execute_query("mutation { createUser(name: \"x\") { id } }", endpoint="https://api") + assert result["success"] is False + assert result["error"] == "Mutations are not allowed (read-only mode)" + + +@pytest.mark.asyncio +async def test_execute_query_success_includes_variables(monkeypatch): + captured: dict = {} + + def _post(endpoint, json, headers): + captured["endpoint"] = endpoint + captured["json"] = json + captured["headers"] = headers + return _mock_response(json_data={"data": {"users": [{"id": 1}]}}) + + client = _mock_client(_post) + monkeypatch.setattr("api_agent.graphql.client.httpx.AsyncClient", lambda **_kwargs: client) + result = await execute_query( + "query GetUser($id: ID!) { user(id: $id) { id } }", + variables={"id": "u1"}, + endpoint="https://api.example.com/graphql", + headers={"Authorization": "Bearer t"}, + ) + + assert result == {"success": True, "data": {"users": [{"id": 1}]}} + assert captured["endpoint"] == "https://api.example.com/graphql" + assert captured["json"]["variables"] == {"id": "u1"} + assert captured["headers"]["Authorization"] == "Bearer t" + + +@pytest.mark.asyncio +async def test_execute_query_success_without_variables_omits_key(monkeypatch): + captured: dict = {} + + def _post(endpoint, json, headers): + captured["json"] = json + return _mock_response(json_data={"data": {"ok": True}}) + + client = _mock_client(_post) + monkeypatch.setattr("api_agent.graphql.client.httpx.AsyncClient", lambda **_kwargs: client) + result = await execute_query("query { ping }", endpoint="https://api.example.com/graphql") + assert result == {"success": True, "data": {"ok": True}} + assert "variables" not in captured["json"] + + +@pytest.mark.asyncio +async def test_execute_query_returns_graphql_errors(monkeypatch): + response = _mock_response(json_data={"errors": [{"message": "bad field"}]}) + client = _mock_client(response) + monkeypatch.setattr("api_agent.graphql.client.httpx.AsyncClient", lambda **_kwargs: client) + result = await execute_query("query { badField }", endpoint="https://api.example.com/graphql") + assert result["success"] is False + assert result["error"] == [{"message": "bad field"}] + + +@pytest.mark.asyncio +async def test_execute_query_returns_http_status_error(monkeypatch): + request = httpx.Request("POST", "https://api.example.com/graphql") + response = httpx.Response(404, request=request) + http_error = httpx.HTTPStatusError("Not found", request=request, response=response) + mock_resp = _mock_response(raise_status=http_error) + client = _mock_client(mock_resp) + + monkeypatch.setattr("api_agent.graphql.client.httpx.AsyncClient", lambda **_kwargs: client) + result = await execute_query("query { users { id } }", endpoint="https://api.example.com/graphql") + assert result["success"] is False + assert result["error"] == "HTTP 404" + assert result["status_code"] == 404 + + +@pytest.mark.asyncio +async def test_execute_query_returns_generic_exception(monkeypatch): + def _raise_error(endpoint, json, headers): + raise RuntimeError("boom") + + client = _mock_client(_raise_error) + monkeypatch.setattr("api_agent.graphql.client.httpx.AsyncClient", lambda **_kwargs: client) + result = await execute_query("query { users { id } }", endpoint="https://api.example.com/graphql") + assert result["success"] is False + assert result["error"] == "boom" diff --git a/tests/test_http_errors.py b/tests/test_http_errors.py new file mode 100644 index 0000000..80a102f --- /dev/null +++ b/tests/test_http_errors.py @@ -0,0 +1,36 @@ +"""Tests for HTTP error detail helpers.""" + +import httpx + +from api_agent.utils.http_errors import build_http_error_response, extract_http_error_details + + +def _response(status: int, *, json_body=None, text_body: str = "") -> httpx.Response: + request = httpx.Request("GET", "https://api.example.com/test") + if json_body is not None: + return httpx.Response(status, request=request, json=json_body) + return httpx.Response(status, request=request, text=text_body) + + +def test_extract_http_error_details_prefers_errors_field(): + response = _response(400, json_body={"errors": [{"message": "bad field"}], "message": "ignored"}) + assert extract_http_error_details(response) == [{"message": "bad field"}] + + +def test_extract_http_error_details_limits_text_fallback(): + response = _response(500, text_body="x" * 5000) + details = extract_http_error_details(response) + assert isinstance(details, str) + assert len(details) == 1000 + + +def test_build_http_error_response_includes_status_and_details(): + response = _response(404, json_body={"error": "missing"}) + request = response.request + exc = httpx.HTTPStatusError("not found", request=request, response=response) + + out = build_http_error_response(exc) + assert out["success"] is False + assert out["error"] == "HTTP 404" + assert out["status_code"] == 404 + assert out["details"] == "missing" diff --git a/tests/test_rest_client.py b/tests/test_rest_client.py index 0ee0a9d..6233646 100644 --- a/tests/test_rest_client.py +++ b/tests/test_rest_client.py @@ -1,5 +1,8 @@ """Tests for REST client.""" +from unittest.mock import patch + +import httpx import pytest from api_agent.rest.client import _build_url, _is_path_allowed, execute_request @@ -181,3 +184,34 @@ async def test_nested_search_pattern(self): ) # Will fail with connection error but NOT blocked assert "not allowed" not in result.get("error", "") + + @pytest.mark.asyncio + async def test_http_status_error_includes_status_code_and_details(self): + request = httpx.Request("GET", "https://api.example.com/users") + response = httpx.Response(404, request=request, json={"error": "missing"}) + + class _Resp: + def raise_for_status(self): + raise httpx.HTTPStatusError("Not found", request=request, response=response) + + class _Client: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, *_args, **_kwargs): + return _Resp() + + with patch("api_agent.rest.client.httpx.AsyncClient", return_value=_Client()): + result = await execute_request( + "GET", + "/users", + base_url="https://api.example.com", + ) + + assert result["success"] is False + assert result["error"] == "HTTP 404" + assert result["status_code"] == 404 + assert result["details"] == "missing" diff --git a/tests/test_rest_schema.py b/tests/test_rest_schema.py index bc5bcbb..58165e5 100644 --- a/tests/test_rest_schema.py +++ b/tests/test_rest_schema.py @@ -1,13 +1,25 @@ """Tests for REST/OpenAPI schema context generation.""" +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx import pytest from api_agent.rest.schema_loader import ( + _rewrite_swagger_ref, + _swagger_param_to_oas3, + _swagger_request_body_to_oas3, + _swagger_responses_to_oas3, + _swagger_security_to_oas3, + _swagger_servers_from_spec, _format_params, _format_schema, _infer_string_format, _schema_to_type, build_schema_context, + load_openapi_spec, + normalize_swagger2_to_oas3, ) @@ -372,3 +384,164 @@ def test_post_endpoint_optional_body(self): ctx = build_schema_context(spec) assert "PUT /update(body: Data)" in ctx assert "body: Data!" not in ctx # not required + + +class TestSwagger2Normalization: + def test_normalize_swagger2_basic_shapes(self): + swagger_spec = { + "swagger": "2.0", + "info": {"title": "OKR API", "version": "2.0"}, + "host": "api.example.com", + "basePath": "/v1", + "schemes": ["https"], + "paths": { + "/users/{id}": { + "parameters": [{"name": "id", "in": "path", "required": True, "type": "string"}], + "get": { + "summary": "Get user", + "responses": {"200": {"schema": {"$ref": "#/definitions/User"}}}, + }, + "post": { + "summary": "Update user", + "parameters": [ + { + "name": "body", + "in": "body", + "required": True, + "schema": {"$ref": "#/definitions/UpdateUser"}, + } + ], + "responses": {"200": {"schema": {"$ref": "#/definitions/User"}}}, + }, + } + }, + "definitions": { + "User": { + "type": "object", + "properties": {"id": {"type": "string"}}, + "required": ["id"], + }, + "UpdateUser": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + }, + "securityDefinitions": {"basicAuth": {"type": "basic"}}, + } + + normalized = normalize_swagger2_to_oas3(swagger_spec) + + assert normalized["openapi"].startswith("3.") + assert normalized["servers"] == [{"url": "https://api.example.com/v1"}] + assert "components" in normalized + assert "schemas" in normalized["components"] + assert "User" in normalized["components"]["schemas"] + assert normalized["components"]["securitySchemes"]["basicAuth"] == { + "type": "http", + "scheme": "basic", + } + + get_op = normalized["paths"]["/users/{id}"]["get"] + assert get_op["responses"]["200"]["content"]["application/json"]["schema"]["$ref"] == ( + "#/components/schemas/User" + ) + + post_op = normalized["paths"]["/users/{id}"]["post"] + assert post_op["requestBody"]["required"] is True + assert ( + post_op["requestBody"]["content"]["application/json"]["schema"]["$ref"] + == "#/components/schemas/UpdateUser" + ) + + def test_build_context_from_normalized_swagger2(self): + swagger_spec = { + "swagger": "2.0", + "paths": { + "/users/{id}": { + "get": { + "parameters": [{"name": "id", "in": "path", "required": True, "type": "string"}], + "responses": {"200": {"schema": {"$ref": "#/definitions/User"}}}, + } + } + }, + "definitions": { + "User": { + "type": "object", + "properties": {"id": {"type": "string"}}, + "required": ["id"], + } + }, + } + normalized = normalize_swagger2_to_oas3(swagger_spec) + ctx = build_schema_context(normalized) + assert "GET /users/{id}(id: str) -> User" in ctx + assert "User { id: str! }" in ctx + + +class TestSwagger2Helpers: + def test_rewrite_swagger_ref(self): + assert _rewrite_swagger_ref("#/definitions/User") == "#/components/schemas/User" + + def test_swagger_param_conversion(self): + result = _swagger_param_to_oas3({"name": "limit", "in": "query", "type": "integer"}) + assert result is not None + assert result["schema"]["type"] == "integer" + + def test_swagger_request_body_extraction(self): + body, remaining = _swagger_request_body_to_oas3( + [ + {"in": "body", "name": "data", "required": True, "schema": {"type": "object"}}, + {"in": "query", "name": "limit", "type": "integer"}, + ] + ) + assert body is not None + assert body["required"] is True + assert len(remaining) == 1 + + def test_swagger_response_conversion(self): + responses = {"200": {"description": "ok", "schema": {"$ref": "#/definitions/User"}}} + result = _swagger_responses_to_oas3(responses) + assert result["200"]["content"]["application/json"]["schema"]["$ref"] == ( + "#/components/schemas/User" + ) + + def test_swagger_security_conversion(self): + result = _swagger_security_to_oas3({"basicAuth": {"type": "basic"}}) + assert result["basicAuth"] == {"type": "http", "scheme": "basic"} + + def test_swagger_servers_from_spec(self): + result = _swagger_servers_from_spec( + {"host": "api.example.com", "basePath": "/v1", "schemes": ["https"]} + ) + assert result == [{"url": "https://api.example.com/v1"}] + + +def _mock_http_response(status: int, text: str): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = status + mock_resp.text = text + if status >= 400: + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "error", request=MagicMock(), response=mock_resp + ) + else: + mock_resp.raise_for_status.return_value = None + return mock_resp + + +def _patch_http(text: str, status: int = 200): + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=_mock_http_response(status, text)) + return patch("httpx.AsyncClient", return_value=mock_client) + + +class TestLoadOpenApiSpec: + @pytest.mark.asyncio + async def test_swagger_2_spec_normalized(self): + spec = json.dumps({"swagger": "2.0", "info": {}, "paths": {}, "host": "api.example.com"}) + with _patch_http(spec): + result = await load_openapi_spec("https://api.example.com/openapi.json") + assert result.get("openapi", "").startswith("3.")