diff --git a/briefing.py b/briefing.py index d11f533d..7653b16f 100755 --- a/briefing.py +++ b/briefing.py @@ -1588,7 +1588,8 @@ def _apply_feedback_bias_to_knowledge( verdicts_by_id: dict[str, list[int]] = {} if normalized_query: for r in rows: - if _normalize_feedback_query(r["query"] or "") != normalized_query: + row_q = _normalize_feedback_query(r["query"] or "") + if row_q != normalized_query and row_q != "*": continue rid = str(r["result_id"] or "") verdicts_by_id.setdefault(rid, []).append(int(r["verdict"])) diff --git a/mcp-server.py b/mcp-server.py index 202e3f4d..e0899600 100644 --- a/mcp-server.py +++ b/mcp-server.py @@ -267,6 +267,24 @@ def _load_script_module(module_name: str, filename: str): "additionalProperties": False, }, }, + { + "name": "rate_entry", + "description": "Rate a knowledge entry as helpful or misleading to improve future briefings.", + "inputSchema": { + "type": "object", + "properties": { + "entry_id": {"type": "integer", "description": "ID of the knowledge entry to rate"}, + "verdict": { + "type": "string", + "enum": ["good", "bad", "neutral"], + "description": "helpful=good, misleading=bad, neutral=neutral", + }, + "note": {"type": "string", "description": "Optional note (max 500 chars)", "maxLength": 500}, + }, + "required": ["entry_id", "verdict"], + "additionalProperties": False, + }, + }, ] @@ -813,6 +831,51 @@ def _run_code_search(arguments: dict) -> dict: } +def _run_rate_entry(arguments: dict[str, Any]) -> dict[str, Any]: + """Write a feedback row for a knowledge entry (good/bad/neutral).""" + _check_auth(arguments) + entry_id = arguments.get("entry_id") + verdict = arguments.get("verdict", "neutral") + note = str(arguments.get("note") or "")[:500] + + if not isinstance(entry_id, int) or verdict not in ("good", "bad", "neutral"): + raise JsonRpcError(JSONRPC_INVALID_PARAMS, "entry_id (int) and verdict (good|bad|neutral) required") + + verdict_map = {"good": 1, "neutral": 0, "bad": -1} + score = verdict_map[verdict] + + if not _DB_PATH.exists(): + raise JsonRpcError(JSONRPC_INTERNAL_ERROR, f"Knowledge DB not found: {_DB_PATH}") + + try: + db = sqlite3.connect(str(_DB_PATH)) + except sqlite3.OperationalError as exc: + raise JsonRpcError(JSONRPC_INTERNAL_ERROR, f"DB open error: {exc}") from exc + + try: + row = db.execute("SELECT id, title FROM knowledge_entries WHERE id = ?", (entry_id,)).fetchone() + if not row: + body = {"error": f"Entry #{entry_id} not found"} + return {"content": [{"type": "text", "text": json.dumps(body)}], "structuredContent": body} + + import time as _time + + created_at = _time.strftime("%Y-%m-%dT%H:%M:%S", _time.gmtime()) + db.execute( + "INSERT INTO search_feedback (query, result_id, result_kind, verdict, created_at, note)" + " VALUES (?, ?, 'knowledge', ?, ?, ?)", + ("*", str(entry_id), score, created_at, note or None), + ) + db.commit() + except sqlite3.OperationalError as exc: + raise JsonRpcError(JSONRPC_INTERNAL_ERROR, f"DB write error: {exc}") from exc + finally: + db.close() + + body = {"status": "ok", "entry_id": entry_id, "title": row[1], "verdict": verdict} + return {"content": [{"type": "text", "text": json.dumps(body, ensure_ascii=False)}], "structuredContent": body} + + def _handle_tools_call(params: dict[str, Any]) -> dict[str, Any]: name = params.get("name") if not isinstance(name, str) or not name: @@ -836,6 +899,8 @@ def _handle_tools_call(params: dict[str, Any]) -> dict[str, Any]: return _run_session_list(arguments) if name == "code_search": return _run_code_search(arguments) + if name == "rate_entry": + return _run_rate_entry(arguments) raise JsonRpcError(JSONRPC_INVALID_PARAMS, f"Unknown tool: {name}") diff --git a/migrate.py b/migrate.py index ae1b42a5..2a59faa3 100755 --- a/migrate.py +++ b/migrate.py @@ -1839,6 +1839,13 @@ def _ensure_base_schema(db: sqlite3.Connection): "CREATE INDEX IF NOT EXISTS idx_tool_spans_tool ON tool_spans (tool_name)", ], ), + ( + 42, + "search_feedback_note", + [ + "ALTER TABLE search_feedback ADD COLUMN note TEXT", + ], + ), ] applied = 0 for ver, name, stmts in MIGRATIONS: diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 8e6346f2..f224a2a6 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -25,6 +25,7 @@ import os import sqlite3 import sys +import tempfile import types import unittest from pathlib import Path @@ -102,8 +103,8 @@ def setUp(self): self.tools = {t["name"]: t for t in mcp.TOOLS} def test_exactly_two_tools(self): - # Updated: wave 8 added learn, status, session_list; code_search added later; now 7 tools total - self.assertEqual(len(mcp.TOOLS), 7) + # Updated: wave 8 added learn, status, session_list; code_search added later; rate_entry added (#820); now 8 tools total + self.assertEqual(len(mcp.TOOLS), 8) def test_briefing_tool_present(self): self.assertIn("briefing", self.tools) @@ -1097,8 +1098,23 @@ def test_code_search_tool_has_description(self): self.assertTrue(self.tools["code_search"]["description"]) def test_exactly_three_tools(self): - # Updated: wave 8 added learn, status, session_list; code_search added later; now 7 tools total - self.assertEqual(len(mcp.TOOLS), 7) + # Updated: wave 8 added learn, status, session_list; code_search added later; rate_entry added (#820); now 8 tools total + self.assertEqual(len(mcp.TOOLS), 8) + + def test_rate_entry_tool_present(self): + self.assertIn("rate_entry", self.tools) + + def test_rate_entry_has_description(self): + self.assertTrue(self.tools["rate_entry"]["description"]) + + def test_rate_entry_required_fields(self): + schema = self.tools["rate_entry"]["inputSchema"] + self.assertIn("entry_id", schema["required"]) + self.assertIn("verdict", schema["required"]) + + def test_rate_entry_verdict_enum(self): + schema = self.tools["rate_entry"]["inputSchema"] + self.assertEqual(schema["properties"]["verdict"]["enum"], ["good", "bad", "neutral"]) # --------------------------------------------------------------------------- @@ -1107,3 +1123,57 @@ def test_exactly_three_tools(self): if __name__ == "__main__": unittest.main(verbosity=2) + + +class TestRateEntryExecution(unittest.TestCase): + """Test _run_rate_entry writes feedback and requires auth.""" + + def setUp(self): + self.db_path = Path(tempfile.mkdtemp()) / "test.db" + db = sqlite3.connect(str(self.db_path)) + db.execute( + "CREATE TABLE knowledge_entries (id INTEGER PRIMARY KEY, title TEXT, content TEXT, " + "category TEXT, confidence REAL, session_id TEXT, occurrence_count INTEGER, " + "first_seen TEXT, last_seen TEXT)" + ) + db.execute( + "INSERT INTO knowledge_entries (id, title, content, category, confidence, session_id) VALUES (1, 'Test', 'body', 'pattern', 0.8, 's1')" + ) + db.execute( + "CREATE TABLE search_feedback (id INTEGER PRIMARY KEY AUTOINCREMENT, query TEXT, " + "result_id TEXT, result_kind TEXT, verdict INTEGER, created_at TEXT, note TEXT)" + ) + db.commit() + db.close() + + def test_rate_entry_writes_feedback(self): + import importlib + import sys + + # Patch _DB_PATH + spec = importlib.util.spec_from_file_location("mcp_server", "mcp-server.py") + mod = importlib.util.module_from_spec(spec) + mod._DB_PATH = self.db_path + # Mock _check_auth to pass + mod._check_auth = lambda args: None + spec.loader.exec_module(mod) + mod._DB_PATH = self.db_path + mod._check_auth = lambda args: None + + result = mod._run_rate_entry({"entry_id": 1, "verdict": "good"}) + body = json.loads(result["content"][0]["text"]) + self.assertEqual(body["status"], "ok") + self.assertEqual(body["entry_id"], 1) + self.assertEqual(body["verdict"], "good") + + # Verify feedback row uses "*" as query for universal matching + db = sqlite3.connect(str(self.db_path)) + row = db.execute("SELECT query, result_id, verdict FROM search_feedback").fetchone() + db.close() + self.assertEqual(row[0], "*") + self.assertEqual(row[1], "1") + self.assertEqual(row[2], 1) + + +if __name__ == "__main__": + unittest.main()