From b0c19e3ba94ca04e32c360fa0341e7e59cfb2279 Mon Sep 17 00:00:00 2001 From: Arsh Verma Date: Wed, 8 Apr 2026 21:04:51 +0530 Subject: [PATCH 1/2] Allow /reset params via body or query; add tests Make the /reset endpoint more robust by accepting task_id and seed from either a JSON body or query parameters. ResetRequest.task_id is now optional (defaults to BUG_DETECTION). The endpoint resolves final values with precedence: default < body < query, then calls env.reset(final_task_id, final_seed). Added fastapi Body import, an info log on reset, and a test (test_api_reset_robustness) that verifies no-body, empty-body, query-only, and query-overrides-body behavior. --- app.py | 31 +++++++++++++++++++++++++++---- tests/test_api.py | 26 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index 31239a0..b9382c3 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request, Body from fastapi.responses import JSONResponse, FileResponse from fastapi.exceptions import RequestValidationError from fastapi.security.api_key import APIKeyHeader @@ -148,7 +148,7 @@ async def cleanup_expired_episodes(): # ── Models ──────────────────────────────────────────────────────────────────── class ResetRequest(BaseModel): - task_id: TaskId + task_id: Optional[TaskId] = TaskId.BUG_DETECTION seed: int = 42 class ResetResponse(BaseModel): @@ -216,12 +216,35 @@ def health_check(): @app.post("/reset", response_model=ResetResponse) @limiter.limit(f"{settings.rate_limit_per_minute}/minute") -def reset_env(request: Request, req: ResetRequest, _: None = Depends(verify_api_key)): +def reset_env( + request: Request, + req: Optional[ResetRequest] = Body(None), + task_id: Optional[TaskId] = Query(None), + seed: Optional[int] = Query(None), + _: None = Depends(verify_api_key) +): + # Determine task_id: Body (if present) > Query params > Default + final_task_id = TaskId.BUG_DETECTION + final_seed = 42 + + if req: + if req.task_id: + final_task_id = req.task_id + final_seed = req.seed + + # Query parameters override body if provided explicitly + if task_id: + final_task_id = task_id + if seed is not None: + final_seed = seed + episode_id = str(uuid.uuid4()) env = CodeLensEnv() - result = env.reset(req.task_id, req.seed) + result = env.reset(final_task_id, final_seed) episodes[episode_id] = env episode_timestamps[episode_id] = datetime.now(timezone.utc) + + logger.info(f"Reset environment: task={final_task_id.value}, seed={final_seed}, id={episode_id}") return ResetResponse(episode_id=episode_id, result=result) @app.post("/step/{episode_id}", response_model=StepResult) diff --git a/tests/test_api.py b/tests/test_api.py index 853f49a..caf1c33 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -139,3 +139,29 @@ def test_api_leaderboard_pagination(client): # Test ordering (best first) assert data["entries"][0]["score"] >= data["entries"][1]["score"] + +def test_api_reset_robustness(client): + # 1. No body at all + resp = client.post("/reset") + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 2. Empty JSON body + resp = client.post("/reset", json={}) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 3. Query params only + resp = client.post("/reset?task_id=security_audit&seed=100") + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "security_audit" + assert resp.json()["result"]["seed"] == 100 + + # 4. Query params overriding body + resp = client.post( + "/reset?task_id=architectural_review", + json={"task_id": "bug_detection", "seed": 50} + ) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "architectural_review" + assert resp.json()["result"]["seed"] == 50 From 8c875cef53ea660bf3aac486dd7f416a5aa89619 Mon Sep 17 00:00:00 2001 From: Arsh Verma Date: Wed, 8 Apr 2026 21:44:44 +0530 Subject: [PATCH 2/2] Implement Maximum Robustness for /reset to bypass body validation errors --- app.py | 27 ++++++++++++++++++--------- tests/test_api.py | 29 ++++++++++++++++------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/app.py b/app.py index b9382c3..c0afde6 100644 --- a/app.py +++ b/app.py @@ -216,23 +216,32 @@ def health_check(): @app.post("/reset", response_model=ResetResponse) @limiter.limit(f"{settings.rate_limit_per_minute}/minute") -def reset_env( +async def reset_env( request: Request, - req: Optional[ResetRequest] = Body(None), task_id: Optional[TaskId] = Query(None), seed: Optional[int] = Query(None), _: None = Depends(verify_api_key) ): - # Determine task_id: Body (if present) > Query params > Default + # Determine task_id and seed with manual fallback strategy final_task_id = TaskId.BUG_DETECTION final_seed = 42 - if req: - if req.task_id: - final_task_id = req.task_id - final_seed = req.seed + # 1. Try to extract from body manually (handles empty/malformed bodies) + try: + body = await request.json() + if body and isinstance(body, dict): + if body.get("task_id"): + try: + final_task_id = TaskId(body["task_id"]) + except ValueError: + pass + if body.get("seed") is not None: + final_seed = int(body["seed"]) + except Exception: + # Ignore body parsing errors (empty/malformed) and fall back + pass - # Query parameters override body if provided explicitly + # 2. Query parameters override body if provided explicitly if task_id: final_task_id = task_id if seed is not None: @@ -244,7 +253,7 @@ def reset_env( episodes[episode_id] = env episode_timestamps[episode_id] = datetime.now(timezone.utc) - logger.info(f"Reset environment: task={final_task_id.value}, seed={final_seed}, id={episode_id}") + logger.info(f"Reset environment (Robust): task={final_task_id.value}, seed={final_seed}, id={episode_id}") return ResetResponse(episode_id=episode_id, result=result) @app.post("/step/{episode_id}", response_model=StepResult) diff --git a/tests/test_api.py b/tests/test_api.py index caf1c33..383b906 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -68,7 +68,8 @@ def test_api_health_fields(client): def test_api_reset_invalid_task(client): resp = client.post("/reset", json={"task_id": "invalid_task", "seed": 0}) - assert resp.status_code == 422 + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" # Fallback def test_api_step_invalid_action_type(client): reset_resp = client.post("/reset", json={"task_id": "bug_detection", "seed": 0}) @@ -151,17 +152,19 @@ def test_api_reset_robustness(client): assert resp.status_code == 200 assert resp.json()["result"]["task_id"] == "bug_detection" - # 3. Query params only - resp = client.post("/reset?task_id=security_audit&seed=100") + # 3. Invalid JSON (should not trigger 422 now) + resp = client.post("/reset", content="invalid json {", headers={"Content-Type": "application/json"}) + assert resp.status_code == 200 + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 4. Plain text body (unexpected header, should still pass) + resp = client.post("/reset", content="just some text", headers={"Content-Type": "text/plain"}) assert resp.status_code == 200 - assert resp.json()["result"]["task_id"] == "security_audit" - assert resp.json()["result"]["seed"] == 100 - - # 4. Query params overriding body - resp = client.post( - "/reset?task_id=architectural_review", - json={"task_id": "bug_detection", "seed": 50} - ) + assert resp.json()["result"]["task_id"] == "bug_detection" + + # 5. Query params override + resp = client.post("/reset?task_id=security_audit&seed=100") assert resp.status_code == 200 - assert resp.json()["result"]["task_id"] == "architectural_review" - assert resp.json()["result"]["seed"] == 50 + data = resp.json() + assert data["result"]["task_id"] == "security_audit" + assert data["result"]["seed"] == 100