Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/cocoindex_code/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def process(


@coco.fn
async def app_main() -> None:
async def indexer_main() -> None:
"""Main indexing function - walks files and processes each."""
db = coco.use_context(SQLITE_DB)

Expand Down Expand Up @@ -159,10 +159,3 @@ async def app_main() -> None:
# Process each file
with coco.component_subpath(coco.Symbol("process_file")):
await coco.mount_each(process_file, files.items(), table)


# Create the app
app = coco.App(
coco.AppConfig(name="CocoIndexCode"),
app_main,
)
76 changes: 76 additions & 0 deletions src/cocoindex_code/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

import asyncio

import cocoindex as coco
from cocoindex.connectors import sqlite
from cocoindex.connectors.localfs import register_base_dir

from .config import config
from .indexer import indexer_main
from .shared import CODEBASE_DIR, SQLITE_DB


class Project:
_env: coco.Environment
_app: coco.App[[], None]
_index_lock: asyncio.Lock
_initial_index_done: bool = False

async def update_index(self, *, report_to_stdout: bool = False) -> None:
"""Update the index, serializing concurrent calls via lock."""
async with self._index_lock:
try:
await self._app.update(report_to_stdout=report_to_stdout)
finally:
self._initial_index_done = True

@property
def env(self) -> coco.Environment:
return self._env

@property
def is_initial_index_done(self) -> bool:
return self._initial_index_done

@staticmethod
async def create() -> Project:
# Ensure index directory exists
config.index_dir.mkdir(parents=True, exist_ok=True)

# Set CocoIndex state database path
settings = coco.Settings.from_env(config.cocoindex_db_path)

context = coco.ContextProvider()

# Provide codebase root directory to environment
context.provide(CODEBASE_DIR, register_base_dir("codebase", config.codebase_root_path))
# Connect to SQLite with vector extension
conn = sqlite.connect(str(config.target_sqlite_db_path), load_vec="auto")
context.provide(SQLITE_DB, sqlite.register_db("index_db", conn))

env = coco.Environment(settings, context_provider=context)
app = coco.App(
coco.AppConfig(
name="CocoIndexCode",
environment=env,
),
indexer_main,
)

result = Project.__new__(Project)
result._env = env
result._app = app
result._index_lock = asyncio.Lock()
return result


_project: Project | None = None


async def default_project() -> Project:
"""Factory function to create the CocoIndexCode project."""
global _project
if _project is None:
_project = await Project.create()
return _project
7 changes: 3 additions & 4 deletions src/cocoindex_code/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import sqlite3
from typing import Any

import cocoindex as coco

from .config import config
from .project import default_project
from .schema import QueryResult
from .shared import SQLITE_DB, embedder, query_prompt_name

Expand Down Expand Up @@ -102,8 +101,8 @@ async def query_codebase(
"Please run a query with refresh_index=True first."
)

coco_env = await coco.default_env()
db = coco_env.get_context(SQLITE_DB)
coco_proj = await default_project()
db = coco_proj.env.get_context(SQLITE_DB)

# Generate query embedding.
query_embedding = await embedder.embed(query, query_prompt_name)
Expand Down
31 changes: 10 additions & 21 deletions src/cocoindex_code/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import argparse
import asyncio

import cocoindex as coco
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel, Field

from .config import config
from .indexer import app as indexer_app
from .project import default_project
from .query import query_codebase
from .shared import SQLITE_DB

Expand All @@ -27,17 +26,11 @@
),
)

# Lock to prevent concurrent index updates
_index_lock = asyncio.Lock()

# Event set once the initial background index is ready
_initial_index_done = asyncio.Event()


async def _refresh_index() -> None:
"""Refresh the index. Uses lock to prevent concurrent updates."""
async with _index_lock:
await indexer_app.update(report_to_stdout=False)
proj = await default_project()
await proj.update_index()


# === Pydantic Models for Tool Inputs/Outputs ===
Expand Down Expand Up @@ -126,7 +119,8 @@ async def search(
),
) -> SearchResultModel:
"""Query the codebase index."""
if not _initial_index_done.is_set():
proj = await default_project()
if not proj.is_initial_index_done:
return SearchResultModel(
success=False,
message=(
Expand Down Expand Up @@ -182,19 +176,14 @@ async def _async_serve() -> None:
"""Async entry point for the MCP server."""

# Refresh index in background so startup isn't blocked
async def _initial_index() -> None:
try:
await _refresh_index()
finally:
_initial_index_done.set()

asyncio.create_task(_initial_index())
asyncio.create_task(_refresh_index())
await mcp.run_stdio_async()


async def _async_index() -> None:
"""Async entry point for the index command."""
await indexer_app.update(report_to_stdout=True)
proj = await default_project()
await proj.update_index(report_to_stdout=True)
await _print_index_stats()


Expand All @@ -205,8 +194,8 @@ async def _print_index_stats() -> None:
print("No index database found.")
return

coco_env = await coco.default_env()
db = coco_env.get_context(SQLITE_DB)
proj = await default_project()
db = proj.env.get_context(SQLITE_DB)

with db.value.readonly() as conn:
total_chunks = conn.execute("SELECT COUNT(*) FROM code_chunks_vec").fetchone()[0]
Expand Down
24 changes: 1 addition & 23 deletions src/cocoindex_code/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from __future__ import annotations

import logging
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated

import cocoindex as coco
from cocoindex.connectors import sqlite
from cocoindex.connectors.localfs import FilePath, register_base_dir
from cocoindex.connectors.localfs import FilePath
from numpy.typing import NDArray

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,27 +53,6 @@
CODEBASE_DIR = coco.ContextKey[FilePath]("codebase_dir")


@coco.lifespan
def coco_lifespan(builder: coco.EnvironmentBuilder) -> Iterator[None]:
"""Set up database connection."""
# Ensure index directory exists
config.index_dir.mkdir(parents=True, exist_ok=True)

# Set CocoIndex state database path
builder.settings.db_path = config.cocoindex_db_path

# Provide codebase root directory to environment
builder.provide(CODEBASE_DIR, register_base_dir("codebase", config.codebase_root_path))

# Connect to SQLite with vector extension
conn = sqlite.connect(str(config.target_sqlite_db_path), load_vec="auto")
builder.provide(SQLITE_DB, sqlite.register_db("index_db", conn))

yield

conn.close()


@dataclass
class CodeChunk:
"""Schema for storing code chunks in SQLite."""
Expand Down
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import AsyncIterator
from pathlib import Path

import cocoindex as coco
import pytest
import pytest_asyncio

Expand All @@ -24,10 +23,12 @@ def test_codebase_root() -> Path:
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def coco_runtime() -> AsyncIterator[None]:
"""
Set up CocoIndex runtime context for the entire test session.
Set up CocoIndex project for the entire test session.

Uses session-scoped event loop to ensure CocoIndex environment
persists across all tests.
"""
async with coco.runtime():
yield
from cocoindex_code.project import default_project

await default_project()
yield
28 changes: 14 additions & 14 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from cocoindex_code.config import _discover_codebase_root
from cocoindex_code.indexer import app
from cocoindex_code.project import default_project
from cocoindex_code.query import query_codebase

pytest_plugins = ("pytest_asyncio",)
Expand Down Expand Up @@ -194,7 +194,7 @@ async def test_index_and_query_codebase(
) -> None:
"""Should index a codebase and return relevant query results."""
setup_base_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

# Verify index was created
index_dir = test_codebase_root / ".cocoindex_code"
Expand All @@ -218,7 +218,7 @@ async def test_incremental_update_add_file(
) -> None:
"""Should reflect newly added files after re-indexing."""
setup_base_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

# Query for ML content - should not find it
results = await query_codebase("machine learning neural network")
Expand All @@ -232,7 +232,7 @@ async def test_incremental_update_add_file(
(test_codebase_root / "ml_model.py").write_text(SAMPLE_ML_MODEL_PY)

# Re-index and query again
await app.update(report_to_stdout=False)
await (await default_project()).update_index()
results = await query_codebase("neural network machine learning")

assert len(results) > 0
Expand All @@ -244,13 +244,13 @@ async def test_incremental_update_modify_file(
) -> None:
"""Should reflect file modifications after re-indexing."""
setup_base_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

# Modify utils.py to add authentication
(test_codebase_root / "utils.py").write_text(SAMPLE_UTILS_AUTH_PY)

# Re-index and query for authentication
await app.update(report_to_stdout=False)
await (await default_project()).update_index()
results = await query_codebase("user authentication login")

assert len(results) > 0
Expand All @@ -264,7 +264,7 @@ async def test_incremental_update_delete_file(
) -> None:
"""Should no longer return results from deleted files after re-indexing."""
setup_base_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

# Query for database - should find it
results = await query_codebase("database connection execute query")
Expand All @@ -274,7 +274,7 @@ async def test_incremental_update_delete_file(
(test_codebase_root / "lib" / "database.py").unlink()

# Re-index and query again - should no longer find database.py
await app.update(report_to_stdout=False)
await (await default_project()).update_index()
results = await query_codebase("database connection execute query")
assert not any("database.py" in r.file_path for r in results)

Expand Down Expand Up @@ -335,7 +335,7 @@ class TestSearchFilters:
async def test_filter_by_language(self, test_codebase_root: Path, coco_runtime: None) -> None:
"""Should return only results matching the specified language."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

results = await query_codebase("function", limit=50, languages=["python"])
assert len(results) > 0
Expand All @@ -347,7 +347,7 @@ async def test_filter_by_language_multiple(
) -> None:
"""Should return results matching any of the specified languages."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

results = await query_codebase("function", limit=50, languages=["python", "javascript"])
assert len(results) > 0
Expand All @@ -363,7 +363,7 @@ async def test_filter_by_file_path_glob(
) -> None:
"""Should return only results matching the file path glob pattern."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

results = await query_codebase("function", limit=50, paths=["lib/*"])
assert len(results) > 0
Expand All @@ -375,7 +375,7 @@ async def test_filter_by_file_path_wildcard_extension(
) -> None:
"""Should filter by file extension using glob wildcard."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

results = await query_codebase("function", limit=50, paths=["*.js"])
assert len(results) > 0
Expand All @@ -387,7 +387,7 @@ async def test_filter_by_both_language_and_file_path(
) -> None:
"""Should apply both language and file path filters together."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

# Filter for Python files under lib/
results = await query_codebase("function", limit=50, languages=["python"], paths=["lib/*"])
Expand All @@ -401,7 +401,7 @@ async def test_no_filter_returns_all_languages(
) -> None:
"""Should return results from all languages when no filter is applied."""
setup_multi_lang_codebase(test_codebase_root)
await app.update(report_to_stdout=False)
await (await default_project()).update_index()

results = await query_codebase("function", limit=50)
languages_found = {r.language for r in results}
Expand Down
Loading