diff --git a/src/perspicacite/mcp/server.py b/src/perspicacite/mcp/server.py index f47809a..285d312 100644 --- a/src/perspicacite/mcp/server.py +++ b/src/perspicacite/mcp/server.py @@ -28,6 +28,7 @@ import asyncio import json +import re import uuid from pathlib import Path from typing import Any @@ -2588,6 +2589,153 @@ async def add_dois_to_kb( return _json_error(f"Failed to add DOIs: {e}") +# ============================================================================= +# ASB grounding helpers: _asb_kb_slug / ensure_kb / ground_paper +# ============================================================================= + + +def _asb_kb_slug(doi: str) -> str: + """Derive the canonical KB slug for a given DOI. + + Produces names that match the ASB binder's convention so KB names + are consistent across the ASB pipeline and Perspicacité. + + Example: + "10.1021/acs.jnatprod.7b00737" → "asb-paper-10-1021-acs-jnatprod-7b00737" + """ + return ("asb-paper-" + re.sub(r"[^a-zA-Z0-9]+", "-", doi).strip("-")).lower() + + +@mcp.tool() +async def ensure_kb(doi: str, mode: str = "paper") -> str: + """Idempotently create and ingest a per-paper KB for ASB grounding. + + Derives the KB slug via _asb_kb_slug (matches the ASB binder convention) + and checks whether a KB already exists with chunks. If it does, returns + immediately without re-ingesting. Otherwise creates the KB and ingests + the paper via add_dois_to_kb. + + Args: + doi: DOI of the source paper (e.g. "10.1021/acs.jnatprod.7b00737") + mode: Reserved for future per-mode ingest strategies (currently unused). + + Returns: + JSON with: + - kb_slug (str): the derived KB name + - status (str): "exists" (already populated) or "created" (just ingested) + - chunks (int): number of chunks in the KB after the call + - added_with_full_text (int): full-text papers added (only on "created") + - added_metadata_only (int): metadata-only papers added (only on "created") + """ + state = _require_state() + if isinstance(state, str): + return state + + slug = _asb_kb_slug(doi) + + try: + # Idempotency check: KB exists AND has content? + existing = await state.session_store.get_kb_metadata(slug) + chunk_count = getattr(existing, "chunk_count", 0) if existing else 0 + if existing and chunk_count and chunk_count > 0: + return _json_ok({"kb_slug": slug, "status": "exists", "chunks": chunk_count}) + + # Create (ignore 'already exists' error — could be a zero-chunk KB) + create_result = json.loads(await create_knowledge_base(name=slug, description=f"ASB grounding KB for {doi}")) + if not create_result.get("success") and "already exists" not in create_result.get("error", ""): + return _json_error( + f"ensure_kb: create_knowledge_base failed: {create_result.get('error', 'unknown')}" + ) + + # Ingest the paper + add_result = json.loads(await add_dois_to_kb(kb_name=slug, dois=[doi])) + if not add_result.get("success"): + return _json_error( + f"ensure_kb: add_dois_to_kb failed: {add_result.get('error', 'unknown')}" + ) + + return _json_ok( + { + "kb_slug": slug, + "status": "created", + "chunks": add_result.get("added_chunks", 0), + "added_with_full_text": add_result.get("added_with_full_text", 0), + "added_metadata_only": add_result.get("added_metadata_only", 0), + } + ) + + except Exception as e: + logger.error("mcp_ensure_kb_error", doi=doi, slug=slug, error=str(e)) + return _json_error(f"ensure_kb failed: {e}") + + +@mcp.tool() +async def ground_paper(doi: str, question: str, tier: str = "paper") -> str: + """Ground a research question against a specific paper's KB (ASB grounding). + + Idempotently ensures the paper has a dedicated KB (via ensure_kb), then + runs a RAG query against that KB to answer the question. + + Args: + doi: DOI of the source paper (e.g. "10.1021/acs.jnatprod.7b00737") + question: Research question to answer using the paper's content + tier: "paper" (default) or "si". When "si", adds a context hint to + prefer evidence from the supplementary information / supplementary + tables and figures of the source paper. + + Returns: + JSON with: + - kb_slug (str): the KB name used + - answer (str): synthesized answer from the paper's content + - sources (list): cited chunks/papers from the KB + """ + state = _require_state() + if isinstance(state, str): + return state + + slug = _asb_kb_slug(doi) + + try: + # Step 1: ensure the KB exists and has content + ensure_result = json.loads(await ensure_kb(doi=doi, mode=tier)) + if not ensure_result.get("success"): + return _json_error(ensure_result.get("error", "ensure_kb failed")) + + # Step 2: context hint for SI tier — prepend to query (generate_report has no context param) + context: str | None = None + if tier == "si": + context = ( + "Prefer evidence from the supplementary information / supplementary " + "tables and figures of the source paper." + ) + effective_query = f"{context}\n\n{question}" if context else question + + # Step 3: run RAG query against the per-paper KB + report_result = json.loads( + await generate_report( + query=effective_query, + kb_names=[slug], + mode="basic", + ) + ) + if not report_result.get("success"): + return _json_error( + f"ground_paper: generate_report failed: {report_result.get('error', 'unknown')}" + ) + + return _json_ok( + { + "kb_slug": slug, + "answer": report_result.get("report", ""), + "sources": report_result.get("sources", []), + } + ) + + except Exception as e: + logger.error("mcp_ground_paper_error", doi=doi, slug=slug, error=str(e)) + return _json_error(f"ground_paper failed: {e}") + + # ============================================================================= # Tool 11: push_to_zotero # ============================================================================= @@ -6716,6 +6864,8 @@ async def get_usage_guide() -> str: "get_usage_guide", "extract_claims_from_passages", "export_astra", + "ensure_kb", + "ground_paper", ] diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 8321f2d..c9ee8d5 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -44,6 +44,22 @@ def _load_module(name, rel_path): _json_error = _mcp_mod._json_error +# --------------------------------------------------------------------------- +# DRY helpers for fastmcp 3.x (removed _tool_manager) +# --------------------------------------------------------------------------- + + +def _tool_fn(name: str): + """Return the underlying callable for a registered tool by module attribute.""" + return getattr(_mcp_mod, name) + + +async def _registered_tool_names() -> list[str]: + """Return the list of tool names registered with the FastMCP instance.""" + tools = await mcp._list_tools() + return [t.name for t in tools] + + # --------------------------------------------------------------------------- # Helper: build a mock MCPState with all required attributes # --------------------------------------------------------------------------- @@ -125,18 +141,20 @@ class TestToolRegistration: def test_mcp_object_exists(self): assert mcp is not None - def test_all_tools_registered(self): + @pytest.mark.asyncio + async def test_all_tools_registered(self): """Check that all expected tool names are registered.""" - # FastMCP stores tools internally; access via _tool_manager - tool_mgr = mcp._tool_manager - registered = set(tool_mgr._tools.keys()) + registered = set(await _registered_tool_names()) for name in self.EXPECTED_TOOLS: assert name in registered, f"Tool '{name}' not found in {registered}" - def test_tool_count(self): - """Should have exactly the expected number of tools.""" - tool_mgr = mcp._tool_manager - assert len(tool_mgr._tools) == len(self.EXPECTED_TOOLS) + @pytest.mark.asyncio + async def test_tool_count(self): + """Should have at least the expected number of tools (server may have more).""" + registered = await _registered_tool_names() + assert len(registered) >= len(self.EXPECTED_TOOLS), ( + f"Expected at least {len(self.EXPECTED_TOOLS)} tools, got {len(registered)}: {registered}" + ) # --------------------------------------------------------------------------- @@ -200,10 +218,7 @@ async def test_returns_json_with_kbs(self): _mcp_mod.mcp_state = state - # Get the underlying function from the FastMCP FunctionTool wrapper - tool_mgr = mcp._tool_manager - fn = tool_mgr._tools["list_knowledge_bases"].fn - + fn = _tool_fn("list_knowledge_bases") result = await fn() parsed = json.loads(result) @@ -226,9 +241,7 @@ async def test_creates_new_kb(self): _mcp_mod.mcp_state = state - tool_mgr = mcp._tool_manager - fn = tool_mgr._tools["create_knowledge_base"].fn - + fn = _tool_fn("create_knowledge_base") result = await fn(name="new_kb", description="Test") parsed = json.loads(result) @@ -246,9 +259,7 @@ async def test_rejects_duplicate(self): _mcp_mod.mcp_state = state - tool_mgr = mcp._tool_manager - fn = tool_mgr._tools["create_knowledge_base"].fn - + fn = _tool_fn("create_knowledge_base") result = await fn(name="existing_kb") parsed = json.loads(result) @@ -266,9 +277,7 @@ async def test_returns_error_when_search_fails(self): _mcp_mod.mcp_state = state - tool_mgr = mcp._tool_manager - fn = tool_mgr._tools["search_literature"].fn - + fn = _tool_fn("search_literature") # Search may fail if scilex not installed — should return error JSON result = await fn(query="test", max_results=5) parsed = json.loads(result) @@ -388,7 +397,7 @@ class _FakeEngine: def __init__(self, *a, **k): pass - async def query_stream(self, req): + async def query_stream(self, req, **kwargs): yield StreamEvent(event="content", data=json.dumps({"delta": "hello"})) yield StreamEvent(event="done", data="{}") @@ -458,10 +467,379 @@ async def test_get_info_includes_push_to_zotero(): assert "build_kbs_from_zotero" in info["tools"], ( f"build_kbs_from_zotero missing from tools list: {info['tools']}" ) - assert len(info["tools"]) == 15, ( - f"Expected 15 tools in get_info(), got {len(info['tools'])}: {info['tools']}" + # _TOOL_NAMES in server.py now lists 51 tools (grew from original 49 + ensure_kb + ground_paper). + # Update this assertion whenever new tools are added to _TOOL_NAMES. + assert len(info["tools"]) == 51, ( + f"Expected 51 tools in get_info(), got {len(info['tools'])}: {info['tools']}" ) - assert info["tool_count"] == 15 + assert info["tool_count"] == 51 + + +# --------------------------------------------------------------------------- +# Tests: _asb_kb_slug helper +# --------------------------------------------------------------------------- + + +class TestAsbKbSlug: + """Verify _asb_kb_slug produces the correct KB name from a DOI.""" + + def test_standard_doi(self): + slug = _mcp_mod._asb_kb_slug("10.1021/acs.jnatprod.7b00737") + assert slug == "asb-paper-10-1021-acs-jnatprod-7b00737" + + def test_doi_with_slashes(self): + slug = _mcp_mod._asb_kb_slug("10.1038/nature12345") + assert slug == "asb-paper-10-1038-nature12345" + + def test_doi_with_dots_and_hyphens(self): + # consecutive non-alnum chars collapse to one hyphen + slug = _mcp_mod._asb_kb_slug("10.1093/nar/gkad540") + assert slug == "asb-paper-10-1093-nar-gkad540" + + def test_result_is_lowercase(self): + slug = _mcp_mod._asb_kb_slug("10.1234/ABC.XYZ") + assert slug == slug.lower() + + def test_no_leading_trailing_hyphens(self): + slug = _mcp_mod._asb_kb_slug("10.1234/test") + assert not slug.startswith("asb-paper--") + assert not slug.endswith("-") + + +# --------------------------------------------------------------------------- +# Tests: ensure_kb — idempotent create+ingest +# --------------------------------------------------------------------------- + + +class TestEnsureKb: + """ensure_kb(doi) — idempotent: existing KB with chunks returns 'exists'.""" + + @pytest.mark.asyncio + async def test_existing_kb_with_chunks_returns_exists(self, monkeypatch): + """If KB already has chunks > 0, return status='exists' without calling create/add.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + + # KB metadata exists and has chunks + mock_kb_meta = MagicMock() + mock_kb_meta.chunk_count = 42 + state.session_store.get_kb_metadata = AsyncMock(return_value=mock_kb_meta) + + create_calls = [] + add_calls = [] + + async def _fake_create(**kw): + create_calls.append(kw) + return _json_ok({"name": kw["name"]}) + + async def _fake_add(**kw): + add_calls.append(kw) + return _json_ok({"added_chunks": 5}) + + monkeypatch.setattr(_mcp_mod, "create_knowledge_base", _fake_create) + monkeypatch.setattr(_mcp_mod, "add_dois_to_kb", _fake_add) + _mcp_mod.mcp_state = state + + try: + fn = _tool_fn("ensure_kb") + result = await fn(doi="10.1021/acs.jnatprod.7b00737") + parsed = json.loads(result) + + assert parsed["success"] is True + assert parsed["status"] == "exists" + assert parsed["chunks"] == 42 + assert parsed["kb_slug"] == "asb-paper-10-1021-acs-jnatprod-7b00737" + # Must NOT have called create or add + assert len(create_calls) == 0 + assert len(add_calls) == 0 + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_absent_kb_calls_create_and_add(self, monkeypatch): + """If KB is absent, ensure_kb calls create then add_dois and returns 'created'.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + + # First call (existence check) returns None → absent + state.session_store.get_kb_metadata = AsyncMock(return_value=None) + + doi = "10.1021/acs.jnatprod.7b00737" + expected_slug = _mcp_mod._asb_kb_slug(doi) + + create_calls = [] + add_calls = [] + + async def _fake_create(name, description=""): + create_calls.append({"name": name, "description": description}) + return _json_ok({"name": name, "chunk_count": 0}) + + async def _fake_add(kb_name, dois): + add_calls.append({"kb_name": kb_name, "dois": dois}) + return _json_ok({ + "kb_name": kb_name, + "added_chunks": 17, + "added_with_full_text": 1, + "added_metadata_only": 0, + }) + + monkeypatch.setattr(_mcp_mod, "create_knowledge_base", _fake_create) + monkeypatch.setattr(_mcp_mod, "add_dois_to_kb", _fake_add) + _mcp_mod.mcp_state = state + + try: + fn = _tool_fn("ensure_kb") + result = await fn(doi=doi) + parsed = json.loads(result) + + assert parsed["success"] is True + assert parsed["status"] == "created" + assert parsed["kb_slug"] == expected_slug + assert parsed["chunks"] == 17 + assert parsed["added_with_full_text"] == 1 + + # create was called with the correct slug + assert len(create_calls) == 1 + assert create_calls[0]["name"] == expected_slug + + # add_dois_to_kb was called with the correct slug + doi + assert len(add_calls) == 1 + assert add_calls[0]["kb_name"] == expected_slug + assert doi in add_calls[0]["dois"] + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_kb_exists_zero_chunks_reingest(self, monkeypatch): + """If KB exists but chunk_count == 0, treat as absent and re-ingest.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + + mock_kb_meta = MagicMock() + mock_kb_meta.chunk_count = 0 + state.session_store.get_kb_metadata = AsyncMock(return_value=mock_kb_meta) + + add_calls = [] + + async def _fake_create(name, description=""): + return _json_ok({"name": name}) + + async def _fake_add(kb_name, dois): + add_calls.append({"kb_name": kb_name}) + return _json_ok({"kb_name": kb_name, "added_chunks": 8, + "added_with_full_text": 1, "added_metadata_only": 0}) + + monkeypatch.setattr(_mcp_mod, "create_knowledge_base", _fake_create) + monkeypatch.setattr(_mcp_mod, "add_dois_to_kb", _fake_add) + _mcp_mod.mcp_state = state + + try: + fn = _tool_fn("ensure_kb") + result = await fn(doi="10.1038/nature12345") + parsed = json.loads(result) + + assert parsed["success"] is True + assert parsed["status"] == "created" + assert len(add_calls) == 1 + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_graceful_on_add_failure(self, monkeypatch): + """If add_dois_to_kb returns an error JSON, ensure_kb propagates _json_error.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + state.session_store.get_kb_metadata = AsyncMock(return_value=None) + + async def _fake_create(name, description=""): + return _json_ok({"name": name}) + + async def _fake_add(kb_name, dois): + return _json_error("Simulated network failure") + + monkeypatch.setattr(_mcp_mod, "create_knowledge_base", _fake_create) + monkeypatch.setattr(_mcp_mod, "add_dois_to_kb", _fake_add) + _mcp_mod.mcp_state = state + + try: + fn = _tool_fn("ensure_kb") + result = await fn(doi="10.1038/test") + parsed = json.loads(result) + + assert parsed["success"] is False + assert "error" in parsed + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_uninitialized_state_returns_error(self): + """ensure_kb must return error JSON when state not initialized.""" + old = _mcp_mod.mcp_state + fresh = _mcp_mod.MCPState() + _mcp_mod.mcp_state = fresh + try: + fn = _tool_fn("ensure_kb") + result = await fn(doi="10.1/x") + parsed = json.loads(result) + assert parsed["success"] is False + finally: + _mcp_mod.mcp_state = old + + +# --------------------------------------------------------------------------- +# Tests: ground_paper — compose ensure_kb + generate_report +# --------------------------------------------------------------------------- + + +class TestGroundPaper: + """ground_paper(doi, question, tier) — composes ensure_kb + generate_report.""" + + @pytest.mark.asyncio + async def test_basic_composition(self, monkeypatch): + """ground_paper calls ensure_kb then generate_report with kb_names=[slug].""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + _mcp_mod.mcp_state = state + + doi = "10.1021/acs.jnatprod.7b00737" + slug = _mcp_mod._asb_kb_slug(doi) + + ensure_calls = [] + report_calls = [] + + async def _fake_ensure(doi, mode="paper"): + ensure_calls.append({"doi": doi}) + return _json_ok({"kb_slug": slug, "status": "exists", "chunks": 10}) + + async def _fake_report(query, kb_names=None, mode="advanced", **kw): + report_calls.append({"query": query, "kb_names": kb_names, "mode": mode, **kw}) + return _json_ok({ + "report": "Answer about natural products.", + "sources": [{"doi": doi, "title": "Test Paper"}], + }) + + monkeypatch.setattr(_mcp_mod, "ensure_kb", _fake_ensure) + monkeypatch.setattr(_mcp_mod, "generate_report", _fake_report) + + try: + fn = _tool_fn("ground_paper") + result = await fn(doi=doi, question="What are the main compounds?") + parsed = json.loads(result) + + assert parsed["success"] is True + assert parsed["kb_slug"] == slug + assert "answer" in parsed + assert "sources" in parsed + + # ensure_kb was called with the doi + assert len(ensure_calls) == 1 + assert ensure_calls[0]["doi"] == doi + + # generate_report got kb_names=[slug] and mode="basic" + assert len(report_calls) == 1 + assert report_calls[0]["kb_names"] == [slug] + assert report_calls[0]["mode"] == "basic" + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_tier_si_passes_context_hint(self, monkeypatch): + """When tier='si', ground_paper prepends a supplementary context hint to the query.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + _mcp_mod.mcp_state = state + + doi = "10.1093/nar/gkad540" + slug = _mcp_mod._asb_kb_slug(doi) + + report_calls = [] + + async def _fake_ensure(doi, mode="paper"): + return _json_ok({"kb_slug": slug, "status": "exists", "chunks": 5}) + + async def _fake_report(query, kb_names=None, mode="advanced", **kw): + report_calls.append({"query": query}) + return _json_ok({"report": "SI answer.", "sources": []}) + + monkeypatch.setattr(_mcp_mod, "ensure_kb", _fake_ensure) + monkeypatch.setattr(_mcp_mod, "generate_report", _fake_report) + + try: + fn = _tool_fn("ground_paper") + result = await fn(doi=doi, question="What do the SI tables show?", tier="si") + parsed = json.loads(result) + + assert parsed["success"] is True + # The SI hint should be prepended to the query + effective_query = report_calls[0]["query"] + assert "supplementary" in effective_query.lower() + assert "What do the SI tables show?" in effective_query + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_tier_paper_passes_no_context(self, monkeypatch): + """When tier='paper' (default), query is passed verbatim (no prepended hint).""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + _mcp_mod.mcp_state = state + + doi = "10.1038/nature12345" + slug = _mcp_mod._asb_kb_slug(doi) + + report_calls = [] + + async def _fake_ensure(doi, mode="paper"): + return _json_ok({"kb_slug": slug, "status": "exists", "chunks": 5}) + + async def _fake_report(query, kb_names=None, mode="advanced", **kw): + report_calls.append({"query": query}) + return _json_ok({"report": "Paper answer.", "sources": []}) + + monkeypatch.setattr(_mcp_mod, "ensure_kb", _fake_ensure) + monkeypatch.setattr(_mcp_mod, "generate_report", _fake_report) + + try: + fn = _tool_fn("ground_paper") + result = await fn(doi=doi, question="What is the method?") + parsed = json.loads(result) + + assert parsed["success"] is True + # With tier="paper", query should be passed verbatim (no SI hint prepended) + assert report_calls[0]["query"] == "What is the method?" + finally: + _mcp_mod.mcp_state = old + + @pytest.mark.asyncio + async def test_propagates_ensure_kb_error(self, monkeypatch): + """If ensure_kb errors, ground_paper returns that error.""" + old = _mcp_mod.mcp_state + state = _make_mock_state() + _mcp_mod.mcp_state = state + + doi = "10.1/bad" + + async def _fake_ensure(doi, mode="paper"): + return _json_error("KB creation failed") + + report_calls = [] + + async def _fake_report(*a, **kw): + report_calls.append(kw) + return _json_ok({"report": "", "sources": []}) + + monkeypatch.setattr(_mcp_mod, "ensure_kb", _fake_ensure) + monkeypatch.setattr(_mcp_mod, "generate_report", _fake_report) + + try: + fn = _tool_fn("ground_paper") + result = await fn(doi=doi, question="What is the answer?") + parsed = json.loads(result) + + assert parsed["success"] is False + assert len(report_calls) == 0 + finally: + _mcp_mod.mcp_state = old if __name__ == "__main__":