diff --git a/CHANGELOG.md b/CHANGELOG.md index 21980c1c..428e00bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - #386, #433: Added `Statevec.fidelity` and `Statevec.isclose` methods for pure-state fidelity computation and equality check up to global phase. +- #387, #444: Improved `Pattern.draw_graph` visualization: MBQC literature node shapes (squares for inputs, filled/empty circles for measured/output), solid gray edges, measurement order arrow, `show_measurements` and `show_legend` parameters. ### Fixed diff --git a/examples/visualization.py b/examples/visualization.py index 2c08fbeb..28668caa 100644 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -33,13 +33,13 @@ pattern = circuit.transpile().pattern # note that this visualization is not always consistent with the correction set of pattern, # since we find the correction sets with flow-finding algorithms. -pattern.draw_graph(flow_from_pattern=False, show_measurement_planes=True) +pattern.draw_graph(flow_from_pattern=False, show_measurements=True) # %% # next, show the gflow: pattern.remove_input_nodes() pattern.perform_pauli_measurements() -pattern.draw_graph(flow_from_pattern=False, show_measurement_planes=True, node_distance=(1, 0.6)) +pattern.draw_graph(flow_from_pattern=False, show_measurements=True, node_distance=(1, 0.6)) # %% @@ -49,7 +49,7 @@ # # node_distance argument specifies the scale of the node arrangement in x and y directions. -pattern.draw_graph(flow_from_pattern=True, show_measurement_planes=True, node_distance=(0.7, 0.6)) +pattern.draw_graph(flow_from_pattern=True, show_measurements=True, node_distance=(0.7, 0.6)) # %% # Instead of the measurement planes, we can show the local Clifford of the resource graph. @@ -75,7 +75,7 @@ measurements = {node: Measurement.XY(0) for node in graph.nodes() if node not in outputs} og = OpenGraph(graph, inputs, outputs, measurements) vis = GraphVisualizer(og) -vis.visualize(show_measurement_planes=True) +vis.visualize(show_measurements=True) # %% @@ -91,6 +91,6 @@ } og = OpenGraph(graph, inputs, outputs, measurements) vis = GraphVisualizer(og) -vis.visualize(show_measurement_planes=True) +vis.visualize(show_measurements=True) # %% diff --git a/graphix/pattern.py b/graphix/pattern.py index ec900639..ad19a16b 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -1437,7 +1437,8 @@ def draw_graph( flow_from_pattern: bool = True, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, @@ -1445,16 +1446,21 @@ def draw_graph( ) -> None: """Visualize the underlying graph of the pattern with flow or gflow structure. + Nodes are drawn following MBQC literature conventions: inputs as squares, + measured nodes as filled circles, and outputs as empty circles. + Parameters ---------- flow_from_pattern : bool If True, the command sequence of the pattern is used to derive flow or gflow structure. If False, only the underlying graph is used. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -1475,7 +1481,8 @@ def draw_graph( pattern=self.copy(), show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, - show_measurement_planes=show_measurement_planes, + show_measurements=show_measurements, + show_legend=show_legend, show_loop=show_loop, node_distance=node_distance, figsize=figsize, @@ -1485,7 +1492,8 @@ def draw_graph( vis.visualize( show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, - show_measurement_planes=show_measurement_planes, + show_measurements=show_measurements, + show_legend=show_legend, show_loop=show_loop, node_distance=node_distance, figsize=figsize, diff --git a/graphix/visualization.py b/graphix/visualization.py index 4322b82b..c8499d88 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -9,13 +9,15 @@ import networkx as nx import numpy as np from matplotlib import pyplot as plt +from matplotlib.lines import Line2D from graphix.flow.exceptions import FlowError -from graphix.measurements import Measurement, PauliMeasurement +from graphix.measurements import BlochMeasurement, Measurement, PauliMeasurement # OpenGraph is needed for dataclass from graphix.opengraph import OpenGraph # noqa: TC001 from graphix.optimization import StandardizedPattern +from graphix.pretty_print import OutputFormat, angle_to_str if TYPE_CHECKING: from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence @@ -56,14 +58,14 @@ def visualize( self, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, filename: Path | None = None, ) -> None: - """ - Visualize the graph with flow or gflow structure. + """Visualize the graph with flow or gflow structure. If there exists a flow structure, then the graph is visualized with the flow structure. If flow structure is not found and there exists a gflow structure, then the graph is visualized @@ -73,11 +75,13 @@ def visualize( Parameters ---------- show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -138,7 +142,8 @@ def place_paths( None, show_pauli_measurement, show_local_clifford, - show_measurement_planes, + show_measurements, + show_legend, show_loop, node_distance, figsize, @@ -150,14 +155,14 @@ def visualize_from_pattern( pattern: Pattern, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, filename: Path | None = None, ) -> None: - """ - Visualize the graph with flow or gflow structure found from the given pattern. + """Visualize the graph with flow or gflow structure found from the given pattern. If pattern sequence is consistent with flow structure, then the graph is visualized with the flow structure. If it is not consistent with flow structure and consistent with gflow structure, then the graph is visualized @@ -168,11 +173,13 @@ def visualize_from_pattern( pattern : Pattern pattern to be visualized show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -235,7 +242,8 @@ def place_paths( corrections, show_pauli_measurement, show_local_clifford, - show_measurement_planes, + show_measurements, + show_legend, show_loop, node_distance, figsize, @@ -252,35 +260,66 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]: new_path[-1] = last_edge return new_path - def _draw_labels(self, pos: Mapping[int, _Point]) -> None: + def _draw_labels(self, pos: Mapping[int, _Point], font_color: Mapping[int, str] | str = "black") -> None: + """Draw node number labels with appropriate text color. + + Parameters + ---------- + pos : Mapping[int, tuple[float, float]] + Dictionary of node positions. + font_color : Mapping[int, str] | str + Font color for node labels. Can be a single color string or a mapping from node to color. + """ fontsize = 12 if max(self.og.graph.nodes(), default=0) >= 100: fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) - nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) + nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize, font_color=font_color) # type: ignore[arg-type] - def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: - """ - Draw the nodes with different colors based on their role (input, output, or other). + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> dict[int, str]: + """Draw the nodes with shapes and fills following MBQC literature conventions. + + Input nodes are drawn as squares, measured (non-output) nodes as filled circles, + and output nodes as empty circles. Pauli-measured nodes are optionally distinguished + with a blue fill. Parameters ---------- pos : Mapping[int, tuple[float, float]] - dictionary of node positions. + Dictionary of node positions. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. + + Returns + ------- + dict[int, str] + Mapping from node index to font color for label rendering. """ + font_colors: dict[int, str] = {} + for node in self.og.graph.nodes(): - color = "black" # default color for 'other' nodes - inner_color = "white" - if node in self.og.input_nodes: - color = "red" + marker = "s" if node in self.og.input_nodes else "o" + is_pauli = node in self.og.measurements and isinstance(self.og.measurements[node], PauliMeasurement) + if node in self.og.output_nodes: - inner_color = "lightgray" - elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): - inner_color = "lightblue" + facecolor = "white" + elif show_pauli_measurement and is_pauli: + facecolor = "#4292c6" + else: + facecolor = "black" + + font_colors[node] = "white" if facecolor == "black" else "black" + plt.scatter( - *pos[node], edgecolor=color, facecolor=inner_color, s=350, zorder=2 - ) # Draw the nodes manually with scatter() + *pos[node], + marker=marker, + edgecolor="black", + facecolor=facecolor, + s=350, + zorder=2, + linewidths=1.5, + ) + + return font_colors def visualize_graph( self, @@ -292,19 +331,19 @@ def visualize_graph( corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, show_pauli_measurement: bool = True, show_local_clifford: bool = False, - show_measurement_planes: bool = False, + show_measurements: bool = False, + show_legend: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: _Point | None = None, filename: Path | None = None, ) -> None: - """ - Visualizes the graph. + """Visualize the graph. - Nodes are colored based on their role (input, output, or other) and edges are depicted as arrows - or dashed lines depending on whether they are in the flow mapping. Vertical dashed lines separate - different layers of the graph. This function does not return anything but plots the graph - using matplotlib's pyplot. + Nodes are drawn following MBQC literature conventions: inputs as squares, + measured nodes as filled circles, and outputs as empty circles. Graph edges + are solid lines and flow arrows indicate corrections. A horizontal arrow + below the graph indicates the measurement order. Parameters ---------- @@ -319,11 +358,13 @@ def visualize_graph( corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None X and Z corrections if any. show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. + If True, Pauli-measured nodes are filled with blue instead of black. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. + show_measurements : bool + If True, measurement labels are displayed adjacent to the nodes. + show_legend : bool + If True, a legend is displayed indicating node types and edge meanings. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple @@ -341,7 +382,7 @@ def visualize_graph( edge_path, arrow_path = place_paths(pos) - if corrections is not None: + if show_legend or corrections is not None: # add some padding to the right for the legend figsize = (figsize[0] + 3.0, figsize[1]) @@ -349,10 +390,12 @@ def visualize_graph( for edge, path in edge_path.items(): if len(path) == 2: - nx.draw_networkx_edges(self.og.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) + nx.draw_networkx_edges( + self.og.graph, pos, edgelist=[edge], style="dashed", edge_color="gray", alpha=0.6 + ) else: curve = self._bezier_curve_linspace(path) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) + plt.plot(curve[:, 0], curve[:, 1], color="gray", linewidth=1, alpha=0.6, linestyle="dashed") if arrow_path is not None: for arrow, path in arrow_path.items(): @@ -391,19 +434,21 @@ def visualize_graph( arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, ) - self.__draw_nodes_role(pos, show_pauli_measurement) + font_colors = self.__draw_nodes_role(pos, show_pauli_measurement) if show_local_clifford: self.__draw_local_clifford(pos) - if show_measurement_planes: - self.__draw_measurement_planes(pos) + if show_measurements: + self.__draw_measurement_labels(pos) - self._draw_labels(pos) + self._draw_labels(pos, font_colors) - if corrections is not None: - # legend for arrow colors - plt.plot([], [], "k--", alpha=0.7, label="graph edge") + if show_legend: + self.__draw_legend(show_pauli_measurement, corrections, arrow_path is not None) + elif corrections is not None: + # backward-compatible minimal legend for correction arrows + plt.plot([], [], color="gray", alpha=0.6, linestyle="dashed", label="graph edge") plt.plot([], [], color="tab:red", label="xflow") plt.plot([], [], color="tab:green", label="zflow") plt.plot([], [], color="tab:brown", label="xflow and zflow") @@ -414,26 +459,43 @@ def visualize_graph( y_min = min((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the minimum y coordinate y_max = max((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the maximum y coordinate - if l_k is not None and l_k: - # Draw the vertical lines to separate different layers - for layer in range(min(l_k.values()), max(l_k.values())): - plt.axvline( - x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 - ) # Draw line between layers - for layer in range(min(l_k.values()), max(l_k.values()) + 1): + has_layers = l_k is not None and len(l_k) > 0 + if has_layers and l_k is not None: + l_min_val = min(l_k.values()) + l_max_val = max(l_k.values()) + # Draw layer labels below nodes + for layer in range(l_min_val, l_max_val + 1): plt.text( - layer * node_distance[0], y_min - 0.5, f"L: {max(l_k.values()) - layer}", ha="center", va="top" - ) # Add layer label at bottom + layer * node_distance[0], + y_min - 0.4, + f"L{l_max_val - layer}", + ha="center", + va="top", + fontsize=8, + color="gray", + ) + # Draw horizontal arrow indicating measurement order + if l_max_val > l_min_val: + arrow_y = y_min - 0.7 + plt.annotate( + "", + xy=(l_max_val * node_distance[0] + 0.3, arrow_y), + xytext=(l_min_val * node_distance[0] - 0.3, arrow_y), + arrowprops={"arrowstyle": "->", "color": "gray", "lw": 1.2}, + ) + mid_x = (l_min_val + l_max_val) / 2 * node_distance[0] + plt.text(mid_x, arrow_y - 0.2, "Measurement order", ha="center", va="top", fontsize=8, color="gray") plt.xlim( x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] ) # Add some padding to the left and right - plt.ylim(y_min - 1, y_max + 0.5) # Add some padding to the top and bottom + bottom_margin = 1.3 if has_layers else 1 + plt.ylim(y_min - bottom_margin, y_max + 0.5) if filename is None: plt.show() else: - plt.savefig(filename) + plt.savefig(filename, bbox_inches="tight") def __draw_local_clifford(self, pos: Mapping[int, _Point]) -> None: if self.local_clifford is not None: @@ -441,12 +503,133 @@ def __draw_local_clifford(self, pos: Mapping[int, _Point]) -> None: x, y = pos[node] + np.array([0.2, 0.2]) plt.text(x, y, f"{self.local_clifford[node]}", fontsize=10, zorder=3) - def __draw_measurement_planes(self, pos: Mapping[int, _Point]) -> None: + @staticmethod + def __draw_legend( + show_pauli_measurement: bool, + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None, + has_arrows: bool, + ) -> None: + """Draw a legend indicating node types and edge meanings. + + Parameters + ---------- + show_pauli_measurement : bool + Whether Pauli-measured nodes are visually distinct. + corrections : tuple or None + X and Z corrections if any, to determine arrow legend entries. + has_arrows : bool + Whether flow arrows are present in the graph. + """ + elements: list[Line2D] = [ + Line2D( + [0], + [0], + marker="s", + color="w", + markerfacecolor="black", + markeredgecolor="black", + markersize=10, + label="Input", + ), + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="black", + markeredgecolor="black", + markersize=10, + label="Measured", + ), + ] + if show_pauli_measurement: + elements.append( + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="#4292c6", + markeredgecolor="black", + markersize=10, + label="Pauli-measured", + ) + ) + elements.extend( + [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="white", + markeredgecolor="black", + markersize=10, + label="Output", + ), + Line2D([0], [0], color="gray", linewidth=1, alpha=0.6, linestyle="dashed", label="Graph edge"), + ] + ) + + if corrections is not None: + elements.extend( + [ + Line2D([0], [0], color="tab:red", linewidth=1, label="X-correction"), + Line2D([0], [0], color="tab:green", linewidth=1, label="Z-correction"), + Line2D([0], [0], color="tab:brown", linewidth=1, label="X & Z-correction"), + ] + ) + elif has_arrows: + elements.append(Line2D([0], [0], color="black", linewidth=1, label="Flow")) + + plt.legend(handles=elements, loc="center left", fontsize=9, bbox_to_anchor=(1, 0.5)) + + def __draw_measurement_labels(self, pos: Mapping[int, _Point]) -> None: + """Draw measurement labels next to measured nodes. + + Labels are rendered with a white background to ensure legibility over graph edges. + + Parameters + ---------- + pos : Mapping[int, tuple[float, float]] + Dictionary of node positions. + """ for node, meas in self.og.measurements.items(): - x, y = pos[node] + np.array([0.22, -0.2]) - label = meas.to_plane_or_axis().name + label = self._format_measurement_label(meas) + if label is not None: + x, y = pos[node] + plt.text( + x + 0.18, + y - 0.2, + label, + fontsize=8, + zorder=3, + bbox={"boxstyle": "round,pad=0.15", "facecolor": "white", "edgecolor": "none", "alpha": 0.85}, + ) + + @staticmethod + def _format_measurement_label(meas: Measurement) -> str | None: + """Format a measurement label for display. - plt.text(x, y, label, fontsize=9, zorder=3) + Parameters + ---------- + meas : Measurement + The measurement to format. + + Returns + ------- + str | None + Formatted label string, or None if nothing to show. + """ + if isinstance(meas, PauliMeasurement): + return str(meas) + if isinstance(meas, BlochMeasurement): + if isinstance(meas.angle, (int, float)): + angle_str = angle_to_str(meas.angle, OutputFormat.Unicode) + else: + angle_str = str(meas.angle) + return f"{meas.plane.name}({angle_str})" + return None def determine_figsize( self, diff --git a/tests/baseline/test_draw_graph_reference_False.png b/tests/baseline/test_draw_graph_reference_False.png index aca8ad90..22c9742d 100644 Binary files a/tests/baseline/test_draw_graph_reference_False.png and b/tests/baseline/test_draw_graph_reference_False.png differ diff --git a/tests/baseline/test_draw_graph_reference_True.png b/tests/baseline/test_draw_graph_reference_True.png index 91e28966..07ae5dc8 100644 Binary files a/tests/baseline/test_draw_graph_reference_True.png and b/tests/baseline/test_draw_graph_reference_True.png differ diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 030ada64..5a1d23f5 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -9,8 +9,8 @@ import pytest from graphix import Circuit, Pattern, command, visualization -from graphix.fundamentals import ANGLE_PI -from graphix.measurements import Measurement +from graphix.fundamentals import ANGLE_PI, Axis, Sign +from graphix.measurements import Measurement, PauliMeasurement from graphix.opengraph import OpenGraph, OpenGraphError from graphix.visualization import GraphVisualizer @@ -144,10 +144,10 @@ def test_draw_graph_show_local_clifford() -> None: @pytest.mark.usefixtures("mock_plot") -def test_draw_graph_show_measurement_planes(fx_rng: Generator) -> None: +def test_draw_graph_show_measurements_basic(fx_rng: Generator) -> None: pattern = example_pflow(fx_rng) pattern.draw_graph( - show_measurement_planes=True, + show_measurements=True, node_distance=(0.7, 0.6), ) @@ -247,6 +247,83 @@ def test_draw_graph_reference(flow_and_not_pauli_presimulate: bool) -> Figure: pattern.perform_pauli_measurements() pattern.standardize() pattern.draw_graph( - flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(0.7, 0.6), show_measurement_planes=True + flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(0.7, 0.6), show_measurements=True ) return plt.gcf() + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_measurements(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + show_measurements=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_measurements_pflow(fx_rng: Generator) -> None: + pattern = example_pflow(fx_rng) + pattern.draw_graph( + show_measurements=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_legend(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + show_legend=True, + node_distance=(0.7, 0.6), + ) + + +@pytest.mark.usefixtures("mock_plot") +def test_draw_graph_show_legend_with_corrections(fx_rng: Generator) -> None: + pattern = example_flow(fx_rng) + pattern.draw_graph( + flow_from_pattern=True, + show_legend=True, + show_pauli_measurement=True, + node_distance=(0.7, 0.6), + ) + + +def test_format_measurement_label_bloch() -> None: + bloch_xy = Measurement.XY(0.25) + label = GraphVisualizer._format_measurement_label(bloch_xy) + assert label is not None + assert "XY" in label + assert "/" in label # pi/4 contains "/" + + +def test_format_measurement_label_bloch_zero() -> None: + bloch_zero = Measurement.XY(0) + label = GraphVisualizer._format_measurement_label(bloch_zero) + assert label is not None + assert "XY" in label + assert "0" in label + + +def test_format_measurement_label_bloch_xz() -> None: + bloch_xz = Measurement.XZ(0.5) + label = GraphVisualizer._format_measurement_label(bloch_xz) + assert label is not None + assert "XZ" in label + + +def test_format_measurement_label_pauli() -> None: + pauli_x = Measurement.X + label = GraphVisualizer._format_measurement_label(pauli_x) + assert label is not None + assert label == str(pauli_x) + assert "X" in label + + +def test_format_measurement_label_pauli_minus() -> None: + pauli_minus_z = PauliMeasurement(Axis.Z, Sign.MINUS) + label = GraphVisualizer._format_measurement_label(pauli_minus_z) + assert label is not None + assert label == str(pauli_minus_z) + assert "-Z" in label