Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/interactive_viz_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Interactive Visualization Demo
==============================

This example demonstrates the interactive graph visualizer using a simple
manually constructed pattern. It shows how to step through the visualization
and observe state changes.
"""

from __future__ import annotations

from graphix.command import E, M, N, X, Z
from graphix.measurements import Measurement
from graphix.pattern import Pattern
from graphix.visualization_interactive import InteractiveGraphVisualizer


def main() -> None:
# optimized pattern for QFT
# Create a simple pattern manually for demonstration
p = Pattern(input_nodes=[0, 1])
p.add(N(node=2))
p.add(E(nodes=(0, 2)))
p.add(E(nodes=(1, 2)))
p.add(M(node=0, measurement=Measurement.XY(0.5)))
p.add(M(node=1, measurement=Measurement.XY(0.25)))
p.add(X(node=2, domain={0, 1}))
p.add(Z(node=2, domain={0}))

# Or standardization to make it interesting
# p.standardize()

print("Pattern created with", len(p), "commands.")
print("Launching interactive visualization with real-time simulation...")

viz = InteractiveGraphVisualizer(p)
viz.visualize()


if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions examples/interactive_viz_qaoa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
QAOA Interactive Visualization (Optimized)
==========================================

This example generates a QAOA pattern using the Graphix Circuit API
and launches the interactive visualizer in simulation-free mode
to demonstrate performance on complex patterns.
"""

from __future__ import annotations

import networkx as nx
import numpy as np

from graphix import Circuit
from graphix.visualization_interactive import InteractiveGraphVisualizer


def main() -> None:
print("Generating QAOA pattern...")

# 1. Define QAOA Circuit
n_qubits = 4
rng = np.random.default_rng(42) # Fixed seed for reproducibility

# Random parameters for the circuit
xi = rng.random(6)
theta = rng.random(4)

# Create a complete graph for the problem hamiltonian
g = nx.complete_graph(n_qubits)
circuit = Circuit(n_qubits)

# Apply unitary evolution for the problem Hamiltonian
for i, (u, v) in enumerate(g.edges):
circuit.cnot(u, v)
circuit.rz(v, float(xi[i])) # Rotation by random angle
circuit.cnot(u, v)

# Apply unitary evolution for the mixing Hamiltonian
for v in g.nodes:
circuit.rx(v, float(theta[v]))

# 2. Transpile to MBQC Pattern
# This automatically generates the measurement pattern from the gate circuit
pattern = circuit.transpile().pattern

# Standardize the pattern to ensure it follows the standard MBQC form (N, E, M, C)
pattern.standardize()
pattern.shift_signals()

print(f"Pattern generated with {len(pattern)} commands.")
print("Launching interactive visualizer...")
print("Optimization enabled: Simulation is DISABLED for performance.")
print("You will see the graph structure and command flow without quantum state calculation.")

# 3. Launch Visualization
# enable_simulation=False prevents high RAM usage for this complex pattern
viz = InteractiveGraphVisualizer(pattern, node_distance=(1.5, 1.5), enable_simulation=False)
viz.visualize()


if __name__ == "__main__":
main()
227 changes: 191 additions & 36 deletions graphix/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TypeAlias, TypeVar

import numpy.typing as npt
from matplotlib.axes import Axes

from graphix.clifford import Clifford
from graphix.flow.core import CausalFlow, PauliFlow
Expand Down Expand Up @@ -52,41 +53,25 @@ class GraphVisualizer:
og: OpenGraph[Measurement]
local_clifford: Mapping[int, Clifford] | None = None

def visualize(
def get_layout(
self,
show_pauli_measurement: bool = True,
show_local_clifford: bool = False,
show_measurement_planes: bool = False,
show_loop: bool = True,
node_distance: tuple[float, float] = (1, 1),
figsize: tuple[int, int] | None = None,
filename: Path | None = None,
) -> None:
"""
Visualize the graph with flow or gflow structure.

If there exists a flow structure, then the graph is visualized with the flow structure.
If flow structure is not found and there exists a gflow structure, then the graph is visualized
with the gflow structure.
If neither flow nor gflow structure is found, then the graph is visualized without any structure.
) -> tuple[
Mapping[int, _Point],
Callable[
[Mapping[int, _Point]], tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None]
],
Mapping[int, int] | None,
]:
"""Determine the layout (positions, paths, layers) for the graph.

Parameters
----------
show_pauli_measurement : bool
If True, the nodes with Pauli measurement angles are colored light blue.
show_local_clifford : bool
If True, indexes of the local Clifford operator are displayed adjacent to the nodes.
show_measurement_planes : bool
If True, the measurement planes are displayed adjacent to the nodes.
show_loop : bool
whether or not to show loops for graphs with gflow. defaulted to True.
node_distance : tuple
Distance multiplication factor between nodes for x and y directions.
figsize : tuple
Figure size of the plot.
filename : Path | None
If not None, filename of the png file to save the plot. If None, the plot is not saved.
Default in None.
Returns
-------
pos : dict
Node positions.
place_paths : callable
Function to place edges and arrows.
l_k : dict or None
Layer mapping.
"""
try:
bloch_graph = self.og.downcast_bloch()
Expand Down Expand Up @@ -131,6 +116,46 @@ def place_paths(
) -> tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None]:
return (self.place_edge_paths_without_structure(pos), None)

return pos, place_paths, l_k

def visualize(
self,
show_pauli_measurement: bool = True,
show_local_clifford: bool = False,
show_measurement_planes: bool = False,
show_loop: bool = True,
node_distance: tuple[float, float] = (1, 1),
figsize: tuple[int, int] | None = None,
filename: Path | None = None,
) -> None:
"""
Visualize the graph with flow or gflow structure.

If there exists a flow structure, then the graph is visualized with the flow structure.
If flow structure is not found and there exists a gflow structure, then the graph is visualized
with the gflow structure.
If neither flow nor gflow structure is found, then the graph is visualized without any structure.

Parameters
----------
show_pauli_measurement : bool
If True, the nodes with Pauli measurement angles are colored light blue.
show_local_clifford : bool
If True, indexes of the local Clifford operator are displayed adjacent to the nodes.
show_measurement_planes : bool
If True, the measurement planes are displayed adjacent to the nodes.
show_loop : bool
whether or not to show loops for graphs with gflow. defaulted to True.
node_distance : tuple
Distance multiplication factor between nodes for x and y directions.
figsize : tuple
Figure size of the plot.
filename : Path | None
If not None, filename of the png file to save the plot. If None, the plot is not saved.
Default in None.
"""
pos, place_paths, l_k = self.get_layout()

self.visualize_graph(
pos,
place_paths,
Expand Down Expand Up @@ -253,11 +278,141 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]:
return new_path

def _draw_labels(self, pos: Mapping[int, _Point]) -> None:
fontsize = 12
if max(self.og.graph.nodes(), default=0) >= 100:
fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes()))))
fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0))
nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize)

def draw_node_labels(self, ax: Axes, pos: Mapping[int, _Point]) -> None:
"""Draw node labels onto a given axes object.

This is an axis-aware counterpart of :meth:`_draw_labels` intended for
use in contexts where the caller manages the :class:`~matplotlib.axes.Axes`
directly (e.g. the interactive visualizer).

Parameters
----------
ax : Axes
The matplotlib axes to draw onto.
pos : Mapping[int, tuple[float, float]]
Dictionary mapping each node to its ``(x, y)`` position.
"""
fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0))
for node, (x, y) in pos.items():
ax.text(x, y, str(node), ha="center", va="center", fontsize=fontsize, zorder=3)

@staticmethod
def get_label_fontsize(max_node: int, base_size: int = 12) -> int:
"""Compute the font size for node labels.

When the largest node number has many digits the font is reduced
so that labels still fit inside the scatter markers.

Parameters
----------
max_node : int
The largest node number in the graph.
base_size : int, optional
The default font size used for small node numbers.
Defaults to ``12``.

Returns
-------
int
The computed font size, never smaller than ``7``.
"""
if max_node >= 100:
return max(7, int(base_size * 2 / len(str(max_node))))
return base_size

def draw_edges(
self,
ax: Axes,
pos: Mapping[int, _Point],
edge_subset: Iterable[tuple[int, ...]] | None = None,
) -> None:
"""Draw graph edges as plain lines onto a given axes object.

This axis-aware method is intended for use in contexts where the caller
manages the :class:`~matplotlib.axes.Axes` directly (e.g. the
interactive visualizer).

Parameters
----------
ax : Axes
The matplotlib axes to draw onto.
pos : Mapping[int, tuple[float, float]]
Dictionary mapping each node to its ``(x, y)`` position.
edge_subset : Iterable[tuple[int, int]] or None, optional
If provided, only these edges are drawn. When ``None``
(the default), all edges in :attr:`og.graph` are drawn.
"""
edges: Iterable[tuple[int, ...]] = self.og.graph.edges() if edge_subset is None else edge_subset
for u, v in edges:
if u in pos and v in pos:
x1, y1 = pos[u]
x2, y2 = pos[v]
ax.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1)

def draw_nodes_role(
self,
ax: Axes,
pos: Mapping[int, _Point],
show_pauli_measurement: bool = False,
node_facecolors: Mapping[int, str] | None = None,
node_edgecolors: Mapping[int, str] | None = None,
node_size: int = 350,
) -> None:
"""Draw nodes onto a given axes object, coloured by their role.

This is an axis-aware counterpart of the private ``__draw_nodes_role``
method, intended for use in contexts where the caller manages the
:class:`~matplotlib.axes.Axes` directly (e.g. the interactive
visualizer). Nodes are styled as follows:

* Input nodes: red border, white fill.
* Output nodes: black border, light-gray fill.
* Pauli-measured nodes (when *show_pauli_measurement* is ``True``):
black border, light-blue fill.
* All other nodes: black border, white fill.

When *node_facecolors* or *node_edgecolors* are provided, their values
override the role-based defaults for the corresponding nodes.

Parameters
----------
ax : Axes
The matplotlib axes to draw onto.
pos : Mapping[int, tuple[float, float]]
Dictionary mapping each node to its ``(x, y)`` position.
show_pauli_measurement : bool, optional
If ``True``, nodes with Pauli measurement angles are coloured
light blue. Defaults to ``False``.
node_facecolors : Mapping[int, str] or None, optional
Per-node fill colour overrides. When a node appears in this
mapping its value is used instead of the role-based default.
node_edgecolors : Mapping[int, str] or None, optional
Per-node border colour overrides.
node_size : int, optional
Marker size for :meth:`~matplotlib.axes.Axes.scatter`.
Defaults to ``350``.
"""
for node in self.og.graph.nodes():
if node not in pos:
continue
edgecolor = "black"
facecolor = "white"
if node in self.og.input_nodes:
edgecolor = "red"
if node in self.og.output_nodes:
facecolor = "lightgray"
elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement):
facecolor = "lightblue"
# Apply per-node overrides if provided
if node_facecolors is not None and node in node_facecolors:
facecolor = node_facecolors[node]
if node_edgecolors is not None and node in node_edgecolors:
edgecolor = node_edgecolors[node]
ax.scatter(*pos[node], edgecolors=edgecolor, facecolors=facecolor, s=node_size, zorder=2, linewidths=1.5)

def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None:
"""
Draw the nodes with different colors based on their role (input, output, or other).
Expand Down
Loading