Skip to content

feat: add circuit drawer module with Matplotlib (#1083) #1094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"openqasm3",
"sympy",
"backports.entry-points-selectable",
"matplotlib",
],
extras_require={
"test": [
Expand Down
3 changes: 3 additions & 0 deletions src/braket/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
from braket.circuits.text_diagram_builders.ascii_circuit_diagram import (
AsciiCircuitDiagram, # noqa: F401
)
from braket.circuits.text_diagram_builders.matplotlib_circuit_diagram import (
MatplotlibCircuitDiagram, # noqa: F401
)
from braket.circuits.text_diagram_builders.unicode_circuit_diagram import (
UnicodeCircuitDiagram, # noqa: F401
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
from __future__ import annotations

from typing import Union

import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.axes import Axes
from matplotlib.figure import Figure

import braket.circuits.circuit as cir
from braket.circuits.circuit_diagram import CircuitDiagram
from braket.circuits.instruction import Instruction
from braket.circuits.moments import MomentType
from braket.circuits.result_type import ResultType
from braket.circuits.result_types import ObservableResultType
from braket.circuits.text_diagram_builders.text_circuit_diagram_utils import _group_items
from braket.registers.qubit_set import QubitSet


class MatplotlibCircuitDiagram(CircuitDiagram):
"""A class that builds circuit diagrams using matplotlib."""

# Default styling parameters
GATE_BOX_HEIGHT = 0.6
GATE_BOX_WIDTH = 1.0
GATE_BOX_PADDING = 0.2
WIRE_SPACING = 1.0
COLUMN_SPACING = 1.5
GATE_COLOR = "#E0E0E0"
WIRE_COLOR = "#000000"
TEXT_COLOR = "#000000"
GLOBAL_PHASE_COLOR = "#666666"

@classmethod
def build_diagram(cls, circuit: cir.Circuit) -> Figure:
"""Build a matplotlib figure for the specified circuit.

Args:
circuit (cir.Circuit): The circuit to build a diagram for.

Returns:
Figure: Matplotlib figure containing the circuit diagram.
"""
if not circuit.instructions:
return plt.figure()

if all(m.moment_type == MomentType.GLOBAL_PHASE for m in circuit._moments):
fig, ax = plt.subplots(figsize=(4, 1))
ax.text(0.5, 0.5, f"Global phase: {circuit.global_phase}", ha="center", va="center")
ax.axis("off")
return fig

# Get circuit qubits and sort them
circuit_qubits = circuit.qubits
circuit_qubits.sort()

# Create figure and axis
num_qubits = len(circuit_qubits)
fig_width = 10 # Will be adjusted based on circuit size
fig_height = max(4, num_qubits * 1.5)
fig, ax = plt.subplots(figsize=(fig_width, fig_height))

# Draw the circuit
cls._draw_circuit(ax, circuit, circuit_qubits)

# Adjust layout and return
plt.tight_layout()
return fig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a Jupyter notebook, when calling build_diagram(), it's always displaying the image automatically - even if I capture the return value in a variable. If I capture the return value of build_diagram() in a variable, or if this is not the last line in the cell, then the image should not be displayed automatically.


@classmethod
def _draw_circuit(cls, ax: Axes, circuit: cir.Circuit, circuit_qubits: QubitSet) -> None:
"""Draw the circuit on the given axis.

Args:
ax (Axes): Matplotlib axis to draw on
circuit (Circuit): Circuit to draw
circuit_qubits (QubitSet): Sorted set of qubits in the circuit
"""
# Set up the axis
ax.set_aspect("equal")
ax.axis("off")

time_slices = circuit.moments.time_slices()
num_columns = len(time_slices) + (1 if circuit.result_types else 0)

# Draw qubit lines
wire_color = getattr(cls, "WIRE_COLOR", "#000000")
text_color = getattr(cls, "TEXT_COLOR", "#000000")
for i, qubit in enumerate(circuit_qubits):
y = -i * cls.WIRE_SPACING
ax.plot(
[0, num_columns * cls.COLUMN_SPACING],
[y, y],
color=wire_color,
linewidth=1,
alpha=1.0,
)
ax.text(-0.5, y, f"q[{qubit}]", ha="right", va="center", color=text_color)

# Draw gates and connections
x_pos = 0

for instructions in time_slices.values():
# Group instructions to handle overlapping gates
groupings = _group_items(circuit_qubits, instructions)
for grouping in groupings:
cls._draw_instruction_group(ax, grouping[1], circuit_qubits, x_pos)
x_pos += cls.COLUMN_SPACING

# Draw result types if any
if circuit.result_types:
cls._draw_result_types(ax, circuit.result_types, circuit_qubits, x_pos)

# Draw global phase if present (after all instructions, so it's not overlapped)
if circuit.global_phase != 0:
y_center = (
0 if len(circuit_qubits) == 1 else -len(circuit_qubits) * cls.WIRE_SPACING / 2
)
# If global phase is exactly 0.5, display as 0.50π, else as a number with two decimals
if abs(circuit.global_phase - 0.5) < 1e-9:
phase_str = "0.50π"
Comment on lines +120 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool. Two comments
(1) would be cool to do this for other multiples of pi, not just 0.5
(2) we should share the logic for displaying global phase and for angle parameters to other gates

else:
phase_str = f"{circuit.global_phase:.2f}"
# Add the global phase text to the axis
ax.text(
x_pos,
y_center,
f"Phase: {phase_str}",
color=getattr(cls, "GLOBAL_PHASE_COLOR", "#666666"),
ha="center",
va="center",
fontsize=10,
bbox={"facecolor": "white", "alpha": 0.7, "edgecolor": "none", "pad": 1},
)

@classmethod
def _draw_instruction_group(
cls,
ax: Axes,
instructions: list[Instruction | ResultType],
circuit_qubits: QubitSet,
x_pos: float,
) -> None:
"""Draw a group of instructions at the given x position.

Args:
ax (Axes): Matplotlib axis to draw on
instructions (list[Instruction | ResultType]): Instructions to draw
circuit_qubits (QubitSet): All qubits in the circuit
x_pos (float): X position to draw at
"""
gate_color = getattr(cls, "GATE_COLOR", "#E0E0E0")
text_color = getattr(cls, "TEXT_COLOR", "#000000")
wire_color = getattr(cls, "WIRE_COLOR", "#000000")
# Draw each instruction
for instruction in instructions:
# Draw compiler directives as boxes with labels
if (
hasattr(instruction, "operator")
and instruction.operator.__class__.__name__ == "CompilerDirective"
):
qubits = instruction.target
if not isinstance(qubits, QubitSet):
qubits = QubitSet(qubits)

Check warning on line 164 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L164

Added line #L164 was not covered by tests
y_positions = (
[-circuit_qubits.index(q) * cls.WIRE_SPACING for q in qubits] if qubits else [0]
)
y_min = min(y_positions)
y_max = max(y_positions)
box = patches.Rectangle(
(x_pos - cls.GATE_BOX_WIDTH / 2, y_min - cls.GATE_BOX_HEIGHT / 2),
cls.GATE_BOX_WIDTH,
y_max - y_min + cls.GATE_BOX_HEIGHT,
facecolor=gate_color,
edgecolor="black",
alpha=0.7,
)
ax.add_patch(box)
# Use the directive string exactly
operator = instruction.operator
text = str(operator.ascii_symbols[0]).strip()
ax.text(
x_pos,
(y_min + y_max) / 2,
text,
ha="center",
va="center",
color=text_color,
fontsize=10,
bbox={"facecolor": "white", "alpha": 0.7, "edgecolor": "none", "pad": 1},
)
continue
if hasattr(instruction, "target"):
qubits = instruction.target
if not isinstance(qubits, QubitSet):
qubits = QubitSet(qubits)

Check warning on line 196 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L196

Added line #L196 was not covered by tests
if not qubits:
continue # Skip drawing if no qubits
# Calculate y positions for the gate
y_positions = [-circuit_qubits.index(q) * cls.WIRE_SPACING for q in qubits]
y_min = min(y_positions)
y_max = max(y_positions)
# Draw the gate box
box = patches.Rectangle(
(x_pos - cls.GATE_BOX_WIDTH / 2, y_min - cls.GATE_BOX_HEIGHT / 2),
cls.GATE_BOX_WIDTH,
y_max - y_min + cls.GATE_BOX_HEIGHT,
facecolor=gate_color,
edgecolor="black",
alpha=0.7,
)
ax.add_patch(box)
# Draw the gate name
text = cls._get_instruction_text(instruction)
ax.text(
x_pos,
(y_min + y_max) / 2,
text,
ha="center",
va="center",
color=text_color,
fontsize=10,
bbox={"facecolor": "white", "alpha": 0.7, "edgecolor": "none", "pad": 1},
)
# Draw connections between qubits
if len(qubits) > 1:
for y in y_positions[1:]:
ax.plot([x_pos, x_pos], [y_positions[0], y], color=wire_color, linewidth=1)

@classmethod
def _draw_result_types(
cls, ax: Axes, result_types: list[ResultType], circuit_qubits: QubitSet, x_pos: float
) -> None:
"""Draw result types at the given x position.

Args:
ax (Axes): Matplotlib axis to draw on
result_types (list[ResultType]): Result types to draw
circuit_qubits (QubitSet): All qubits in the circuit
x_pos (float): X position to draw at
"""
for result_type in result_types:
if hasattr(result_type, "target") and result_type.target:
qubits = result_type.target
if not isinstance(qubits, QubitSet):
qubits = QubitSet(qubits)

Check warning on line 246 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L246

Added line #L246 was not covered by tests
y_positions = [-circuit_qubits.index(q) * cls.WIRE_SPACING for q in qubits]
y_min = min(y_positions)
y_max = max(y_positions)
else:
# For result types without targets, draw at the top wire
y_min = -0.5 * cls.WIRE_SPACING
y_max = y_min + cls.GATE_BOX_HEIGHT
# Draw the result type box
box = patches.Rectangle(
(x_pos - cls.GATE_BOX_WIDTH / 2, y_min - cls.GATE_BOX_HEIGHT / 2),
cls.GATE_BOX_WIDTH,
y_max - y_min + cls.GATE_BOX_HEIGHT,
facecolor=cls.GATE_COLOR if isinstance(cls.GATE_COLOR, str) else "#E0E0E0",
edgecolor="black",
alpha=0.7,
)
ax.add_patch(box)
# Draw the result type name
text = cls._get_instruction_text(result_type)
ax.text(
x_pos,
(y_min + y_max) / 2,
text,
ha="center",
va="center",
color=cls.TEXT_COLOR,
fontsize=10,
bbox={"facecolor": "white", "alpha": 0.7, "edgecolor": "none", "pad": 1},
)

@classmethod
def _get_instruction_text(cls, instruction: Union[Instruction, ResultType]) -> str:
"""Get the text representation of an instruction.

Args:
instruction (Union[Instruction, ResultType]): The instruction to get text for

Returns:
str: Text representation of the instruction
"""
if hasattr(instruction, "operator"):
operator = instruction.operator
# Compiler directive
if operator.__class__.__name__ == "CompilerDirective":
return str(operator.ascii_symbols[0])

Check warning on line 291 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L291

Added line #L291 was not covered by tests
# Measurement
if operator.__class__.__name__ == "Measure":
return "Measure"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for text diagrams we just use M here - could do the same for consistency

# Noise operations
noise_names = [
"Depolarizing",
"BitFlip",
"PhaseFlip",
"AmplitudeDamping",
"GeneralizedAmplitudeDamping",
"PhaseDamping",
"Kraus",
]
if hasattr(operator, "name") and operator.name in noise_names:
return operator.name.capitalize()
# Parameterized gate
if hasattr(operator, "angle") and hasattr(operator, "name") and operator.name:
angle = getattr(operator, "angle", None)
if angle is not None:
# Handle FreeParameter objects and expressions
if hasattr(angle, "name"):
angle_str = angle.name
elif hasattr(angle, "__str__"):
angle_str = str(angle)
else:
# Format numeric angles
angle_str = f"{angle:.2f}"

Check warning on line 318 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L318

Added line #L318 was not covered by tests
if abs(angle - 0.5) < 1e-9:
angle_str = "0.50π"

Check warning on line 320 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L320

Added line #L320 was not covered by tests
return f"{operator.name.capitalize()}({angle_str})"
# Standard gate
if hasattr(operator, "name") and operator.name:
return operator.name.upper()
# Fallback to ascii_symbols
if hasattr(operator, "ascii_symbols"):
return str(operator.ascii_symbols[0]).upper()
return str(operator)

Check warning on line 328 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L328

Added line #L328 was not covered by tests
# Handle ResultType and ObservableResultType
if isinstance(instruction, ObservableResultType):
# Use the ascii_symbols, which are like 'Expectation(X)', 'Sample(Y)', etc.
return str(instruction.ascii_symbols[0])
if isinstance(instruction, ResultType):
# For other result types, use their ascii_symbols (e.g., 'Probability', 'StateVector', etc.)
return str(instruction.ascii_symbols[0])
if hasattr(instruction, "ascii_symbols"):
return str(instruction.ascii_symbols[0])
return str(instruction)

Check warning on line 338 in src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py

View check run for this annotation

Codecov / codecov/patch

src/braket/circuits/text_diagram_builders/matplotlib_circuit_diagram.py#L337-L338

Added lines #L337 - L338 were not covered by tests
Loading
Loading