diff --git a/cogsol/core/loader.py b/cogsol/core/loader.py index cc42ade..922dc7c 100644 --- a/cogsol/core/loader.py +++ b/cogsol/core/loader.py @@ -31,6 +31,15 @@ BaseTool, ) +_SERIALIZABLE_TYPE_BASES = ( + BaseRetrieval, + BaseReferenceFormatter, + BaseTopic, + BaseIngestionConfig, + BaseMetadataConfig, + BaseRetrievalTool, +) + def _normalize_code(value: Any) -> Any: if not isinstance(value, str): @@ -69,15 +78,7 @@ def serialize_value(value: Any) -> Any: getattr(value, "name", None) or getattr(value, "key", None) or value.__class__.__name__ ) if isinstance(value, type): - if issubclass(value, BaseRetrieval): - return getattr(value, "name", None) or value.__name__ - if issubclass(value, BaseReferenceFormatter): - return getattr(value, "name", None) or value.__name__ - if issubclass(value, BaseTopic): - return getattr(value, "name", None) or value.__name__ - if issubclass(value, BaseIngestionConfig): - return getattr(value, "name", None) or value.__name__ - if issubclass(value, BaseRetrievalTool): + if any(issubclass(value, b) for b in _SERIALIZABLE_TYPE_BASES): return getattr(value, "name", None) or value.__name__ if is_dataclass(value) and not isinstance(value, type): data = asdict(value) @@ -465,7 +466,8 @@ def collect_classes(project_path: Path, app_name: str = "agents") -> dict[str, d or obj.__module__.startswith(retrieval_prefix) ) ): - classes["retrieval_tools"][obj.__name__] = obj + key = getattr(obj, "name", None) or obj.__name__ + classes["retrieval_tools"][key] = obj except ModuleNotFoundError as exc: if not _ignore_missing_module(exc, f"{app_name}.searches"): _raise_import_error("retrieval tools module", f"{app_name}.searches", exc) diff --git a/cogsol/management/commands/migrate.py b/cogsol/management/commands/migrate.py index 35763c0..3b995c8 100644 --- a/cogsol/management/commands/migrate.py +++ b/cogsol/management/commands/migrate.py @@ -307,6 +307,46 @@ def _sync_content_with_api( created.append(("formatter", None, new_id)) new_remote.setdefault("formatters", {})[fmt_name] = new_id + # Upsert metadata configs (before retrievals, since filters reference them) + for cfg_key, definition in state.get("metadata_configs", {}).items(): + if touched is not None and cfg_key not in touched.get("metadata_configs", set()): + continue + fields = definition.get("fields", {}) + topic_path = definition.get("topic", "") + cfg_name = fields.get("name") + if not topic_path or not cfg_name: + continue + node_id = topic_id_map.get(topic_path) or new_remote.get("topics", {}).get( + topic_path + ) + if not node_id: + continue + + cfg_payload = { + "name": cfg_name, + "type": fields.get("type", "STRING"), + "possible_values": fields.get("possible_values", []), + "default_value": fields.get("default_value"), + "format": fields.get("format"), + "filtrable": fields.get("filtrable", False), + "required": fields.get("required", False), + "in_embedding": fields.get("in_embedding", False), + "in_retrieval": fields.get("in_retrieval", True), + } + if cfg_payload["required"] and cfg_payload.get("default_value") is None: + raise CogSolAPIError( + "Default value is required for required metadata configs. " + f"Set default_value for '{cfg_key}'." + ) + + cfg_remote_id = new_remote.get("metadata_configs", {}).get(cfg_key) + if cfg_remote_id: + client.update_metadata_config(cfg_remote_id, cfg_payload) + else: + new_cfg_id = client.create_metadata_config(node_id=node_id, payload=cfg_payload) + created.append(("metadata_config", node_id, new_cfg_id)) + new_remote.setdefault("metadata_configs", {})[cfg_key] = new_cfg_id + # Upsert retrievals for ret_name, definition in state.get("retrievals", {}).items(): if touched is not None and ret_name not in touched.get("retrievals", set()): @@ -351,7 +391,25 @@ def _set_if_defined( _set_if_defined("next_blocks") _set_if_defined("contingency_for_embedding") _set_if_defined("threshold_similarity") - _set_if_defined("filters") + if "filters" in fields and fields["filters"]: + filters_value = fields["filters"] + if not isinstance(filters_value, list): + raise CogSolAPIError( + f"filters must be a list of metadata config names. " + f"Fix retrieval '{ret_name}'." + ) + metadata_configs = new_remote.get("metadata_configs", {}) + filters_payload: list[int] = [] + for filter_name in filters_value: + filter_id = metadata_configs.get(filter_name) + if filter_id is None: + raise CogSolAPIError( + "MetadataConfig must be migrated before use as filter. " + f"Missing metadata config id for '{filter_name}' " + f"in retrieval '{ret_name}'." + ) + filters_payload.append(int(filter_id)) + retrieval_payload["filters"] = filters_payload if ( "strategy_reordering" in retrieval_payload @@ -425,46 +483,6 @@ def _set_if_defined( created.append(("ingestion_config", None, new_id)) new_remote.setdefault("ingestion_configs", {})[cfg_name] = new_id - # Upsert metadata configs - for cfg_key, definition in state.get("metadata_configs", {}).items(): - if touched is not None and cfg_key not in touched.get("metadata_configs", set()): - continue - fields = definition.get("fields", {}) - topic_path = definition.get("topic", "") - cfg_name = fields.get("name") - if not topic_path or not cfg_name: - continue - node_id = topic_id_map.get(topic_path) or new_remote.get("topics", {}).get( - topic_path - ) - if not node_id: - continue - - cfg_payload = { - "name": cfg_name, - "type": fields.get("type", "STRING"), - "possible_values": fields.get("possible_values", []), - "default_value": fields.get("default_value"), - "format": fields.get("format"), - "filtrable": fields.get("filtrable", False), - "required": fields.get("required", False), - "in_embedding": fields.get("in_embedding", False), - "in_retrieval": fields.get("in_retrieval", True), - } - if cfg_payload["required"] and cfg_payload.get("default_value") is None: - raise CogSolAPIError( - "Default value is required for required metadata configs. " - f"Set default_value for '{cfg_key}'." - ) - - cfg_remote_id = new_remote.get("metadata_configs", {}).get(cfg_key) - if cfg_remote_id: - client.update_metadata_config(cfg_remote_id, cfg_payload) - else: - new_cfg_id = client.create_metadata_config(node_id=node_id, payload=cfg_payload) - created.append(("metadata_config", node_id, new_cfg_id)) - new_remote.setdefault("metadata_configs", {})[cfg_key] = new_cfg_id - return new_remote except Exception: diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 0000000..69aaab9 --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,107 @@ +""" +Tests for cogsol.core.loader utilities. +""" + +import tempfile +from pathlib import Path + +from cogsol.content import BaseMetadataConfig, BaseRetrieval +from cogsol.core.loader import collect_classes, serialize_value + + +class TestSerializeValueTypeHandling: + """Tests for serialize_value handling of type subclasses.""" + + def test_metadata_config_with_name(self): + """BaseMetadataConfig subclass with explicit name returns name.""" + + class GenreMetadata(BaseMetadataConfig): + name = "genre" + + assert serialize_value(GenreMetadata) == "genre" + + def test_metadata_config_without_name(self): + """BaseMetadataConfig subclass without name returns __name__.""" + + class GenreMetadata(BaseMetadataConfig): + pass + + assert serialize_value(GenreMetadata) == "GenreMetadata" + + def test_metadata_config_list(self): + """List of BaseMetadataConfig subclasses serializes each element.""" + + class GenreMetadata(BaseMetadataConfig): + name = "genre" + + class LanguageMetadata(BaseMetadataConfig): + name = "language" + + result = serialize_value([GenreMetadata, LanguageMetadata]) + assert result == ["genre", "language"] + + def test_retrieval_still_works(self): + """Existing BaseRetrieval handling is not broken.""" + + class MyRetrieval(BaseRetrieval): + name = "my_retrieval" + + assert serialize_value(MyRetrieval) == "my_retrieval" + + +class TestCollectClassesRetrievalToolKey: + """Tests for collect_classes keying retrieval tools by name attr.""" + + def test_retrieval_tool_keyed_by_name_attr(self): + """collect_classes should use the name attribute, not __name__, as dict key.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_path = Path(tmpdir) + agents_path = project_path / "agents" + agents_path.mkdir(parents=True) + + (agents_path / "__init__.py").write_text("", encoding="utf-8") + (agents_path / "tools.py").write_text("", encoding="utf-8") + (agents_path / "searches.py").write_text( + """ +from cogsol.tools import BaseRetrievalTool + +class ProductDocsSearch(BaseRetrievalTool): + name = "product_docs_search" + description = "Search product docs." + retrieval = "product_docs" + parameters = [ + {"name": "question", "description": "Query", "type": "string", "required": True} + ] +""", + encoding="utf-8", + ) + + classes = collect_classes(project_path, "agents") + assert "product_docs_search" in classes["retrieval_tools"] + assert "ProductDocsSearch" not in classes["retrieval_tools"] + + def test_retrieval_tool_without_name_falls_back_to_class_name(self): + """collect_classes should fall back to __name__ when name attr is not set.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_path = Path(tmpdir) + agents_path = project_path / "agents" + agents_path.mkdir(parents=True) + + (agents_path / "__init__.py").write_text("", encoding="utf-8") + (agents_path / "tools.py").write_text("", encoding="utf-8") + (agents_path / "searches.py").write_text( + """ +from cogsol.tools import BaseRetrievalTool + +class MySearch(BaseRetrievalTool): + description = "Search." + retrieval = "my_retrieval" + parameters = [ + {"name": "question", "description": "Query", "type": "string", "required": True} + ] +""", + encoding="utf-8", + ) + + classes = collect_classes(project_path, "agents") + assert "MySearch" in classes["retrieval_tools"]