diff --git a/cogsol/management/commands/migrate.py b/cogsol/management/commands/migrate.py index 41353a1..b6c55ca 100644 --- a/cogsol/management/commands/migrate.py +++ b/cogsol/management/commands/migrate.py @@ -764,14 +764,15 @@ def _resolve_retrieval_id(value: Any) -> int: return int(retrieval_id) params = list(fields.get("parameters") or []) - if not params: - params.append( + if not any(p.get("name") == "question" for p in params): + params.insert( + 0, { "name": "question", "description": "Search query", "type": "string", "required": True, - } + }, ) description = fields.get("description") or f"Retrieval tool {tool_name}" retrieval_id = _resolve_retrieval_id(fields.get("retrieval")) diff --git a/tests/test_migrate_tools.py b/tests/test_migrate_tools.py index 6cfa820..217cef1 100644 --- a/tests/test_migrate_tools.py +++ b/tests/test_migrate_tools.py @@ -3,6 +3,8 @@ """ import ast +import json +from pathlib import Path from cogsol.management.commands.migrate import Command @@ -131,3 +133,76 @@ def test_preserves_multiline_helper_signature(self) -> None: assert "response = helper(text=text)" in script assert "self.helper" not in script ast.parse(script) + + +def _make_state_file(tmp_path: Path, retrieval_key: str = "my_retrieval") -> Path: + """Create a minimal content state file with a fake retrieval ID.""" + project = tmp_path / "data" / "migrations" + project.mkdir(parents=True, exist_ok=True) + state_path = project / ".state.json" + state_path.write_text( + json.dumps( + { + "state": {}, + "remote": {"retrievals": {retrieval_key: 42}}, + } + ) + ) + return tmp_path + + +class TestRetrievalToolQuestionParam: + """Tests for auto-injection of the question parameter in retrieval tools.""" + + def test_adds_question_when_params_empty(self, tmp_path: Path) -> None: + project = _make_state_file(tmp_path) + definition = { + "fields": { + "retrieval": "my_retrieval", + } + } + + payload = Command()._retrieval_tool_payload( + tool_name="search", definition=definition, project_path=project, + ) + + assert payload["parameters"][0]["name"] == "question" + assert len(payload["parameters"]) == 1 + + def test_adds_question_when_only_filter_params(self, tmp_path: Path) -> None: + project = _make_state_file(tmp_path) + definition = { + "fields": { + "retrieval": "my_retrieval", + "parameters": [ + {"name": "genre", "description": "Filter by genre", "type": "string", "required": False}, + ], + } + } + + payload = Command()._retrieval_tool_payload( + tool_name="search", definition=definition, project_path=project, + ) + + param_names = [p["name"] for p in payload["parameters"]] + assert param_names == ["question", "genre"] + + def test_does_not_duplicate_question_if_already_present(self, tmp_path: Path) -> None: + project = _make_state_file(tmp_path) + definition = { + "fields": { + "retrieval": "my_retrieval", + "parameters": [ + {"name": "question", "description": "Custom query", "type": "string", "required": True}, + {"name": "genre", "description": "Filter by genre", "type": "string", "required": False}, + ], + } + } + + payload = Command()._retrieval_tool_payload( + tool_name="search", definition=definition, project_path=project, + ) + + question_params = [p for p in payload["parameters"] if p["name"] == "question"] + assert len(question_params) == 1 + assert question_params[0]["description"] == "Custom query"