Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cogsol/management/commands/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,14 +764,15 @@ def _resolve_retrieval_id(value: Any) -> int:
return int(retrieval_id)

params = list(fields.get("parameters") or [])
if not params:
params.append(
if not any(p.get("name") == "question" for p in params):
params.insert(
0,
{
"name": "question",
"description": "Search query",
"type": "string",
"required": True,
}
},
)
description = fields.get("description") or f"Retrieval tool {tool_name}"
retrieval_id = _resolve_retrieval_id(fields.get("retrieval"))
Expand Down
75 changes: 75 additions & 0 deletions tests/test_migrate_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import ast
import json
from pathlib import Path

from cogsol.management.commands.migrate import Command

Expand Down Expand Up @@ -131,3 +133,76 @@ def test_preserves_multiline_helper_signature(self) -> None:
assert "response = helper(text=text)" in script
assert "self.helper" not in script
ast.parse(script)


def _make_state_file(tmp_path: Path, retrieval_key: str = "my_retrieval") -> Path:
"""Create a minimal content state file with a fake retrieval ID."""
project = tmp_path / "data" / "migrations"
project.mkdir(parents=True, exist_ok=True)
state_path = project / ".state.json"
state_path.write_text(
json.dumps(
{
"state": {},
"remote": {"retrievals": {retrieval_key: 42}},
}
)
)
return tmp_path


class TestRetrievalToolQuestionParam:
"""Tests for auto-injection of the question parameter in retrieval tools."""

def test_adds_question_when_params_empty(self, tmp_path: Path) -> None:
project = _make_state_file(tmp_path)
definition = {
"fields": {
"retrieval": "my_retrieval",
}
}

payload = Command()._retrieval_tool_payload(
tool_name="search", definition=definition, project_path=project,
)

assert payload["parameters"][0]["name"] == "question"
assert len(payload["parameters"]) == 1

def test_adds_question_when_only_filter_params(self, tmp_path: Path) -> None:
project = _make_state_file(tmp_path)
definition = {
"fields": {
"retrieval": "my_retrieval",
"parameters": [
{"name": "genre", "description": "Filter by genre", "type": "string", "required": False},
],
}
}

payload = Command()._retrieval_tool_payload(
tool_name="search", definition=definition, project_path=project,
)

param_names = [p["name"] for p in payload["parameters"]]
assert param_names == ["question", "genre"]

def test_does_not_duplicate_question_if_already_present(self, tmp_path: Path) -> None:
project = _make_state_file(tmp_path)
definition = {
"fields": {
"retrieval": "my_retrieval",
"parameters": [
{"name": "question", "description": "Custom query", "type": "string", "required": True},
{"name": "genre", "description": "Filter by genre", "type": "string", "required": False},
],
}
}

payload = Command()._retrieval_tool_payload(
tool_name="search", definition=definition, project_path=project,
)

question_params = [p for p in payload["parameters"] if p["name"] == "question"]
assert len(question_params) == 1
assert question_params[0]["description"] == "Custom query"
Loading