diff --git a/app.py b/app.py index 31239a0..c0afde6 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request, Body from fastapi.responses import JSONResponse, FileResponse from fastapi.exceptions import RequestValidationError from fastapi.security.api_key import APIKeyHeader @@ -148,7 +148,7 @@ async def cleanup_expired_episodes(): # ── Models ──────────────────────────────────────────────────────────────────── class ResetRequest(BaseModel): - task_id: TaskId + task_id: Optional[TaskId] = TaskId.BUG_DETECTION seed: int = 42 class ResetResponse(BaseModel): @@ -216,12 +216,44 @@ def health_check(): @app.post("/reset", response_model=ResetResponse) @limiter.limit(f"{settings.rate_limit_per_minute}/minute") -def reset_env(request: Request, req: ResetRequest, _: None = Depends(verify_api_key)): +async def reset_env( + request: Request, + task_id: Optional[TaskId] = Query(None), + seed: Optional[int] = Query(None), + _: None = Depends(verify_api_key) +): + # Determine task_id and seed with manual fallback strategy + final_task_id = TaskId.BUG_DETECTION + final_seed = 42 + + # 1. Try to extract from body manually (handles empty/malformed bodies) + try: + body = await request.json() + if body and isinstance(body, dict): + if body.get("task_id"): + try: + final_task_id = TaskId(body["task_id"]) + except ValueError: + pass + if body.get("seed") is not None: + final_seed = int(body["seed"]) + except Exception: + # Ignore body parsing errors (empty/malformed) and fall back + pass + + # 2. Query parameters override body if provided explicitly + if task_id: + final_task_id = task_id + if seed is not None: + final_seed = seed + episode_id = str(uuid.uuid4()) env = CodeLensEnv() - result = env.reset(req.task_id, req.seed) + result = env.reset(final_task_id, final_seed) episodes[episode_id] = env episode_timestamps[episode_id] = datetime.now(timezone.utc) + + logger.info(f"Reset environment (Robust): task={final_task_id.value}, seed={final_seed}, id={episode_id}") return ResetResponse(episode_id=episode_id, result=result) @app.post("/step/{episode_id}", response_model=StepResult) diff --git a/tests/test_api.py b/tests/test_api.py index 853f49a..383b906 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -68,7 +68,8 @@ def test_api_health_fields(client): def test_api_reset_invalid_task(client): resp = client.post("/reset", json={"task_id": "invalid_task", "seed": 0}) - assert resp.status_code == 422 + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" # Fallback def test_api_step_invalid_action_type(client): reset_resp = client.post("/reset", json={"task_id": "bug_detection", "seed": 0}) @@ -139,3 +140,31 @@ def test_api_leaderboard_pagination(client): # Test ordering (best first) assert data["entries"][0]["score"] >= data["entries"][1]["score"] + +def test_api_reset_robustness(client): + # 1. No body at all + resp = client.post("/reset") + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 2. Empty JSON body + resp = client.post("/reset", json={}) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 3. Invalid JSON (should not trigger 422 now) + resp = client.post("/reset", content="invalid json {", headers={"Content-Type": "application/json"}) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 4. Plain text body (unexpected header, should still pass) + resp = client.post("/reset", content="just some text", headers={"Content-Type": "text/plain"}) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 5. Query params override + resp = client.post("/reset?task_id=security_audit&seed=100") + assert resp.status_code == 200 + data = resp.json() + assert data["result"]["task_id"] == "security_audit" + assert data["result"]["seed"] == 100