diff --git a/code_review_graph/graph.py b/code_review_graph/graph.py index 2dfa97fc..4b2b496b 100644 --- a/code_review_graph/graph.py +++ b/code_review_graph/graph.py @@ -236,16 +236,28 @@ def store_file_nodes_edges( self, file_path: str, nodes: list[NodeInfo], edges: list[EdgeInfo], fhash: str = "" ) -> None: """Atomically replace all data for a file.""" - self._conn.execute("BEGIN IMMEDIATE") + savepoint = "__crg_store_file_nodes_edges__" + use_savepoint = self._conn.in_transaction + if use_savepoint: + self._conn.execute(f"SAVEPOINT {savepoint}") # nosec B608 + else: + self._conn.execute("BEGIN IMMEDIATE") try: self.remove_file_data(file_path) for node in nodes: self.upsert_node(node, file_hash=fhash) for edge in edges: self.upsert_edge(edge) - self._conn.commit() + if use_savepoint: + self._conn.execute(f"RELEASE SAVEPOINT {savepoint}") # nosec B608 + else: + self._conn.commit() except BaseException: - self._conn.rollback() + if use_savepoint: + self._conn.execute(f"ROLLBACK TO SAVEPOINT {savepoint}") # nosec B608 + self._conn.execute(f"RELEASE SAVEPOINT {savepoint}") # nosec B608 + else: + self._conn.rollback() raise self._invalidate_cache() diff --git a/tests/test_graph.py b/tests/test_graph.py index 5923f578..bb253fbb 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -107,6 +107,24 @@ def test_store_file_nodes_edges(self): result = self.store.get_nodes_by_file("/test/file.py") assert len(result) == 2 + def test_store_file_nodes_edges_with_existing_transaction(self): + """File replacement should work even if the caller already opened a transaction.""" + self.store.upsert_node(self._make_file_node("/test/other.py")) + nodes = [self._make_file_node(), self._make_func_node()] + edges = [ + EdgeInfo( + kind="CONTAINS", source="/test/file.py", + target="/test/file.py::my_func", file_path="/test/file.py", + ) + ] + + self.store.store_file_nodes_edges("/test/file.py", nodes, edges) + self.store.commit() + + result = self.store.get_nodes_by_file("/test/file.py") + assert len(result) == 2 + assert self.store.get_node("/test/other.py") is not None + def test_search_nodes(self): self.store.upsert_node(self._make_func_node("authenticate")) self.store.upsert_node(self._make_func_node("authorize"))