From 6acb0b777c5c5c8c13eb909a9d73c6b0633f125f Mon Sep 17 00:00:00 2001 From: Micaela Perdomo Date: Sun, 22 Feb 2026 11:28:34 -0300 Subject: [PATCH] fix(loader): use consistent key for retrieval tools in collect_classes collect_classes stores retrieval tool entries using obj.__name__ (e.g. "ProductDocsSearch"), but collect_definitions uses the explicit name attribute (e.g. "product_docs_search"). When the migration runner looks up classes by definition key, the mismatch causes "retrieval is required for retrieval tool" errors. Use getattr(obj, "name", None) or obj.__name__ as the key, matching the logic already used in collect_definitions. --- cogsol/core/loader.py | 3 ++- tests/test_agents.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/cogsol/core/loader.py b/cogsol/core/loader.py index cc42ade..bb82fb9 100644 --- a/cogsol/core/loader.py +++ b/cogsol/core/loader.py @@ -465,7 +465,8 @@ def collect_classes(project_path: Path, app_name: str = "agents") -> dict[str, d or obj.__module__.startswith(retrieval_prefix) ) ): - classes["retrieval_tools"][obj.__name__] = obj + key = getattr(obj, "name", None) or obj.__name__ + classes["retrieval_tools"][key] = obj except ModuleNotFoundError as exc: if not _ignore_missing_module(exc, f"{app_name}.searches"): _raise_import_error("retrieval tools module", f"{app_name}.searches", exc) diff --git a/tests/test_agents.py b/tests/test_agents.py index 146d4cb..73c39e0 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -280,3 +280,61 @@ class ProductDocsSearch(BaseRetrievalTool): ops = diff_states(empty_state(), defs, app="agents") create_ops = [op for op in ops if isinstance(op, CreateRetrievalTool)] assert len(create_ops) == 1 + + +class TestCollectClassesRetrievalToolKey: + """Tests for consistent keying of retrieval tools in collect_classes.""" + + def test_explicit_name_used_as_key(self): + """collect_classes should use the explicit name attribute as the key.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_path = Path(tmpdir) + agents_path = project_path / "agents" + agents_path.mkdir(parents=True) + + (agents_path / "__init__.py").write_text("", encoding="utf-8") + (agents_path / "tools.py").write_text("", encoding="utf-8") + (agents_path / "searches.py").write_text( + """ +from cogsol.tools import BaseRetrievalTool + +class ProductDocsSearch(BaseRetrievalTool): + name = "product_docs_search" + description = "Search product docs." + retrieval = "product_docs_search" + parameters = [] +""", + encoding="utf-8", + ) + + from cogsol.core.loader import collect_classes + + classes = collect_classes(project_path, "agents") + assert "product_docs_search" in classes["retrieval_tools"] + assert "ProductDocsSearch" not in classes["retrieval_tools"] + + def test_fallback_to_class_name(self): + """collect_classes should fall back to __name__ when no name attribute.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_path = Path(tmpdir) + agents_path = project_path / "agents" + agents_path.mkdir(parents=True) + + (agents_path / "__init__.py").write_text("", encoding="utf-8") + (agents_path / "tools.py").write_text("", encoding="utf-8") + (agents_path / "searches.py").write_text( + """ +from cogsol.tools import BaseRetrievalTool + +class SimpleSearch(BaseRetrievalTool): + description = "Simple search." + retrieval = "simple_search" + parameters = [] +""", + encoding="utf-8", + ) + + from cogsol.core.loader import collect_classes + + classes = collect_classes(project_path, "agents") + assert "SimpleSearch" in classes["retrieval_tools"]