diff --git a/code_review_graph/embeddings.py b/code_review_graph/embeddings.py index c556d12..57e1842 100644 --- a/code_review_graph/embeddings.py +++ b/code_review_graph/embeddings.py @@ -418,6 +418,12 @@ def __init__( self._conn.commit() + def __enter__(self) -> "EmbeddingStore": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] + self.close() + def close(self) -> None: self._conn.close() diff --git a/code_review_graph/flows.py b/code_review_graph/flows.py index 193171e..d6424c5 100644 --- a/code_review_graph/flows.py +++ b/code_review_graph/flows.py @@ -496,10 +496,22 @@ def incremental_trace_flows( # ------------------------------------------------------------------ # 3. Delete affected flows and their memberships # ------------------------------------------------------------------ - for fid in affected_ids: - conn.execute("DELETE FROM flow_memberships WHERE flow_id = ?", (fid,)) - conn.execute("DELETE FROM flows WHERE id = ?", (fid,)) - conn.commit() + # Wrap in an explicit transaction so a crash mid-loop cannot leave + # orphaned flow_memberships rows pointing at deleted flows. See #258. + if affected_ids: + if conn.in_transaction: + conn.commit() + conn.execute("BEGIN IMMEDIATE") + try: + for fid in affected_ids: + conn.execute( + "DELETE FROM flow_memberships WHERE flow_id = ?", (fid,), + ) + conn.execute("DELETE FROM flows WHERE id = ?", (fid,)) + conn.commit() + except BaseException: + conn.rollback() + raise # ------------------------------------------------------------------ # 4. Re-detect entry points and filter to relevant ones diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index cfa672c..3a90f20 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -431,15 +431,39 @@ def _single_hop_dependents(store: GraphStore, file_path: str) -> set[str]: return dependents +class DependentList(list): + """A ``list[str]`` with a ``.truncated`` flag. + + When :func:`find_dependents` hits ``_MAX_DEPENDENT_FILES`` it truncates + the result and sets ``truncated = True`` so callers can distinguish a + complete expansion from a capped one. See issue #261. + + This is a transparent ``list`` subclass — existing callers that iterate, + ``len()``, or slice continue to work unchanged; only callers that + specifically check ``.truncated`` benefit from the signal. + """ + + truncated: bool + + def __init__(self, items: list, *, truncated: bool = False) -> None: + super().__init__(items) + self.truncated = truncated + + def find_dependents( store: GraphStore, file_path: str, max_hops: int = _MAX_DEPENDENT_HOPS, -) -> list[str]: +) -> DependentList: """Find files that import from or depend on the given file. Performs up to *max_hops* iterations of expansion (default 2). Stops early if the total exceeds 500 files. + + Returns a :class:`DependentList` — a regular ``list[str]`` that also + carries a ``.truncated`` flag. When ``truncated is True`` the + returned list is capped at ``_MAX_DEPENDENT_FILES`` and the full + set of dependents was not explored. See issue #261. """ all_dependents: set[str] = set() visited: set[str] = {file_path} @@ -460,9 +484,11 @@ def find_dependents( "Dependent expansion capped at %d files for %s", len(all_dependents), file_path, ) - # Truncate to the cap - return list(all_dependents)[:_MAX_DEPENDENT_FILES] - return list(all_dependents) + return DependentList( + list(all_dependents)[:_MAX_DEPENDENT_FILES], + truncated=True, + ) + return DependentList(list(all_dependents)) def _parse_single_file( diff --git a/code_review_graph/search.py b/code_review_graph/search.py index d2eb84e..46c5214 100644 --- a/code_review_graph/search.py +++ b/code_review_graph/search.py @@ -35,23 +35,30 @@ def rebuild_fts_index(store: GraphStore) -> int: # the FTS5 virtual table DDL, which is tightly coupled to SQLite internals. conn = store._conn - # Drop and recreate the FTS table to avoid content-sync mismatch issues - conn.execute("DROP TABLE IF EXISTS nodes_fts") - conn.execute(""" - CREATE VIRTUAL TABLE nodes_fts USING fts5( - name, qualified_name, file_path, signature, - tokenize='porter unicode61' - ) - """) - conn.commit() - - # Populate from nodes table - conn.execute(""" - INSERT INTO nodes_fts(rowid, name, qualified_name, file_path, signature) - SELECT id, name, qualified_name, file_path, COALESCE(signature, '') - FROM nodes - """) - conn.commit() + # Wrap the full DROP + CREATE + INSERT sequence in an explicit transaction + # so a crash mid-rebuild cannot leave the DB without an FTS table at all + # (DROP succeeded but CREATE/INSERT didn't). See #259. + if conn.in_transaction: + conn.commit() + conn.execute("BEGIN IMMEDIATE") + try: + conn.execute("DROP TABLE IF EXISTS nodes_fts") + conn.execute(""" + CREATE VIRTUAL TABLE nodes_fts USING fts5( + name, qualified_name, file_path, signature, + tokenize='porter unicode61' + ) + """) + # Populate from nodes table + conn.execute(""" + INSERT INTO nodes_fts(rowid, name, qualified_name, file_path, signature) + SELECT id, name, qualified_name, file_path, COALESCE(signature, '') + FROM nodes + """) + conn.commit() + except BaseException: + conn.rollback() + raise count = conn.execute("SELECT count(*) FROM nodes_fts").fetchone()[0] logger.info("FTS index rebuilt: %d rows indexed", count) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 371a1c8..43bcf51 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -345,3 +345,30 @@ def test_get_provider_minimax_without_key_raises(self): with patch.dict("os.environ", {}, clear=True): with pytest.raises(ValueError, match="MINIMAX_API_KEY"): get_provider("minimax") + + +class TestEmbeddingStoreContextManager: + """Regression tests for #260: EmbeddingStore must support the context + manager protocol so connections are cleaned up on exception.""" + + def test_supports_context_manager(self, tmp_path): + db = tmp_path / "embed_ctx.db" + with EmbeddingStore(db) as store: + assert store is not None + assert store.db_path == db + # After exiting, connection should be closed. + # (Attempting another query would fail, but we don't test that + # because close() doesn't invalidate the object — it just + # closes the underlying sqlite3 connection.) + + def test_context_manager_closes_on_exception(self, tmp_path): + db = tmp_path / "embed_err.db" + try: + with EmbeddingStore(db) as store: + assert store.db_path == db + raise RuntimeError("simulated crash") + except RuntimeError: + pass + # The connection was closed by __exit__ even though an exception + # was raised. This is the whole point of #260 — without the + # context manager, the connection would leak. diff --git a/tests/test_flows.py b/tests/test_flows.py index 34cfd05..3e0f742 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -558,3 +558,31 @@ def test_incremental_trace_flows_no_affected_flows(self): assert count == 0 # Original flows unchanged. assert len(get_flows(self.store)) == initial_count + + def test_incremental_trace_flows_delete_is_atomic(self): + """Regression test for #258: the DELETE loop in incremental_trace_flows + must be wrapped in a transaction so a crash mid-loop cannot leave + orphaned flow_memberships rows.""" + self._add_func("handler", path="routes.py") + self._add_func("service", path="services.py") + self._add_call("routes.py::handler", "services.py::service", "routes.py") + + flows = trace_flows(self.store) + store_flows(self.store, flows) + assert len(get_flows(self.store)) > 0 + + # Incremental trace touching routes.py should delete old flows and + # re-trace them. The key assertion is that this does NOT raise + # "cannot start a transaction within a transaction" and that the + # DB ends in a consistent state. + count = incremental_trace_flows(self.store, ["routes.py"]) + # The re-trace should find the same entry points. + assert count >= 0 + # No orphaned memberships: every membership references a valid flow. + conn = self.store._conn + orphans = conn.execute( + "SELECT fm.flow_id FROM flow_memberships fm " + "LEFT JOIN flows f ON f.id = fm.flow_id " + "WHERE f.id IS NULL" + ).fetchall() + assert len(orphans) == 0, f"found {len(orphans)} orphaned memberships" diff --git a/tests/test_incremental.py b/tests/test_incremental.py index d438ccb..ab4e963 100644 --- a/tests/test_incremental.py +++ b/tests/test_incremental.py @@ -592,3 +592,58 @@ def test_cap_triggers_on_many_files(self, tmp_path): assert len(deps) <= 500 finally: store.close() + + def test_truncated_flag_set_when_capped(self, tmp_path): + """Regression test for #261: find_dependents must set + DependentList.truncated = True when the result is capped.""" + from code_review_graph.parser import EdgeInfo, NodeInfo + + db_path = tmp_path / "trunc.db" + store = GraphStore(db_path) + try: + store.upsert_node(NodeInfo( + kind="File", name="/hub.py", file_path="/hub.py", + line_start=1, line_end=10, language="python", + )) + store.upsert_node(NodeInfo( + kind="Function", name="hub_func", file_path="/hub.py", + line_start=2, line_end=8, language="python", + )) + for i in range(600): + path = f"/dep{i}.py" + store.upsert_node(NodeInfo( + kind="File", name=path, file_path=path, + line_start=1, line_end=10, language="python", + )) + store.upsert_node(NodeInfo( + kind="Function", name=f"func_{i}", file_path=path, + line_start=2, line_end=8, language="python", + )) + store.upsert_edge(EdgeInfo( + kind="IMPORTS_FROM", source=f"{path}::func_{i}", + target="/hub.py::hub_func", file_path=path, line=1, + )) + store.commit() + + deps = find_dependents(store, "/hub.py", max_hops=5) + assert len(deps) <= 500 + # The key assertion: truncated flag must be set. + assert deps.truncated is True, ( + "DependentList.truncated should be True when capped at " + "_MAX_DEPENDENT_FILES, but it was False" + ) + finally: + store.close() + + def test_truncated_flag_false_when_not_capped(self, tmp_path): + """Regression test for #261: find_dependents must set + DependentList.truncated = False when the result is complete.""" + store = self._make_chain_store(tmp_path) + try: + deps = find_dependents(store, "/c.py", max_hops=2) + assert deps.truncated is False, ( + "DependentList.truncated should be False when the " + "expansion completed without hitting the cap" + ) + finally: + store.close() diff --git a/tests/test_search.py b/tests/test_search.py index 165ad83..e7d9067 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -248,3 +248,23 @@ def test_fts_query_with_special_chars(self): results = hybrid_search(self.store, dangerous_query) # Just assert no exception was raised assert isinstance(results, list) + + def test_fts_rebuild_is_atomic(self): + """Regression test for #259: rebuild_fts_index must wrap the DROP + + CREATE + INSERT sequence in a single transaction so a crash between + DROP and CREATE cannot leave the DB without an FTS table.""" + # Build, rebuild, then verify the table exists and is queryable. + rebuild_fts_index(self.store) + + # Verify the FTS table exists and has rows. + conn = self.store._conn + count = conn.execute("SELECT count(*) FROM nodes_fts").fetchone()[0] + assert count > 0 + + # Rebuild again — must not raise and must leave the table intact. + new_count = rebuild_fts_index(self.store) + assert new_count == count + + # Verify search still works after double-rebuild. + results = hybrid_search(self.store, "auth") + assert isinstance(results, list)