Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Tests for the FastAPI task endpoints in apps/api/main.py.

Covers regressions for:
* CR-03 — DELETE /tasks/{id} cancels the in-flight asyncio task and
removes the dict entry without leaving the background coroutine to
KeyError on resume.
* WR-07 — _tasks OrderedDict is bounded by _MAX_TASKS via LRU eviction.
"""

from __future__ import annotations

import asyncio

import fakeredis
import fakeredis.aioredis
import pytest
from fastapi import HTTPException

from apps.api.main import (
TaskRequest,
_task_handles,
_tasks,
delete_task,
get_task,
list_tasks,
submit_task,
)


@pytest.fixture(autouse=True)
def mock_redis(monkeypatch):
server = fakeredis.FakeServer()
monkeypatch.setattr(
"redis.from_url",
lambda *a, **kw: fakeredis.FakeRedis(server=server, decode_responses=True),
)
monkeypatch.setattr(
"redis.asyncio.from_url",
lambda *a, **kw: fakeredis.aioredis.FakeRedis(server=server, decode_responses=True),
)


@pytest.fixture(autouse=True)
def clean_task_state():
_tasks.clear()
_task_handles.clear()
yield
for handle in list(_task_handles.values()):
if not handle.done():
handle.cancel()
_tasks.clear()
_task_handles.clear()


class _GatedOrchestrator:
"""Stub orchestrator whose execute() blocks on an event the test controls.

Lets us exercise the in-flight cancellation path deterministically without
relying on timing or real LLM/Qdrant calls.
"""

def __init__(self) -> None:
self.gate = asyncio.Event()
self.execute_started = asyncio.Event()

async def execute(self, research_context):
from core.schemas import ResearchSession

self.execute_started.set()
await self.gate.wait()
return ResearchSession(
session_id="stub-session",
research_context=research_context,
created_at="2026-01-01T00:00:00Z",
status="complete",
)


@pytest.fixture
def stub_orchestrator(monkeypatch):
instances: list[_GatedOrchestrator] = []

def _make():
s = _GatedOrchestrator()
instances.append(s)
return s

monkeypatch.setattr("apps.api.main._make_orchestrator", _make)
return instances


@pytest.mark.asyncio
async def test_submit_task_returns_queued_response(stub_orchestrator):
resp = await submit_task(TaskRequest(query="hello"))
assert resp.status == "queued"
assert resp.task_id in _tasks
assert _tasks[resp.task_id]["status"] in {"queued", "running"}
assert resp.task_id in _task_handles


@pytest.mark.asyncio
async def test_get_task_returns_task_state(stub_orchestrator):
resp = await submit_task(TaskRequest(query="q"))
result = await get_task(resp.task_id)
assert result.task_id == resp.task_id
assert result.status in {"queued", "running"}


@pytest.mark.asyncio
async def test_get_task_404_when_unknown():
with pytest.raises(HTTPException) as exc_info:
await get_task("nonexistent")
assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_list_tasks_returns_all_submitted(stub_orchestrator):
r1 = await submit_task(TaskRequest(query="q1"))
r2 = await submit_task(TaskRequest(query="q2"))
items = await list_tasks()
ids = {item["task_id"] for item in items}
assert {r1.task_id, r2.task_id} <= ids


@pytest.mark.asyncio
async def test_delete_task_404_when_unknown():
with pytest.raises(HTTPException) as exc_info:
await delete_task("nonexistent")
assert exc_info.value.status_code == 404


@pytest.mark.asyncio
async def test_delete_task_cancels_in_flight_handle_cr03(stub_orchestrator):
"""CR-03 regression: DELETE /tasks/{id} cancels the running asyncio.Task.

Pre-fix the handle was never stored, so DELETE only removed the dict
entry while the background coroutine kept running and eventually
KeyError'd on resume.
"""
resp = await submit_task(TaskRequest(query="slow"))
task_id = resp.task_id

# Let _run_task start so it calls _make_orchestrator and reaches execute().
while not stub_orchestrator:
await asyncio.sleep(0)
await stub_orchestrator[0].execute_started.wait()

handle = _task_handles[task_id]
assert not handle.done(), "handle should still be running while gate is closed"

await delete_task(task_id)

assert task_id not in _tasks
assert task_id not in _task_handles
# Wait for cancellation to propagate; _run_task swallows CancelledError
# so awaiting the handle should not raise.
try:
await asyncio.wait_for(handle, timeout=2.0)
except asyncio.CancelledError:
pass
assert handle.cancelled() or handle.done()


@pytest.mark.asyncio
async def test_run_task_recovers_from_dict_eviction_cr03(stub_orchestrator):
"""CR-03 secondary: _run_task guards against _tasks[task_id] disappearing mid-flight.

Simulates the original race: dict entry deleted while orchestrator is
awaiting. Pre-fix the resume would KeyError on `_tasks[task_id][...] = ...`;
post-fix the `if task_id not in _tasks: return` guards short-circuit.
"""
resp = await submit_task(TaskRequest(query="q"))
task_id = resp.task_id

while not stub_orchestrator:
await asyncio.sleep(0)
await stub_orchestrator[0].execute_started.wait()

# Simulate the race: drop dict entry but DON'T cancel the handle.
del _tasks[task_id]
_task_handles.pop(task_id, None)

# Release the orchestrator so _run_task tries to write back to _tasks[task_id].
stub_orchestrator[0].gate.set()

# _run_task should observe the missing key and return cleanly — no KeyError.
await asyncio.wait_for(asyncio.shield(_pending_handles()), timeout=2.0)


async def _pending_handles():
# Drain any handles still in flight without surfacing exceptions.
pending = [t for t in asyncio.all_tasks() if t.get_coro().__name__ == "_run_task"]
if pending:
await asyncio.gather(*pending, return_exceptions=True)


@pytest.mark.asyncio
async def test_max_tasks_lru_eviction_wr07(monkeypatch, stub_orchestrator):
"""WR-07 regression: _tasks dict is capped via LRU eviction at submit time."""
monkeypatch.setattr("apps.api.main._MAX_TASKS", 3)

resps = []
for i in range(5):
r = await submit_task(TaskRequest(query=f"q{i}"))
resps.append(r)

assert len(_tasks) == 3
# Two oldest evicted
assert resps[0].task_id not in _tasks
assert resps[1].task_id not in _tasks
# Three newest retained, in insertion order
retained_ids = list(_tasks.keys())
assert retained_ids == [resps[2].task_id, resps[3].task_id, resps[4].task_id]
123 changes: 123 additions & 0 deletions tests/test_blackboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Tests for core/blackboard/engine.py.

Covers regressions for:
* CR-01 — "memory_guidance" is a valid entry type. The pre-fix
validate-list raised ValueError, which was silently swallowed by
the orchestrator's _add_entry, so memory was never recorded.
* WR-08 — Timestamps are timezone-aware ISO 8601 (datetime.now(timezone.utc)),
not naive datetime.utcnow() with a manually appended "Z".
"""

from __future__ import annotations

from datetime import datetime

import fakeredis
import pytest

from core.blackboard.engine import Blackboard

_DOCUMENTED_TYPES = (
"task",
"evidence_ref",
"route_decision",
"agent_output",
"status",
"lit_search",
"lit_map",
"critique",
"citation_audit",
"corpus_benchmarks",
"memory_guidance",
)


@pytest.fixture(autouse=True)
def mock_redis(monkeypatch):
server = fakeredis.FakeServer()
monkeypatch.setattr(
"redis.from_url",
lambda *a, **kw: fakeredis.FakeRedis(server=server, decode_responses=True),
)


def test_memory_guidance_is_allowed_entry_type_cr01():
"""CR-01 regression: memory_guidance must be in the allowed_types set."""
bb = Blackboard()
entry = bb.add_entry("memory_guidance", {"guidance": "consider prior episodes"})
assert entry.entry_type == "memory_guidance"
assert entry.content == {"guidance": "consider prior episodes"}


@pytest.mark.parametrize("entry_type", _DOCUMENTED_TYPES)
def test_all_documented_entry_types_accepted(entry_type):
"""Snapshot of accepted entry types — change deliberately, not by accident."""
bb = Blackboard()
entry = bb.add_entry(entry_type, {"x": 1})
assert entry.entry_type == entry_type


def test_invalid_entry_type_raises_value_error():
bb = Blackboard()
with pytest.raises(ValueError, match="Invalid entry_type"):
bb.add_entry("not_a_real_type", {"x": 1})


def test_timestamp_is_timezone_aware_iso8601_wr08():
"""WR-08 regression: timestamp uses datetime.now(timezone.utc), not naive utcnow + 'Z'."""
bb = Blackboard()
entry = bb.add_entry("status", {"x": 1})

# Tz-aware ISO 8601 ends with +00:00 (or an offset). The pre-fix
# naive utcnow().isoformat() + "Z" produced "...000Z" which has no offset.
parsed = datetime.fromisoformat(entry.timestamp)
assert parsed.tzinfo is not None, f"Expected tz-aware datetime, got naive: {entry.timestamp!r}"


def test_blackboard_isolates_per_session_id():
"""Different session_ids namespace their entries — no cross-session leakage."""
bb_a = Blackboard(session_id="session-a")
bb_b = Blackboard(session_id="session-b")

bb_a.add_entry("task", {"who": "a"})
bb_b.add_entry("task", {"who": "b"})

a_entries = bb_a.get_all_entries()
b_entries = bb_b.get_all_entries()

assert len(a_entries) == 1
assert len(b_entries) == 1
assert a_entries[0].content == {"who": "a"}
assert b_entries[0].content == {"who": "b"}


def test_get_entries_by_type_filters_correctly():
bb = Blackboard(session_id="filter-test")
bb.add_entry("task", {"x": 1})
bb.add_entry("status", {"x": 2})
bb.add_entry("task", {"x": 3})

tasks = bb.get_entries_by_type("task")
statuses = bb.get_entries_by_type("status")

assert len(tasks) == 2
assert len(statuses) == 1
assert all(e.entry_type == "task" for e in tasks)


def test_get_all_entries_returns_only_session_keys_wr04():
"""WR-04 supporting: get_all_entries scans only this session's prefix.

Sanity check that scan_iter with prefix correctly isolates session keys
even when other sessions populate the same Redis instance.
"""
bb_a = Blackboard(session_id="a")
bb_b = Blackboard(session_id="b")

for _ in range(5):
bb_a.add_entry("task", {"x": 1})
for _ in range(3):
bb_b.add_entry("task", {"x": 2})

assert len(bb_a.get_all_entries()) == 5
assert len(bb_b.get_all_entries()) == 3
30 changes: 30 additions & 0 deletions tests/test_budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,33 @@ async def _capture_guard(task, ctx, mode="planner"):
def test_module_singleton_resets():
"""Smoke test: module-level utilities don't leak state across tests."""
assert budget_module.get_current_guard() is None


def test_unified_estimate_cost_includes_input_tokens_wr05():
"""WR-05 regression: _estimate_cost counts input tokens, not just response length.

Pre-fix the estimate measured only the response string. For long prompts
with short responses (the typical thesis-pipeline pattern) this systematically
under-reported spend, letting cumulative cost drift past MAX_BUDGET_USD.
"""
from providers.unified import COST_PER_1K_TOKENS, UnifiedLLM

unified = UnifiedLLM()

long_prompt = "x" * 10000
long_system = "y" * 5000
short_response = "ok"

cost = unified._estimate_cost(short_response, "openai", long_prompt, long_system)

# Expected: (10000 + 5000 + 2) / 3.5 ≈ 4286 tokens × $0.005/1K ≈ $0.0214
expected_min = (15000 / 3.5 / 1000) * COST_PER_1K_TOKENS["openai"] * 0.95
assert (
cost >= expected_min
), f"Estimated cost {cost} is too low — input tokens may not be counted (WR-05)"

# Sanity: response-only would be ~$0.0000029, far below expected_min.
response_only = (len(short_response) / 3.5 / 1000) * COST_PER_1K_TOKENS["openai"]
assert (
cost > response_only * 100
), f"Cost {cost} is closer to response-only ({response_only}) than to input-inclusive"
Loading
Loading