diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py new file mode 100644 index 00000000..31accf69 --- /dev/null +++ b/examples/interactive_viz_demo.py @@ -0,0 +1,41 @@ +""" +Interactive Visualization Demo +============================== + +This example demonstrates the interactive graph visualizer using a simple +manually constructed pattern. It shows how to step through the visualization +and observe state changes. +""" + +from __future__ import annotations + +from graphix.command import E, M, N, X, Z +from graphix.measurements import Measurement +from graphix.pattern import Pattern +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +def main() -> None: + # optimized pattern for QFT + # Create a simple pattern manually for demonstration + p = Pattern(input_nodes=[0, 1]) + p.add(N(node=2)) + p.add(E(nodes=(0, 2))) + p.add(E(nodes=(1, 2))) + p.add(M(node=0, measurement=Measurement.XY(0.5))) + p.add(M(node=1, measurement=Measurement.XY(0.25))) + p.add(X(node=2, domain={0, 1})) + p.add(Z(node=2, domain={0})) + + # Or standardization to make it interesting + # p.standardize() + + print("Pattern created with", len(p), "commands.") + print("Launching interactive visualization with real-time simulation...") + + viz = InteractiveGraphVisualizer(p) + viz.visualize() + + +if __name__ == "__main__": + main() diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py new file mode 100644 index 00000000..b830e560 --- /dev/null +++ b/examples/interactive_viz_qaoa.py @@ -0,0 +1,64 @@ +""" +QAOA Interactive Visualization (Optimized) +========================================== + +This example generates a QAOA pattern using the Graphix Circuit API +and launches the interactive visualizer in simulation-free mode +to demonstrate performance on complex patterns. +""" + +from __future__ import annotations + +import networkx as nx +import numpy as np + +from graphix import Circuit +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +def main() -> None: + print("Generating QAOA pattern...") + + # 1. Define QAOA Circuit + n_qubits = 4 + rng = np.random.default_rng(42) # Fixed seed for reproducibility + + # Random parameters for the circuit + xi = rng.random(6) + theta = rng.random(4) + + # Create a complete graph for the problem hamiltonian + g = nx.complete_graph(n_qubits) + circuit = Circuit(n_qubits) + + # Apply unitary evolution for the problem Hamiltonian + for i, (u, v) in enumerate(g.edges): + circuit.cnot(u, v) + circuit.rz(v, float(xi[i])) # Rotation by random angle + circuit.cnot(u, v) + + # Apply unitary evolution for the mixing Hamiltonian + for v in g.nodes: + circuit.rx(v, float(theta[v])) + + # 2. Transpile to MBQC Pattern + # This automatically generates the measurement pattern from the gate circuit + pattern = circuit.transpile().pattern + + # Standardize the pattern to ensure it follows the standard MBQC form (N, E, M, C) + pattern.standardize() + pattern.shift_signals() + + print(f"Pattern generated with {len(pattern)} commands.") + print("Launching interactive visualizer...") + print("Optimization enabled: Simulation is DISABLED for performance.") + print("You will see the graph structure and command flow without quantum state calculation.") + + # 3. Launch Visualization + # enable_simulation=False prevents high RAM usage for this complex pattern + viz = InteractiveGraphVisualizer(pattern, node_distance=(1.5, 1.5), enable_simulation=False) + viz.visualize() + + +if __name__ == "__main__": + main() diff --git a/graphix/visualization.py b/graphix/visualization.py index 4322b82b..27366fd7 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -24,6 +24,7 @@ from typing import TypeAlias, TypeVar import numpy.typing as npt + from matplotlib.axes import Axes from graphix.clifford import Clifford from graphix.flow.core import CausalFlow, PauliFlow @@ -52,41 +53,25 @@ class GraphVisualizer: og: OpenGraph[Measurement] local_clifford: Mapping[int, Clifford] | None = None - def visualize( + def get_layout( self, - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - show_loop: bool = True, - node_distance: tuple[float, float] = (1, 1), - figsize: tuple[int, int] | None = None, - filename: Path | None = None, - ) -> None: - """ - Visualize the graph with flow or gflow structure. - - If there exists a flow structure, then the graph is visualized with the flow structure. - If flow structure is not found and there exists a gflow structure, then the graph is visualized - with the gflow structure. - If neither flow nor gflow structure is found, then the graph is visualized without any structure. + ) -> tuple[ + Mapping[int, _Point], + Callable[ + [Mapping[int, _Point]], tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None] + ], + Mapping[int, int] | None, + ]: + """Determine the layout (positions, paths, layers) for the graph. - Parameters - ---------- - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - show_loop : bool - whether or not to show loops for graphs with gflow. defaulted to True. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - filename : Path | None - If not None, filename of the png file to save the plot. If None, the plot is not saved. - Default in None. + Returns + ------- + pos : dict + Node positions. + place_paths : callable + Function to place edges and arrows. + l_k : dict or None + Layer mapping. """ try: bloch_graph = self.og.downcast_bloch() @@ -131,6 +116,46 @@ def place_paths( ) -> tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None]: return (self.place_edge_paths_without_structure(pos), None) + return pos, place_paths, l_k + + def visualize( + self, + show_pauli_measurement: bool = True, + show_local_clifford: bool = False, + show_measurement_planes: bool = False, + show_loop: bool = True, + node_distance: tuple[float, float] = (1, 1), + figsize: tuple[int, int] | None = None, + filename: Path | None = None, + ) -> None: + """ + Visualize the graph with flow or gflow structure. + + If there exists a flow structure, then the graph is visualized with the flow structure. + If flow structure is not found and there exists a gflow structure, then the graph is visualized + with the gflow structure. + If neither flow nor gflow structure is found, then the graph is visualized without any structure. + + Parameters + ---------- + show_pauli_measurement : bool + If True, the nodes with Pauli measurement angles are colored light blue. + show_local_clifford : bool + If True, indexes of the local Clifford operator are displayed adjacent to the nodes. + show_measurement_planes : bool + If True, the measurement planes are displayed adjacent to the nodes. + show_loop : bool + whether or not to show loops for graphs with gflow. defaulted to True. + node_distance : tuple + Distance multiplication factor between nodes for x and y directions. + figsize : tuple + Figure size of the plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. + """ + pos, place_paths, l_k = self.get_layout() + self.visualize_graph( pos, place_paths, @@ -253,11 +278,141 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]: return new_path def _draw_labels(self, pos: Mapping[int, _Point]) -> None: - fontsize = 12 - if max(self.og.graph.nodes(), default=0) >= 100: - fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) + def draw_node_labels(self, ax: Axes, pos: Mapping[int, _Point]) -> None: + """Draw node labels onto a given axes object. + + This is an axis-aware counterpart of :meth:`_draw_labels` intended for + use in contexts where the caller manages the :class:`~matplotlib.axes.Axes` + directly (e.g. the interactive visualizer). + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + """ + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) + for node, (x, y) in pos.items(): + ax.text(x, y, str(node), ha="center", va="center", fontsize=fontsize, zorder=3) + + @staticmethod + def get_label_fontsize(max_node: int, base_size: int = 12) -> int: + """Compute the font size for node labels. + + When the largest node number has many digits the font is reduced + so that labels still fit inside the scatter markers. + + Parameters + ---------- + max_node : int + The largest node number in the graph. + base_size : int, optional + The default font size used for small node numbers. + Defaults to ``12``. + + Returns + ------- + int + The computed font size, never smaller than ``7``. + """ + if max_node >= 100: + return max(7, int(base_size * 2 / len(str(max_node)))) + return base_size + + def draw_edges( + self, + ax: Axes, + pos: Mapping[int, _Point], + edge_subset: Iterable[tuple[int, ...]] | None = None, + ) -> None: + """Draw graph edges as plain lines onto a given axes object. + + This axis-aware method is intended for use in contexts where the caller + manages the :class:`~matplotlib.axes.Axes` directly (e.g. the + interactive visualizer). + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + edge_subset : Iterable[tuple[int, int]] or None, optional + If provided, only these edges are drawn. When ``None`` + (the default), all edges in :attr:`og.graph` are drawn. + """ + edges: Iterable[tuple[int, ...]] = self.og.graph.edges() if edge_subset is None else edge_subset + for u, v in edges: + if u in pos and v in pos: + x1, y1 = pos[u] + x2, y2 = pos[v] + ax.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1) + + def draw_nodes_role( + self, + ax: Axes, + pos: Mapping[int, _Point], + show_pauli_measurement: bool = False, + node_facecolors: Mapping[int, str] | None = None, + node_edgecolors: Mapping[int, str] | None = None, + node_size: int = 350, + ) -> None: + """Draw nodes onto a given axes object, coloured by their role. + + This is an axis-aware counterpart of the private ``__draw_nodes_role`` + method, intended for use in contexts where the caller manages the + :class:`~matplotlib.axes.Axes` directly (e.g. the interactive + visualizer). Nodes are styled as follows: + + * Input nodes: red border, white fill. + * Output nodes: black border, light-gray fill. + * Pauli-measured nodes (when *show_pauli_measurement* is ``True``): + black border, light-blue fill. + * All other nodes: black border, white fill. + + When *node_facecolors* or *node_edgecolors* are provided, their values + override the role-based defaults for the corresponding nodes. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + show_pauli_measurement : bool, optional + If ``True``, nodes with Pauli measurement angles are coloured + light blue. Defaults to ``False``. + node_facecolors : Mapping[int, str] or None, optional + Per-node fill colour overrides. When a node appears in this + mapping its value is used instead of the role-based default. + node_edgecolors : Mapping[int, str] or None, optional + Per-node border colour overrides. + node_size : int, optional + Marker size for :meth:`~matplotlib.axes.Axes.scatter`. + Defaults to ``350``. + """ + for node in self.og.graph.nodes(): + if node not in pos: + continue + edgecolor = "black" + facecolor = "white" + if node in self.og.input_nodes: + edgecolor = "red" + if node in self.og.output_nodes: + facecolor = "lightgray" + elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): + facecolor = "lightblue" + # Apply per-node overrides if provided + if node_facecolors is not None and node in node_facecolors: + facecolor = node_facecolors[node] + if node_edgecolors is not None and node in node_edgecolors: + edgecolor = node_edgecolors[node] + ax.scatter(*pos[node], edgecolors=edgecolor, facecolors=facecolor, s=node_size, zorder=2, linewidths=1.5) + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: """ Draw the nodes with different colors based on their role (input, output, or other). diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py new file mode 100644 index 00000000..e0d6beac --- /dev/null +++ b/graphix/visualization_interactive.py @@ -0,0 +1,474 @@ +"""Interactive visualization for MBQC patterns.""" + +from __future__ import annotations + +import sys +import traceback +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.text import Text +from matplotlib.widgets import Button, Slider + +from graphix.clifford import Clifford +from graphix.command import CommandKind +from graphix.opengraph import OpenGraph +from graphix.pretty_print import OutputFormat, command_to_str +from graphix.sim.statevec import StatevectorBackend +from graphix.visualization import GraphVisualizer + +if TYPE_CHECKING: + from graphix.pattern import Pattern + + +class InteractiveGraphVisualizer: + """Interactive visualization tool for MBQC patterns. + + Attributes + ---------- + pattern : Pattern + The MBQC pattern to visualize. + node_distance : tuple[float, float] + Scale factors (x, y) for the node positions. + enable_simulation : bool + If True, simulates the state vector and measurement outcomes. + marker_fill_ratio : float + Fraction of the inter-node spacing used by each marker (0-1). + label_size_ratio : float + Label font size as a fraction of the marker diameter. + max_label_fontsize : int + Upper bound for label font size in points. + min_inches_per_node : float + Minimum vertical inches per node for adaptive figure height. + active_node_color : str + Border colour for active (current-step) nodes. + measured_node_color : str + Fill colour for already-measured nodes. + """ + + def __init__( + self, + pattern: Pattern, + node_distance: tuple[float, float] = (1, 1), + enable_simulation: bool = True, + *, + marker_fill_ratio: float = 0.80, + label_size_ratio: float = 0.55, + max_label_fontsize: int = 12, + min_inches_per_node: float = 0.3, + active_node_color: str = "#2060cc", + measured_node_color: str = "lightgray", + ) -> None: + """Construct an interactive visualizer. + + Parameters + ---------- + pattern : Pattern + The MBQC pattern to visualize. + node_distance : tuple[float, float], optional + Scale factors (x, y) for node positions. Defaults to (1, 1). + enable_simulation : bool, optional + If True, enables state vector simulation. Defaults to True. + marker_fill_ratio : float, optional + Fraction of the inter-node spacing used by marker diameter. + Defaults to 0.80. + label_size_ratio : float, optional + Label font size as a fraction of the marker diameter in points. + Defaults to 0.55. + max_label_fontsize : int, optional + Upper bound for label font size. Prevents text from overflowing + the marker in sparse graphs. Defaults to 12. + min_inches_per_node : float, optional + Minimum vertical inches allocated per node when computing the + adaptive figure height. Defaults to 0.3. + active_node_color : str, optional + Border colour for active nodes. Defaults to ``"#2060cc"``. + measured_node_color : str, optional + Fill colour for measured nodes. Defaults to ``"lightgray"``. + """ + self.pattern = pattern + self.node_positions: dict[int, tuple[float, float]] = {} + self.node_distance = node_distance + self.enable_simulation = enable_simulation + self.marker_fill_ratio = marker_fill_ratio + self.label_size_ratio = label_size_ratio + self.max_label_fontsize = max_label_fontsize + self.min_inches_per_node = min_inches_per_node + self.active_node_color = active_node_color + self.measured_node_color = measured_node_color + + # Prepare graph layout reusing GraphVisualizer + self._prepare_layout() + + # Figure height adapts to graph density so circles and labels remain + # readable even for dense layouts like QAOA. + ax_h_frac = 0.80 # height fraction of ax_graph in figure + min_fig_height = 7 + if self.node_positions: + ys = [p[1] for p in self.node_positions.values()] + y_data_span = max(ys) - min(ys) + 1 + needed_height = y_data_span * self.min_inches_per_node / ax_h_frac + fig_height = max(min_fig_height, needed_height) + else: + fig_height = min_fig_height + y_data_span = 1 + + # Compute node marker size and label font size ONCE from the known + # figure geometry. This avoids instability when the user resizes the + # window, because the values are fixed at construction time. + ax_height_inches = fig_height * ax_h_frac + y_margin = y_data_span * 0.08 + 0.5 # mirrors _draw_graph margin + y_range = y_data_span + 2 * y_margin + points_per_unit = (ax_height_inches / y_range) * 72 # 72 pt/inch + marker_diameter = self.marker_fill_ratio * points_per_unit + self.node_size: int = max(30, int(marker_diameter**2)) + self.label_fontsize: int = min(self.max_label_fontsize, max(6, int(marker_diameter * self.label_size_ratio))) + + self.fig = plt.figure(figsize=(14, fig_height)) + + # Grid layout: command list (~27%), graph (~65%), bottom strips for controls + self.ax_commands = self.fig.add_axes((0.02, 0.15, 0.27, ax_h_frac)) + self.ax_cmd_scroll = self.fig.add_axes((0.02, 0.08, 0.27, 0.03)) + self.ax_graph = self.fig.add_axes((0.32, 0.15, 0.65, ax_h_frac)) + self.ax_prev = self.fig.add_axes((0.32, 0.04, 0.03, 0.03)) + self.ax_slider = self.fig.add_axes((0.40, 0.04, 0.48, 0.03)) + self.ax_next = self.fig.add_axes((0.90, 0.04, 0.03, 0.03)) + + # Turn off axes frame for command list and graph + self.ax_commands.axis("off") + self.ax_graph.axis("off") + + # Interaction state + self.current_step = 0 + self.total_steps = len(pattern) + self.command_window_size = 30 + self._cmd_scroll_offset: int = 0 # first visible command index + + # Widget placeholders + self.slider: Slider | None = None + self.cmd_scroll_slider: Slider | None = None + self.btn_prev: Button | None = None + self.btn_next: Button | None = None + + def _prepare_layout(self) -> None: + """Compute node positions by reusing :class:`GraphVisualizer` layout. + + Builds the full graph from the pattern commands, delegates layout + computation to :meth:`GraphVisualizer.get_layout`, and normalizes + the resulting positions to fit the interactive panel area. + The flow-based layout is always preserved. + """ + # Build the full graph from all commands + g: Any = __import__("networkx").Graph() + measurements: dict[int, Any] = {} + for cmd in self.pattern: + if cmd.kind == CommandKind.N: + g.add_node(cmd.node) + elif cmd.kind == CommandKind.E: + g.add_edge(cmd.nodes[0], cmd.nodes[1]) + elif cmd.kind == CommandKind.M: + measurements[cmd.node] = cmd.measurement + + # Delegate layout to GraphVisualizer (shares flow-detection logic) + og = OpenGraph(g, self.pattern.input_nodes, self.pattern.output_nodes, measurements) + og = og.infer_pauli_measurements() + + vis = GraphVisualizer(og) + pos_mapping, _, _ = vis.get_layout() + self.node_positions = dict(pos_mapping) + + # Apply user-provided scaling + self.node_positions = { + k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() + } + # Store the visualizer for reuse in drawing helpers + self._graph_visualizer = vis + + def visualize(self) -> None: + """Launch the interactive visualization window.""" + # Initial draw + self._draw_command_list() + self._draw_graph() + self._update(0) + + # Step slider (horizontal, bottom) + self.slider = Slider(self.ax_slider, "Step", 0, self.total_steps, valinit=0, valstep=1, color="lightblue") + self.slider.on_changed(self._update) + + # Command list scroll slider (horizontal, below command panel) + max_scroll = max(0, self.total_steps - self.command_window_size) + self.cmd_scroll_slider = Slider( + self.ax_cmd_scroll, + "", + 0, + max(1, max_scroll), + valinit=0, + valstep=1, + color="#cccccc", + ) + self.cmd_scroll_slider.on_changed(self._on_cmd_scroll) + + # Buttons config + self.btn_prev = Button(self.ax_prev, "<") + self.btn_prev.on_clicked(self._prev_step) + + self.btn_next = Button(self.ax_next, ">") + self.btn_next.on_clicked(self._next_step) + + # Key events + self.fig.canvas.mpl_connect("key_press_event", self._on_key) + + # Pick events for command list + self.fig.canvas.mpl_connect("pick_event", self._on_pick) + + plt.show() + + def _draw_command_list(self) -> None: + self.ax_commands.clear() + self.ax_commands.axis("off") + self.ax_commands.set_title(f"Commands ({self.total_steps})", loc="left") + + # Use scroll offset for visible window + start = max(0, min(self._cmd_scroll_offset, self.total_steps - self.command_window_size)) + end = min(self.total_steps, start + self.command_window_size) + + cmds: Any = self.pattern[start:end] # type: ignore[index] + + for i, cmd in enumerate(cmds): + abs_idx = start + i + text_str = f"{abs_idx}: {command_to_str(cmd, OutputFormat.Unicode)}" + + color = "black" + weight = "normal" + if abs_idx < self.current_step: + color = "green" + elif abs_idx == self.current_step: + color = "red" + weight = "bold" + + # Position text from top to bottom + y_pos = 1.0 - (i + 1) * (1.0 / (self.command_window_size + 2)) + + text_obj = self.ax_commands.text( + 0.05, + y_pos, + text_str, + color=color, + weight=weight, + fontsize=10, + transform=self.ax_commands.transAxes, + picker=True, + ) + # Store index with artist for picking + text_obj.index = abs_idx # type: ignore[attr-defined] + + def _update_graph_state( + self, step: int + ) -> tuple[set[int], set[int], list[tuple[int, ...]], dict[int, set[str]], dict[int, int]]: + """Calculate the graph state by simulating the pattern up to *step*. + + Parameters + ---------- + step : int + The command index up to which the pattern is executed. + + Returns + ------- + active_nodes : set[int] + Nodes that have been initialised but not yet measured. + measured_nodes : set[int] + Nodes that have been measured. + active_edges : list[tuple[int, ...]] + Edges currently present in the graph (both endpoints active). + corrections : dict[int, set[str]] + Accumulated byproduct corrections per node (``"X"`` and/or ``"Z"``). + results : dict[int, int] + Measurement outcomes keyed by node (only populated when + *enable_simulation* is ``True``). + """ + active_nodes = set() + measured_nodes = set() + active_edges = [] + corrections: dict[int, set[str]] = {} + results: dict[int, int] = {} + + if self.enable_simulation: + backend = StatevectorBackend() + + # Prerun input nodes (standard MBQC initialization) + for node in self.pattern.input_nodes: + backend.add_nodes([node]) + + rng = np.random.default_rng(42) # Fixed seed for determinism + + for i in range(step): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + backend.add_nodes([cmd.node], data=cmd.state) + elif cmd.kind == CommandKind.E: + backend.entangle_nodes(cmd.nodes) + elif cmd.kind == CommandKind.M: + # Adaptive measurement (feedforward) + s_signal = sum(results.get(j, 0) for j in cmd.s_domain) if cmd.s_domain else 0 + t_signal = sum(results.get(j, 0) for j in cmd.t_domain) if cmd.t_domain else 0 + + clifford = Clifford.I + if s_signal % 2 == 1: + clifford = Clifford.X @ clifford + if t_signal % 2 == 1: + clifford = Clifford.Z @ clifford + + measurement = cmd.measurement.clifford(clifford) + result = backend.measure(cmd.node, measurement, rng=rng) + results[cmd.node] = result + elif cmd.kind == CommandKind.X: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add("X") + backend.correct_byproduct(cmd) + elif cmd.kind == CommandKind.Z: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add("Z") + backend.correct_byproduct(cmd) + + # ---- Topological tracking (independent of simulation) ---- + current_active: set[int] = set(self.pattern.input_nodes) + current_edges: set[tuple[int, ...]] = set() + current_measured: set[int] = set() + + for i in range(step): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + current_active.add(cmd.node) + elif cmd.kind == CommandKind.E: + u, v = cmd.nodes + if u in current_active and v in current_active: + current_edges.add(tuple(sorted((u, v)))) + elif cmd.kind == CommandKind.M and cmd.node in current_active: + current_active.remove(cmd.node) + current_measured.add(cmd.node) + current_edges = {e for e in current_edges if cmd.node not in e} + + active_nodes = current_active + measured_nodes = current_measured + active_edges = list(current_edges) + + return active_nodes, measured_nodes, active_edges, corrections, results + + def _draw_graph(self) -> None: + """Draw nodes and edges onto the graph axes. + + Delegates to :class:`GraphVisualizer` for edge and node rendering, + passing per-node colour overrides to distinguish measured (grey) from + active (blue border) nodes. Labels are drawn locally because they + include dynamic content (measurement results, corrections). + """ + try: + self.ax_graph.clear() + + active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( + self.current_step + ) + + # ---- Edges (delegate to GraphVisualizer) ---- + self._graph_visualizer.draw_edges(self.ax_graph, self.node_positions, edge_subset=active_edges) + + # ---- Axis limits (set before drawing nodes so geometry is known) ---- + xs = [p[0] for p in self.node_positions.values()] + ys = [p[1] for p in self.node_positions.values()] + x_margin = (max(xs) - min(xs)) * 0.08 + 0.5 + y_margin = (max(ys) - min(ys)) * 0.08 + 0.5 + self.ax_graph.set_xlim(min(xs) - x_margin, max(xs) + x_margin) + self.ax_graph.set_ylim(min(ys) - y_margin, max(ys) + y_margin) + + # ---- Nodes (delegate to GraphVisualizer with colour overrides) ---- + node_facecolors: dict[int, str] = {} + node_edgecolors: dict[int, str] = {} + for node in measured_nodes: + node_facecolors[node] = self.measured_node_color + node_edgecolors[node] = "black" + for node in active_nodes: + node_facecolors[node] = "white" + node_edgecolors[node] = self.active_node_color + + self._graph_visualizer.draw_nodes_role( + self.ax_graph, + self.node_positions, + node_facecolors=node_facecolors, + node_edgecolors=node_edgecolors, + node_size=self.node_size, + ) + + # ---- Labels (drawn locally for dynamic content) ---- + fontsize = self.label_fontsize + + for node in measured_nodes: + if node not in self.node_positions: + continue + x, y = self.node_positions[node] + label_text = str(node) + if node in results: + label_text += f"\nm={results[node]}" + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=fontsize, zorder=3) + + for node in active_nodes: + if node not in self.node_positions: + continue + x, y = self.node_positions[node] + label_text = str(node) + if node in corrections: + label_text += "\n" + "".join(sorted(corrections[node])) + text_color = "blue" if node in corrections else "black" + self.ax_graph.text( + x, y, label_text, ha="center", va="center", fontsize=fontsize, color=text_color, zorder=3 + ) + + self.ax_graph.axis("off") + + except Exception as e: # noqa: BLE001 + traceback.print_exc() + print(f"Error drawing graph: {e}", file=sys.stderr) + + def _update(self, val: float) -> None: + step = int(val) + if step != self.current_step: + self.current_step = step + # Auto-scroll command list to keep current step visible + if step < self._cmd_scroll_offset or step >= self._cmd_scroll_offset + self.command_window_size: + new_offset = max(0, step - self.command_window_size // 2) + self._cmd_scroll_offset = new_offset + if self.cmd_scroll_slider is not None: + self.cmd_scroll_slider.set_val(new_offset) + self._draw_command_list() + self._draw_graph() + self.fig.canvas.draw_idle() + + def _on_cmd_scroll(self, val: float) -> None: + """Handle vertical scroll slider changes.""" + new_offset = int(val) + if new_offset != self._cmd_scroll_offset: + self._cmd_scroll_offset = new_offset + self._draw_command_list() + self.fig.canvas.draw_idle() + + def _prev_step(self, _event: Any) -> None: + if self.current_step > 0 and self.slider is not None: + self.slider.set_val(self.current_step - 1) + + def _next_step(self, _event: Any) -> None: + if self.current_step < self.total_steps and self.slider is not None: + self.slider.set_val(self.current_step + 1) + + def _on_key(self, event: Any) -> None: + if event.key == "right": + self._next_step(None) + elif event.key == "left": + self._prev_step(None) + + def _on_pick(self, event: Any) -> None: + if isinstance(event.artist, Text): + idx = getattr(event.artist, "index", None) + if idx is not None and self.slider is not None: + self.slider.set_val(idx + 1) # Jump to state AFTER the clicked command diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py new file mode 100644 index 00000000..53ed8f63 --- /dev/null +++ b/tests/test_visualization_interactive.py @@ -0,0 +1,467 @@ +"""Tests for the interactive visualization module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from matplotlib.text import Text + +from graphix.command import E, M, N, X, Z +from graphix.measurements import Measurement +from graphix.pattern import Pattern +from graphix.visualization import GraphVisualizer +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +class TestInteractiveGraphVisualizer: + @pytest.fixture + def pattern(self) -> Pattern: + """Fixture to provide a standard pattern for testing.""" + pattern = Pattern(input_nodes=[0, 1]) + pattern.add(N(node=0)) + pattern.add(N(node=1)) + pattern.add(N(node=2)) + pattern.add(E(nodes=(0, 1))) + pattern.add(E(nodes=(1, 2))) + pattern.add(M(node=0, measurement=Measurement.XY(0.5), s_domain={1}, t_domain={2})) + pattern.add(M(node=1, measurement=Measurement.XY(0.0), s_domain={2}, t_domain=set())) + pattern.add(X(node=2, domain={0})) + pattern.add(Z(node=2, domain={1})) + return pattern + + def test_init(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test initialization of the visualizer.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + + # Capture the OpenGraph mock correctly + mock_og_class = mocker.patch("graphix.visualization_interactive.OpenGraph") + mock_og_instance = mock_og_class.return_value + # Ensure infer_pauli_measurements returns a mock (or itself) to support chaining + mock_og_instance.infer_pauli_measurements.return_value = mock_og_instance + + mocker.patch("matplotlib.pyplot.figure") + + # Mock layout generation + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + + assert viz.total_steps == len(pattern) + assert viz.enable_simulation + # Check if get_layout was called + mock_visualizer.assert_called_with(mock_og_instance) # Verify visualizer init with corrected OG + mock_vis_obj.get_layout.assert_called_once() + # Check if node positions are set + assert len(viz.node_positions) == 3 + + # Check if infer_pauli_measurements was called + mock_og_instance.infer_pauli_measurements.assert_called_once() + + def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that layout logic delegates to GraphVisualizer.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + # Return specific positions to verify they are used + expected_pos = {0: (10, 10), 1: (20, 20), 2: (30, 30)} + mock_vis_obj.get_layout.return_value = (expected_pos, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + + # Keys should match the layout output + assert viz.node_positions.keys() == expected_pos.keys() + # Positions are the raw layout scaled by node_distance (default 1, 1) + for node, (ex, ey) in expected_pos.items(): + ax, ay = viz.node_positions[node] + assert ax == pytest.approx(ex * viz.node_distance[0]) + assert ay == pytest.approx(ey * viz.node_distance[1]) + + def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test graph state update with simulation enabled.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + # Mock simulation backend + backend_instance = mock_backend.return_value + backend_instance.measure.return_value = 1 + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=True) + + # Update to the end of the pattern + active, measured, _, _, results = viz._update_graph_state(len(pattern)) + + # Basic Checks + # Node 0 and 1 are measured + assert 0 in measured + assert 1 in measured + # Node 2 is active + assert 2 in active + # Results should be populated (since we mocked measure return value) + assert results[0] == 1 + assert results[1] == 1 + # Check if backend methods were called + backend_instance.add_nodes.assert_called() + backend_instance.entangle_nodes.assert_called() + assert backend_instance.measure.call_count == 2 + assert backend_instance.correct_byproduct.call_count == 2 + + # Manually trigger update to test drawing logic for measured/active nodes + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz.slider.val = len(pattern) + viz._update(len(pattern)) + + # Labels are drawn locally: check that text was called + assert viz.ax_graph.text.call_count > 0 + + def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test graph state update with simulation disabled.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=False) + + # Update to the end of the pattern + active, measured, _, _, results = viz._update_graph_state(len(pattern)) + + # Basic Checks + # Node 0 and 1 are measured (topology tracking works without sim) + assert 0 in measured + assert 1 in measured + # Node 2 is active + assert 2 in active + # Results should be empty as simulation is disabled + assert results == {} + + # Manually trigger update to test drawing logic without simulation + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz.slider.val = len(pattern) + viz._update(len(pattern)) + + # Ensure text is drawn (commands, node labels) + assert viz.ax_commands.text.call_count > 0 + assert viz.ax_graph.text.call_count > 0 + + def test_measurement_result_label_format(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that measurement result labels use the 'm=' prefix to avoid ambiguity.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + backend_instance = mock_backend.return_value + backend_instance.measure.return_value = 1 + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=True) + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + + # Execute all commands so that nodes 0 and 1 are measured + viz._update(len(pattern)) + + # Collect all text calls on ax_graph + text_calls = viz.ax_graph.text.call_args_list + label_strings = [str(call.args[2]) if len(call.args) >= 3 else "" for call in text_calls] + + # At least one label should contain 'm=' (the measurement result prefix) + assert any("m=" in label for label in label_strings), ( + f"Expected 'm=' in at least one node label, got: {label_strings}" + ) + # None of the labels should use the old ambiguous '\n=' format + assert not any(label.endswith(("\n=1", "\n=0")) for label in label_strings), ( + f"Found ambiguous '=' label format in: {label_strings}" + ) + + def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test step navigation methods.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + # Mock slider + viz.slider = MagicMock() + viz.total_steps = 10 + viz.current_step = 5 + + # Test Prev + viz._prev_step(None) + viz.slider.set_val.assert_called_with(4) + + # Test Next + viz._next_step(None) + viz.slider.set_val.assert_called_with(6) + + # Test Boundary Prev + viz.current_step = 0 + viz.slider.reset_mock() + viz._prev_step(None) + viz.slider.set_val.assert_not_called() + + # Test Boundary Next + viz.current_step = 10 + viz.slider.reset_mock() + viz._next_step(None) + viz.slider.set_val.assert_not_called() + + def test_visualize(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test the main visualize method (smoke test).""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_show = mocker.patch("matplotlib.pyplot.show") + mocker.patch("graphix.visualization_interactive.Slider") # Mock Slider to avoid matplotlib validation + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.visualize() + + # Should show plot + mock_show.assert_called_once() + # Should initialize axes + assert viz.ax_commands is not None + assert viz.ax_graph is not None + assert viz.slider is not None + assert viz.btn_next is not None + assert viz.btn_prev is not None + + def test_interaction_events(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test interaction event handlers.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.slider = MagicMock() + + # Test key press + valid_key_event = MagicMock() + valid_key_event.key = "right" + # Mock return value of _next_step side effect + mock_next = mocker.patch.object(viz, "_next_step") + viz._on_key(valid_key_event) + mock_next.assert_called_once() + + valid_key_event.key = "left" + mock_prev = mocker.patch.object(viz, "_prev_step") + viz._on_key(valid_key_event) + mock_prev.assert_called_once() + + def test_z_correction_initialization(self, mocker: MagicMock) -> None: + """Test tracking of Z corrections specifically to cover Z initialization.""" + # Create a pattern with a Z correction on a fresh node + pattern = Pattern(input_nodes=[0]) + pattern.add(N(node=0)) + pattern.add(Z(node=0, domain=set())) + + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=False) + + # Trigger update to process the Z command + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz._update(len(pattern)) + + # Test pick event (clicking on command list) + # We need a real Text object (or a mock that spec=Text) because _on_pick uses isinstance + + mock_artist = MagicMock(spec=Text) + mock_artist.index = 5 + pick_event = MagicMock() + pick_event.artist = mock_artist + + # Should set slider to index + 1 (highlight executed commands up to that point) + viz._on_pick(pick_event) + viz.slider.set_val.assert_called_with(6) + + def test_draw_edges_delegates(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that _draw_graph delegates edge drawing to GraphVisualizer.draw_edges.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + viz.slider = MagicMock() + + # Step 5: entanglement E(0, 1) and E(1, 2), no measurements yet + viz._update(5) + + # draw_edges should have been called with edge_subset + mock_vis_obj.draw_edges.assert_called() + call_kwargs = mock_vis_obj.draw_edges.call_args + assert "edge_subset" in call_kwargs.kwargs + + def test_draw_nodes_delegates(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that _draw_graph delegates node drawing to GraphVisualizer.draw_nodes_role.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + viz.slider = MagicMock() + + # Step after all N + E commands (5 commands) + measurements + viz._update(len(pattern)) + + # draw_nodes_role should have been called with colour overrides + mock_vis_obj.draw_nodes_role.assert_called() + call_kwargs = mock_vis_obj.draw_nodes_role.call_args + assert "node_facecolors" in call_kwargs.kwargs + assert "node_edgecolors" in call_kwargs.kwargs + + def test_draw_graph_exception_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test the exception handling in _draw_graph.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + + # Force an exception during state update + mocker.patch.object(viz, "_update_graph_state", side_effect=ValueError("Test Exception")) + # Mock traceback to avoid cluttering test output + mocker.patch("traceback.print_exc") + + # This should not raise but log/print + viz._draw_graph() + + +class TestGraphVisualizerSharedAPI: + """Tests for the shared drawing API exposed by GraphVisualizer.""" + + def test_get_label_fontsize_small_nodes(self) -> None: + """Font size should equal base_size for small node numbers.""" + assert GraphVisualizer.get_label_fontsize(0) == 12 + assert GraphVisualizer.get_label_fontsize(99) == 12 + + def test_get_label_fontsize_large_nodes(self) -> None: + """Font size should shrink for large node numbers.""" + result = GraphVisualizer.get_label_fontsize(100) + assert result < 12 + assert result >= 7 + + def test_get_label_fontsize_custom_base(self) -> None: + """Font size should use the custom base_size.""" + assert GraphVisualizer.get_label_fontsize(0, base_size=10) == 10 + result = GraphVisualizer.get_label_fontsize(1000, base_size=10) + assert result >= 7 + assert result < 10 + + def test_draw_nodes_role_with_overrides(self) -> None: + """Test draw_nodes_role applies per-node colour overrides.""" + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0, 1, 2] + mock_og.input_nodes = [0] + mock_og.output_nodes = [2] + mock_og.measurements = {0: MagicMock(), 1: MagicMock(), 2: MagicMock()} + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0)} + + vis.draw_nodes_role( + ax, + pos, + node_facecolors={0: "yellow", 1: "pink"}, + node_edgecolors={0: "green"}, + ) + + assert ax.scatter.call_count == 3 + # Check overrides were applied by inspecting scatter kwargs + scatter_calls = ax.scatter.call_args_list + # Node 0: facecolors=yellow (override), edgecolors=green (override) + assert scatter_calls[0].kwargs["facecolors"] == "yellow" + assert scatter_calls[0].kwargs["edgecolors"] == "green" + # Node 1: facecolors=pink (override), edgecolors=black (default) + assert scatter_calls[1].kwargs["facecolors"] == "pink" + assert scatter_calls[1].kwargs["edgecolors"] == "black" + # Node 2: facecolors=lightgray (output role), edgecolors=black (default) + assert scatter_calls[2].kwargs["facecolors"] == "lightgray" + assert scatter_calls[2].kwargs["edgecolors"] == "black" + + def test_draw_edges_with_subset(self) -> None: + """Test draw_edges with edge_subset only draws specified edges.""" + mock_og = MagicMock() + mock_og.graph.edges.return_value = [(0, 1), (1, 2), (2, 3)] + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0), 3: (3.0, 0.0)} + + # Draw only a subset + vis.draw_edges(ax, pos, edge_subset=[(0, 1), (2, 3)]) + assert ax.plot.call_count == 2 + + def test_draw_edges_without_subset(self) -> None: + """Test draw_edges without edge_subset draws all edges.""" + mock_og = MagicMock() + mock_og.graph.edges.return_value = [(0, 1), (1, 2), (2, 3)] + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0), 3: (3.0, 0.0)} + + vis.draw_edges(ax, pos) + assert ax.plot.call_count == 3