diff --git a/cogsol/management/commands/startproject.py b/cogsol/management/commands/startproject.py index 46dbe32..365ed91 100644 --- a/cogsol/management/commands/startproject.py +++ b/cogsol/management/commands/startproject.py @@ -121,12 +121,13 @@ def run(self, chat=None, data=None, secrets=None, log=None, text: str = "", coun DATA_RETRIEVALS_PY = """\ from cogsol.content import BaseRetrieval, ReorderingStrategy +# from data.product_docs import ProductDocsTopic # # class ProductDocsRetrieval(BaseRetrieval): # \"\"\"Sample retrieval configuration.\"\"\" # # name = "product_docs_search" -# topic = "product_docs" +# topic = ProductDocsTopic # num_refs = 10 # reordering = False # strategy_reordering = ReorderingStrategy.NONE diff --git a/docs/agents-tools.md b/docs/agents-tools.md index d4637cd..f676c07 100644 --- a/docs/agents-tools.md +++ b/docs/agents-tools.md @@ -462,10 +462,11 @@ Retrieval tools reference Content API retrievals defined in `data/retrievals.py` ```python # data/retrievals.py from cogsol.content import BaseRetrieval +from data.product_docs import ProductDocsTopic class ProductDocsRetrieval(BaseRetrieval): name = "product_docs_search" - topic = "product_docs" + topic = ProductDocsTopic num_refs = 10 # agents/searches.py @@ -682,11 +683,12 @@ Retrievals define semantic search behavior. Place them in `data/retrievals.py`. ```python from cogsol.content import BaseRetrieval, ReorderingStrategy from data.formatters import DefaultFormatter +from data.product_docs import ProductDocsTopic from data.product_docs.metadata import ProductMetadata class ProductDocsRetrieval(BaseRetrieval): name = "product_docs_search" - topic = "product_docs" + topic = ProductDocsTopic num_refs = 10 reordering = False strategy_reordering = ReorderingStrategy.NONE @@ -699,7 +701,7 @@ class ProductDocsRetrieval(BaseRetrieval): | Attribute | Type | Description | |-----------|------|-------------| | `name` | `str` | Retrieval identifier | -| `topic` | `str` or `BaseTopic` | Topic name or topic class | +| `topic` | `type[BaseTopic]` | Topic class reference (e.g. `ProductDocsTopic`) | | `num_refs` | `int` | Number of references to return | | `max_msg_length` | `int` | Max response length | | `reordering` | `bool` | Enable reordering | @@ -753,11 +755,12 @@ class DetailedFormatter(BaseReferenceFormatter): # data/retrievals.py from cogsol.content import BaseRetrieval from data.formatters import DetailedFormatter +from data.knowledge_base import KnowledgeBaseTopic from data.knowledge_base.metadata import DepartmentMetadata class KnowledgeBaseRetrieval(BaseRetrieval): name = "kb_search" - topic = "knowledge_base" + topic = KnowledgeBaseTopic num_refs = 8 formatters = {"Text Document": DetailedFormatter} filters = [DepartmentMetadata] diff --git a/docs/getting-started.md b/docs/getting-started.md index 2db7f72..9bc79eb 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -574,11 +574,12 @@ Add a retrieval in `data/retrievals.py` to enable semantic search: ```python from cogsol.content import BaseRetrieval +from data.product_docs import ProductDocsTopic class ProductDocsRetrieval(BaseRetrieval): name = "product_docs_search" - topic = "product_docs" + topic = ProductDocsTopic num_refs = 10 ``` diff --git a/tests/test_content_commands.py b/tests/test_content_commands.py index 3dc458f..8f16c53 100644 --- a/tests/test_content_commands.py +++ b/tests/test_content_commands.py @@ -10,6 +10,9 @@ collect_files, load_ingestion_config, ) +from cogsol.management.commands.startproject import ( + Command as StartprojectCommand, +) from cogsol.management.commands.starttopic import ( Command as StarttopicCommand, ) @@ -19,6 +22,62 @@ ) +class TestStartprojectCommand: + """Tests for startproject command.""" + + def test_creates_project_structure(self): + """Command should create the full project directory structure.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) / "myproject" + + cmd = StartprojectCommand() + result = cmd.handle(project_path=None, name="myproject", directory=str(project_dir)) + + assert result == 0 + assert (project_dir / "manage.py").exists() + assert (project_dir / "settings.py").exists() + assert (project_dir / "data" / "retrievals.py").exists() + assert (project_dir / "data" / "migrations").exists() + assert (project_dir / "agents" / "searches.py").exists() + + def test_retrievals_template_uses_topic_class_reference(self): + """Generated data/retrievals.py must use a Topic class reference, not a string.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) / "myproject" + + cmd = StartprojectCommand() + cmd.handle(project_path=None, name="myproject", directory=str(project_dir)) + + content = (project_dir / "data" / "retrievals.py").read_text(encoding="utf-8") + + assert 'topic = "product_docs"' not in content + assert "topic = ProductDocsTopic" in content + + def test_retrievals_template_includes_topic_import(self): + """Generated data/retrievals.py must include the commented import for the Topic class.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) / "myproject" + + cmd = StartprojectCommand() + cmd.handle(project_path=None, name="myproject", directory=str(project_dir)) + + content = (project_dir / "data" / "retrievals.py").read_text(encoding="utf-8") + + assert "from data.product_docs import ProductDocsTopic" in content + + def test_fails_if_directory_not_empty(self): + """Command should fail if the target directory already has files.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) / "myproject" + project_dir.mkdir() + (project_dir / "existing_file.txt").write_text("content", encoding="utf-8") + + cmd = StartprojectCommand() + result = cmd.handle(project_path=None, name="myproject", directory=str(project_dir)) + + assert result == 1 + + class TestToClassName: """Tests for to_class_name helper."""