forked from KnowledgeXLab/LeanRAG
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize_graph.py
More file actions
155 lines (133 loc) · 6.03 KB
/
visualize_graph.py
File metadata and controls
155 lines (133 loc) · 6.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
"""Graph visualization utilities for Neo4j data.
Features:
- Fetch a subgraph filtered by 'level' property on nodes if present.
- Limit number of nodes and relationships for lightweight display.
- Layout options: spring, kamada_kawai, circular.
- Outputs an interactive matplotlib window or saves to file.
Usage examples:
python visualize_graph.py --level 0 --limit 100 --layout spring
python visualize_graph.py --level 2 --limit 50 --output graph_lvl2.png
python visualize_graph.py --limit 200 --relationship-type RELATES_TO
Environment variables used:
GRAPH_URI, GRAPH_USER, GRAPH_PASSWORD
"""
from __future__ import annotations
import os
import argparse
from neo4j import GraphDatabase
import networkx as nx
import matplotlib.pyplot as plt
def fetch_subgraph(level: int | None, limit: int, rel_type: str | None, sample: bool) -> tuple[list[dict], list[dict]]:
uri = os.getenv("GRAPH_URI", "bolt://localhost:7687")
user = os.getenv("GRAPH_USER", "neo4j")
password = os.getenv("GRAPH_PASSWORD", "test123456")
driver = GraphDatabase.driver(uri, auth=(user, password))
# Cypher dynamically filters on level if property exists; uses apoc-less pure Cypher
level_clause = """
AND (
$level IS NULL OR (
n:Entity AND exists(n.level) AND n.level = $level
)
)
"""
rel_match = "" if not rel_type else "AND type(r) = $rel_type"
# Use ORDER BY rand() for a random sample if requested (Neo4j may warn about performance for large graphs)
order_clause = "ORDER BY rand()" if sample else ""
# Two-step query: gather node ids then fetch relationships among them to keep it simple
node_query = f"""
MATCH (n)
WHERE (n:Entity OR n:Community)
{level_clause}
RETURN id(n) AS id,
labels(n) AS labels,
CASE WHEN n:Entity THEN n.name ELSE n.entity_name END AS title,
CASE WHEN n:Entity THEN n.level ELSE NULL END AS level,
CASE WHEN n:Entity THEN n.description ELSE n.entity_description END AS description
{order_clause}
LIMIT $limit
"""
rel_query = f"""
MATCH (a)-[r]->(b)
WHERE id(a) in $ids AND id(b) in $ids {rel_match}
RETURN id(a) AS src_id, id(b) AS tgt_id, type(r) AS type, r.description AS description
LIMIT $rel_limit
"""
with driver.session() as session:
nodes = session.run(node_query, level=level, limit=limit).data()
ids = [n['id'] for n in nodes]
if not ids:
return [], []
relationships = session.run(rel_query, ids=ids, rel_type=rel_type, rel_limit=limit*5).data()
driver.close()
return nodes, relationships
def build_nx_graph(nodes: list[dict], relationships: list[dict]) -> nx.Graph:
G = nx.DiGraph()
for n in nodes:
label = n.get('title') or '∅'
level = n.get('level')
G.add_node(
n['id'],
label=label,
level=level,
description=n.get('description'),
labels=n.get('labels', []),
)
for r in relationships:
G.add_edge(r['src_id'], r['tgt_id'], type=r['type'], description=r.get('description'))
return G
def draw_graph(G: nx.Graph, layout: str, output: str | None, show_labels: bool, node_size: int, font_size: int):
if G.number_of_nodes() == 0:
print("No nodes to display")
return
if layout == 'spring':
pos = nx.spring_layout(G, seed=42, k=0.5)
elif layout == 'kamada':
pos = nx.kamada_kawai_layout(G)
elif layout == 'circular':
pos = nx.circular_layout(G)
else:
pos = nx.spring_layout(G, seed=42)
levels = [G.nodes[n].get('level') for n in G.nodes]
# Map levels to colors
unique_levels = sorted({lv for lv in levels if lv is not None})
color_map = {}
palette = plt.cm.get_cmap('tab10', max(1, len(unique_levels)))
for idx, lv in enumerate(unique_levels):
color_map[lv] = palette(idx)
node_colors = [color_map.get(lv, (0.7,0.7,0.7,1.0)) for lv in levels]
plt.figure(figsize=(10, 7))
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_size, linewidths=0.5, edgecolors='black')
nx.draw_networkx_edges(G, pos, alpha=0.4, arrows=True, arrowstyle='-|>', arrowsize=10)
if show_labels:
labels = {n: G.nodes[n]['label'] for n in G.nodes}
nx.draw_networkx_labels(G, pos, labels, font_size=font_size)
plt.axis('off')
title_parts = [f"Nodes={G.number_of_nodes()}", f"Edges={G.number_of_edges()}"]
if unique_levels:
title_parts.append(f"Levels={len(unique_levels)}")
plt.title("Graph View (" + ", ".join(title_parts) + ")")
if output:
plt.tight_layout()
plt.savefig(output, dpi=150)
print(f"Saved visualization to {output}")
else:
plt.show()
def main():
parser = argparse.ArgumentParser(description="Visualize a Neo4j subgraph by level")
parser.add_argument('--level', type=int, help='Filter nodes by specific level value')
parser.add_argument('--limit', type=int, default=200, help='Maximum nodes to fetch')
parser.add_argument('--relationship-type', '--rel-type', dest='rel_type', help='Only include this relationship type')
parser.add_argument('--layout', choices=['spring', 'kamada', 'circular'], default='spring', help='Layout algorithm')
parser.add_argument('--no-labels', action='store_true', help='Hide node labels')
parser.add_argument('--output', help='Path to save image instead of showing interactively')
parser.add_argument('--sample', action='store_true', help='Random sample (ORDER BY rand())')
parser.add_argument('--node-size', type=int, default=500, help='Node size')
parser.add_argument('--font-size', type=int, default=8, help='Font size for labels')
args = parser.parse_args()
nodes, rels = fetch_subgraph(args.level, args.limit, args.rel_type, args.sample)
print(f"Fetched {len(nodes)} nodes and {len(rels)} relationships")
G = build_nx_graph(nodes, rels)
draw_graph(G, args.layout, args.output, not args.no_labels, args.node_size, args.font_size)
if __name__ == '__main__':
main()