Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions code_review_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,28 @@ def store_file_nodes_edges(
self, file_path: str, nodes: list[NodeInfo], edges: list[EdgeInfo], fhash: str = ""
) -> None:
"""Atomically replace all data for a file."""
self._conn.execute("BEGIN IMMEDIATE")
savepoint = "__crg_store_file_nodes_edges__"
use_savepoint = self._conn.in_transaction
if use_savepoint:
self._conn.execute(f"SAVEPOINT {savepoint}") # nosec B608
else:
self._conn.execute("BEGIN IMMEDIATE")
try:
self.remove_file_data(file_path)
for node in nodes:
self.upsert_node(node, file_hash=fhash)
for edge in edges:
self.upsert_edge(edge)
self._conn.commit()
if use_savepoint:
self._conn.execute(f"RELEASE SAVEPOINT {savepoint}") # nosec B608
else:
self._conn.commit()
except BaseException:
self._conn.rollback()
if use_savepoint:
self._conn.execute(f"ROLLBACK TO SAVEPOINT {savepoint}") # nosec B608
self._conn.execute(f"RELEASE SAVEPOINT {savepoint}") # nosec B608
else:
self._conn.rollback()
raise
self._invalidate_cache()

Expand Down
18 changes: 18 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def test_store_file_nodes_edges(self):
result = self.store.get_nodes_by_file("/test/file.py")
assert len(result) == 2

def test_store_file_nodes_edges_with_existing_transaction(self):
"""File replacement should work even if the caller already opened a transaction."""
self.store.upsert_node(self._make_file_node("/test/other.py"))
nodes = [self._make_file_node(), self._make_func_node()]
edges = [
EdgeInfo(
kind="CONTAINS", source="/test/file.py",
target="/test/file.py::my_func", file_path="/test/file.py",
)
]

self.store.store_file_nodes_edges("/test/file.py", nodes, edges)
self.store.commit()

result = self.store.get_nodes_by_file("/test/file.py")
assert len(result) == 2
assert self.store.get_node("/test/other.py") is not None

def test_search_nodes(self):
self.store.upsert_node(self._make_func_node("authenticate"))
self.store.upsert_node(self._make_func_node("authorize"))
Expand Down