Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions code_review_graph/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ def __init__(

self._conn.commit()

def __enter__(self) -> "EmbeddingStore":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
self.close()

def close(self) -> None:
self._conn.close()

Expand Down
20 changes: 16 additions & 4 deletions code_review_graph/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,22 @@ def incremental_trace_flows(
# ------------------------------------------------------------------
# 3. Delete affected flows and their memberships
# ------------------------------------------------------------------
for fid in affected_ids:
conn.execute("DELETE FROM flow_memberships WHERE flow_id = ?", (fid,))
conn.execute("DELETE FROM flows WHERE id = ?", (fid,))
conn.commit()
# Wrap in an explicit transaction so a crash mid-loop cannot leave
# orphaned flow_memberships rows pointing at deleted flows. See #258.
if affected_ids:
if conn.in_transaction:
conn.commit()
conn.execute("BEGIN IMMEDIATE")
try:
for fid in affected_ids:
conn.execute(
"DELETE FROM flow_memberships WHERE flow_id = ?", (fid,),
)
conn.execute("DELETE FROM flows WHERE id = ?", (fid,))
conn.commit()
except BaseException:
conn.rollback()
raise

# ------------------------------------------------------------------
# 4. Re-detect entry points and filter to relevant ones
Expand Down
34 changes: 30 additions & 4 deletions code_review_graph/incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,39 @@ def _single_hop_dependents(store: GraphStore, file_path: str) -> set[str]:
return dependents


class DependentList(list):
"""A ``list[str]`` with a ``.truncated`` flag.

When :func:`find_dependents` hits ``_MAX_DEPENDENT_FILES`` it truncates
the result and sets ``truncated = True`` so callers can distinguish a
complete expansion from a capped one. See issue #261.

This is a transparent ``list`` subclass — existing callers that iterate,
``len()``, or slice continue to work unchanged; only callers that
specifically check ``.truncated`` benefit from the signal.
"""

truncated: bool

def __init__(self, items: list, *, truncated: bool = False) -> None:
super().__init__(items)
self.truncated = truncated


def find_dependents(
store: GraphStore,
file_path: str,
max_hops: int = _MAX_DEPENDENT_HOPS,
) -> list[str]:
) -> DependentList:
"""Find files that import from or depend on the given file.

Performs up to *max_hops* iterations of expansion (default 2).
Stops early if the total exceeds 500 files.

Returns a :class:`DependentList` — a regular ``list[str]`` that also
carries a ``.truncated`` flag. When ``truncated is True`` the
returned list is capped at ``_MAX_DEPENDENT_FILES`` and the full
set of dependents was not explored. See issue #261.
"""
all_dependents: set[str] = set()
visited: set[str] = {file_path}
Expand All @@ -460,9 +484,11 @@ def find_dependents(
"Dependent expansion capped at %d files for %s",
len(all_dependents), file_path,
)
# Truncate to the cap
return list(all_dependents)[:_MAX_DEPENDENT_FILES]
return list(all_dependents)
return DependentList(
list(all_dependents)[:_MAX_DEPENDENT_FILES],
truncated=True,
)
return DependentList(list(all_dependents))


def _parse_single_file(
Expand Down
41 changes: 24 additions & 17 deletions code_review_graph/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,30 @@ def rebuild_fts_index(store: GraphStore) -> int:
# the FTS5 virtual table DDL, which is tightly coupled to SQLite internals.
conn = store._conn

# Drop and recreate the FTS table to avoid content-sync mismatch issues
conn.execute("DROP TABLE IF EXISTS nodes_fts")
conn.execute("""
CREATE VIRTUAL TABLE nodes_fts USING fts5(
name, qualified_name, file_path, signature,
tokenize='porter unicode61'
)
""")
conn.commit()

# Populate from nodes table
conn.execute("""
INSERT INTO nodes_fts(rowid, name, qualified_name, file_path, signature)
SELECT id, name, qualified_name, file_path, COALESCE(signature, '')
FROM nodes
""")
conn.commit()
# Wrap the full DROP + CREATE + INSERT sequence in an explicit transaction
# so a crash mid-rebuild cannot leave the DB without an FTS table at all
# (DROP succeeded but CREATE/INSERT didn't). See #259.
if conn.in_transaction:
conn.commit()
conn.execute("BEGIN IMMEDIATE")
try:
conn.execute("DROP TABLE IF EXISTS nodes_fts")
conn.execute("""
CREATE VIRTUAL TABLE nodes_fts USING fts5(
name, qualified_name, file_path, signature,
tokenize='porter unicode61'
)
""")
# Populate from nodes table
conn.execute("""
INSERT INTO nodes_fts(rowid, name, qualified_name, file_path, signature)
SELECT id, name, qualified_name, file_path, COALESCE(signature, '')
FROM nodes
""")
conn.commit()
except BaseException:
conn.rollback()
raise

count = conn.execute("SELECT count(*) FROM nodes_fts").fetchone()[0]
logger.info("FTS index rebuilt: %d rows indexed", count)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,30 @@ def test_get_provider_minimax_without_key_raises(self):
with patch.dict("os.environ", {}, clear=True):
with pytest.raises(ValueError, match="MINIMAX_API_KEY"):
get_provider("minimax")


class TestEmbeddingStoreContextManager:
"""Regression tests for #260: EmbeddingStore must support the context
manager protocol so connections are cleaned up on exception."""

def test_supports_context_manager(self, tmp_path):
db = tmp_path / "embed_ctx.db"
with EmbeddingStore(db) as store:
assert store is not None
assert store.db_path == db
# After exiting, connection should be closed.
# (Attempting another query would fail, but we don't test that
# because close() doesn't invalidate the object — it just
# closes the underlying sqlite3 connection.)

def test_context_manager_closes_on_exception(self, tmp_path):
db = tmp_path / "embed_err.db"
try:
with EmbeddingStore(db) as store:
assert store.db_path == db
raise RuntimeError("simulated crash")
except RuntimeError:
pass
# The connection was closed by __exit__ even though an exception
# was raised. This is the whole point of #260 — without the
# context manager, the connection would leak.
28 changes: 28 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,31 @@ def test_incremental_trace_flows_no_affected_flows(self):
assert count == 0
# Original flows unchanged.
assert len(get_flows(self.store)) == initial_count

def test_incremental_trace_flows_delete_is_atomic(self):
"""Regression test for #258: the DELETE loop in incremental_trace_flows
must be wrapped in a transaction so a crash mid-loop cannot leave
orphaned flow_memberships rows."""
self._add_func("handler", path="routes.py")
self._add_func("service", path="services.py")
self._add_call("routes.py::handler", "services.py::service", "routes.py")

flows = trace_flows(self.store)
store_flows(self.store, flows)
assert len(get_flows(self.store)) > 0

# Incremental trace touching routes.py should delete old flows and
# re-trace them. The key assertion is that this does NOT raise
# "cannot start a transaction within a transaction" and that the
# DB ends in a consistent state.
count = incremental_trace_flows(self.store, ["routes.py"])
# The re-trace should find the same entry points.
assert count >= 0
# No orphaned memberships: every membership references a valid flow.
conn = self.store._conn
orphans = conn.execute(
"SELECT fm.flow_id FROM flow_memberships fm "
"LEFT JOIN flows f ON f.id = fm.flow_id "
"WHERE f.id IS NULL"
).fetchall()
assert len(orphans) == 0, f"found {len(orphans)} orphaned memberships"
55 changes: 55 additions & 0 deletions tests/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,58 @@ def test_cap_triggers_on_many_files(self, tmp_path):
assert len(deps) <= 500
finally:
store.close()

def test_truncated_flag_set_when_capped(self, tmp_path):
"""Regression test for #261: find_dependents must set
DependentList.truncated = True when the result is capped."""
from code_review_graph.parser import EdgeInfo, NodeInfo

db_path = tmp_path / "trunc.db"
store = GraphStore(db_path)
try:
store.upsert_node(NodeInfo(
kind="File", name="/hub.py", file_path="/hub.py",
line_start=1, line_end=10, language="python",
))
store.upsert_node(NodeInfo(
kind="Function", name="hub_func", file_path="/hub.py",
line_start=2, line_end=8, language="python",
))
for i in range(600):
path = f"/dep{i}.py"
store.upsert_node(NodeInfo(
kind="File", name=path, file_path=path,
line_start=1, line_end=10, language="python",
))
store.upsert_node(NodeInfo(
kind="Function", name=f"func_{i}", file_path=path,
line_start=2, line_end=8, language="python",
))
store.upsert_edge(EdgeInfo(
kind="IMPORTS_FROM", source=f"{path}::func_{i}",
target="/hub.py::hub_func", file_path=path, line=1,
))
store.commit()

deps = find_dependents(store, "/hub.py", max_hops=5)
assert len(deps) <= 500
# The key assertion: truncated flag must be set.
assert deps.truncated is True, (
"DependentList.truncated should be True when capped at "
"_MAX_DEPENDENT_FILES, but it was False"
)
finally:
store.close()

def test_truncated_flag_false_when_not_capped(self, tmp_path):
"""Regression test for #261: find_dependents must set
DependentList.truncated = False when the result is complete."""
store = self._make_chain_store(tmp_path)
try:
deps = find_dependents(store, "/c.py", max_hops=2)
assert deps.truncated is False, (
"DependentList.truncated should be False when the "
"expansion completed without hitting the cap"
)
finally:
store.close()
20 changes: 20 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,23 @@ def test_fts_query_with_special_chars(self):
results = hybrid_search(self.store, dangerous_query)
# Just assert no exception was raised
assert isinstance(results, list)

def test_fts_rebuild_is_atomic(self):
"""Regression test for #259: rebuild_fts_index must wrap the DROP +
CREATE + INSERT sequence in a single transaction so a crash between
DROP and CREATE cannot leave the DB without an FTS table."""
# Build, rebuild, then verify the table exists and is queryable.
rebuild_fts_index(self.store)

# Verify the FTS table exists and has rows.
conn = self.store._conn
count = conn.execute("SELECT count(*) FROM nodes_fts").fetchone()[0]
assert count > 0

# Rebuild again — must not raise and must leave the table intact.
new_count = rebuild_fts_index(self.store)
assert new_count == count

# Verify search still works after double-rebuild.
results = hybrid_search(self.store, "auth")
assert isinstance(results, list)
Loading