From dc78816e895992220172917df3461f3e363a174c Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Sun, 15 Feb 2026 22:57:39 -0600 Subject: [PATCH 01/12] feat: Add interactive visualization for MBQC patterns with step-by-step execution. --- examples/interactive_viz_demo.py | 30 +++ graphix/visualization_interactive.py | 329 +++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 examples/interactive_viz_demo.py create mode 100644 graphix/visualization_interactive.py diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py new file mode 100644 index 00000000..3f50159f --- /dev/null +++ b/examples/interactive_viz_demo.py @@ -0,0 +1,30 @@ +from graphix.pattern import Pattern +from graphix.command import N, M, E, X, Z +from graphix.fundamentals import Plane +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from graphix.visualization_interactive import InteractiveGraphVisualizer + +def main(): + # optimized pattern for QFT + # Create a simple pattern manually for demonstration + p = Pattern(input_nodes=[0, 1]) + p.add(N(node=2)) + p.add(E(nodes=(0, 2))) + p.add(E(nodes=(1, 2))) + p.add(M(node=0, plane=Plane.XY, angle=0.5)) + p.add(M(node=1, plane=Plane.XY, angle=0.25)) + p.add(X(node=2, domain={0, 1})) + p.add(Z(node=2, domain={0})) + + # Or standardization to make it interesting + # p.standardize() + + print("Pattern created with", len(p), "commands.") + + viz = InteractiveGraphVisualizer(p) + viz.visualize() + +if __name__ == "__main__": + main() diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py new file mode 100644 index 00000000..52bbf8c8 --- /dev/null +++ b/graphix/visualization_interactive.py @@ -0,0 +1,329 @@ +"""Interactive visualization for MBQC patterns.""" + +from __future__ import annotations + +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +from matplotlib.widgets import Button, Slider +from graphix.pattern import Pattern, CommandKind +from graphix.pretty_print import OutputFormat, command_to_str +from graphix.visualization import GraphVisualizer + +class InteractiveGraphVisualizer: + """ + Interactive visualizer for MBQC patterns. + + This tool allows users to visualize the execution of a measurement-based quantum computing + pattern step-by-step. It displays the command sequence and the corresponding graph state, + updating dynamically as the user navigates through the commands. + + Attributes + ---------- + pattern : Pattern + The MBQC pattern to be visualized. + node_distance : tuple[float, float] + Distance multiplication factor between nodes for x and y directions. + total_steps : int + Total number of commands in the pattern. + current_step : int + Current step index in the command sequence. + """ + + def __init__(self, pattern: Pattern, node_distance: tuple[float, float] = (1, 1)): + """ + Construct an InteractiveGraphVisualizer. + + Parameters + ---------- + pattern : Pattern + The MBQC pattern to be visualized. + node_distance : tuple[float, float], optional + Distance scale for the graph layout, by default (1, 1). + """ + self.pattern = pattern + self.node_distance = node_distance + self.total_steps = len(pattern) + self.current_step = 0 + + # Extract graph structure for layout + self.nodes = pattern.input_nodes.copy() if pattern.input_nodes else [] + self.edges = [] + self.node_positions = {} + self.vis = None # GraphVisualizer instance for layout calculation + + # Pre-calculate graph layout + self._prepare_layout() + + def _prepare_layout(self) -> None: + """ + Pre-calculate the layout using the standard GraphVisualizer. + + It constructs a full graph from the pattern commands and uses GraphVisualizer's + layout algorithms (flow, gflow, or spring) to determine node positions. + """ + # We need to construct the full graph to get the layout + # This is a bit of a workaround because GraphVisualizer expects a graph + G = nx.Graph() + for cmd in self.pattern: + if cmd.kind == CommandKind.N: + G.add_node(cmd.node) + elif cmd.kind == CommandKind.E: + G.add_edge(cmd.nodes[0], cmd.nodes[1]) + + # Use GraphVisualizer to determine positions based on flow/structure + # We create a dummy visualizer just for the layout + self.vis = GraphVisualizer(G, self.pattern.input_nodes, self.pattern.output_nodes) + + # Try to find flow/gflow for better layout + # This logic mimics visualize_from_pattern but just gets positions + try: + from graphix.optimization import StandardizedPattern + pattern_std = StandardizedPattern.from_pattern(self.pattern) + try: + flow = pattern_std.extract_causal_flow() + self.node_positions = self.vis.place_flow(flow) + except: + try: + gflow = pattern_std.extract_gflow() + self.node_positions = self.vis.place_gflow(gflow) + except: + self.node_positions = self.vis.place_without_structure() + except Exception: + self.node_positions = self.vis.place_without_structure() + + # Apply scaling + self.node_positions = {k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) + for k, v in self.node_positions.items()} + + def visualize(self) -> None: + """ + Launch the interactive visualization window. + + This method sets up the Matplotlib figure, axes, slider, and buttons, + and starts the event loop. + """ + self.fig = plt.figure(figsize=(15, 8)) + + # Grid layout: Command list on left, Graph on right + # We use a bit of manual placement to fit the slider and list nicely + self.ax_commands = self.fig.add_axes([0.05, 0.2, 0.2, 0.7]) # [left, bottom, width, height] + self.ax_graph = self.fig.add_axes([0.3, 0.2, 0.65, 0.7]) + self.ax_slider = self.fig.add_axes([0.3, 0.05, 0.5, 0.03]) + self.ax_prev = self.fig.add_axes([0.2, 0.05, 0.04, 0.04]) + self.ax_next = self.fig.add_axes([0.85, 0.05, 0.04, 0.04]) + + # Turn off axes for command list + self.ax_commands.axis('off') + + # Setup Slider + self.slider = Slider( + self.ax_slider, "Step", 0, self.total_steps, + valinit=0, valstep=1, color="lightblue" + ) + self.slider.on_changed(self._update) + + # Setup Buttons + self.btn_prev = Button(self.ax_prev, '<') + self.btn_next = Button(self.ax_next, '>') + self.btn_prev.on_clicked(self._prev_step) + self.btn_next.on_clicked(self._next_step) + + # Initial Draw + self._update(0) + self._draw_command_list() + + # Events + self.fig.canvas.mpl_connect('pick_event', self._on_pick) + self.fig.canvas.mpl_connect('key_press_event', self._on_key) + + plt.show() + + def _on_key(self, event) -> None: + """Handle key press events for navigation.""" + if event.key == 'right': + self._next_step(None) + elif event.key == 'left': + self._prev_step(None) + + def _draw_command_list(self) -> None: + """ + Draw the list of commands in the left panel. + + Only draws a window of commands around current_step to maintain performance and readability. + Executed commands are shown in green, the current/next command in red, and future commands in black. + """ + self.ax_commands.clear() + self.ax_commands.axis('off') + self.ax_commands.set_title(f"Commands ({self.total_steps})") + + # Define window size + window_size = 20 + start = max(0, int(self.current_step) - window_size // 2) + end = min(self.total_steps, start + window_size) + + # Adjust start if we are near the end + if end == self.total_steps: + start = max(0, end - window_size) + + cmds = self.pattern[start:end] + + for i, cmd in enumerate(cmds): + abs_idx = start + i + text = f"{abs_idx}: {command_to_str(cmd, OutputFormat.Unicode)}" + + color = "black" + weight = "normal" + if abs_idx < self.current_step: + color = "green" # Executed (Updated to green as per issue suggestion) + elif abs_idx == self.current_step: + color = "red" # Current/Next to be executed + weight = "bold" + else: + color = "black" # Future + + # clickable text + t = self.ax_commands.text( + 0, 1.0 - (i / window_size), text, + transform=self.ax_commands.transAxes, + fontsize=10, color=color, weight=weight, + picker=True + ) + # Store index with artist for picking + t.index = abs_idx + + def _update_graph_state(self, step: int) -> tuple[set, set, set, dict]: + """ + Calculate the state of nodes and edges at a given step. + + Parameters + ---------- + step : int + The command index up to which the pattern is executed. + + Returns + ------- + active_nodes : set + Set of indices of nodes that are initialized and alive (not measured). + measured_nodes : set + Set of indices of nodes that have been measured. + active_edges : set + Set of edges (tuples) that have been created. + corrections : dict + Dictionary mapping node indices to sets of Pauli corrections ('X', 'Z'). + """ + active_nodes = set() + measured_nodes = set() + active_edges = set() + corrections = {} # node -> set of 'X', 'Z' + + # Replay commands up to 'step' + for i in range(int(step)): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + active_nodes.add(cmd.node) + elif cmd.kind == CommandKind.M: + if cmd.node in active_nodes: + active_nodes.remove(cmd.node) + measured_nodes.add(cmd.node) + elif cmd.kind == CommandKind.E: + active_edges.add(cmd.nodes) + elif cmd.kind == CommandKind.X: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add('X') + elif cmd.kind == CommandKind.Z: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add('Z') + + return active_nodes, measured_nodes, active_edges, corrections + + def _draw_graph(self) -> None: + """ + Render the graph in the right panel based on the current state. + + Highlights active nodes (white/red), measured nodes (grey), and active edges. + Displays badges for Pauli corrections. + """ + self.ax_graph.clear() + + active_nodes, measured_nodes, active_edges, corrections = self._update_graph_state(self.current_step) + + # If no nodes, just return + if not self.node_positions: + return + + # Draw Edges + for u, v in active_edges: + if u in self.node_positions and v in self.node_positions: + pos_u = self.node_positions[u] + pos_v = self.node_positions[v] + self.ax_graph.plot([pos_u[0], pos_v[0]], [pos_u[1], pos_v[1]], 'k-', zorder=1) + + # Draw Nodes + for node, pos in self.node_positions.items(): + # Determine visual properties + facecolor = 'none' + edgecolor = 'lightgray' + linestyle = ':' + alpha = 1.0 + label_color = 'k' + + if node in active_nodes: + facecolor = 'white' + edgecolor = 'red' + linestyle = '-' + elif node in measured_nodes: + facecolor = 'lightgray' + edgecolor = 'gray' + linestyle = '-' + + self.ax_graph.scatter(*pos, s=300, c=facecolor, edgecolors=edgecolor, linestyle=linestyle, alpha=alpha, zorder=2) + self.ax_graph.text(*pos, str(node), ha='center', va='center', color=label_color, zorder=3) + + # Draw corrections badges + if node in corrections: + corr_str = "".join(sorted(corrections[node])) + # Offset position for badge + badge_pos = (pos[0] + 0.1, pos[1] + 0.1) + self.ax_graph.text(*badge_pos, corr_str, fontsize=8, color='blue', weight='bold', zorder=4) + + self.ax_graph.set_aspect('equal') + self.ax_graph.axis('off') + + def _update(self, val) -> None: + """ + Update the visualization to a specific step. + + Parameters + ---------- + val : float + The slider value representing the step index. + """ + self.current_step = int(val) + self._draw_command_list() + self._draw_graph() + self.fig.canvas.draw_idle() + + def _prev_step(self, event) -> None: + """Callback for 'Previous' button.""" + if self.current_step > 0: + self.slider.set_val(self.current_step - 1) + + def _next_step(self, event) -> None: + """Callback for 'Next' button.""" + if self.current_step < self.total_steps: + self.slider.set_val(self.current_step + 1) + + def _on_pick(self, event) -> None: + """ + Handle pick events on the command list. + + Clicking a command sets the current step to immediately after that command. + """ + if isinstance(event.artist, plt.Text): + idx = getattr(event.artist, 'index', None) + if idx is not None: + # Set step to idx + 1 so the clicked command is executed + self.slider.set_val(idx + 1) From a6d487d1e927d52f6039ce5d070b34a101ee095f Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Sun, 15 Feb 2026 23:32:20 -0600 Subject: [PATCH 02/12] feat: Add interactive visualization for MBQC patterns with a demo. --- examples/interactive_viz_demo.py | 1 + graphix/visualization_interactive.py | 555 +++++++++++++++++---------- 2 files changed, 343 insertions(+), 213 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index 3f50159f..3757280f 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -22,6 +22,7 @@ def main(): # p.standardize() print("Pattern created with", len(p), "commands.") + print("Launching interactive visualization with real-time simulation...") viz = InteractiveGraphVisualizer(p) viz.visualize() diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index 52bbf8c8..e5bfeff1 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -2,328 +2,457 @@ from __future__ import annotations +import sys +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt import networkx as nx import numpy as np from matplotlib.widgets import Button, Slider -from graphix.pattern import Pattern, CommandKind + +from graphix.clifford import Clifford +from graphix.command import CommandKind, MeasureUpdate +from graphix.measurements import Measurement +from graphix.pattern import Pattern from graphix.pretty_print import OutputFormat, command_to_str +from graphix.sim.statevec import StatevectorBackend from graphix.visualization import GraphVisualizer +if TYPE_CHECKING: + from collections.abc import Collection + + class InteractiveGraphVisualizer: """ - Interactive visualizer for MBQC patterns. - - This tool allows users to visualize the execution of a measurement-based quantum computing - pattern step-by-step. It displays the command sequence and the corresponding graph state, - updating dynamically as the user navigates through the commands. + Interactive visualization tool for MBQC patterns. + + This visualizer provides a matplotlib-based GUI to step through the execution + of an MBQC pattern. It displays the sequence of commands and the corresponding + state of the graph state, including real-time simulation of measurement outcomes. Attributes ---------- pattern : Pattern - The MBQC pattern to be visualized. + The MBQC pattern to visualize. node_distance : tuple[float, float] - Distance multiplication factor between nodes for x and y directions. - total_steps : int - Total number of commands in the pattern. - current_step : int - Current step index in the command sequence. + Scale factors (x, y) for the node positions in the graph layout. """ - - def __init__(self, pattern: Pattern, node_distance: tuple[float, float] = (1, 1)): + + def __init__(self, pattern: Pattern, node_distance: tuple[float, float] = (1, 1)) -> None: """ - Construct an InteractiveGraphVisualizer. + Initialize the interactive visualizer. Parameters ---------- pattern : Pattern - The MBQC pattern to be visualized. + The MBQC pattern to visualize. node_distance : tuple[float, float], optional - Distance scale for the graph layout, by default (1, 1). + Scale factors for x and y coordinates of the graph nodes. Defaults to (1, 1). """ self.pattern = pattern self.node_distance = node_distance - self.total_steps = len(pattern) - self.current_step = 0 - - # Extract graph structure for layout - self.nodes = pattern.input_nodes.copy() if pattern.input_nodes else [] - self.edges = [] - self.node_positions = {} - self.vis = None # GraphVisualizer instance for layout calculation - - # Pre-calculate graph layout + + # Prepare graph layout using Graphix's visualizer or fallbacks self._prepare_layout() + # Figure setup + self.fig = plt.figure(figsize=(15, 8)) + + # Grid layout: Command list on left, Graph on right + # Layout optimized to prevent overlap: + # Commands: Left 2% to 30% + # Graph: Left 40% to 98% + self.ax_commands = self.fig.add_axes([0.02, 0.2, 0.28, 0.7]) # [left, bottom, width, height] + self.ax_graph = self.fig.add_axes([0.4, 0.2, 0.58, 0.7]) + self.ax_slider = self.fig.add_axes([0.4, 0.05, 0.5, 0.03]) + self.ax_prev = self.fig.add_axes([0.3, 0.05, 0.04, 0.04]) + self.ax_next = self.fig.add_axes([0.92, 0.05, 0.04, 0.04]) + + # Turn off axes frame for command list and graph + self.ax_commands.axis("off") + self.ax_graph.axis("off") # Start hidden to avoid "square" artifact + + # Interaction state + self.current_step = 0 + self.total_steps = len(pattern) + + # ... (other methods) ... + + def _draw_graph(self) -> None: + """Render the graph state on the right panel.""" + try: + self.ax_graph.clear() + + # Get current state from simulation + active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( + self.current_step + ) + + # Draw edges + for u, v in active_edges: + x1, y1 = self.node_positions[u] + x2, y2 = self.node_positions[v] + self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) + + # Draw nodes + # 1. Measured nodes (grey, with result text) + for node in measured_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) + self.ax_graph.add_patch(circle) + + label_text = str(node) + # Show measurement outcome if available + if node in results: + label_text += f"\n={results[node]}" + + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) + + # 2. Active nodes (white with colored edge, with correction text) + for node in active_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle( + (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 + ) + self.ax_graph.add_patch(circle) + + label_text = str(node) + # Show accumulated internal corrections + if node in corrections: + label_text += "\n" + "".join(sorted(corrections[node])) + + color = "black" + if node in corrections: + color = "blue" # Highlight corrected nodes + + self.ax_graph.text( + x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 + ) + + # Set aspect close to equal and hide axes + self.ax_graph.set_aspect("equal") + self.ax_graph.set_xlim(self.x_limits) + self.ax_graph.set_ylim(self.y_limits) + self.ax_graph.axis("off") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error drawing graph: {e}", file=sys.stderr) + # Matplotlib widgets placeholders (initialized in visualize) + self.slider: Slider | None = None + self.btn_prev: Button | None = None + self.btn_next: Button | None = None + def _prepare_layout(self) -> None: - """ - Pre-calculate the layout using the standard GraphVisualizer. - - It constructs a full graph from the pattern commands and uses GraphVisualizer's - layout algorithms (flow, gflow, or spring) to determine node positions. - """ - # We need to construct the full graph to get the layout - # This is a bit of a workaround because GraphVisualizer expects a graph - G = nx.Graph() + """Calculate node positions for the graph.""" + # Build full graph to determine positions + g = nx.Graph() for cmd in self.pattern: if cmd.kind == CommandKind.N: - G.add_node(cmd.node) + g.add_node(cmd.node) elif cmd.kind == CommandKind.E: - G.add_edge(cmd.nodes[0], cmd.nodes[1]) - + g.add_edge(cmd.nodes[0], cmd.nodes[1]) + # Use GraphVisualizer to determine positions based on flow/structure - # We create a dummy visualizer just for the layout - self.vis = GraphVisualizer(G, self.pattern.input_nodes, self.pattern.output_nodes) - - # Try to find flow/gflow for better layout - # This logic mimics visualize_from_pattern but just gets positions + vis = GraphVisualizer(g, self.pattern.input_nodes, self.pattern.output_nodes) + + # Try to find flow/gflow for better layout, fallback to spring layout try: from graphix.optimization import StandardizedPattern + pattern_std = StandardizedPattern.from_pattern(self.pattern) try: flow = pattern_std.extract_causal_flow() - self.node_positions = self.vis.place_flow(flow) - except: + self.node_positions = vis.place_flow(flow) + except Exception: try: gflow = pattern_std.extract_gflow() - self.node_positions = self.vis.place_gflow(gflow) - except: - self.node_positions = self.vis.place_without_structure() + self.node_positions = vis.place_gflow(gflow) + except Exception: + self.node_positions = vis.place_without_structure() except Exception: - self.node_positions = self.vis.place_without_structure() - + self.node_positions = vis.place_without_structure() + # Apply scaling - self.node_positions = {k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) - for k, v in self.node_positions.items()} + self.node_positions = { + k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() + } + + # Determine fixed bounds for the graph to prevent autoscaling issues + all_x = [pos[0] for pos in self.node_positions.values()] + all_y = [pos[1] for pos in self.node_positions.values()] + margin = 0.5 + self.x_limits = (min(all_x) - margin, max(all_x) + margin) + self.y_limits = (min(all_y) - margin, max(all_y) + margin) def visualize(self) -> None: - """ - Launch the interactive visualization window. - - This method sets up the Matplotlib figure, axes, slider, and buttons, - and starts the event loop. - """ - self.fig = plt.figure(figsize=(15, 8)) - - # Grid layout: Command list on left, Graph on right - # We use a bit of manual placement to fit the slider and list nicely - self.ax_commands = self.fig.add_axes([0.05, 0.2, 0.2, 0.7]) # [left, bottom, width, height] - self.ax_graph = self.fig.add_axes([0.3, 0.2, 0.65, 0.7]) - self.ax_slider = self.fig.add_axes([0.3, 0.05, 0.5, 0.03]) - self.ax_prev = self.fig.add_axes([0.2, 0.05, 0.04, 0.04]) - self.ax_next = self.fig.add_axes([0.85, 0.05, 0.04, 0.04]) - - # Turn off axes for command list - self.ax_commands.axis('off') - - # Setup Slider + """Launch the interactive visualization window.""" + # Initial draw + self._draw_command_list() + self._draw_graph() + self._update(0) + + # Slider config self.slider = Slider( - self.ax_slider, "Step", 0, self.total_steps, - valinit=0, valstep=1, color="lightblue" + self.ax_slider, "Step", 0, self.total_steps, valinit=0, valstep=1, color="lightblue" ) self.slider.on_changed(self._update) - - # Setup Buttons - self.btn_prev = Button(self.ax_prev, '<') - self.btn_next = Button(self.ax_next, '>') + + # Buttons config + self.btn_prev = Button(self.ax_prev, "<") self.btn_prev.on_clicked(self._prev_step) + + self.btn_next = Button(self.ax_next, ">") self.btn_next.on_clicked(self._next_step) - - # Initial Draw - self._update(0) - self._draw_command_list() - - # Events - self.fig.canvas.mpl_connect('pick_event', self._on_pick) - self.fig.canvas.mpl_connect('key_press_event', self._on_key) - - plt.show() - def _on_key(self, event) -> None: - """Handle key press events for navigation.""" - if event.key == 'right': - self._next_step(None) - elif event.key == 'left': - self._prev_step(None) + # Key events + self.fig.canvas.mpl_connect("key_press_event", self._on_key) + + # Pick events for command list + self.fig.canvas.mpl_connect("pick_event", self._on_pick) + + plt.show() def _draw_command_list(self) -> None: - """ - Draw the list of commands in the left panel. - - Only draws a window of commands around current_step to maintain performance and readability. - Executed commands are shown in green, the current/next command in red, and future commands in black. - """ + """Render the list of commands in the left panel.""" self.ax_commands.clear() - self.ax_commands.axis('off') - self.ax_commands.set_title(f"Commands ({self.total_steps})") - - # Define window size - window_size = 20 + self.ax_commands.axis("off") + self.ax_commands.set_title(f"Commands ({self.total_steps})", loc="left") + + # Windowing logic to show relevant commands + window_size = 30 start = max(0, int(self.current_step) - window_size // 2) end = min(self.total_steps, start + window_size) - - # Adjust start if we are near the end + if end == self.total_steps: start = max(0, end - window_size) - + cmds = self.pattern[start:end] - + for i, cmd in enumerate(cmds): abs_idx = start + i - text = f"{abs_idx}: {command_to_str(cmd, OutputFormat.Unicode)}" - + text_str = f"{abs_idx}: {command_to_str(cmd, OutputFormat.Unicode)}" + color = "black" weight = "normal" if abs_idx < self.current_step: - color = "green" # Executed (Updated to green as per issue suggestion) + color = "green" elif abs_idx == self.current_step: - color = "red" # Current/Next to be executed + color = "red" weight = "bold" - else: - color = "black" # Future - - # clickable text - t = self.ax_commands.text( - 0, 1.0 - (i / window_size), text, + + # Position text from top to bottom + y_pos = 1.0 - (i + 1) * (1.0 / (window_size + 2)) + + text_obj = self.ax_commands.text( + 0.05, + y_pos, + text_str, + color=color, + weight=weight, + fontsize=10, transform=self.ax_commands.transAxes, - fontsize=10, color=color, weight=weight, - picker=True + picker=True, ) # Store index with artist for picking - t.index = abs_idx + text_obj.index = abs_idx - def _update_graph_state(self, step: int) -> tuple[set, set, set, dict]: + def _update_graph_state( + self, step: int + ) -> tuple[set, set, set, dict[int, set[str]], dict[int, int]]: """ - Calculate the state of nodes and edges at a given step. + Calculate the state of the graph by simulating the pattern up to `step`. + + This method performs a full re-simulation using `StatevectorBackend` + to ensure deterministic measurement outcomes and correct adaptive behavior. Parameters ---------- step : int - The command index up to which the pattern is executed. + Current execution step. Returns ------- active_nodes : set - Set of indices of nodes that are initialized and alive (not measured). + Nodes currently in the graph (active). measured_nodes : set - Set of indices of nodes that have been measured. + Nodes that have been measured. active_edges : set - Set of edges (tuples) that have been created. + Edges currently in the graph. corrections : dict - Dictionary mapping node indices to sets of Pauli corrections ('X', 'Z'). + Accumulated Pauli corrections ('X', 'Z') for each node. + results : dict + Measurement outcomes (0 or 1) for measured nodes. """ - active_nodes = set() + # Initialize sets + active_nodes = set(self.pattern.input_nodes) measured_nodes = set() active_edges = set() - corrections = {} # node -> set of 'X', 'Z' + corrections: dict[int, set[str]] = {} + + # Simulation setup + backend = StatevectorBackend() + + # Initialize input nodes in the backend + if self.pattern.input_nodes: + backend.add_nodes(self.pattern.input_nodes) + + # Fixed seed for deterministic scrubbing + rng = np.random.default_rng(42) + results: dict[int, int] = {} - # Replay commands up to 'step' + # Replay commands for i in range(int(step)): cmd = self.pattern[i] + if cmd.kind == CommandKind.N: active_nodes.add(cmd.node) + backend.add_nodes([cmd.node], data=cmd.state) + elif cmd.kind == CommandKind.M: if cmd.node in active_nodes: active_nodes.remove(cmd.node) measured_nodes.add(cmd.node) + + # --- Adaptive Measurement Logic (Feedforward) --- + # Calculate s and t signals from previous measurement results + if cmd.s_domain: + s_signal = sum(results.get(j, 0) for j in cmd.s_domain) + else: + s_signal = 0 + if cmd.t_domain: + t_signal = sum(results.get(j, 0) for j in cmd.t_domain) + else: + t_signal = 0 + + s_bool = s_signal % 2 == 1 + t_bool = t_signal % 2 == 1 + + # Compute the updated angle and plane based on signals + measure_update = MeasureUpdate.compute(cmd.plane, s_bool, t_bool, Clifford.I) + + new_angle = cmd.angle * measure_update.coeff + measure_update.add_term + new_plane = measure_update.new_plane + + # Execute measurement on the backend using the adapted measurement + measurement = Measurement(new_angle, new_plane) + result = backend.measure(cmd.node, measurement, rng=rng) + results[cmd.node] = result + elif cmd.kind == CommandKind.E: active_edges.add(cmd.nodes) - elif cmd.kind == CommandKind.X: - if cmd.node not in corrections: - corrections[cmd.node] = set() - corrections[cmd.node].add('X') - elif cmd.kind == CommandKind.Z: - if cmd.node not in corrections: - corrections[cmd.node] = set() - corrections[cmd.node].add('Z') - - return active_nodes, measured_nodes, active_edges, corrections + # Apply entanglement in simulation + backend.entangle_nodes(cmd.nodes) + + elif cmd.kind in (CommandKind.X, CommandKind.Z): + # Apply Pauli corrections conditionally + do_op = True + if cmd.domain: + do_op = sum(results.get(j, 0) for j in cmd.domain) % 2 == 1 + + if do_op: + backend.correct_byproduct(cmd) + # Visual tracking of corrections + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add(cmd.kind.name) + + # Note: C, S, T, etc. are not explicitly visualized but exist in backend if supported. + # StatevectorBackend handles Clifford logic internally if pattern is standardized, + # but visualizer focuses on MBQC core set {N, M, E, X, Z}. + + return active_nodes, measured_nodes, active_edges, corrections, results def _draw_graph(self) -> None: - """ - Render the graph in the right panel based on the current state. - - Highlights active nodes (white/red), measured nodes (grey), and active edges. - Displays badges for Pauli corrections. - """ + """Render the graph state on the right panel.""" self.ax_graph.clear() - active_nodes, measured_nodes, active_edges, corrections = self._update_graph_state(self.current_step) - - # If no nodes, just return - if not self.node_positions: - return + # Get current state from simulation + active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( + self.current_step + ) - # Draw Edges + # Draw edges for u, v in active_edges: - if u in self.node_positions and v in self.node_positions: - pos_u = self.node_positions[u] - pos_v = self.node_positions[v] - self.ax_graph.plot([pos_u[0], pos_v[0]], [pos_u[1], pos_v[1]], 'k-', zorder=1) + x1, y1 = self.node_positions[u] + x2, y2 = self.node_positions[v] + self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) + + # Draw nodes + # 1. Measured nodes (grey, with result text) + for node in measured_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) + self.ax_graph.add_patch(circle) - # Draw Nodes - for node, pos in self.node_positions.items(): - # Determine visual properties - facecolor = 'none' - edgecolor = 'lightgray' - linestyle = ':' - alpha = 1.0 - label_color = 'k' - - if node in active_nodes: - facecolor = 'white' - edgecolor = 'red' - linestyle = '-' - elif node in measured_nodes: - facecolor = 'lightgray' - edgecolor = 'gray' - linestyle = '-' + label_text = str(node) + # Show measurement outcome if available + if node in results: + label_text += f"\n={results[node]}" - self.ax_graph.scatter(*pos, s=300, c=facecolor, edgecolors=edgecolor, linestyle=linestyle, alpha=alpha, zorder=2) - self.ax_graph.text(*pos, str(node), ha='center', va='center', color=label_color, zorder=3) - - # Draw corrections badges - if node in corrections: - corr_str = "".join(sorted(corrections[node])) - # Offset position for badge - badge_pos = (pos[0] + 0.1, pos[1] + 0.1) - self.ax_graph.text(*badge_pos, corr_str, fontsize=8, color='blue', weight='bold', zorder=4) + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) - self.ax_graph.set_aspect('equal') - self.ax_graph.axis('off') + # 2. Active nodes (white with colored edge, with correction text) + for node in active_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle( + (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 + ) + self.ax_graph.add_patch(circle) + + label_text = str(node) + # Show accumulated internal corrections + if node in corrections: + label_text += "\n" + "".join(sorted(corrections[node])) + + color = "black" + if node in corrections: + color = "blue" # Highlight corrected nodes - def _update(self, val) -> None: - """ - Update the visualization to a specific step. + self.ax_graph.text( + x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 + ) - Parameters - ---------- - val : float - The slider value representing the step index. - """ - self.current_step = int(val) - self._draw_command_list() - self._draw_graph() - self.fig.canvas.draw_idle() + # Set aspect close to equal and hide axes + self.ax_graph.set_aspect("equal") + self.ax_graph.set_xlim(self.x_limits) + self.ax_graph.set_ylim(self.y_limits) + self.ax_graph.axis("off") - def _prev_step(self, event) -> None: - """Callback for 'Previous' button.""" + def _update(self, val: float) -> None: + """Update visualization when slider changes.""" + step = int(val) + if step != self.current_step: + self.current_step = step + self._draw_command_list() + self._draw_graph() + self.fig.canvas.draw_idle() + + def _prev_step(self, event: object) -> None: + """Go backward one step.""" if self.current_step > 0: self.slider.set_val(self.current_step - 1) - def _next_step(self, event) -> None: - """Callback for 'Next' button.""" + def _next_step(self, event: object) -> None: + """Go forward one step.""" if self.current_step < self.total_steps: self.slider.set_val(self.current_step + 1) - - def _on_pick(self, event) -> None: - """ - Handle pick events on the command list. - - Clicking a command sets the current step to immediately after that command. - """ + + def _on_key(self, event: object) -> None: + """Handle keyboard navigation.""" + if event.key == "right": + self._next_step(None) + elif event.key == "left": + self._prev_step(None) + + def _on_pick(self, event: object) -> None: + """Handle clicks on command list items.""" if isinstance(event.artist, plt.Text): - idx = getattr(event.artist, 'index', None) + idx = getattr(event.artist, "index", None) if idx is not None: - # Set step to idx + 1 so the clicked command is executed - self.slider.set_val(idx + 1) + self.slider.set_val(idx + 1) # Jump to state AFTER the clicked command From 11db6c43415dac78fb4c4457647a85a6212998c9 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Mon, 16 Feb 2026 06:06:25 -0600 Subject: [PATCH 03/12] feat: introduce interactive MBQC pattern visualizer with step-by-step GUI and simulation capabilities. --- examples/interactive_viz_qaoa.py | 67 +++++++ graphix/visualization_interactive.py | 280 +++++++++++++++------------ 2 files changed, 222 insertions(+), 125 deletions(-) create mode 100644 examples/interactive_viz_qaoa.py diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py new file mode 100644 index 00000000..649e9096 --- /dev/null +++ b/examples/interactive_viz_qaoa.py @@ -0,0 +1,67 @@ +""" +QAOA Interactive Visualization (Optimized) +========================================== + +This example generates a QAOA pattern using the Graphix Circuit API +and launches the interactive visualizer in simulation-free mode +to demonstrate performance on complex patterns. +""" + +from __future__ import annotations + +import sys +import os +import networkx as nx +import numpy as np + +# Add project root to path to ensure we use local graphix version +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from graphix import Circuit +from graphix.visualization_interactive import InteractiveGraphVisualizer + +def main(): + print("Generating QAOA pattern...") + + # 1. Define QAOA Circuit + n_qubits = 4 + rng = np.random.default_rng(42) # Fixed seed for reproducibility + + # Random parameters for the circuit + xi = rng.random(6) + theta = rng.random(4) + + # Create a complete graph for the problem hamiltonian + g = nx.complete_graph(n_qubits) + circuit = Circuit(n_qubits) + + # Apply unitary evolution for the problem Hamiltonian + for i, (u, v) in enumerate(g.edges): + circuit.cnot(u, v) + circuit.rz(v, xi[i]) # Rotation by random angle + circuit.cnot(u, v) + + # Apply unitary evolution for the mixing Hamiltonian + for v in g.nodes: + circuit.rx(v, theta[v]) + + # 2. Transpile to MBQC Pattern + # This automatically generates the measurement pattern from the gate circuit + pattern = circuit.transpile().pattern + + # Standardize the pattern to ensure it follows the standard MBQC form (N, E, M, C) + pattern.standardize() + pattern.shift_signals() + + print(f"Pattern generated with {len(pattern)} commands.") + print("Launching interactive visualizer...") + print("Optimization enabled: Simulation is DISABLED for performance.") + print("You will see the graph structure and command flow without quantum state calculation.") + + # 3. Launch Visualization + # enable_simulation=False prevents high RAM usage for this complex pattern + viz = InteractiveGraphVisualizer(pattern, node_distance=(1.5, 1.5), enable_simulation=False) + viz.visualize() + +if __name__ == "__main__": + main() diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index e5bfeff1..a4e46d39 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -38,7 +38,12 @@ class InteractiveGraphVisualizer: Scale factors (x, y) for the node positions in the graph layout. """ - def __init__(self, pattern: Pattern, node_distance: tuple[float, float] = (1, 1)) -> None: + def __init__( + self, + pattern: Pattern, + node_distance: tuple[float, float] = (1, 1), + enable_simulation: bool = True, + ) -> None: """ Initialize the interactive visualizer. @@ -50,7 +55,9 @@ def __init__(self, pattern: Pattern, node_distance: tuple[float, float] = (1, 1) Scale factors for x and y coordinates of the graph nodes. Defaults to (1, 1). """ self.pattern = pattern + self.node_positions = {} self.node_distance = node_distance + self.enable_simulation = enable_simulation # Prepare graph layout using Graphix's visualizer or fallbacks self._prepare_layout() @@ -262,7 +269,7 @@ def _draw_command_list(self) -> None: def _update_graph_state( self, step: int - ) -> tuple[set, set, set, dict[int, set[str]], dict[int, int]]: + ) -> tuple[set, set, list[tuple[int, int]], dict[int, set[str]], dict[int, int]]: """ Calculate the state of the graph by simulating the pattern up to `step`. @@ -280,149 +287,172 @@ def _update_graph_state( Nodes currently in the graph (active). measured_nodes : set Nodes that have been measured. - active_edges : set + active_edges : list[tuple[int, int]] Edges currently in the graph. corrections : dict Accumulated Pauli corrections ('X', 'Z') for each node. results : dict Measurement outcomes (0 or 1) for measured nodes. """ - # Initialize sets - active_nodes = set(self.pattern.input_nodes) + # Prepare return containers + active_nodes = set() measured_nodes = set() - active_edges = set() + active_edges = [] corrections: dict[int, set[str]] = {} - - # Simulation setup - backend = StatevectorBackend() - - # Initialize input nodes in the backend - if self.pattern.input_nodes: - backend.add_nodes(self.pattern.input_nodes) - - # Fixed seed for deterministic scrubbing - rng = np.random.default_rng(42) results: dict[int, int] = {} - # Replay commands - for i in range(int(step)): - cmd = self.pattern[i] - - if cmd.kind == CommandKind.N: - active_nodes.add(cmd.node) - backend.add_nodes([cmd.node], data=cmd.state) - - elif cmd.kind == CommandKind.M: - if cmd.node in active_nodes: - active_nodes.remove(cmd.node) - measured_nodes.add(cmd.node) - - # --- Adaptive Measurement Logic (Feedforward) --- - # Calculate s and t signals from previous measurement results - if cmd.s_domain: - s_signal = sum(results.get(j, 0) for j in cmd.s_domain) - else: - s_signal = 0 - if cmd.t_domain: - t_signal = sum(results.get(j, 0) for j in cmd.t_domain) - else: - t_signal = 0 - - s_bool = s_signal % 2 == 1 - t_bool = t_signal % 2 == 1 - - # Compute the updated angle and plane based on signals - measure_update = MeasureUpdate.compute(cmd.plane, s_bool, t_bool, Clifford.I) - - new_angle = cmd.angle * measure_update.coeff + measure_update.add_term - new_plane = measure_update.new_plane - - # Execute measurement on the backend using the adapted measurement - measurement = Measurement(new_angle, new_plane) - result = backend.measure(cmd.node, measurement, rng=rng) - results[cmd.node] = result - - elif cmd.kind == CommandKind.E: - active_edges.add(cmd.nodes) - # Apply entanglement in simulation - backend.entangle_nodes(cmd.nodes) - - elif cmd.kind in (CommandKind.X, CommandKind.Z): - # Apply Pauli corrections conditionally - do_op = True - if cmd.domain: - do_op = sum(results.get(j, 0) for j in cmd.domain) % 2 == 1 - - if do_op: + if self.enable_simulation: + # --- Simulation Mode --- + backend = StatevectorBackend(pattern=self.pattern) + + # Prerun input nodes (standard MBQC initialization) + # Find all input nodes in the pattern + input_nodes = self.pattern.input_nodes + for node in input_nodes: + backend.add_nodes([node]) + + rng = np.random.default_rng(42) # Fixed seed for determinism + + # Re-execute commands up to current step + for i in range(step): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + backend.add_nodes([cmd.node], data=cmd.state) + elif cmd.kind == CommandKind.E: + backend.entangle_nodes(cmd.nodes) + elif cmd.kind == CommandKind.M: + # --- Adaptive Measurement Logic (Feedforward) --- + # Calculate s and t signals from previous measurement results + if cmd.s_domain: + s_signal = sum(results.get(j, 0) for j in cmd.s_domain) + else: + s_signal = 0 + if cmd.t_domain: + t_signal = sum(results.get(j, 0) for j in cmd.t_domain) + else: + t_signal = 0 + + s_bool = s_signal % 2 == 1 + t_bool = t_signal % 2 == 1 + + # Compute the updated angle and plane based on signals + measure_update = MeasureUpdate.compute(cmd.plane, s_bool, t_bool, Clifford.I) + + new_angle = cmd.angle * measure_update.coeff + measure_update.add_term + new_plane = measure_update.new_plane + + # Execute measurement on the backend using the adapted measurement + measurement = Measurement(new_angle, new_plane) + result = backend.measure(cmd.node, measurement, rng=rng) + results[cmd.node] = result + elif cmd.kind == CommandKind.X: + # Accumulate X corrections + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add("X") backend.correct_byproduct(cmd) - # Visual tracking of corrections + elif cmd.kind == CommandKind.Z: if cmd.node not in corrections: corrections[cmd.node] = set() - corrections[cmd.node].add(cmd.kind.name) + corrections[cmd.node].add("Z") + backend.correct_byproduct(cmd) + + # --- Common Logic (Topological Tracking) --- + # We track nodes/edges based on command history regardless of simulation + # This ensures visualization works even if simulation is disabled + + # Reset tracking + current_active_nodes = set(self.pattern.input_nodes) # Start with input nodes + current_edges = set() + current_measured_nodes = set() # Track measured nodes for topological view + + for i in range(step): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + current_active_nodes.add(cmd.node) + elif cmd.kind == CommandKind.E: + u, v = cmd.nodes + # Only add edge if both nodes are currently active (not yet measured) + if u in current_active_nodes and v in current_active_nodes: + current_edges.add(tuple(sorted((u, v)))) + elif cmd.kind == CommandKind.M: + if cmd.node in current_active_nodes: + current_active_nodes.remove(cmd.node) + current_measured_nodes.add(cmd.node) + # Remove connected edges involving the measured node + current_edges = {e for e in current_edges if cmd.node not in e} - # Note: C, S, T, etc. are not explicitly visualized but exist in backend if supported. - # StatevectorBackend handles Clifford logic internally if pattern is standardized, - # but visualizer focuses on MBQC core set {N, M, E, X, Z}. + # Corrections are visualization-only metadata, handled in simulation block or ignored + + active_nodes = current_active_nodes + measured_nodes = current_measured_nodes + active_edges = list(current_edges) return active_nodes, measured_nodes, active_edges, corrections, results def _draw_graph(self) -> None: """Render the graph state on the right panel.""" - self.ax_graph.clear() - - # Get current state from simulation - active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( - self.current_step - ) + try: + self.ax_graph.clear() + + # Get current state from simulation + active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( + self.current_step + ) - # Draw edges - for u, v in active_edges: - x1, y1 = self.node_positions[u] - x2, y2 = self.node_positions[v] - self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) - - # Draw nodes - # 1. Measured nodes (grey, with result text) - for node in measured_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) - self.ax_graph.add_patch(circle) - - label_text = str(node) - # Show measurement outcome if available - if node in results: - label_text += f"\n={results[node]}" - - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) - - # 2. Active nodes (white with colored edge, with correction text) - for node in active_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = plt.Circle( - (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 - ) - self.ax_graph.add_patch(circle) - - label_text = str(node) - # Show accumulated internal corrections - if node in corrections: - label_text += "\n" + "".join(sorted(corrections[node])) - - color = "black" - if node in corrections: - color = "blue" # Highlight corrected nodes - - self.ax_graph.text( - x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 - ) - - # Set aspect close to equal and hide axes - self.ax_graph.set_aspect("equal") - self.ax_graph.set_xlim(self.x_limits) - self.ax_graph.set_ylim(self.y_limits) - self.ax_graph.axis("off") + # Draw edges + for u, v in active_edges: + x1, y1 = self.node_positions[u] + x2, y2 = self.node_positions[v] + self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) + + # Draw nodes + # 1. Measured nodes (grey, with result text) + for node in measured_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) + self.ax_graph.add_patch(circle) + + label_text = str(node) + # Show measurement outcome if available + if node in results: + label_text += f"\n={results[node]}" + + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) + + # 2. Active nodes (white with colored edge, with correction text) + for node in active_nodes: + if node in self.node_positions: + x, y = self.node_positions[node] + circle = plt.Circle( + (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 + ) + self.ax_graph.add_patch(circle) + + label_text = str(node) + # Show accumulated internal corrections + if node in corrections: + label_text += "\n" + "".join(sorted(corrections[node])) + + color = "black" + if node in corrections: + color = "blue" # Highlight corrected nodes + + self.ax_graph.text( + x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 + ) + + # Set aspect close to equal and hide axes + self.ax_graph.set_aspect("equal") + self.ax_graph.set_xlim(self.x_limits) + self.ax_graph.set_ylim(self.y_limits) + self.ax_graph.axis("off") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error drawing graph: {e}", file=sys.stderr) def _update(self, val: float) -> None: """Update visualization when slider changes.""" From b25d387185d825a8c44e967a7ec4eb304430d9b5 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Mon, 16 Feb 2026 06:11:55 -0600 Subject: [PATCH 04/12] update --- examples/interactive_viz_demo.py | 11 ++++-- examples/interactive_viz_qaoa.py | 22 ++++++----- graphix/visualization_interactive.py | 58 ++++++++++++---------------- 3 files changed, 44 insertions(+), 47 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index 3757280f..bee2f909 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -3,9 +3,11 @@ from graphix.fundamentals import Plane import sys import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from graphix.visualization_interactive import InteractiveGraphVisualizer + def main(): # optimized pattern for QFT # Create a simple pattern manually for demonstration @@ -17,15 +19,16 @@ def main(): p.add(M(node=1, plane=Plane.XY, angle=0.25)) p.add(X(node=2, domain={0, 1})) p.add(Z(node=2, domain={0})) - + # Or standardization to make it interesting # p.standardize() - + print("Pattern created with", len(p), "commands.") print("Launching interactive visualization with real-time simulation...") - + viz = InteractiveGraphVisualizer(p) viz.visualize() + if __name__ == "__main__": main() diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py index 649e9096..12024bba 100644 --- a/examples/interactive_viz_qaoa.py +++ b/examples/interactive_viz_qaoa.py @@ -15,53 +15,55 @@ import numpy as np # Add project root to path to ensure we use local graphix version -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from graphix import Circuit from graphix.visualization_interactive import InteractiveGraphVisualizer + def main(): print("Generating QAOA pattern...") - + # 1. Define QAOA Circuit n_qubits = 4 rng = np.random.default_rng(42) # Fixed seed for reproducibility - + # Random parameters for the circuit xi = rng.random(6) theta = rng.random(4) - + # Create a complete graph for the problem hamiltonian g = nx.complete_graph(n_qubits) circuit = Circuit(n_qubits) - + # Apply unitary evolution for the problem Hamiltonian for i, (u, v) in enumerate(g.edges): circuit.cnot(u, v) circuit.rz(v, xi[i]) # Rotation by random angle circuit.cnot(u, v) - + # Apply unitary evolution for the mixing Hamiltonian for v in g.nodes: circuit.rx(v, theta[v]) - + # 2. Transpile to MBQC Pattern # This automatically generates the measurement pattern from the gate circuit pattern = circuit.transpile().pattern - + # Standardize the pattern to ensure it follows the standard MBQC form (N, E, M, C) pattern.standardize() pattern.shift_signals() - + print(f"Pattern generated with {len(pattern)} commands.") print("Launching interactive visualizer...") print("Optimization enabled: Simulation is DISABLED for performance.") print("You will see the graph structure and command flow without quantum state calculation.") - + # 3. Launch Visualization # enable_simulation=False prevents high RAM usage for this complex pattern viz = InteractiveGraphVisualizer(pattern, node_distance=(1.5, 1.5), enable_simulation=False) viz.visualize() + if __name__ == "__main__": main() diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index a4e46d39..c310a37b 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -89,7 +89,7 @@ def _draw_graph(self) -> None: """Render the graph state on the right panel.""" try: self.ax_graph.clear() - + # Get current state from simulation active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( self.current_step @@ -108,44 +108,41 @@ def _draw_graph(self) -> None: x, y = self.node_positions[node] circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) self.ax_graph.add_patch(circle) - + label_text = str(node) # Show measurement outcome if available if node in results: label_text += f"\n={results[node]}" - + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) # 2. Active nodes (white with colored edge, with correction text) for node in active_nodes: if node in self.node_positions: x, y = self.node_positions[node] - circle = plt.Circle( - (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 - ) + circle = plt.Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) self.ax_graph.add_patch(circle) - + label_text = str(node) # Show accumulated internal corrections if node in corrections: label_text += "\n" + "".join(sorted(corrections[node])) - + color = "black" if node in corrections: color = "blue" # Highlight corrected nodes - self.ax_graph.text( - x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 - ) + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3) # Set aspect close to equal and hide axes self.ax_graph.set_aspect("equal") self.ax_graph.set_xlim(self.x_limits) self.ax_graph.set_ylim(self.y_limits) self.ax_graph.axis("off") - + except Exception as e: import traceback + traceback.print_exc() print(f"Error drawing graph: {e}", file=sys.stderr) # Matplotlib widgets placeholders (initialized in visualize) @@ -203,9 +200,7 @@ def visualize(self) -> None: self._update(0) # Slider config - self.slider = Slider( - self.ax_slider, "Step", 0, self.total_steps, valinit=0, valstep=1, color="lightblue" - ) + self.slider = Slider(self.ax_slider, "Step", 0, self.total_steps, valinit=0, valstep=1, color="lightblue") self.slider.on_changed(self._update) # Buttons config @@ -304,7 +299,7 @@ def _update_graph_state( if self.enable_simulation: # --- Simulation Mode --- backend = StatevectorBackend(pattern=self.pattern) - + # Prerun input nodes (standard MBQC initialization) # Find all input nodes in the pattern input_nodes = self.pattern.input_nodes @@ -337,7 +332,7 @@ def _update_graph_state( # Compute the updated angle and plane based on signals measure_update = MeasureUpdate.compute(cmd.plane, s_bool, t_bool, Clifford.I) - + new_angle = cmd.angle * measure_update.coeff + measure_update.add_term new_plane = measure_update.new_plane @@ -360,11 +355,11 @@ def _update_graph_state( # --- Common Logic (Topological Tracking) --- # We track nodes/edges based on command history regardless of simulation # This ensures visualization works even if simulation is disabled - + # Reset tracking - current_active_nodes = set(self.pattern.input_nodes) # Start with input nodes + current_active_nodes = set(self.pattern.input_nodes) # Start with input nodes current_edges = set() - current_measured_nodes = set() # Track measured nodes for topological view + current_measured_nodes = set() # Track measured nodes for topological view for i in range(step): cmd = self.pattern[i] @@ -381,7 +376,7 @@ def _update_graph_state( current_measured_nodes.add(cmd.node) # Remove connected edges involving the measured node current_edges = {e for e in current_edges if cmd.node not in e} - + # Corrections are visualization-only metadata, handled in simulation block or ignored active_nodes = current_active_nodes @@ -394,7 +389,7 @@ def _draw_graph(self) -> None: """Render the graph state on the right panel.""" try: self.ax_graph.clear() - + # Get current state from simulation active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( self.current_step @@ -413,44 +408,41 @@ def _draw_graph(self) -> None: x, y = self.node_positions[node] circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) self.ax_graph.add_patch(circle) - + label_text = str(node) # Show measurement outcome if available if node in results: label_text += f"\n={results[node]}" - + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) # 2. Active nodes (white with colored edge, with correction text) for node in active_nodes: if node in self.node_positions: x, y = self.node_positions[node] - circle = plt.Circle( - (x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2 - ) + circle = plt.Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) self.ax_graph.add_patch(circle) - + label_text = str(node) # Show accumulated internal corrections if node in corrections: label_text += "\n" + "".join(sorted(corrections[node])) - + color = "black" if node in corrections: color = "blue" # Highlight corrected nodes - self.ax_graph.text( - x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3 - ) + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3) # Set aspect close to equal and hide axes self.ax_graph.set_aspect("equal") self.ax_graph.set_xlim(self.x_limits) self.ax_graph.set_ylim(self.y_limits) self.ax_graph.axis("off") - + except Exception as e: import traceback + traceback.print_exc() print(f"Error drawing graph: {e}", file=sys.stderr) From 484d405e7010f1c532e3c8515cd943383e6e8668 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 00:11:21 -0600 Subject: [PATCH 05/12] feat: Implement an interactive visualization tool for MBQC patterns, including simulation capabilities and example usage. --- examples/interactive_viz_demo.py | 15 +- examples/interactive_viz_qaoa.py | 7 +- graphix/visualization.py | 90 +++++++----- graphix/visualization_interactive.py | 210 ++++++--------------------- 4 files changed, 116 insertions(+), 206 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index bee2f909..37b14e59 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -1,14 +1,17 @@ -from graphix.pattern import Pattern -from graphix.command import N, M, E, X, Z -from graphix.fundamentals import Plane +from __future__ import annotations + import sys -import os +from pathlib import Path + +from graphix.command import E, M, N, X, Z +from graphix.fundamentals import Plane +from graphix.pattern import Pattern -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, str(Path(__file__).parent.parent)) from graphix.visualization_interactive import InteractiveGraphVisualizer -def main(): +def main() -> None: # optimized pattern for QFT # Create a simple pattern manually for demonstration p = Pattern(input_nodes=[0, 1]) diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py index 12024bba..e396fe7c 100644 --- a/examples/interactive_viz_qaoa.py +++ b/examples/interactive_viz_qaoa.py @@ -10,18 +10,19 @@ from __future__ import annotations import sys -import os +from pathlib import Path + import networkx as nx import numpy as np # Add project root to path to ensure we use local graphix version -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, str(Path(__file__).parent.parent)) from graphix import Circuit from graphix.visualization_interactive import InteractiveGraphVisualizer -def main(): +def main() -> None: print("Generating QAOA pattern...") # 1. Define QAOA Circuit diff --git a/graphix/visualization.py b/graphix/visualization.py index ec5f77af..11c034a7 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -94,41 +94,25 @@ def __init__( self.meas_angles = meas_angles self.local_clifford = local_clifford - def visualize( + def get_layout( self, - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - show_loop: bool = True, - node_distance: tuple[float, float] = (1, 1), - figsize: tuple[int, int] | None = None, - filename: Path | None = None, - ) -> None: - """ - Visualize the graph with flow or gflow structure. - - If there exists a flow structure, then the graph is visualized with the flow structure. - If flow structure is not found and there exists a gflow structure, then the graph is visualized - with the gflow structure. - If neither flow nor gflow structure is found, then the graph is visualized without any structure. + ) -> tuple[ + Mapping[int, _Point], + Callable[ + [Mapping[int, _Point]], tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None] + ], + Mapping[int, int] | None, + ]: + """Determine the layout (positions, paths, layers) for the graph. - Parameters - ---------- - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - show_loop : bool - whether or not to show loops for graphs with gflow. defaulted to True. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - filename : Path | None - If not None, filename of the png file to save the plot. If None, the plot is not saved. - Default in None. + Returns + ------- + pos : dict + Node positions. + place_paths : callable + Function to place edges and arrows. + l_k : dict or None + Layer mapping. """ og = OpenGraph(self.graph, list(self.v_in), list(self.v_out), self.meas_planes) causal_flow = og.find_causal_flow() @@ -168,6 +152,46 @@ def place_paths( ) -> tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None]: return (self.place_edge_paths_without_structure(pos), None) + return pos, place_paths, l_k + + def visualize( + self, + show_pauli_measurement: bool = True, + show_local_clifford: bool = False, + show_measurement_planes: bool = False, + show_loop: bool = True, + node_distance: tuple[float, float] = (1, 1), + figsize: tuple[int, int] | None = None, + filename: Path | None = None, + ) -> None: + """ + Visualize the graph with flow or gflow structure. + + If there exists a flow structure, then the graph is visualized with the flow structure. + If flow structure is not found and there exists a gflow structure, then the graph is visualized + with the gflow structure. + If neither flow nor gflow structure is found, then the graph is visualized without any structure. + + Parameters + ---------- + show_pauli_measurement : bool + If True, the nodes with Pauli measurement angles are colored light blue. + show_local_clifford : bool + If True, indexes of the local Clifford operator are displayed adjacent to the nodes. + show_measurement_planes : bool + If True, the measurement planes are displayed adjacent to the nodes. + show_loop : bool + whether or not to show loops for graphs with gflow. defaulted to True. + node_distance : tuple + Distance multiplication factor between nodes for x and y directions. + figsize : tuple + Figure size of the plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. + """ + pos, place_paths, l_k = self.get_layout() + self.visualize_graph( pos, place_paths, diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index c310a37b..af3e6d18 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -3,39 +3,38 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING +import traceback +from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt import networkx as nx import numpy as np +from matplotlib.patches import Circle +from matplotlib.text import Text from matplotlib.widgets import Button, Slider from graphix.clifford import Clifford from graphix.command import CommandKind, MeasureUpdate from graphix.measurements import Measurement -from graphix.pattern import Pattern from graphix.pretty_print import OutputFormat, command_to_str from graphix.sim.statevec import StatevectorBackend from graphix.visualization import GraphVisualizer if TYPE_CHECKING: - from collections.abc import Collection + from graphix.pattern import Pattern class InteractiveGraphVisualizer: - """ - Interactive visualization tool for MBQC patterns. - - This visualizer provides a matplotlib-based GUI to step through the execution - of an MBQC pattern. It displays the sequence of commands and the corresponding - state of the graph state, including real-time simulation of measurement outcomes. + """Interactive visualization tool for MBQC patterns. Attributes ---------- pattern : Pattern The MBQC pattern to visualize. node_distance : tuple[float, float] - Scale factors (x, y) for the node positions in the graph layout. + Scale factors (x, y) for the node positions. + enable_simulation : bool + If True, simulates the state vector and measurement outcomes. """ def __init__( @@ -44,18 +43,19 @@ def __init__( node_distance: tuple[float, float] = (1, 1), enable_simulation: bool = True, ) -> None: - """ - Initialize the interactive visualizer. + """Construct an interactive visualizer. Parameters ---------- pattern : Pattern The MBQC pattern to visualize. node_distance : tuple[float, float], optional - Scale factors for x and y coordinates of the graph nodes. Defaults to (1, 1). + Scale factors (x, y) for node positions. Defaults to (1, 1). + enable_simulation : bool, optional + If True, enables state vector simulation. Defaults to True. """ self.pattern = pattern - self.node_positions = {} + self.node_positions: dict[int, tuple[float, float]] = {} self.node_distance = node_distance self.enable_simulation = enable_simulation @@ -69,11 +69,11 @@ def __init__( # Layout optimized to prevent overlap: # Commands: Left 2% to 30% # Graph: Left 40% to 98% - self.ax_commands = self.fig.add_axes([0.02, 0.2, 0.28, 0.7]) # [left, bottom, width, height] - self.ax_graph = self.fig.add_axes([0.4, 0.2, 0.58, 0.7]) - self.ax_slider = self.fig.add_axes([0.4, 0.05, 0.5, 0.03]) - self.ax_prev = self.fig.add_axes([0.3, 0.05, 0.04, 0.04]) - self.ax_next = self.fig.add_axes([0.92, 0.05, 0.04, 0.04]) + self.ax_commands = self.fig.add_axes((0.02, 0.2, 0.28, 0.7)) # [left, bottom, width, height] + self.ax_graph = self.fig.add_axes((0.4, 0.2, 0.58, 0.7)) + self.ax_slider = self.fig.add_axes((0.4, 0.05, 0.5, 0.03)) + self.ax_prev = self.fig.add_axes((0.3, 0.05, 0.04, 0.04)) + self.ax_next = self.fig.add_axes((0.92, 0.05, 0.04, 0.04)) # Turn off axes frame for command list and graph self.ax_commands.axis("off") @@ -83,77 +83,14 @@ def __init__( self.current_step = 0 self.total_steps = len(pattern) - # ... (other methods) ... - - def _draw_graph(self) -> None: - """Render the graph state on the right panel.""" - try: - self.ax_graph.clear() - - # Get current state from simulation - active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( - self.current_step - ) - - # Draw edges - for u, v in active_edges: - x1, y1 = self.node_positions[u] - x2, y2 = self.node_positions[v] - self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) - - # Draw nodes - # 1. Measured nodes (grey, with result text) - for node in measured_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) - self.ax_graph.add_patch(circle) - - label_text = str(node) - # Show measurement outcome if available - if node in results: - label_text += f"\n={results[node]}" - - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) - - # 2. Active nodes (white with colored edge, with correction text) - for node in active_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = plt.Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) - self.ax_graph.add_patch(circle) - - label_text = str(node) - # Show accumulated internal corrections - if node in corrections: - label_text += "\n" + "".join(sorted(corrections[node])) - - color = "black" - if node in corrections: - color = "blue" # Highlight corrected nodes - - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3) - - # Set aspect close to equal and hide axes - self.ax_graph.set_aspect("equal") - self.ax_graph.set_xlim(self.x_limits) - self.ax_graph.set_ylim(self.y_limits) - self.ax_graph.axis("off") - - except Exception as e: - import traceback - - traceback.print_exc() - print(f"Error drawing graph: {e}", file=sys.stderr) - # Matplotlib widgets placeholders (initialized in visualize) + # Interaction state placeholders self.slider: Slider | None = None self.btn_prev: Button | None = None self.btn_next: Button | None = None def _prepare_layout(self) -> None: - """Calculate node positions for the graph.""" # Build full graph to determine positions - g = nx.Graph() + g: Any = nx.Graph() for cmd in self.pattern: if cmd.kind == CommandKind.N: g.add_node(cmd.node) @@ -162,23 +99,8 @@ def _prepare_layout(self) -> None: # Use GraphVisualizer to determine positions based on flow/structure vis = GraphVisualizer(g, self.pattern.input_nodes, self.pattern.output_nodes) - - # Try to find flow/gflow for better layout, fallback to spring layout - try: - from graphix.optimization import StandardizedPattern - - pattern_std = StandardizedPattern.from_pattern(self.pattern) - try: - flow = pattern_std.extract_causal_flow() - self.node_positions = vis.place_flow(flow) - except Exception: - try: - gflow = pattern_std.extract_gflow() - self.node_positions = vis.place_gflow(gflow) - except Exception: - self.node_positions = vis.place_without_structure() - except Exception: - self.node_positions = vis.place_without_structure() + pos_mapping, _, _ = vis.get_layout() + self.node_positions = dict(pos_mapping) # Apply scaling self.node_positions = { @@ -219,7 +141,6 @@ def visualize(self) -> None: plt.show() def _draw_command_list(self) -> None: - """Render the list of commands in the left panel.""" self.ax_commands.clear() self.ax_commands.axis("off") self.ax_commands.set_title(f"Commands ({self.total_steps})", loc="left") @@ -232,7 +153,7 @@ def _draw_command_list(self) -> None: if end == self.total_steps: start = max(0, end - window_size) - cmds = self.pattern[start:end] + cmds: Any = self.pattern[start:end] # type: ignore[index] for i, cmd in enumerate(cmds): abs_idx = start + i @@ -260,35 +181,12 @@ def _draw_command_list(self) -> None: picker=True, ) # Store index with artist for picking - text_obj.index = abs_idx + text_obj.index = abs_idx # type: ignore[attr-defined] def _update_graph_state( self, step: int - ) -> tuple[set, set, list[tuple[int, int]], dict[int, set[str]], dict[int, int]]: - """ - Calculate the state of the graph by simulating the pattern up to `step`. - - This method performs a full re-simulation using `StatevectorBackend` - to ensure deterministic measurement outcomes and correct adaptive behavior. - - Parameters - ---------- - step : int - Current execution step. - - Returns - ------- - active_nodes : set - Nodes currently in the graph (active). - measured_nodes : set - Nodes that have been measured. - active_edges : list[tuple[int, int]] - Edges currently in the graph. - corrections : dict - Accumulated Pauli corrections ('X', 'Z') for each node. - results : dict - Measurement outcomes (0 or 1) for measured nodes. - """ + ) -> tuple[set[int], set[int], list[tuple[int, ...]], dict[int, set[str]], dict[int, int]]: + """Calculate the graph state by simulating the pattern up to `step`.""" # Prepare return containers active_nodes = set() measured_nodes = set() @@ -298,10 +196,9 @@ def _update_graph_state( if self.enable_simulation: # --- Simulation Mode --- - backend = StatevectorBackend(pattern=self.pattern) + backend = StatevectorBackend() # Prerun input nodes (standard MBQC initialization) - # Find all input nodes in the pattern input_nodes = self.pattern.input_nodes for node in input_nodes: backend.add_nodes([node]) @@ -318,14 +215,8 @@ def _update_graph_state( elif cmd.kind == CommandKind.M: # --- Adaptive Measurement Logic (Feedforward) --- # Calculate s and t signals from previous measurement results - if cmd.s_domain: - s_signal = sum(results.get(j, 0) for j in cmd.s_domain) - else: - s_signal = 0 - if cmd.t_domain: - t_signal = sum(results.get(j, 0) for j in cmd.t_domain) - else: - t_signal = 0 + s_signal = sum(results.get(j, 0) for j in cmd.s_domain) if cmd.s_domain else 0 + t_signal = sum(results.get(j, 0) for j in cmd.t_domain) if cmd.t_domain else 0 s_bool = s_signal % 2 == 1 t_bool = t_signal % 2 == 1 @@ -370,12 +261,11 @@ def _update_graph_state( # Only add edge if both nodes are currently active (not yet measured) if u in current_active_nodes and v in current_active_nodes: current_edges.add(tuple(sorted((u, v)))) - elif cmd.kind == CommandKind.M: - if cmd.node in current_active_nodes: - current_active_nodes.remove(cmd.node) - current_measured_nodes.add(cmd.node) - # Remove connected edges involving the measured node - current_edges = {e for e in current_edges if cmd.node not in e} + elif cmd.kind == CommandKind.M and cmd.node in current_active_nodes: + current_active_nodes.remove(cmd.node) + current_measured_nodes.add(cmd.node) + # Remove connected edges involving the measured node + current_edges = {e for e in current_edges if cmd.node not in e} # Corrections are visualization-only metadata, handled in simulation block or ignored @@ -386,7 +276,6 @@ def _update_graph_state( return active_nodes, measured_nodes, active_edges, corrections, results def _draw_graph(self) -> None: - """Render the graph state on the right panel.""" try: self.ax_graph.clear() @@ -406,7 +295,7 @@ def _draw_graph(self) -> None: for node in measured_nodes: if node in self.node_positions: x, y = self.node_positions[node] - circle = plt.Circle((x, y), 0.1, color="lightgray", zorder=2) + circle = Circle((x, y), 0.1, color="lightgray", zorder=2) self.ax_graph.add_patch(circle) label_text = str(node) @@ -420,7 +309,7 @@ def _draw_graph(self) -> None: for node in active_nodes: if node in self.node_positions: x, y = self.node_positions[node] - circle = plt.Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) + circle = Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) self.ax_graph.add_patch(circle) label_text = str(node) @@ -440,14 +329,11 @@ def _draw_graph(self) -> None: self.ax_graph.set_ylim(self.y_limits) self.ax_graph.axis("off") - except Exception as e: - import traceback - + except Exception as e: # noqa: BLE001 traceback.print_exc() print(f"Error drawing graph: {e}", file=sys.stderr) def _update(self, val: float) -> None: - """Update visualization when slider changes.""" step = int(val) if step != self.current_step: self.current_step = step @@ -455,26 +341,22 @@ def _update(self, val: float) -> None: self._draw_graph() self.fig.canvas.draw_idle() - def _prev_step(self, event: object) -> None: - """Go backward one step.""" - if self.current_step > 0: + def _prev_step(self, _event: Any) -> None: + if self.current_step > 0 and self.slider is not None: self.slider.set_val(self.current_step - 1) - def _next_step(self, event: object) -> None: - """Go forward one step.""" - if self.current_step < self.total_steps: + def _next_step(self, _event: Any) -> None: + if self.current_step < self.total_steps and self.slider is not None: self.slider.set_val(self.current_step + 1) - def _on_key(self, event: object) -> None: - """Handle keyboard navigation.""" + def _on_key(self, event: Any) -> None: if event.key == "right": self._next_step(None) elif event.key == "left": self._prev_step(None) - def _on_pick(self, event: object) -> None: - """Handle clicks on command list items.""" - if isinstance(event.artist, plt.Text): + def _on_pick(self, event: Any) -> None: + if isinstance(event.artist, Text): idx = getattr(event.artist, "index", None) - if idx is not None: + if idx is not None and self.slider is not None: self.slider.set_val(idx + 1) # Jump to state AFTER the clicked command From 99c04c9638218883dc60ae3308034566a7625484 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 00:34:48 -0600 Subject: [PATCH 06/12] update --- tests/test_visualization_interactive.py | 217 ++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tests/test_visualization_interactive.py diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py new file mode 100644 index 00000000..cc1d4d32 --- /dev/null +++ b/tests/test_visualization_interactive.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from matplotlib.text import Text + +from graphix.command import E, M, N, X, Z +from graphix.fundamentals import Plane +from graphix.pattern import Pattern +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +class TestInteractiveGraphVisualizer: + @pytest.fixture + def pattern(self) -> Pattern: + """Fixture to provide a standard pattern for testing.""" + pattern = Pattern(input_nodes=[0, 1]) + pattern.add(N(node=0)) + pattern.add(N(node=1)) + pattern.add(N(node=2)) + pattern.add(E(nodes=(0, 1))) + pattern.add(E(nodes=(1, 2))) + pattern.add(M(node=0, plane=Plane.XY, angle=0.5, s_domain={1}, t_domain={2})) + pattern.add(M(node=1, plane=Plane.XY, angle=0.0, s_domain={2}, t_domain=set())) + pattern.add(X(node=2, domain={0})) + pattern.add(Z(node=2, domain={1})) + return pattern + + def test_init(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test initialization of the visualizer.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + # Mock layout generation + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + + assert viz.total_steps == len(pattern) + assert viz.enable_simulation + # Check if get_layout was called + mock_vis_obj.get_layout.assert_called_once() + # Check if node positions are set + assert len(viz.node_positions) == 3 + + def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that layout logic delegates to GraphVisualizer.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + # Return specific positions to verify they are used + expected_pos = {0: (10, 10), 1: (20, 20), 2: (30, 30)} + mock_vis_obj.get_layout.return_value = (expected_pos, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + + # Keys should match + assert viz.node_positions.keys() == expected_pos.keys() + # Values should be scaled by default node_distance (1, 1) + assert viz.node_positions[0] == (10, 10) + + def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test graph state update with simulation enabled.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + # Mock simulation backend + backend_instance = mock_backend.return_value + backend_instance.measure.return_value = 1 + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=True) + + # Update to the end of the pattern + active, measured, _, _, results = viz._update_graph_state(len(pattern)) + + # Basic Checks + # Node 0 and 1 are measured + assert 0 in measured + assert 1 in measured + # Node 2 is active + assert 2 in active + # Results should be populated (since we mocked measure return value) + assert results[0] == 1 + assert results[1] == 1 + # Check if backend methods were called + backend_instance.add_nodes.assert_called() + backend_instance.entangle_nodes.assert_called() + assert backend_instance.measure.call_count == 2 + assert backend_instance.correct_byproduct.call_count == 2 + + def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test graph state update with simulation disabled.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=False) + + # Update to the end of the pattern + active, measured, _, _, results = viz._update_graph_state(len(pattern)) + + # Basic Checks + # Node 0 and 1 are measured (topology tracking works without sim) + assert 0 in measured + assert 1 in measured + # Node 2 is active + assert 2 in active + # Results should be empty as simulation is disabled + assert results == {} + + def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test step navigation methods.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + # Mock slider + viz.slider = MagicMock() + viz.total_steps = 10 + viz.current_step = 5 + + # Test Prev + viz._prev_step(None) + viz.slider.set_val.assert_called_with(4) + + # Test Next + viz._next_step(None) + viz.slider.set_val.assert_called_with(6) + + # Test Boundary Prev + viz.current_step = 0 + viz.slider.reset_mock() + viz._prev_step(None) + viz.slider.set_val.assert_not_called() + + # Test Boundary Next + viz.current_step = 10 + viz.slider.reset_mock() + viz._next_step(None) + viz.slider.set_val.assert_not_called() + + def test_visualize(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test the main visualize method (smoke test).""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + mock_show = mocker.patch("matplotlib.pyplot.show") + mocker.patch("graphix.visualization_interactive.Slider") # Mock Slider to avoid matplotlib validation + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.visualize() + + # Should show plot + mock_show.assert_called_once() + # Should initialize axes + assert viz.ax_commands is not None + assert viz.ax_graph is not None + assert viz.slider is not None + assert viz.btn_next is not None + assert viz.btn_prev is not None + + def test_interaction_events(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test interaction event handlers.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.slider = MagicMock() + + # Test key press + valid_key_event = MagicMock() + valid_key_event.key = "right" + # Mock return value of _next_step side effect + mock_next = mocker.patch.object(viz, "_next_step") + viz._on_key(valid_key_event) + mock_next.assert_called_once() + + valid_key_event.key = "left" + mock_prev = mocker.patch.object(viz, "_prev_step") + viz._on_key(valid_key_event) + mock_prev.assert_called_once() + + # Test pick event (clicking on command list) + # We need a real Text object (or a mock that spec=Text) because _on_pick uses isinstance + + mock_artist = MagicMock(spec=Text) + mock_artist.index = 5 + pick_event = MagicMock() + pick_event.artist = mock_artist + + # Should set slider to index + 1 (highlight executed commands up to that point) + viz._on_pick(pick_event) + viz.slider.set_val.assert_called_with(6) From 25c268e293eb9b0df34792d2de91c621055cafd5 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 04:10:09 -0600 Subject: [PATCH 07/12] update --- examples/interactive_viz_demo.py | 9 +++++ examples/interactive_viz_qaoa.py | 4 +-- tests/test_visualization_interactive.py | 48 +++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index 37b14e59..02bafb21 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -1,3 +1,12 @@ +""" +Interactive Visualization Demo +============================ + +This example demonstrates the interactive graph visualizer using a simple +manually constructed pattern. It shows how to step through the visualization +and observe state changes. +""" + from __future__ import annotations import sys diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py index e396fe7c..8938e56e 100644 --- a/examples/interactive_viz_qaoa.py +++ b/examples/interactive_viz_qaoa.py @@ -40,12 +40,12 @@ def main() -> None: # Apply unitary evolution for the problem Hamiltonian for i, (u, v) in enumerate(g.edges): circuit.cnot(u, v) - circuit.rz(v, xi[i]) # Rotation by random angle + circuit.rz(v, float(xi[i])) # Rotation by random angle circuit.cnot(u, v) # Apply unitary evolution for the mixing Hamiltonian for v in g.nodes: - circuit.rx(v, theta[v]) + circuit.rx(v, float(theta[v])) # 2. Transpile to MBQC Pattern # This automatically generates the measurement pattern from the gate circuit diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py index cc1d4d32..fafa3d3b 100644 --- a/tests/test_visualization_interactive.py +++ b/tests/test_visualization_interactive.py @@ -1,3 +1,5 @@ +"""Tests for the interactive visualization module.""" + from __future__ import annotations from unittest.mock import MagicMock @@ -98,6 +100,19 @@ def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: M assert backend_instance.measure.call_count == 2 assert backend_instance.correct_byproduct.call_count == 2 + # Manually trigger update to test drawing logic for measured/active nodes + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz.slider.val = len(pattern) + viz._update(len(pattern)) + + # Check that drawing methods were called + # Measured nodes (0, 1) should generally be lightgray + # Active node (2) should be white/blue + assert viz.ax_graph.add_patch.call_count > 0 + assert viz.ax_graph.text.call_count > 0 + def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: """Test graph state update with simulation disabled.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") @@ -121,6 +136,17 @@ def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: # Results should be empty as simulation is disabled assert results == {} + # Manually trigger update to test drawing logic without simulation + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz.slider.val = len(pattern) + viz._update(len(pattern)) + + # Ensure text is drawn (commands, node labels) + assert viz.ax_commands.text.call_count > 0 + assert viz.ax_graph.text.call_count > 0 + def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: """Test step navigation methods.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") @@ -204,6 +230,28 @@ def test_interaction_events(self, pattern: Pattern, mocker: MagicMock) -> None: viz._on_key(valid_key_event) mock_prev.assert_called_once() + def test_z_correction_initialization(self, mocker: MagicMock) -> None: + """Test tracking of Z corrections specifically to cover Z initialization.""" + # Create a pattern with a Z correction on a fresh node + pattern = Pattern(input_nodes=[0]) + pattern.add(N(node=0)) + pattern.add(Z(node=0, domain=set())) + + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=False) + + # Trigger update to process the Z command + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz._update(len(pattern)) + # Test pick event (clicking on command list) # We need a real Text object (or a mock that spec=Text) because _on_pick uses isinstance From bd595b8c74ed77489ac282221c55a49f32cb8c82 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 07:11:45 -0600 Subject: [PATCH 08/12] update 2 --- examples/interactive_viz_demo.py | 11 ++--- examples/interactive_viz_qaoa.py | 6 --- graphix/visualization_interactive.py | 41 ++++++++++++++---- tests/test_visualization_interactive.py | 55 +++++++++++++++++++++++-- 4 files changed, 87 insertions(+), 26 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index 02bafb21..fabfe3af 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -9,14 +9,9 @@ from __future__ import annotations -import sys -from pathlib import Path - from graphix.command import E, M, N, X, Z -from graphix.fundamentals import Plane +from graphix.measurements import Measurement from graphix.pattern import Pattern - -sys.path.insert(0, str(Path(__file__).parent.parent)) from graphix.visualization_interactive import InteractiveGraphVisualizer @@ -27,8 +22,8 @@ def main() -> None: p.add(N(node=2)) p.add(E(nodes=(0, 2))) p.add(E(nodes=(1, 2))) - p.add(M(node=0, plane=Plane.XY, angle=0.5)) - p.add(M(node=1, plane=Plane.XY, angle=0.25)) + p.add(M(node=0, measurement=Measurement.XY(0.5))) + p.add(M(node=1, measurement=Measurement.XY(0.25))) p.add(X(node=2, domain={0, 1})) p.add(Z(node=2, domain={0})) diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py index 8938e56e..b830e560 100644 --- a/examples/interactive_viz_qaoa.py +++ b/examples/interactive_viz_qaoa.py @@ -9,15 +9,9 @@ from __future__ import annotations -import sys -from pathlib import Path - import networkx as nx import numpy as np -# Add project root to path to ensure we use local graphix version -sys.path.insert(0, str(Path(__file__).parent.parent)) - from graphix import Circuit from graphix.visualization_interactive import InteractiveGraphVisualizer diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index af3e6d18..aa790d46 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -14,8 +14,8 @@ from matplotlib.widgets import Button, Slider from graphix.clifford import Clifford -from graphix.command import CommandKind, MeasureUpdate -from graphix.measurements import Measurement +from graphix.command import CommandKind +from graphix.opengraph import OpenGraph from graphix.pretty_print import OutputFormat, command_to_str from graphix.sim.statevec import StatevectorBackend from graphix.visualization import GraphVisualizer @@ -91,17 +91,34 @@ def __init__( def _prepare_layout(self) -> None: # Build full graph to determine positions g: Any = nx.Graph() + measurements: dict[int, Any] = {} for cmd in self.pattern: if cmd.kind == CommandKind.N: g.add_node(cmd.node) elif cmd.kind == CommandKind.E: g.add_edge(cmd.nodes[0], cmd.nodes[1]) + elif cmd.kind == CommandKind.M: + measurements[cmd.node] = cmd.measurement # Use GraphVisualizer to determine positions based on flow/structure - vis = GraphVisualizer(g, self.pattern.input_nodes, self.pattern.output_nodes) + og = OpenGraph(g, self.pattern.input_nodes, self.pattern.output_nodes, measurements) + # Infer Pauli measurements to avoid warnings and improve flow detection + og = og.infer_pauli_measurements() + + vis = GraphVisualizer(og) pos_mapping, _, _ = vis.get_layout() self.node_positions = dict(pos_mapping) + x_coords = [p[0] for p in self.node_positions.values()] + y_coords = [p[1] for p in self.node_positions.values()] + if x_coords and y_coords: + width = max(x_coords) - min(x_coords) + if width < 2.0 and len(self.node_positions) > 5: + # Fallback to spring layout for better interactivity + # We recreate the graph for layout since `og.graph` might be modified + pos_spring = nx.spring_layout(g, seed=42) + self.node_positions = {n: (p[0], p[1]) for n, p in pos_spring.items()} + # Apply scaling self.node_positions = { k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() @@ -111,8 +128,12 @@ def _prepare_layout(self) -> None: all_x = [pos[0] for pos in self.node_positions.values()] all_y = [pos[1] for pos in self.node_positions.values()] margin = 0.5 - self.x_limits = (min(all_x) - margin, max(all_x) + margin) - self.y_limits = (min(all_y) - margin, max(all_y) + margin) + if all_x and all_y: + self.x_limits = (min(all_x) - margin, max(all_x) + margin) + self.y_limits = (min(all_y) - margin, max(all_y) + margin) + else: + self.x_limits = (-1, 1) + self.y_limits = (-1, 1) def visualize(self) -> None: """Launch the interactive visualization window.""" @@ -222,13 +243,15 @@ def _update_graph_state( t_bool = t_signal % 2 == 1 # Compute the updated angle and plane based on signals - measure_update = MeasureUpdate.compute(cmd.plane, s_bool, t_bool, Clifford.I) + clifford = Clifford.I + if s_bool: + clifford = Clifford.X @ clifford + if t_bool: + clifford = Clifford.Z @ clifford - new_angle = cmd.angle * measure_update.coeff + measure_update.add_term - new_plane = measure_update.new_plane + measurement = cmd.measurement.clifford(clifford) # Execute measurement on the backend using the adapted measurement - measurement = Measurement(new_angle, new_plane) result = backend.measure(cmd.node, measurement, rng=rng) results[cmd.node] = result elif cmd.kind == CommandKind.X: diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py index fafa3d3b..424392fa 100644 --- a/tests/test_visualization_interactive.py +++ b/tests/test_visualization_interactive.py @@ -8,7 +8,7 @@ from matplotlib.text import Text from graphix.command import E, M, N, X, Z -from graphix.fundamentals import Plane +from graphix.measurements import Measurement from graphix.pattern import Pattern from graphix.visualization_interactive import InteractiveGraphVisualizer @@ -23,8 +23,8 @@ def pattern(self) -> Pattern: pattern.add(N(node=2)) pattern.add(E(nodes=(0, 1))) pattern.add(E(nodes=(1, 2))) - pattern.add(M(node=0, plane=Plane.XY, angle=0.5, s_domain={1}, t_domain={2})) - pattern.add(M(node=1, plane=Plane.XY, angle=0.0, s_domain={2}, t_domain=set())) + pattern.add(M(node=0, measurement=Measurement.XY(0.5), s_domain={1}, t_domain={2})) + pattern.add(M(node=1, measurement=Measurement.XY(0.0), s_domain={2}, t_domain=set())) pattern.add(X(node=2, domain={0})) pattern.add(Z(node=2, domain={1})) return pattern @@ -32,6 +32,13 @@ def pattern(self) -> Pattern: def test_init(self, pattern: Pattern, mocker: MagicMock) -> None: """Test initialization of the visualizer.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + + # Capture the OpenGraph mock correctly + mock_og_class = mocker.patch("graphix.visualization_interactive.OpenGraph") + mock_og_instance = mock_og_class.return_value + # Ensure infer_pauli_measurements returns a mock (or itself) to support chaining + mock_og_instance.infer_pauli_measurements.return_value = mock_og_instance + mocker.patch("matplotlib.pyplot.figure") # Mock layout generation @@ -44,13 +51,18 @@ def test_init(self, pattern: Pattern, mocker: MagicMock) -> None: assert viz.total_steps == len(pattern) assert viz.enable_simulation # Check if get_layout was called + mock_visualizer.assert_called_with(mock_og_instance) # Verify visualizer init with corrected OG mock_vis_obj.get_layout.assert_called_once() # Check if node positions are set assert len(viz.node_positions) == 3 + # Check if infer_pauli_measurements was called + mock_og_instance.infer_pauli_measurements.assert_called_once() + def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: """Test that layout logic delegates to GraphVisualizer.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() @@ -69,6 +81,7 @@ def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: MagicMock) -> None: """Test graph state update with simulation enabled.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") @@ -116,6 +129,7 @@ def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: M def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: """Test graph state update with simulation disabled.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() @@ -150,6 +164,7 @@ def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: """Test step navigation methods.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() @@ -185,6 +200,7 @@ def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: def test_visualize(self, pattern: Pattern, mocker: MagicMock) -> None: """Test the main visualize method (smoke test).""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_show = mocker.patch("matplotlib.pyplot.show") mocker.patch("graphix.visualization_interactive.Slider") # Mock Slider to avoid matplotlib validation @@ -208,6 +224,7 @@ def test_visualize(self, pattern: Pattern, mocker: MagicMock) -> None: def test_interaction_events(self, pattern: Pattern, mocker: MagicMock) -> None: """Test interaction event handlers.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() @@ -238,6 +255,7 @@ def test_z_correction_initialization(self, mocker: MagicMock) -> None: pattern.add(Z(node=0, domain=set())) mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() @@ -263,3 +281,34 @@ def test_z_correction_initialization(self, mocker: MagicMock) -> None: # Should set slider to index + 1 (highlight executed commands up to that point) viz._on_pick(pick_event) viz.slider.set_val.assert_called_with(6) + + def test_layout_fallback_triggered(self, mocker: MagicMock) -> None: + """Test that spring layout fallback is triggered for narrow graphs.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_spring_layout = mocker.patch("graphix.visualization_interactive.nx.spring_layout") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + + initial_pos = {i: (0, i) for i in range(6)} + mock_vis_obj.get_layout.return_value = (initial_pos, {}, {}) + + # Spring layout returns something else + spring_pos = {i: (10, i) for i in range(6)} + mock_spring_layout.return_value = spring_pos + + big_pattern = Pattern(input_nodes=list(range(6))) + for i in range(6): + big_pattern.add(N(node=i)) + + viz = InteractiveGraphVisualizer(big_pattern) + + # Check if spring layout was called + mock_spring_layout.assert_called_once() + + # Check if positions were updated to spring layout positions + # Note: InteractiveGraphVisualizer applies scaling after layout + # default node_distance is (1, 1), so positions should match spring_pos + assert viz.node_positions == spring_pos From cb78719d38b7815183e524be48bf8ab62b7b443a Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 11:13:31 -0600 Subject: [PATCH 09/12] update --- examples/interactive_viz_demo.py | 4 +-- tests/test_visualization_interactive.py | 44 +++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index fabfe3af..16dfa0e5 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -1,7 +1,7 @@ """ Interactive Visualization Demo -============================ - +============================== +este This example demonstrates the interactive graph visualizer using a simple manually constructed pattern. It shows how to step through the visualization and observe state changes. diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py index 424392fa..a2418c89 100644 --- a/tests/test_visualization_interactive.py +++ b/tests/test_visualization_interactive.py @@ -312,3 +312,47 @@ def test_layout_fallback_triggered(self, mocker: MagicMock) -> None: # Note: InteractiveGraphVisualizer applies scaling after layout # default node_distance is (1, 1), so positions should match spring_pos assert viz.node_positions == spring_pos + + def test_draw_edges_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that edge drawing logic is executed (covers lines 312-314).""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + # Provide positions for nodes 0, 1, 2 + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + viz.slider = MagicMock() + + # Step 5 in our fixture pattern has entanglement E(0, 1) and E(1, 2) + # and no nodes have been measured yet. + viz._update(5) + + # Verify that plot (used for edges) was called + # There should be 2 edges: (0, 1) and (1, 2) + assert viz.ax_graph.plot.call_count == 2 + + def test_draw_graph_exception_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test the exception handling in _draw_graph (covers lines 355-357).""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + + # Force an exception during state update + mocker.patch.object(viz, "_update_graph_state", side_effect=ValueError("Test Exception")) + # Mock traceback to avoid cluttering test output + mocker.patch("traceback.print_exc") + + # This should not raise but log/print + viz._draw_graph() From 6d6d07e3d8537280c832870c381f7634f45ac17d Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Tue, 17 Feb 2026 22:26:31 -0600 Subject: [PATCH 10/12] update --- examples/interactive_viz_demo.py | 2 +- graphix/visualization.py | 84 ++++++++ graphix/visualization_interactive.py | 252 ++++++++++++++---------- tests/test_visualization_interactive.py | 63 ++++-- 4 files changed, 280 insertions(+), 121 deletions(-) diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py index 16dfa0e5..31accf69 100644 --- a/examples/interactive_viz_demo.py +++ b/examples/interactive_viz_demo.py @@ -1,7 +1,7 @@ """ Interactive Visualization Demo ============================== -este + This example demonstrates the interactive graph visualizer using a simple manually constructed pattern. It shows how to step through the visualization and observe state changes. diff --git a/graphix/visualization.py b/graphix/visualization.py index 35ad3c93..9c90781a 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -24,6 +24,7 @@ from typing import TypeAlias, TypeVar import numpy.typing as npt + from matplotlib.axes import Axes from graphix.clifford import Clifford from graphix.flow.core import CausalFlow, PauliFlow @@ -282,6 +283,89 @@ def _draw_labels(self, pos: Mapping[int, _Point]) -> None: fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) + def draw_node_labels(self, ax: Axes, pos: Mapping[int, _Point]) -> None: + """Draw node labels onto a given axes object. + + This is an axis-aware counterpart of :meth:`_draw_labels` intended for + use in contexts where the caller manages the :class:`~matplotlib.axes.Axes` + directly (e.g. the interactive visualizer). + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + """ + fontsize = 12 + if max(self.og.graph.nodes(), default=0) >= 100: + fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) + for node, (x, y) in pos.items(): + ax.text(x, y, str(node), ha="center", va="center", fontsize=fontsize, zorder=3) + + def draw_edges(self, ax: Axes, pos: Mapping[int, _Point]) -> None: + """Draw graph edges as plain lines onto a given axes object. + + This axis-aware method is intended for use in contexts where the caller + manages the :class:`~matplotlib.axes.Axes` directly (e.g. the + interactive visualizer). Only the edges present in :attr:`og.graph` + are drawn; no flow arrows are rendered. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + """ + for u, v in self.og.graph.edges(): + if u in pos and v in pos: + x1, y1 = pos[u] + x2, y2 = pos[v] + ax.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1) + + def draw_nodes_role( + self, + ax: Axes, + pos: Mapping[int, _Point], + show_pauli_measurement: bool = False, + ) -> None: + """Draw nodes onto a given axes object, coloured by their role. + + This is an axis-aware counterpart of the private ``__draw_nodes_role`` + method, intended for use in contexts where the caller manages the + :class:`~matplotlib.axes.Axes` directly (e.g. the interactive + visualizer). Nodes are styled as follows: + + * Input nodes: red border, white fill. + * Output nodes: black border, light-gray fill. + * Pauli-measured nodes (when *show_pauli_measurement* is ``True``): + black border, light-blue fill. + * All other nodes: black border, white fill. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + show_pauli_measurement : bool, optional + If ``True``, nodes with Pauli measurement angles are coloured + light blue. Defaults to ``False``. + """ + for node in self.og.graph.nodes(): + if node not in pos: + continue + edgecolor = "black" + facecolor = "white" + if node in self.og.input_nodes: + edgecolor = "red" + if node in self.og.output_nodes: + facecolor = "lightgray" + elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): + facecolor = "lightblue" + ax.scatter(*pos[node], edgecolors=edgecolor, facecolors=facecolor, s=350, zorder=2, linewidths=1.5) + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: """ Draw the nodes with different colors based on their role (input, output, or other). diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index aa790d46..38c0bf8d 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -9,7 +9,6 @@ import matplotlib.pyplot as plt import networkx as nx import numpy as np -from matplotlib.patches import Circle from matplotlib.text import Text from matplotlib.widgets import Button, Slider @@ -59,37 +58,42 @@ def __init__( self.node_distance = node_distance self.enable_simulation = enable_simulation - # Prepare graph layout using Graphix's visualizer or fallbacks + # Prepare graph layout reusing GraphVisualizer self._prepare_layout() - # Figure setup - self.fig = plt.figure(figsize=(15, 8)) + # Figure setup - tighter layout to reduce whitespace + self.fig = plt.figure(figsize=(14, 7)) - # Grid layout: Command list on left, Graph on right - # Layout optimized to prevent overlap: - # Commands: Left 2% to 30% - # Graph: Left 40% to 98% - self.ax_commands = self.fig.add_axes((0.02, 0.2, 0.28, 0.7)) # [left, bottom, width, height] - self.ax_graph = self.fig.add_axes((0.4, 0.2, 0.58, 0.7)) - self.ax_slider = self.fig.add_axes((0.4, 0.05, 0.5, 0.03)) - self.ax_prev = self.fig.add_axes((0.3, 0.05, 0.04, 0.04)) - self.ax_next = self.fig.add_axes((0.92, 0.05, 0.04, 0.04)) + # Grid layout: command list (~28%), graph (~67%), bottom strip for controls + self.ax_commands = self.fig.add_axes((0.02, 0.15, 0.27, 0.80)) + self.ax_graph = self.fig.add_axes((0.32, 0.15, 0.65, 0.80)) + self.ax_prev = self.fig.add_axes((0.30, 0.04, 0.03, 0.03)) + self.ax_slider = self.fig.add_axes((0.34, 0.04, 0.55, 0.03)) + self.ax_next = self.fig.add_axes((0.90, 0.04, 0.03, 0.03)) # Turn off axes frame for command list and graph self.ax_commands.axis("off") - self.ax_graph.axis("off") # Start hidden to avoid "square" artifact + self.ax_graph.axis("off") # Interaction state self.current_step = 0 self.total_steps = len(pattern) - # Interaction state placeholders + # Widget placeholders self.slider: Slider | None = None self.btn_prev: Button | None = None self.btn_next: Button | None = None def _prepare_layout(self) -> None: - # Build full graph to determine positions + """Compute node positions by reusing :class:`GraphVisualizer` layout. + + Builds the full graph from the pattern commands, delegates layout + computation to :meth:`GraphVisualizer.get_layout`, and normalizes + the resulting positions to fit the interactive panel area. If the + flow-based layout is too narrow for comfortable display (e.g. a + deep Pauli-flow graph), a spring-layout fallback is used. + """ + # Build the full graph from all commands g: Any = nx.Graph() measurements: dict[int, Any] = {} for cmd in self.pattern: @@ -100,40 +104,64 @@ def _prepare_layout(self) -> None: elif cmd.kind == CommandKind.M: measurements[cmd.node] = cmd.measurement - # Use GraphVisualizer to determine positions based on flow/structure + # Delegate layout to GraphVisualizer (shares flow-detection logic) og = OpenGraph(g, self.pattern.input_nodes, self.pattern.output_nodes, measurements) - # Infer Pauli measurements to avoid warnings and improve flow detection og = og.infer_pauli_measurements() vis = GraphVisualizer(og) pos_mapping, _, _ = vis.get_layout() self.node_positions = dict(pos_mapping) + # Check if the layout is too narrow for the interactive panel x_coords = [p[0] for p in self.node_positions.values()] y_coords = [p[1] for p in self.node_positions.values()] if x_coords and y_coords: - width = max(x_coords) - min(x_coords) - if width < 2.0 and len(self.node_positions) > 5: - # Fallback to spring layout for better interactivity - # We recreate the graph for layout since `og.graph` might be modified + x_range = max(x_coords) - min(x_coords) + y_range = max(y_coords) - min(y_coords) + aspect = x_range / max(y_range, 1e-6) + if aspect < 0.3 and len(self.node_positions) > 5: + # Layout is too narrow (tall vertical strip) -- use spring layout pos_spring = nx.spring_layout(g, seed=42) - self.node_positions = {n: (p[0], p[1]) for n, p in pos_spring.items()} + self.node_positions = {n: (float(p[0]), float(p[1])) for n, p in pos_spring.items()} - # Apply scaling + # Apply user-provided scaling self.node_positions = { k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() } - # Determine fixed bounds for the graph to prevent autoscaling issues - all_x = [pos[0] for pos in self.node_positions.values()] - all_y = [pos[1] for pos in self.node_positions.values()] - margin = 0.5 - if all_x and all_y: - self.x_limits = (min(all_x) - margin, max(all_x) + margin) - self.y_limits = (min(all_y) - margin, max(all_y) + margin) - else: - self.x_limits = (-1, 1) - self.y_limits = (-1, 1) + # Normalize to [0, 1] range so positions fill the available axes area + # regardless of the data's original aspect ratio. + self._normalize_positions() + + # Store the visualizer for reuse in drawing helpers + self._graph_visualizer = vis + + def _normalize_positions(self) -> None: + """Normalize node positions into the ``[margin, 1-margin]`` range. + + This ensures the graph fills the interactive axes area uniformly, + avoiding the distortion caused by ``set_aspect("equal")`` when the + data's x/y ranges differ significantly. + """ + if not self.node_positions: + return + + xs = [p[0] for p in self.node_positions.values()] + ys = [p[1] for p in self.node_positions.values()] + + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + x_range = x_max - x_min if x_max != x_min else 1.0 + y_range = y_max - y_min if y_max != y_min else 1.0 + + margin = 0.08 + lo = margin + hi = 1.0 - margin + + self.node_positions = { + k: (lo + (v[0] - x_min) / x_range * (hi - lo), lo + (v[1] - y_min) / y_range * (hi - lo)) + for k, v in self.node_positions.items() + } def visualize(self) -> None: """Launch the interactive visualization window.""" @@ -207,8 +235,27 @@ def _draw_command_list(self) -> None: def _update_graph_state( self, step: int ) -> tuple[set[int], set[int], list[tuple[int, ...]], dict[int, set[str]], dict[int, int]]: - """Calculate the graph state by simulating the pattern up to `step`.""" - # Prepare return containers + """Calculate the graph state by simulating the pattern up to *step*. + + Parameters + ---------- + step : int + The command index up to which the pattern is executed. + + Returns + ------- + active_nodes : set[int] + Nodes that have been initialised but not yet measured. + measured_nodes : set[int] + Nodes that have been measured. + active_edges : list[tuple[int, ...]] + Edges currently present in the graph (both endpoints active). + corrections : dict[int, set[str]] + Accumulated byproduct corrections per node (``"X"`` and/or ``"Z"``). + results : dict[int, int] + Measurement outcomes keyed by node (only populated when + *enable_simulation* is ``True``). + """ active_nodes = set() measured_nodes = set() active_edges = [] @@ -216,17 +263,14 @@ def _update_graph_state( results: dict[int, int] = {} if self.enable_simulation: - # --- Simulation Mode --- backend = StatevectorBackend() # Prerun input nodes (standard MBQC initialization) - input_nodes = self.pattern.input_nodes - for node in input_nodes: + for node in self.pattern.input_nodes: backend.add_nodes([node]) rng = np.random.default_rng(42) # Fixed seed for determinism - # Re-execute commands up to current step for i in range(step): cmd = self.pattern[i] if cmd.kind == CommandKind.N: @@ -234,28 +278,20 @@ def _update_graph_state( elif cmd.kind == CommandKind.E: backend.entangle_nodes(cmd.nodes) elif cmd.kind == CommandKind.M: - # --- Adaptive Measurement Logic (Feedforward) --- - # Calculate s and t signals from previous measurement results + # Adaptive measurement (feedforward) s_signal = sum(results.get(j, 0) for j in cmd.s_domain) if cmd.s_domain else 0 t_signal = sum(results.get(j, 0) for j in cmd.t_domain) if cmd.t_domain else 0 - s_bool = s_signal % 2 == 1 - t_bool = t_signal % 2 == 1 - - # Compute the updated angle and plane based on signals clifford = Clifford.I - if s_bool: + if s_signal % 2 == 1: clifford = Clifford.X @ clifford - if t_bool: + if t_signal % 2 == 1: clifford = Clifford.Z @ clifford measurement = cmd.measurement.clifford(clifford) - - # Execute measurement on the backend using the adapted measurement result = backend.measure(cmd.node, measurement, rng=rng) results[cmd.node] = result elif cmd.kind == CommandKind.X: - # Accumulate X corrections if cmd.node not in corrections: corrections[cmd.node] = set() corrections[cmd.node].add("X") @@ -266,90 +302,92 @@ def _update_graph_state( corrections[cmd.node].add("Z") backend.correct_byproduct(cmd) - # --- Common Logic (Topological Tracking) --- - # We track nodes/edges based on command history regardless of simulation - # This ensures visualization works even if simulation is disabled - - # Reset tracking - current_active_nodes = set(self.pattern.input_nodes) # Start with input nodes - current_edges = set() - current_measured_nodes = set() # Track measured nodes for topological view + # ---- Topological tracking (independent of simulation) ---- + current_active: set[int] = set(self.pattern.input_nodes) + current_edges: set[tuple[int, ...]] = set() + current_measured: set[int] = set() for i in range(step): cmd = self.pattern[i] if cmd.kind == CommandKind.N: - current_active_nodes.add(cmd.node) + current_active.add(cmd.node) elif cmd.kind == CommandKind.E: u, v = cmd.nodes - # Only add edge if both nodes are currently active (not yet measured) - if u in current_active_nodes and v in current_active_nodes: + if u in current_active and v in current_active: current_edges.add(tuple(sorted((u, v)))) - elif cmd.kind == CommandKind.M and cmd.node in current_active_nodes: - current_active_nodes.remove(cmd.node) - current_measured_nodes.add(cmd.node) - # Remove connected edges involving the measured node + elif cmd.kind == CommandKind.M and cmd.node in current_active: + current_active.remove(cmd.node) + current_measured.add(cmd.node) current_edges = {e for e in current_edges if cmd.node not in e} - # Corrections are visualization-only metadata, handled in simulation block or ignored - - active_nodes = current_active_nodes - measured_nodes = current_measured_nodes + active_nodes = current_active + measured_nodes = current_measured active_edges = list(current_edges) return active_nodes, measured_nodes, active_edges, corrections, results def _draw_graph(self) -> None: + """Draw nodes and edges onto the graph axes. + + Uses :meth:`GraphVisualizer.draw_edges` for edge rendering (shared + with the static visualizer) and draws nodes with interactive-specific + colouring: grey for measured, red border for active. + """ try: self.ax_graph.clear() - # Get current state from simulation active_nodes, measured_nodes, active_edges, corrections, results = self._update_graph_state( self.current_step ) - # Draw edges + # ---- Edges (reuse GraphVisualizer helper if possible) ---- for u, v in active_edges: - x1, y1 = self.node_positions[u] - x2, y2 = self.node_positions[v] - self.ax_graph.plot([x1, x2], [y1, y2], color="black", zorder=1) + if u in self.node_positions and v in self.node_positions: + x1, y1 = self.node_positions[u] + x2, y2 = self.node_positions[v] + self.ax_graph.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1) + + # Adaptive font-size: shrink labels when node numbers are large + fontsize = 10 + max_node = max( + (n for ns in (active_nodes, measured_nodes) for n in ns), + default=0, + ) + if max_node >= 100: + fontsize = max(7, int(fontsize * 2 / len(str(max_node)))) - # Draw nodes - # 1. Measured nodes (grey, with result text) + # ---- Measured nodes (grey fill, black border) ---- for node in measured_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = Circle((x, y), 0.1, color="lightgray", zorder=2) - self.ax_graph.add_patch(circle) + if node not in self.node_positions: + continue + x, y = self.node_positions[node] + self.ax_graph.scatter(x, y, edgecolors="black", facecolors="lightgray", s=350, zorder=2, linewidths=1.5) - label_text = str(node) - # Show measurement outcome if available - if node in results: - label_text += f"\n={results[node]}" + label_text = str(node) + if node in results: + label_text += f"\nm={results[node]}" - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, zorder=3) + self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=fontsize, zorder=3) - # 2. Active nodes (white with colored edge, with correction text) + # ---- Active nodes (white fill, red border) ---- for node in active_nodes: - if node in self.node_positions: - x, y = self.node_positions[node] - circle = Circle((x, y), 0.1, edgecolor="red", facecolor="white", linewidth=1.5, zorder=2) - self.ax_graph.add_patch(circle) - - label_text = str(node) - # Show accumulated internal corrections - if node in corrections: - label_text += "\n" + "".join(sorted(corrections[node])) - - color = "black" - if node in corrections: - color = "blue" # Highlight corrected nodes - - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=9, color=color, zorder=3) - - # Set aspect close to equal and hide axes - self.ax_graph.set_aspect("equal") - self.ax_graph.set_xlim(self.x_limits) - self.ax_graph.set_ylim(self.y_limits) + if node not in self.node_positions: + continue + x, y = self.node_positions[node] + self.ax_graph.scatter(x, y, edgecolors="red", facecolors="white", s=350, zorder=2, linewidths=1.5) + + label_text = str(node) + if node in corrections: + label_text += "\n" + "".join(sorted(corrections[node])) + + text_color = "blue" if node in corrections else "black" + self.ax_graph.text( + x, y, label_text, ha="center", va="center", fontsize=fontsize, color=text_color, zorder=3 + ) + + # Axis limits use normalized [0, 1] positions - no set_aspect("equal") + self.ax_graph.set_xlim(-0.02, 1.02) + self.ax_graph.set_ylim(-0.02, 1.02) self.ax_graph.axis("off") except Exception as e: # noqa: BLE001 diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py index a2418c89..343e224f 100644 --- a/tests/test_visualization_interactive.py +++ b/tests/test_visualization_interactive.py @@ -73,10 +73,12 @@ def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: viz = InteractiveGraphVisualizer(pattern) - # Keys should match + # Keys should match the layout output assert viz.node_positions.keys() == expected_pos.keys() - # Values should be scaled by default node_distance (1, 1) - assert viz.node_positions[0] == (10, 10) + # Positions are normalized to [0, 1] range after layout + for x, y in viz.node_positions.values(): + assert 0.0 <= x <= 1.0 + assert 0.0 <= y <= 1.0 def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: MagicMock) -> None: """Test graph state update with simulation enabled.""" @@ -121,9 +123,8 @@ def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: M viz._update(len(pattern)) # Check that drawing methods were called - # Measured nodes (0, 1) should generally be lightgray - # Active node (2) should be white/blue - assert viz.ax_graph.add_patch.call_count > 0 + # Measured nodes (0, 1) and active node (2) are drawn with scatter + assert viz.ax_graph.scatter.call_count > 0 assert viz.ax_graph.text.call_count > 0 def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: @@ -161,6 +162,41 @@ def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: assert viz.ax_commands.text.call_count > 0 assert viz.ax_graph.text.call_count > 0 + def test_measurement_result_label_format(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that measurement result labels use the 'm=' prefix to avoid ambiguity.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) + + backend_instance = mock_backend.return_value + backend_instance.measure.return_value = 1 + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=True) + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + + # Execute all commands so that nodes 0 and 1 are measured + viz._update(len(pattern)) + + # Collect all text calls on ax_graph + text_calls = viz.ax_graph.text.call_args_list + label_strings = [str(call.args[2]) if len(call.args) >= 3 else "" for call in text_calls] + + # At least one label should contain 'm=' (the measurement result prefix) + assert any("m=" in label for label in label_strings), ( + f"Expected 'm=' in at least one node label, got: {label_strings}" + ) + # None of the labels should use the old ambiguous '\n=' format + assert not any(label.endswith(("\n=1", "\n=0")) for label in label_strings), ( + f"Found ambiguous '=' label format in: {label_strings}" + ) + def test_navigation(self, pattern: Pattern, mocker: MagicMock) -> None: """Test step navigation methods.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") @@ -292,11 +328,12 @@ def test_layout_fallback_triggered(self, mocker: MagicMock) -> None: mock_vis_obj = MagicMock() mock_visualizer.return_value = mock_vis_obj + # Narrow layout: x=0 for all nodes, y varies -- aspect ratio < 0.3 initial_pos = {i: (0, i) for i in range(6)} mock_vis_obj.get_layout.return_value = (initial_pos, {}, {}) - # Spring layout returns something else - spring_pos = {i: (10, i) for i in range(6)} + # Spring layout returns different positions + spring_pos = {i: (float(10 + i), float(i)) for i in range(6)} mock_spring_layout.return_value = spring_pos big_pattern = Pattern(input_nodes=list(range(6))) @@ -305,13 +342,13 @@ def test_layout_fallback_triggered(self, mocker: MagicMock) -> None: viz = InteractiveGraphVisualizer(big_pattern) - # Check if spring layout was called + # Spring layout should have been called as fallback mock_spring_layout.assert_called_once() - # Check if positions were updated to spring layout positions - # Note: InteractiveGraphVisualizer applies scaling after layout - # default node_distance is (1, 1), so positions should match spring_pos - assert viz.node_positions == spring_pos + # Positions should be normalized to [0, 1] range + for x, y in viz.node_positions.values(): + assert 0.0 <= x <= 1.0 + assert 0.0 <= y <= 1.0 def test_draw_edges_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: """Test that edge drawing logic is executed (covers lines 312-314).""" From cddbb2a97a18593e8e7b35698a23b9d63dd3c5dd Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Wed, 18 Feb 2026 06:09:33 -0600 Subject: [PATCH 11/12] update --- graphix/visualization.py | 69 +++++++-- graphix/visualization_interactive.py | 188 ++++++++++++++---------- tests/test_visualization_interactive.py | 150 ++++++++++++++----- 3 files changed, 276 insertions(+), 131 deletions(-) diff --git a/graphix/visualization.py b/graphix/visualization.py index 9c90781a..27366fd7 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -278,9 +278,7 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]: return new_path def _draw_labels(self, pos: Mapping[int, _Point]) -> None: - fontsize = 12 - if max(self.og.graph.nodes(), default=0) >= 100: - fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) def draw_node_labels(self, ax: Axes, pos: Mapping[int, _Point]) -> None: @@ -297,19 +295,45 @@ def draw_node_labels(self, ax: Axes, pos: Mapping[int, _Point]) -> None: pos : Mapping[int, tuple[float, float]] Dictionary mapping each node to its ``(x, y)`` position. """ - fontsize = 12 - if max(self.og.graph.nodes(), default=0) >= 100: - fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) for node, (x, y) in pos.items(): ax.text(x, y, str(node), ha="center", va="center", fontsize=fontsize, zorder=3) - def draw_edges(self, ax: Axes, pos: Mapping[int, _Point]) -> None: + @staticmethod + def get_label_fontsize(max_node: int, base_size: int = 12) -> int: + """Compute the font size for node labels. + + When the largest node number has many digits the font is reduced + so that labels still fit inside the scatter markers. + + Parameters + ---------- + max_node : int + The largest node number in the graph. + base_size : int, optional + The default font size used for small node numbers. + Defaults to ``12``. + + Returns + ------- + int + The computed font size, never smaller than ``7``. + """ + if max_node >= 100: + return max(7, int(base_size * 2 / len(str(max_node)))) + return base_size + + def draw_edges( + self, + ax: Axes, + pos: Mapping[int, _Point], + edge_subset: Iterable[tuple[int, ...]] | None = None, + ) -> None: """Draw graph edges as plain lines onto a given axes object. This axis-aware method is intended for use in contexts where the caller manages the :class:`~matplotlib.axes.Axes` directly (e.g. the - interactive visualizer). Only the edges present in :attr:`og.graph` - are drawn; no flow arrows are rendered. + interactive visualizer). Parameters ---------- @@ -317,8 +341,12 @@ def draw_edges(self, ax: Axes, pos: Mapping[int, _Point]) -> None: The matplotlib axes to draw onto. pos : Mapping[int, tuple[float, float]] Dictionary mapping each node to its ``(x, y)`` position. + edge_subset : Iterable[tuple[int, int]] or None, optional + If provided, only these edges are drawn. When ``None`` + (the default), all edges in :attr:`og.graph` are drawn. """ - for u, v in self.og.graph.edges(): + edges: Iterable[tuple[int, ...]] = self.og.graph.edges() if edge_subset is None else edge_subset + for u, v in edges: if u in pos and v in pos: x1, y1 = pos[u] x2, y2 = pos[v] @@ -329,6 +357,9 @@ def draw_nodes_role( ax: Axes, pos: Mapping[int, _Point], show_pauli_measurement: bool = False, + node_facecolors: Mapping[int, str] | None = None, + node_edgecolors: Mapping[int, str] | None = None, + node_size: int = 350, ) -> None: """Draw nodes onto a given axes object, coloured by their role. @@ -343,6 +374,9 @@ def draw_nodes_role( black border, light-blue fill. * All other nodes: black border, white fill. + When *node_facecolors* or *node_edgecolors* are provided, their values + override the role-based defaults for the corresponding nodes. + Parameters ---------- ax : Axes @@ -352,6 +386,14 @@ def draw_nodes_role( show_pauli_measurement : bool, optional If ``True``, nodes with Pauli measurement angles are coloured light blue. Defaults to ``False``. + node_facecolors : Mapping[int, str] or None, optional + Per-node fill colour overrides. When a node appears in this + mapping its value is used instead of the role-based default. + node_edgecolors : Mapping[int, str] or None, optional + Per-node border colour overrides. + node_size : int, optional + Marker size for :meth:`~matplotlib.axes.Axes.scatter`. + Defaults to ``350``. """ for node in self.og.graph.nodes(): if node not in pos: @@ -364,7 +406,12 @@ def draw_nodes_role( facecolor = "lightgray" elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): facecolor = "lightblue" - ax.scatter(*pos[node], edgecolors=edgecolor, facecolors=facecolor, s=350, zorder=2, linewidths=1.5) + # Apply per-node overrides if provided + if node_facecolors is not None and node in node_facecolors: + facecolor = node_facecolors[node] + if node_edgecolors is not None and node in node_edgecolors: + edgecolor = node_edgecolors[node] + ax.scatter(*pos[node], edgecolors=edgecolor, facecolors=facecolor, s=node_size, zorder=2, linewidths=1.5) def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: """ diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index 38c0bf8d..7c29d0bc 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt -import networkx as nx import numpy as np from matplotlib.text import Text from matplotlib.widgets import Button, Slider @@ -34,6 +33,18 @@ class InteractiveGraphVisualizer: Scale factors (x, y) for the node positions. enable_simulation : bool If True, simulates the state vector and measurement outcomes. + marker_fill_ratio : float + Fraction of the inter-node spacing used by each marker (0-1). + label_size_ratio : float + Label font size as a fraction of the marker diameter. + max_label_fontsize : int + Upper bound for label font size in points. + min_inches_per_node : float + Minimum vertical inches per node for adaptive figure height. + active_node_color : str + Border colour for active (current-step) nodes. + measured_node_color : str + Fill colour for already-measured nodes. """ def __init__( @@ -41,6 +52,13 @@ def __init__( pattern: Pattern, node_distance: tuple[float, float] = (1, 1), enable_simulation: bool = True, + *, + marker_fill_ratio: float = 0.80, + label_size_ratio: float = 0.55, + max_label_fontsize: int = 12, + min_inches_per_node: float = 0.3, + active_node_color: str = "#2060cc", + measured_node_color: str = "lightgray", ) -> None: """Construct an interactive visualizer. @@ -52,21 +70,68 @@ def __init__( Scale factors (x, y) for node positions. Defaults to (1, 1). enable_simulation : bool, optional If True, enables state vector simulation. Defaults to True. + marker_fill_ratio : float, optional + Fraction of the inter-node spacing used by marker diameter. + Defaults to 0.80. + label_size_ratio : float, optional + Label font size as a fraction of the marker diameter in points. + Defaults to 0.55. + max_label_fontsize : int, optional + Upper bound for label font size. Prevents text from overflowing + the marker in sparse graphs. Defaults to 12. + min_inches_per_node : float, optional + Minimum vertical inches allocated per node when computing the + adaptive figure height. Defaults to 0.3. + active_node_color : str, optional + Border colour for active nodes. Defaults to ``"#2060cc"``. + measured_node_color : str, optional + Fill colour for measured nodes. Defaults to ``"lightgray"``. """ self.pattern = pattern self.node_positions: dict[int, tuple[float, float]] = {} self.node_distance = node_distance self.enable_simulation = enable_simulation + self.marker_fill_ratio = marker_fill_ratio + self.label_size_ratio = label_size_ratio + self.max_label_fontsize = max_label_fontsize + self.min_inches_per_node = min_inches_per_node + self.active_node_color = active_node_color + self.measured_node_color = measured_node_color # Prepare graph layout reusing GraphVisualizer self._prepare_layout() - # Figure setup - tighter layout to reduce whitespace - self.fig = plt.figure(figsize=(14, 7)) + # Figure height adapts to graph density so circles and labels remain + # readable even for dense layouts like QAOA. + ax_h_frac = 0.80 # height fraction of ax_graph in figure + min_fig_height = 7 + if self.node_positions: + ys = [p[1] for p in self.node_positions.values()] + y_data_span = max(ys) - min(ys) + 1 + needed_height = y_data_span * self.min_inches_per_node / ax_h_frac + fig_height = max(min_fig_height, needed_height) + else: + fig_height = min_fig_height + y_data_span = 1 + + # Compute node marker size and label font size ONCE from the known + # figure geometry. This avoids instability when the user resizes the + # window, because the values are fixed at construction time. + ax_height_inches = fig_height * ax_h_frac + y_margin = y_data_span * 0.08 + 0.5 # mirrors _draw_graph margin + y_range = y_data_span + 2 * y_margin + points_per_unit = (ax_height_inches / y_range) * 72 # 72 pt/inch + marker_diameter = self.marker_fill_ratio * points_per_unit + self.node_size: int = max(30, int(marker_diameter**2)) + self.label_fontsize: int = min( + self.max_label_fontsize, max(6, int(marker_diameter * self.label_size_ratio)) + ) + + self.fig = plt.figure(figsize=(14, fig_height)) # Grid layout: command list (~28%), graph (~67%), bottom strip for controls - self.ax_commands = self.fig.add_axes((0.02, 0.15, 0.27, 0.80)) - self.ax_graph = self.fig.add_axes((0.32, 0.15, 0.65, 0.80)) + self.ax_commands = self.fig.add_axes((0.02, 0.15, 0.27, ax_h_frac)) + self.ax_graph = self.fig.add_axes((0.32, 0.15, 0.65, ax_h_frac)) self.ax_prev = self.fig.add_axes((0.30, 0.04, 0.03, 0.03)) self.ax_slider = self.fig.add_axes((0.34, 0.04, 0.55, 0.03)) self.ax_next = self.fig.add_axes((0.90, 0.04, 0.03, 0.03)) @@ -89,12 +154,11 @@ def _prepare_layout(self) -> None: Builds the full graph from the pattern commands, delegates layout computation to :meth:`GraphVisualizer.get_layout`, and normalizes - the resulting positions to fit the interactive panel area. If the - flow-based layout is too narrow for comfortable display (e.g. a - deep Pauli-flow graph), a spring-layout fallback is used. + the resulting positions to fit the interactive panel area. + The flow-based layout is always preserved. """ # Build the full graph from all commands - g: Any = nx.Graph() + g: Any = __import__("networkx").Graph() measurements: dict[int, Any] = {} for cmd in self.pattern: if cmd.kind == CommandKind.N: @@ -112,57 +176,13 @@ def _prepare_layout(self) -> None: pos_mapping, _, _ = vis.get_layout() self.node_positions = dict(pos_mapping) - # Check if the layout is too narrow for the interactive panel - x_coords = [p[0] for p in self.node_positions.values()] - y_coords = [p[1] for p in self.node_positions.values()] - if x_coords and y_coords: - x_range = max(x_coords) - min(x_coords) - y_range = max(y_coords) - min(y_coords) - aspect = x_range / max(y_range, 1e-6) - if aspect < 0.3 and len(self.node_positions) > 5: - # Layout is too narrow (tall vertical strip) -- use spring layout - pos_spring = nx.spring_layout(g, seed=42) - self.node_positions = {n: (float(p[0]), float(p[1])) for n, p in pos_spring.items()} - # Apply user-provided scaling self.node_positions = { k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() } - - # Normalize to [0, 1] range so positions fill the available axes area - # regardless of the data's original aspect ratio. - self._normalize_positions() - # Store the visualizer for reuse in drawing helpers self._graph_visualizer = vis - def _normalize_positions(self) -> None: - """Normalize node positions into the ``[margin, 1-margin]`` range. - - This ensures the graph fills the interactive axes area uniformly, - avoiding the distortion caused by ``set_aspect("equal")`` when the - data's x/y ranges differ significantly. - """ - if not self.node_positions: - return - - xs = [p[0] for p in self.node_positions.values()] - ys = [p[1] for p in self.node_positions.values()] - - x_min, x_max = min(xs), max(xs) - y_min, y_max = min(ys), max(ys) - x_range = x_max - x_min if x_max != x_min else 1.0 - y_range = y_max - y_min if y_max != y_min else 1.0 - - margin = 0.08 - lo = margin - hi = 1.0 - margin - - self.node_positions = { - k: (lo + (v[0] - x_min) / x_range * (hi - lo), lo + (v[1] - y_min) / y_range * (hi - lo)) - for k, v in self.node_positions.items() - } - def visualize(self) -> None: """Launch the interactive visualization window.""" # Initial draw @@ -329,9 +349,10 @@ def _update_graph_state( def _draw_graph(self) -> None: """Draw nodes and edges onto the graph axes. - Uses :meth:`GraphVisualizer.draw_edges` for edge rendering (shared - with the static visualizer) and draws nodes with interactive-specific - colouring: grey for measured, red border for active. + Delegates to :class:`GraphVisualizer` for edge and node rendering, + passing per-node colour overrides to distinguish measured (grey) from + active (blue border) nodes. Labels are drawn locally because they + include dynamic content (measurement results, corrections). """ try: self.ax_graph.clear() @@ -340,54 +361,59 @@ def _draw_graph(self) -> None: self.current_step ) - # ---- Edges (reuse GraphVisualizer helper if possible) ---- - for u, v in active_edges: - if u in self.node_positions and v in self.node_positions: - x1, y1 = self.node_positions[u] - x2, y2 = self.node_positions[v] - self.ax_graph.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1) - - # Adaptive font-size: shrink labels when node numbers are large - fontsize = 10 - max_node = max( - (n for ns in (active_nodes, measured_nodes) for n in ns), - default=0, + # ---- Edges (delegate to GraphVisualizer) ---- + self._graph_visualizer.draw_edges(self.ax_graph, self.node_positions, edge_subset=active_edges) + + # ---- Axis limits (set before drawing nodes so geometry is known) ---- + xs = [p[0] for p in self.node_positions.values()] + ys = [p[1] for p in self.node_positions.values()] + x_margin = (max(xs) - min(xs)) * 0.08 + 0.5 + y_margin = (max(ys) - min(ys)) * 0.08 + 0.5 + self.ax_graph.set_xlim(min(xs) - x_margin, max(xs) + x_margin) + self.ax_graph.set_ylim(min(ys) - y_margin, max(ys) + y_margin) + + # ---- Nodes (delegate to GraphVisualizer with colour overrides) ---- + node_facecolors: dict[int, str] = {} + node_edgecolors: dict[int, str] = {} + for node in measured_nodes: + node_facecolors[node] = self.measured_node_color + node_edgecolors[node] = "black" + for node in active_nodes: + node_facecolors[node] = "white" + node_edgecolors[node] = self.active_node_color + + self._graph_visualizer.draw_nodes_role( + self.ax_graph, + self.node_positions, + node_facecolors=node_facecolors, + node_edgecolors=node_edgecolors, + node_size=self.node_size, ) - if max_node >= 100: - fontsize = max(7, int(fontsize * 2 / len(str(max_node)))) - # ---- Measured nodes (grey fill, black border) ---- + # ---- Labels (drawn locally for dynamic content) ---- + fontsize = self.label_fontsize + for node in measured_nodes: if node not in self.node_positions: continue x, y = self.node_positions[node] - self.ax_graph.scatter(x, y, edgecolors="black", facecolors="lightgray", s=350, zorder=2, linewidths=1.5) - label_text = str(node) if node in results: label_text += f"\nm={results[node]}" - self.ax_graph.text(x, y, label_text, ha="center", va="center", fontsize=fontsize, zorder=3) - # ---- Active nodes (white fill, red border) ---- for node in active_nodes: if node not in self.node_positions: continue x, y = self.node_positions[node] - self.ax_graph.scatter(x, y, edgecolors="red", facecolors="white", s=350, zorder=2, linewidths=1.5) - label_text = str(node) if node in corrections: label_text += "\n" + "".join(sorted(corrections[node])) - text_color = "blue" if node in corrections else "black" self.ax_graph.text( x, y, label_text, ha="center", va="center", fontsize=fontsize, color=text_color, zorder=3 ) - # Axis limits use normalized [0, 1] positions - no set_aspect("equal") - self.ax_graph.set_xlim(-0.02, 1.02) - self.ax_graph.set_ylim(-0.02, 1.02) self.ax_graph.axis("off") except Exception as e: # noqa: BLE001 diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py index 343e224f..53ed8f63 100644 --- a/tests/test_visualization_interactive.py +++ b/tests/test_visualization_interactive.py @@ -10,6 +10,7 @@ from graphix.command import E, M, N, X, Z from graphix.measurements import Measurement from graphix.pattern import Pattern +from graphix.visualization import GraphVisualizer from graphix.visualization_interactive import InteractiveGraphVisualizer @@ -75,10 +76,11 @@ def test_layout_generation(self, pattern: Pattern, mocker: MagicMock) -> None: # Keys should match the layout output assert viz.node_positions.keys() == expected_pos.keys() - # Positions are normalized to [0, 1] range after layout - for x, y in viz.node_positions.values(): - assert 0.0 <= x <= 1.0 - assert 0.0 <= y <= 1.0 + # Positions are the raw layout scaled by node_distance (default 1, 1) + for node, (ex, ey) in expected_pos.items(): + ax, ay = viz.node_positions[node] + assert ax == pytest.approx(ex * viz.node_distance[0]) + assert ay == pytest.approx(ey * viz.node_distance[1]) def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: MagicMock) -> None: """Test graph state update with simulation enabled.""" @@ -122,9 +124,7 @@ def test_update_graph_state_simulation_enabled(self, pattern: Pattern, mocker: M viz.slider.val = len(pattern) viz._update(len(pattern)) - # Check that drawing methods were called - # Measured nodes (0, 1) and active node (2) are drawn with scatter - assert viz.ax_graph.scatter.call_count > 0 + # Labels are drawn locally: check that text was called assert viz.ax_graph.text.call_count > 0 def test_update_graph_state_simulation_disabled(self, pattern: Pattern, mocker: MagicMock) -> None: @@ -318,63 +318,53 @@ def test_z_correction_initialization(self, mocker: MagicMock) -> None: viz._on_pick(pick_event) viz.slider.set_val.assert_called_with(6) - def test_layout_fallback_triggered(self, mocker: MagicMock) -> None: - """Test that spring layout fallback is triggered for narrow graphs.""" + def test_draw_edges_delegates(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that _draw_graph delegates edge drawing to GraphVisualizer.draw_edges.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") - mock_spring_layout = mocker.patch("graphix.visualization_interactive.nx.spring_layout") mock_vis_obj = MagicMock() mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) - # Narrow layout: x=0 for all nodes, y varies -- aspect ratio < 0.3 - initial_pos = {i: (0, i) for i in range(6)} - mock_vis_obj.get_layout.return_value = (initial_pos, {}, {}) - - # Spring layout returns different positions - spring_pos = {i: (float(10 + i), float(i)) for i in range(6)} - mock_spring_layout.return_value = spring_pos - - big_pattern = Pattern(input_nodes=list(range(6))) - for i in range(6): - big_pattern.add(N(node=i)) - - viz = InteractiveGraphVisualizer(big_pattern) + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + viz.slider = MagicMock() - # Spring layout should have been called as fallback - mock_spring_layout.assert_called_once() + # Step 5: entanglement E(0, 1) and E(1, 2), no measurements yet + viz._update(5) - # Positions should be normalized to [0, 1] range - for x, y in viz.node_positions.values(): - assert 0.0 <= x <= 1.0 - assert 0.0 <= y <= 1.0 + # draw_edges should have been called with edge_subset + mock_vis_obj.draw_edges.assert_called() + call_kwargs = mock_vis_obj.draw_edges.call_args + assert "edge_subset" in call_kwargs.kwargs - def test_draw_edges_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: - """Test that edge drawing logic is executed (covers lines 312-314).""" + def test_draw_nodes_delegates(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that _draw_graph delegates node drawing to GraphVisualizer.draw_nodes_role.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") mock_vis_obj = MagicMock() mock_visualizer.return_value = mock_vis_obj - # Provide positions for nodes 0, 1, 2 mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, {}, {}) viz = InteractiveGraphVisualizer(pattern) viz.ax_graph = MagicMock() viz.slider = MagicMock() - # Step 5 in our fixture pattern has entanglement E(0, 1) and E(1, 2) - # and no nodes have been measured yet. - viz._update(5) + # Step after all N + E commands (5 commands) + measurements + viz._update(len(pattern)) - # Verify that plot (used for edges) was called - # There should be 2 edges: (0, 1) and (1, 2) - assert viz.ax_graph.plot.call_count == 2 + # draw_nodes_role should have been called with colour overrides + mock_vis_obj.draw_nodes_role.assert_called() + call_kwargs = mock_vis_obj.draw_nodes_role.call_args + assert "node_facecolors" in call_kwargs.kwargs + assert "node_edgecolors" in call_kwargs.kwargs def test_draw_graph_exception_coverage(self, pattern: Pattern, mocker: MagicMock) -> None: - """Test the exception handling in _draw_graph (covers lines 355-357).""" + """Test the exception handling in _draw_graph.""" mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") mocker.patch("graphix.visualization_interactive.OpenGraph") mocker.patch("matplotlib.pyplot.figure") @@ -393,3 +383,85 @@ def test_draw_graph_exception_coverage(self, pattern: Pattern, mocker: MagicMock # This should not raise but log/print viz._draw_graph() + + +class TestGraphVisualizerSharedAPI: + """Tests for the shared drawing API exposed by GraphVisualizer.""" + + def test_get_label_fontsize_small_nodes(self) -> None: + """Font size should equal base_size for small node numbers.""" + assert GraphVisualizer.get_label_fontsize(0) == 12 + assert GraphVisualizer.get_label_fontsize(99) == 12 + + def test_get_label_fontsize_large_nodes(self) -> None: + """Font size should shrink for large node numbers.""" + result = GraphVisualizer.get_label_fontsize(100) + assert result < 12 + assert result >= 7 + + def test_get_label_fontsize_custom_base(self) -> None: + """Font size should use the custom base_size.""" + assert GraphVisualizer.get_label_fontsize(0, base_size=10) == 10 + result = GraphVisualizer.get_label_fontsize(1000, base_size=10) + assert result >= 7 + assert result < 10 + + def test_draw_nodes_role_with_overrides(self) -> None: + """Test draw_nodes_role applies per-node colour overrides.""" + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0, 1, 2] + mock_og.input_nodes = [0] + mock_og.output_nodes = [2] + mock_og.measurements = {0: MagicMock(), 1: MagicMock(), 2: MagicMock()} + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0)} + + vis.draw_nodes_role( + ax, + pos, + node_facecolors={0: "yellow", 1: "pink"}, + node_edgecolors={0: "green"}, + ) + + assert ax.scatter.call_count == 3 + # Check overrides were applied by inspecting scatter kwargs + scatter_calls = ax.scatter.call_args_list + # Node 0: facecolors=yellow (override), edgecolors=green (override) + assert scatter_calls[0].kwargs["facecolors"] == "yellow" + assert scatter_calls[0].kwargs["edgecolors"] == "green" + # Node 1: facecolors=pink (override), edgecolors=black (default) + assert scatter_calls[1].kwargs["facecolors"] == "pink" + assert scatter_calls[1].kwargs["edgecolors"] == "black" + # Node 2: facecolors=lightgray (output role), edgecolors=black (default) + assert scatter_calls[2].kwargs["facecolors"] == "lightgray" + assert scatter_calls[2].kwargs["edgecolors"] == "black" + + def test_draw_edges_with_subset(self) -> None: + """Test draw_edges with edge_subset only draws specified edges.""" + mock_og = MagicMock() + mock_og.graph.edges.return_value = [(0, 1), (1, 2), (2, 3)] + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0), 3: (3.0, 0.0)} + + # Draw only a subset + vis.draw_edges(ax, pos, edge_subset=[(0, 1), (2, 3)]) + assert ax.plot.call_count == 2 + + def test_draw_edges_without_subset(self) -> None: + """Test draw_edges without edge_subset draws all edges.""" + mock_og = MagicMock() + mock_og.graph.edges.return_value = [(0, 1), (1, 2), (2, 3)] + + vis = GraphVisualizer(og=mock_og) + + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0), 3: (3.0, 0.0)} + + vis.draw_edges(ax, pos) + assert ax.plot.call_count == 3 From 4152dcb4e5cc4fa8e3b36c3a79f6a558d81b3213 Mon Sep 17 00:00:00 2001 From: Kitsunp Date: Wed, 18 Feb 2026 06:19:23 -0600 Subject: [PATCH 12/12] update --- graphix/visualization_interactive.py | 55 ++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py index 7c29d0bc..e0d6beac 100644 --- a/graphix/visualization_interactive.py +++ b/graphix/visualization_interactive.py @@ -123,17 +123,16 @@ def __init__( points_per_unit = (ax_height_inches / y_range) * 72 # 72 pt/inch marker_diameter = self.marker_fill_ratio * points_per_unit self.node_size: int = max(30, int(marker_diameter**2)) - self.label_fontsize: int = min( - self.max_label_fontsize, max(6, int(marker_diameter * self.label_size_ratio)) - ) + self.label_fontsize: int = min(self.max_label_fontsize, max(6, int(marker_diameter * self.label_size_ratio))) self.fig = plt.figure(figsize=(14, fig_height)) - # Grid layout: command list (~28%), graph (~67%), bottom strip for controls + # Grid layout: command list (~27%), graph (~65%), bottom strips for controls self.ax_commands = self.fig.add_axes((0.02, 0.15, 0.27, ax_h_frac)) + self.ax_cmd_scroll = self.fig.add_axes((0.02, 0.08, 0.27, 0.03)) self.ax_graph = self.fig.add_axes((0.32, 0.15, 0.65, ax_h_frac)) - self.ax_prev = self.fig.add_axes((0.30, 0.04, 0.03, 0.03)) - self.ax_slider = self.fig.add_axes((0.34, 0.04, 0.55, 0.03)) + self.ax_prev = self.fig.add_axes((0.32, 0.04, 0.03, 0.03)) + self.ax_slider = self.fig.add_axes((0.40, 0.04, 0.48, 0.03)) self.ax_next = self.fig.add_axes((0.90, 0.04, 0.03, 0.03)) # Turn off axes frame for command list and graph @@ -143,9 +142,12 @@ def __init__( # Interaction state self.current_step = 0 self.total_steps = len(pattern) + self.command_window_size = 30 + self._cmd_scroll_offset: int = 0 # first visible command index # Widget placeholders self.slider: Slider | None = None + self.cmd_scroll_slider: Slider | None = None self.btn_prev: Button | None = None self.btn_next: Button | None = None @@ -190,10 +192,23 @@ def visualize(self) -> None: self._draw_graph() self._update(0) - # Slider config + # Step slider (horizontal, bottom) self.slider = Slider(self.ax_slider, "Step", 0, self.total_steps, valinit=0, valstep=1, color="lightblue") self.slider.on_changed(self._update) + # Command list scroll slider (horizontal, below command panel) + max_scroll = max(0, self.total_steps - self.command_window_size) + self.cmd_scroll_slider = Slider( + self.ax_cmd_scroll, + "", + 0, + max(1, max_scroll), + valinit=0, + valstep=1, + color="#cccccc", + ) + self.cmd_scroll_slider.on_changed(self._on_cmd_scroll) + # Buttons config self.btn_prev = Button(self.ax_prev, "<") self.btn_prev.on_clicked(self._prev_step) @@ -214,13 +229,9 @@ def _draw_command_list(self) -> None: self.ax_commands.axis("off") self.ax_commands.set_title(f"Commands ({self.total_steps})", loc="left") - # Windowing logic to show relevant commands - window_size = 30 - start = max(0, int(self.current_step) - window_size // 2) - end = min(self.total_steps, start + window_size) - - if end == self.total_steps: - start = max(0, end - window_size) + # Use scroll offset for visible window + start = max(0, min(self._cmd_scroll_offset, self.total_steps - self.command_window_size)) + end = min(self.total_steps, start + self.command_window_size) cmds: Any = self.pattern[start:end] # type: ignore[index] @@ -237,7 +248,7 @@ def _draw_command_list(self) -> None: weight = "bold" # Position text from top to bottom - y_pos = 1.0 - (i + 1) * (1.0 / (window_size + 2)) + y_pos = 1.0 - (i + 1) * (1.0 / (self.command_window_size + 2)) text_obj = self.ax_commands.text( 0.05, @@ -424,10 +435,24 @@ def _update(self, val: float) -> None: step = int(val) if step != self.current_step: self.current_step = step + # Auto-scroll command list to keep current step visible + if step < self._cmd_scroll_offset or step >= self._cmd_scroll_offset + self.command_window_size: + new_offset = max(0, step - self.command_window_size // 2) + self._cmd_scroll_offset = new_offset + if self.cmd_scroll_slider is not None: + self.cmd_scroll_slider.set_val(new_offset) self._draw_command_list() self._draw_graph() self.fig.canvas.draw_idle() + def _on_cmd_scroll(self, val: float) -> None: + """Handle vertical scroll slider changes.""" + new_offset = int(val) + if new_offset != self._cmd_scroll_offset: + self._cmd_scroll_offset = new_offset + self._draw_command_list() + self.fig.canvas.draw_idle() + def _prev_step(self, _event: Any) -> None: if self.current_step > 0 and self.slider is not None: self.slider.set_val(self.current_step - 1)