Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions cogsol/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
BaseTool,
)

_SERIALIZABLE_TYPE_BASES = (
BaseRetrieval,
BaseReferenceFormatter,
BaseTopic,
BaseIngestionConfig,
BaseMetadataConfig,
BaseRetrievalTool,
)


def _normalize_code(value: Any) -> Any:
if not isinstance(value, str):
Expand Down Expand Up @@ -69,15 +78,7 @@ def serialize_value(value: Any) -> Any:
getattr(value, "name", None) or getattr(value, "key", None) or value.__class__.__name__
)
if isinstance(value, type):
if issubclass(value, BaseRetrieval):
return getattr(value, "name", None) or value.__name__
if issubclass(value, BaseReferenceFormatter):
return getattr(value, "name", None) or value.__name__
if issubclass(value, BaseTopic):
return getattr(value, "name", None) or value.__name__
if issubclass(value, BaseIngestionConfig):
return getattr(value, "name", None) or value.__name__
if issubclass(value, BaseRetrievalTool):
if any(issubclass(value, b) for b in _SERIALIZABLE_TYPE_BASES):
return getattr(value, "name", None) or value.__name__
if is_dataclass(value) and not isinstance(value, type):
data = asdict(value)
Expand Down Expand Up @@ -465,7 +466,8 @@ def collect_classes(project_path: Path, app_name: str = "agents") -> dict[str, d
or obj.__module__.startswith(retrieval_prefix)
)
):
classes["retrieval_tools"][obj.__name__] = obj
key = getattr(obj, "name", None) or obj.__name__
classes["retrieval_tools"][key] = obj
except ModuleNotFoundError as exc:
if not _ignore_missing_module(exc, f"{app_name}.searches"):
_raise_import_error("retrieval tools module", f"{app_name}.searches", exc)
Expand Down
100 changes: 59 additions & 41 deletions cogsol/management/commands/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,46 @@ def _sync_content_with_api(
created.append(("formatter", None, new_id))
new_remote.setdefault("formatters", {})[fmt_name] = new_id

# Upsert metadata configs (before retrievals, since filters reference them)
for cfg_key, definition in state.get("metadata_configs", {}).items():
if touched is not None and cfg_key not in touched.get("metadata_configs", set()):
continue
fields = definition.get("fields", {})
topic_path = definition.get("topic", "")
cfg_name = fields.get("name")
if not topic_path or not cfg_name:
continue
node_id = topic_id_map.get(topic_path) or new_remote.get("topics", {}).get(
topic_path
)
if not node_id:
continue

cfg_payload = {
"name": cfg_name,
"type": fields.get("type", "STRING"),
"possible_values": fields.get("possible_values", []),
"default_value": fields.get("default_value"),
"format": fields.get("format"),
"filtrable": fields.get("filtrable", False),
"required": fields.get("required", False),
"in_embedding": fields.get("in_embedding", False),
"in_retrieval": fields.get("in_retrieval", True),
}
if cfg_payload["required"] and cfg_payload.get("default_value") is None:
raise CogSolAPIError(
"Default value is required for required metadata configs. "
f"Set default_value for '{cfg_key}'."
)

cfg_remote_id = new_remote.get("metadata_configs", {}).get(cfg_key)
if cfg_remote_id:
client.update_metadata_config(cfg_remote_id, cfg_payload)
else:
new_cfg_id = client.create_metadata_config(node_id=node_id, payload=cfg_payload)
created.append(("metadata_config", node_id, new_cfg_id))
new_remote.setdefault("metadata_configs", {})[cfg_key] = new_cfg_id

# Upsert retrievals
for ret_name, definition in state.get("retrievals", {}).items():
if touched is not None and ret_name not in touched.get("retrievals", set()):
Expand Down Expand Up @@ -351,7 +391,25 @@ def _set_if_defined(
_set_if_defined("next_blocks")
_set_if_defined("contingency_for_embedding")
_set_if_defined("threshold_similarity")
_set_if_defined("filters")
if "filters" in fields and fields["filters"]:
filters_value = fields["filters"]
if not isinstance(filters_value, list):
raise CogSolAPIError(
f"filters must be a list of metadata config names. "
f"Fix retrieval '{ret_name}'."
)
metadata_configs = new_remote.get("metadata_configs", {})
filters_payload: list[int] = []
for filter_name in filters_value:
filter_id = metadata_configs.get(filter_name)
if filter_id is None:
raise CogSolAPIError(
"MetadataConfig must be migrated before use as filter. "
f"Missing metadata config id for '{filter_name}' "
f"in retrieval '{ret_name}'."
)
filters_payload.append(int(filter_id))
retrieval_payload["filters"] = filters_payload

if (
"strategy_reordering" in retrieval_payload
Expand Down Expand Up @@ -425,46 +483,6 @@ def _set_if_defined(
created.append(("ingestion_config", None, new_id))
new_remote.setdefault("ingestion_configs", {})[cfg_name] = new_id

# Upsert metadata configs
for cfg_key, definition in state.get("metadata_configs", {}).items():
if touched is not None and cfg_key not in touched.get("metadata_configs", set()):
continue
fields = definition.get("fields", {})
topic_path = definition.get("topic", "")
cfg_name = fields.get("name")
if not topic_path or not cfg_name:
continue
node_id = topic_id_map.get(topic_path) or new_remote.get("topics", {}).get(
topic_path
)
if not node_id:
continue

cfg_payload = {
"name": cfg_name,
"type": fields.get("type", "STRING"),
"possible_values": fields.get("possible_values", []),
"default_value": fields.get("default_value"),
"format": fields.get("format"),
"filtrable": fields.get("filtrable", False),
"required": fields.get("required", False),
"in_embedding": fields.get("in_embedding", False),
"in_retrieval": fields.get("in_retrieval", True),
}
if cfg_payload["required"] and cfg_payload.get("default_value") is None:
raise CogSolAPIError(
"Default value is required for required metadata configs. "
f"Set default_value for '{cfg_key}'."
)

cfg_remote_id = new_remote.get("metadata_configs", {}).get(cfg_key)
if cfg_remote_id:
client.update_metadata_config(cfg_remote_id, cfg_payload)
else:
new_cfg_id = client.create_metadata_config(node_id=node_id, payload=cfg_payload)
created.append(("metadata_config", node_id, new_cfg_id))
new_remote.setdefault("metadata_configs", {})[cfg_key] = new_cfg_id

return new_remote

except Exception:
Expand Down
107 changes: 107 additions & 0 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Tests for cogsol.core.loader utilities.
"""

import tempfile
from pathlib import Path

from cogsol.content import BaseMetadataConfig, BaseRetrieval
from cogsol.core.loader import collect_classes, serialize_value


class TestSerializeValueTypeHandling:
"""Tests for serialize_value handling of type subclasses."""

def test_metadata_config_with_name(self):
"""BaseMetadataConfig subclass with explicit name returns name."""

class GenreMetadata(BaseMetadataConfig):
name = "genre"

assert serialize_value(GenreMetadata) == "genre"

def test_metadata_config_without_name(self):
"""BaseMetadataConfig subclass without name returns __name__."""

class GenreMetadata(BaseMetadataConfig):
pass

assert serialize_value(GenreMetadata) == "GenreMetadata"

def test_metadata_config_list(self):
"""List of BaseMetadataConfig subclasses serializes each element."""

class GenreMetadata(BaseMetadataConfig):
name = "genre"

class LanguageMetadata(BaseMetadataConfig):
name = "language"

result = serialize_value([GenreMetadata, LanguageMetadata])
assert result == ["genre", "language"]

def test_retrieval_still_works(self):
"""Existing BaseRetrieval handling is not broken."""

class MyRetrieval(BaseRetrieval):
name = "my_retrieval"

assert serialize_value(MyRetrieval) == "my_retrieval"


class TestCollectClassesRetrievalToolKey:
"""Tests for collect_classes keying retrieval tools by name attr."""

def test_retrieval_tool_keyed_by_name_attr(self):
"""collect_classes should use the name attribute, not __name__, as dict key."""
with tempfile.TemporaryDirectory() as tmpdir:
project_path = Path(tmpdir)
agents_path = project_path / "agents"
agents_path.mkdir(parents=True)

(agents_path / "__init__.py").write_text("", encoding="utf-8")
(agents_path / "tools.py").write_text("", encoding="utf-8")
(agents_path / "searches.py").write_text(
"""
from cogsol.tools import BaseRetrievalTool

class ProductDocsSearch(BaseRetrievalTool):
name = "product_docs_search"
description = "Search product docs."
retrieval = "product_docs"
parameters = [
{"name": "question", "description": "Query", "type": "string", "required": True}
]
""",
encoding="utf-8",
)

classes = collect_classes(project_path, "agents")
assert "product_docs_search" in classes["retrieval_tools"]
assert "ProductDocsSearch" not in classes["retrieval_tools"]

def test_retrieval_tool_without_name_falls_back_to_class_name(self):
"""collect_classes should fall back to __name__ when name attr is not set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_path = Path(tmpdir)
agents_path = project_path / "agents"
agents_path.mkdir(parents=True)

(agents_path / "__init__.py").write_text("", encoding="utf-8")
(agents_path / "tools.py").write_text("", encoding="utf-8")
(agents_path / "searches.py").write_text(
"""
from cogsol.tools import BaseRetrievalTool

class MySearch(BaseRetrievalTool):
description = "Search."
retrieval = "my_retrieval"
parameters = [
{"name": "question", "description": "Query", "type": "string", "required": True}
]
""",
encoding="utf-8",
)

classes = collect_classes(project_path, "agents")
assert "MySearch" in classes["retrieval_tools"]