From 5e419480537f0f4649d90fb83088d51671e25e78 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 13 Mar 2026 08:36:16 -0700 Subject: [PATCH] refactor: create separate `Project` to encapsulate `coco.Environment` --- src/cocoindex_code/indexer.py | 9 +---- src/cocoindex_code/project.py | 76 +++++++++++++++++++++++++++++++++++ src/cocoindex_code/query.py | 7 ++-- src/cocoindex_code/server.py | 31 +++++--------- src/cocoindex_code/shared.py | 24 +---------- tests/conftest.py | 9 +++-- tests/test_e2e.py | 28 ++++++------- 7 files changed, 110 insertions(+), 74 deletions(-) create mode 100644 src/cocoindex_code/project.py diff --git a/src/cocoindex_code/indexer.py b/src/cocoindex_code/indexer.py index f7b3b60..c6c132b 100644 --- a/src/cocoindex_code/indexer.py +++ b/src/cocoindex_code/indexer.py @@ -129,7 +129,7 @@ async def process( @coco.fn -async def app_main() -> None: +async def indexer_main() -> None: """Main indexing function - walks files and processes each.""" db = coco.use_context(SQLITE_DB) @@ -159,10 +159,3 @@ async def app_main() -> None: # Process each file with coco.component_subpath(coco.Symbol("process_file")): await coco.mount_each(process_file, files.items(), table) - - -# Create the app -app = coco.App( - coco.AppConfig(name="CocoIndexCode"), - app_main, -) diff --git a/src/cocoindex_code/project.py b/src/cocoindex_code/project.py new file mode 100644 index 0000000..eeb5f7a --- /dev/null +++ b/src/cocoindex_code/project.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import asyncio + +import cocoindex as coco +from cocoindex.connectors import sqlite +from cocoindex.connectors.localfs import register_base_dir + +from .config import config +from .indexer import indexer_main +from .shared import CODEBASE_DIR, SQLITE_DB + + +class Project: + _env: coco.Environment + _app: coco.App[[], None] + _index_lock: asyncio.Lock + _initial_index_done: bool = False + + async def update_index(self, *, report_to_stdout: bool = False) -> None: + """Update the index, serializing concurrent calls via lock.""" + async with self._index_lock: + try: + await self._app.update(report_to_stdout=report_to_stdout) + finally: + self._initial_index_done = True + + @property + def env(self) -> coco.Environment: + return self._env + + @property + def is_initial_index_done(self) -> bool: + return self._initial_index_done + + @staticmethod + async def create() -> Project: + # Ensure index directory exists + config.index_dir.mkdir(parents=True, exist_ok=True) + + # Set CocoIndex state database path + settings = coco.Settings.from_env(config.cocoindex_db_path) + + context = coco.ContextProvider() + + # Provide codebase root directory to environment + context.provide(CODEBASE_DIR, register_base_dir("codebase", config.codebase_root_path)) + # Connect to SQLite with vector extension + conn = sqlite.connect(str(config.target_sqlite_db_path), load_vec="auto") + context.provide(SQLITE_DB, sqlite.register_db("index_db", conn)) + + env = coco.Environment(settings, context_provider=context) + app = coco.App( + coco.AppConfig( + name="CocoIndexCode", + environment=env, + ), + indexer_main, + ) + + result = Project.__new__(Project) + result._env = env + result._app = app + result._index_lock = asyncio.Lock() + return result + + +_project: Project | None = None + + +async def default_project() -> Project: + """Factory function to create the CocoIndexCode project.""" + global _project + if _project is None: + _project = await Project.create() + return _project diff --git a/src/cocoindex_code/query.py b/src/cocoindex_code/query.py index ca6a277..22b0380 100644 --- a/src/cocoindex_code/query.py +++ b/src/cocoindex_code/query.py @@ -4,9 +4,8 @@ import sqlite3 from typing import Any -import cocoindex as coco - from .config import config +from .project import default_project from .schema import QueryResult from .shared import SQLITE_DB, embedder, query_prompt_name @@ -102,8 +101,8 @@ async def query_codebase( "Please run a query with refresh_index=True first." ) - coco_env = await coco.default_env() - db = coco_env.get_context(SQLITE_DB) + coco_proj = await default_project() + db = coco_proj.env.get_context(SQLITE_DB) # Generate query embedding. query_embedding = await embedder.embed(query, query_prompt_name) diff --git a/src/cocoindex_code/server.py b/src/cocoindex_code/server.py index 7de4210..605bd8e 100644 --- a/src/cocoindex_code/server.py +++ b/src/cocoindex_code/server.py @@ -3,12 +3,11 @@ import argparse import asyncio -import cocoindex as coco from mcp.server.fastmcp import FastMCP from pydantic import BaseModel, Field from .config import config -from .indexer import app as indexer_app +from .project import default_project from .query import query_codebase from .shared import SQLITE_DB @@ -27,17 +26,11 @@ ), ) -# Lock to prevent concurrent index updates -_index_lock = asyncio.Lock() - -# Event set once the initial background index is ready -_initial_index_done = asyncio.Event() - async def _refresh_index() -> None: """Refresh the index. Uses lock to prevent concurrent updates.""" - async with _index_lock: - await indexer_app.update(report_to_stdout=False) + proj = await default_project() + await proj.update_index() # === Pydantic Models for Tool Inputs/Outputs === @@ -126,7 +119,8 @@ async def search( ), ) -> SearchResultModel: """Query the codebase index.""" - if not _initial_index_done.is_set(): + proj = await default_project() + if not proj.is_initial_index_done: return SearchResultModel( success=False, message=( @@ -182,19 +176,14 @@ async def _async_serve() -> None: """Async entry point for the MCP server.""" # Refresh index in background so startup isn't blocked - async def _initial_index() -> None: - try: - await _refresh_index() - finally: - _initial_index_done.set() - - asyncio.create_task(_initial_index()) + asyncio.create_task(_refresh_index()) await mcp.run_stdio_async() async def _async_index() -> None: """Async entry point for the index command.""" - await indexer_app.update(report_to_stdout=True) + proj = await default_project() + await proj.update_index(report_to_stdout=True) await _print_index_stats() @@ -205,8 +194,8 @@ async def _print_index_stats() -> None: print("No index database found.") return - coco_env = await coco.default_env() - db = coco_env.get_context(SQLITE_DB) + proj = await default_project() + db = proj.env.get_context(SQLITE_DB) with db.value.readonly() as conn: total_chunks = conn.execute("SELECT COUNT(*) FROM code_chunks_vec").fetchone()[0] diff --git a/src/cocoindex_code/shared.py b/src/cocoindex_code/shared.py index 7b28d1b..48aa90c 100644 --- a/src/cocoindex_code/shared.py +++ b/src/cocoindex_code/shared.py @@ -3,13 +3,12 @@ from __future__ import annotations import logging -from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Annotated import cocoindex as coco from cocoindex.connectors import sqlite -from cocoindex.connectors.localfs import FilePath, register_base_dir +from cocoindex.connectors.localfs import FilePath from numpy.typing import NDArray if TYPE_CHECKING: @@ -54,27 +53,6 @@ CODEBASE_DIR = coco.ContextKey[FilePath]("codebase_dir") -@coco.lifespan -def coco_lifespan(builder: coco.EnvironmentBuilder) -> Iterator[None]: - """Set up database connection.""" - # Ensure index directory exists - config.index_dir.mkdir(parents=True, exist_ok=True) - - # Set CocoIndex state database path - builder.settings.db_path = config.cocoindex_db_path - - # Provide codebase root directory to environment - builder.provide(CODEBASE_DIR, register_base_dir("codebase", config.codebase_root_path)) - - # Connect to SQLite with vector extension - conn = sqlite.connect(str(config.target_sqlite_db_path), load_vec="auto") - builder.provide(SQLITE_DB, sqlite.register_db("index_db", conn)) - - yield - - conn.close() - - @dataclass class CodeChunk: """Schema for storing code chunks in SQLite.""" diff --git a/tests/conftest.py b/tests/conftest.py index a9cadc2..faabf0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from collections.abc import AsyncIterator from pathlib import Path -import cocoindex as coco import pytest import pytest_asyncio @@ -24,10 +23,12 @@ def test_codebase_root() -> Path: @pytest_asyncio.fixture(scope="session", loop_scope="session") async def coco_runtime() -> AsyncIterator[None]: """ - Set up CocoIndex runtime context for the entire test session. + Set up CocoIndex project for the entire test session. Uses session-scoped event loop to ensure CocoIndex environment persists across all tests. """ - async with coco.runtime(): - yield + from cocoindex_code.project import default_project + + await default_project() + yield diff --git a/tests/test_e2e.py b/tests/test_e2e.py index a802fd9..8071831 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -6,7 +6,7 @@ import pytest from cocoindex_code.config import _discover_codebase_root -from cocoindex_code.indexer import app +from cocoindex_code.project import default_project from cocoindex_code.query import query_codebase pytest_plugins = ("pytest_asyncio",) @@ -194,7 +194,7 @@ async def test_index_and_query_codebase( ) -> None: """Should index a codebase and return relevant query results.""" setup_base_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() # Verify index was created index_dir = test_codebase_root / ".cocoindex_code" @@ -218,7 +218,7 @@ async def test_incremental_update_add_file( ) -> None: """Should reflect newly added files after re-indexing.""" setup_base_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() # Query for ML content - should not find it results = await query_codebase("machine learning neural network") @@ -232,7 +232,7 @@ async def test_incremental_update_add_file( (test_codebase_root / "ml_model.py").write_text(SAMPLE_ML_MODEL_PY) # Re-index and query again - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("neural network machine learning") assert len(results) > 0 @@ -244,13 +244,13 @@ async def test_incremental_update_modify_file( ) -> None: """Should reflect file modifications after re-indexing.""" setup_base_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() # Modify utils.py to add authentication (test_codebase_root / "utils.py").write_text(SAMPLE_UTILS_AUTH_PY) # Re-index and query for authentication - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("user authentication login") assert len(results) > 0 @@ -264,7 +264,7 @@ async def test_incremental_update_delete_file( ) -> None: """Should no longer return results from deleted files after re-indexing.""" setup_base_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() # Query for database - should find it results = await query_codebase("database connection execute query") @@ -274,7 +274,7 @@ async def test_incremental_update_delete_file( (test_codebase_root / "lib" / "database.py").unlink() # Re-index and query again - should no longer find database.py - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("database connection execute query") assert not any("database.py" in r.file_path for r in results) @@ -335,7 +335,7 @@ class TestSearchFilters: async def test_filter_by_language(self, test_codebase_root: Path, coco_runtime: None) -> None: """Should return only results matching the specified language.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("function", limit=50, languages=["python"]) assert len(results) > 0 @@ -347,7 +347,7 @@ async def test_filter_by_language_multiple( ) -> None: """Should return results matching any of the specified languages.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("function", limit=50, languages=["python", "javascript"]) assert len(results) > 0 @@ -363,7 +363,7 @@ async def test_filter_by_file_path_glob( ) -> None: """Should return only results matching the file path glob pattern.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("function", limit=50, paths=["lib/*"]) assert len(results) > 0 @@ -375,7 +375,7 @@ async def test_filter_by_file_path_wildcard_extension( ) -> None: """Should filter by file extension using glob wildcard.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("function", limit=50, paths=["*.js"]) assert len(results) > 0 @@ -387,7 +387,7 @@ async def test_filter_by_both_language_and_file_path( ) -> None: """Should apply both language and file path filters together.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() # Filter for Python files under lib/ results = await query_codebase("function", limit=50, languages=["python"], paths=["lib/*"]) @@ -401,7 +401,7 @@ async def test_no_filter_returns_all_languages( ) -> None: """Should return results from all languages when no filter is applied.""" setup_multi_lang_codebase(test_codebase_root) - await app.update(report_to_stdout=False) + await (await default_project()).update_index() results = await query_codebase("function", limit=50) languages_found = {r.language for r in results}