Skip to content

Commit

Permalink
Add minimal support for standalone Var to visualisers (#12307)
Browse files Browse the repository at this point in the history
* Add minimal support for standalone `Var` to visualisers

This adds best-effort only support to the visualisers for handling
stand-alone `Var` nodes.  Most of the changes are actually in `qasm3`,
since the visualisers use internal details of that to handle the nodes.

This commit decouples the visualisers _slightly_ more from the inner
workings of the OQ3 exporter by having them manage their own
variable-naming contexts and using the encapsulated `_ExprBuilder`,
rather than poking into random internals of the full circuit exporter.
This is necessary to allow the OQ3 exporter to expand to support these
variables itself, and also for the visualisers, since variables may now
be introduced in inner scopes.

This commit does not attempt to solve many of the known problems around
zero-operand "gates", of which `Store` is one, just leaving it un-drawn.
Printing to OpenQASM 3 is possibly a better visualisation strategy for
large dynamic circuits for the time being.

* Fix typos

Co-authored-by: Matthew Treinish <[email protected]>

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
jakelishman and mtreinish authored May 1, 2024
1 parent 6b73b58 commit a78c941
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 42 deletions.
11 changes: 11 additions & 0 deletions qiskit/qasm3/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,24 @@ class FloatType(ClassicalType, enum.Enum):
OCT = 256


class BoolType(ClassicalType):
"""Type information for a Boolean."""


class IntType(ClassicalType):
"""Type information for a signed integer."""

def __init__(self, size: Optional[int] = None):
self.size = size


class UintType(ClassicalType):
"""Type information for an unsigned integer."""

def __init__(self, size: Optional[int] = None):
self.size = size


class BitType(ClassicalType):
"""Type information for a single bit."""

Expand Down
20 changes: 11 additions & 9 deletions qiskit/qasm3/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,14 @@ def _lift_condition(condition):
return expr.lift_legacy_condition(condition)


def _build_ast_type(type_: types.Type) -> ast.ClassicalType:
if type_.kind is types.Bool:
return ast.BoolType()
if type_.kind is types.Uint:
return ast.UintType(type_.width)
raise RuntimeError(f"unhandled expr type '{type_}'") # pragma: no cover


class _ExprBuilder(expr.ExprVisitor[ast.Expression]):
__slots__ = ("lookup",)

Expand All @@ -1069,7 +1077,7 @@ def __init__(self, lookup):
self.lookup = lookup

def visit_var(self, node, /):
return self.lookup(node.var)
return self.lookup(node) if node.standalone else self.lookup(node.var)

def visit_value(self, node, /):
if node.type.kind is types.Bool:
Expand All @@ -1080,14 +1088,8 @@ def visit_value(self, node, /):

def visit_cast(self, node, /):
if node.implicit:
return node.accept(self)
if node.type.kind is types.Bool:
oq3_type = ast.BoolType()
elif node.type.kind is types.Uint:
oq3_type = ast.BitArrayType(node.type.width)
else:
raise RuntimeError(f"unhandled cast type '{node.type}'")
return ast.Cast(oq3_type, node.operand.accept(self))
return node.operand.accept(self)
return ast.Cast(_build_ast_type(node.type), node.operand.accept(self))

def visit_unary(self, node, /):
return ast.Unary(ast.Unary.Op[node.op.name], node.operand.accept(self))
Expand Down
8 changes: 8 additions & 0 deletions qiskit/qasm3/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,19 @@ def _visit_CalibrationGrammarDeclaration(self, node: ast.CalibrationGrammarDecla
def _visit_FloatType(self, node: ast.FloatType) -> None:
self.stream.write(f"float[{self._FLOAT_WIDTH_LOOKUP[node]}]")

def _visit_BoolType(self, _node: ast.BoolType) -> None:
self.stream.write("bool")

def _visit_IntType(self, node: ast.IntType) -> None:
self.stream.write("int")
if node.size is not None:
self.stream.write(f"[{node.size}]")

def _visit_UintType(self, node: ast.UintType) -> None:
self.stream.write("uint")
if node.size is not None:
self.stream.write(f"[{node.size}]")

def _visit_BitType(self, _node: ast.BitType) -> None:
self.stream.write("bit")

Expand Down
56 changes: 43 additions & 13 deletions qiskit/visualization/circuit/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
IfElseOp,
ForLoopOp,
SwitchCaseOp,
CircuitError,
)
from qiskit.circuit.controlflow import condition_resources
from qiskit.circuit.classical import expr
Expand All @@ -46,7 +47,8 @@
XGate,
ZGate,
)
from qiskit.qasm3.exporter import QASM3Builder
from qiskit.qasm3 import ast
from qiskit.qasm3.exporter import _ExprBuilder
from qiskit.qasm3.printer import BasicPrinter

from qiskit.circuit.tools.pi_check import pi_check
Expand Down Expand Up @@ -393,7 +395,7 @@ def draw(self, filename=None, verbose=False):
matplotlib_close_if_inline(mpl_figure)
return mpl_figure

def _get_layer_widths(self, node_data, wire_map, outer_circuit, glob_data, builder=None):
def _get_layer_widths(self, node_data, wire_map, outer_circuit, glob_data):
"""Compute the layer_widths for the layers"""

layer_widths = {}
Expand Down Expand Up @@ -482,18 +484,41 @@ def _get_layer_widths(self, node_data, wire_map, outer_circuit, glob_data, build
if (isinstance(op, SwitchCaseOp) and isinstance(op.target, expr.Expr)) or (
getattr(op, "condition", None) and isinstance(op.condition, expr.Expr)
):
condition = op.target if isinstance(op, SwitchCaseOp) else op.condition
if builder is None:
builder = QASM3Builder(
outer_circuit,
includeslist=("stdgates.inc",),
basis_gates=("U",),
disable_constants=False,
allow_aliasing=False,

def lookup_var(var):
"""Look up a classical-expression variable or register/bit in our
internal symbol table, and return an OQ3-like identifier."""
# We don't attempt to disambiguate anything like register/var naming
# collisions; we already don't really show classical variables.
if isinstance(var, expr.Var):
return ast.Identifier(var.name)
if isinstance(var, ClassicalRegister):
return ast.Identifier(var.name)
# Single clbit. This is not actually the correct way to lookup a bit on
# the circuit (it doesn't handle bit bindings fully), but the mpl
# drawer doesn't completely track inner-outer _bit_ bindings, only
# inner-indices, so we can't fully recover the information losslessly.
# Since most control-flow uses the control-flow builders, we should
# decay to something usable most of the time.
try:
register, bit_index, reg_index = get_bit_reg_index(
outer_circuit, var
)
except CircuitError:
# We failed to find the bit due to binding problems - fall back to
# something that's probably wrong, but at least disambiguating.
return ast.Identifier(f"bit{wire_map[var]}")
if register is None:
return ast.Identifier(f"bit{bit_index}")
return ast.SubscriptedIdentifier(
register.name, ast.IntegerLiteral(reg_index)
)
builder.build_classical_declarations()

condition = op.target if isinstance(op, SwitchCaseOp) else op.condition
stream = StringIO()
BasicPrinter(stream, indent=" ").visit(builder.build_expression(condition))
BasicPrinter(stream, indent=" ").visit(
condition.accept(_ExprBuilder(lookup_var))
)
expr_text = stream.getvalue()
# Truncate expr_text so that first gate is no more than about 3 x_index's over
if len(expr_text) > self._expr_len:
Expand Down Expand Up @@ -570,7 +595,7 @@ def _get_layer_widths(self, node_data, wire_map, outer_circuit, glob_data, build

# Recursively call _get_layer_widths for the circuit inside the ControlFlowOp
flow_widths = flow_drawer._get_layer_widths(
node_data, flow_wire_map, outer_circuit, glob_data, builder
node_data, flow_wire_map, outer_circuit, glob_data
)
layer_widths.update(flow_widths)

Expand Down Expand Up @@ -1243,6 +1268,11 @@ def _condition(self, node, node_data, wire_map, outer_circuit, cond_xy, glob_dat
self._ax.add_patch(box)
xy_plot.append(xy)

if not xy_plot:
# Expression that's only on new-style `expr.Var` nodes, and doesn't need any vertical
# line drawing.
return

qubit_b = min(node_data[node].q_xy, key=lambda xy: xy[1])
clbit_b = min(xy_plot, key=lambda xy: xy[1])

Expand Down
56 changes: 38 additions & 18 deletions qiskit/visualization/circuit/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
import collections
import sys

from qiskit.circuit import Qubit, Clbit, ClassicalRegister
from qiskit.circuit import Qubit, Clbit, ClassicalRegister, CircuitError
from qiskit.circuit import ControlledGate, Reset, Measure
from qiskit.circuit import ControlFlowOp, WhileLoopOp, IfElseOp, ForLoopOp, SwitchCaseOp
from qiskit.circuit.classical import expr
from qiskit.circuit.controlflow import node_resources
from qiskit.circuit.library.standard_gates import IGate, RZZGate, SwapGate, SXGate, SXdgGate
from qiskit.circuit.annotated_operation import _canonicalize_modifiers, ControlModifier
from qiskit.circuit.tools.pi_check import pi_check
from qiskit.qasm3.exporter import QASM3Builder
from qiskit.qasm3 import ast
from qiskit.qasm3.printer import BasicPrinter
from qiskit.qasm3.exporter import _ExprBuilder

from ._utils import (
get_gate_ctrl_text,
Expand Down Expand Up @@ -748,7 +749,6 @@ def __init__(

self._nest_depth = 0 # nesting depth for control flow ops
self._expr_text = "" # expression text to display
self._builder = None # QASM3Builder class instance for expressions

# Because jupyter calls both __repr__ and __repr_html__ for some backends,
# the entire drawer can be run twice which can result in different output
Expand Down Expand Up @@ -1306,25 +1306,44 @@ def add_control_flow(self, node, layers, wire_map):
if (isinstance(node.op, SwitchCaseOp) and isinstance(node.op.target, expr.Expr)) or (
getattr(node.op, "condition", None) and isinstance(node.op.condition, expr.Expr)
):

def lookup_var(var):
"""Look up a classical-expression variable or register/bit in our internal symbol
table, and return an OQ3-like identifier."""
# We don't attempt to disambiguate anything like register/var naming collisions; we
# already don't really show classical variables.
if isinstance(var, expr.Var):
return ast.Identifier(var.name)
if isinstance(var, ClassicalRegister):
return ast.Identifier(var.name)
# Single clbit. This is not actually the correct way to lookup a bit on the
# circuit (it doesn't handle bit bindings fully), but the text drawer doesn't
# completely track inner-outer _bit_ bindings, only inner-indices, so we can't fully
# recover the information losslessly. Since most control-flow uses the control-flow
# builders, we should decay to something usable most of the time.
try:
register, bit_index, reg_index = get_bit_reg_index(self._circuit, var)
except CircuitError:
# We failed to find the bit due to binding problems - fall back to something
# that's probably wrong, but at least disambiguating.
return ast.Identifier(f"_bit{wire_map[var]}")
if register is None:
return ast.Identifier(f"_bit{bit_index}")
return ast.SubscriptedIdentifier(register.name, ast.IntegerLiteral(reg_index))

condition = node.op.target if isinstance(node.op, SwitchCaseOp) else node.op.condition
if self._builder is None:
self._builder = QASM3Builder(
self._circuit,
includeslist=("stdgates.inc",),
basis_gates=("U",),
disable_constants=False,
allow_aliasing=False,
)
self._builder.build_classical_declarations()
draw_conditional = bool(node_resources(condition).clbits)
stream = StringIO()
BasicPrinter(stream, indent=" ").visit(self._builder.build_expression(condition))
BasicPrinter(stream, indent=" ").visit(condition.accept(_ExprBuilder(lookup_var)))
self._expr_text = stream.getvalue()
# Truncate expr_text at 30 chars or user-set expr_len
if len(self._expr_text) > self.expr_len:
self._expr_text = self._expr_text[: self.expr_len] + "..."
else:
draw_conditional = not isinstance(node.op, ForLoopOp)

# # Draw a left box such as If, While, For, and Switch
flow_layer = self.draw_flow_box(node, wire_map, CF_LEFT)
flow_layer = self.draw_flow_box(node, wire_map, CF_LEFT, conditional=draw_conditional)
layers.append(flow_layer.full_layer)

# Get the list of circuits in the ControlFlowOp from the node blocks
Expand All @@ -1351,7 +1370,9 @@ def add_control_flow(self, node, layers, wire_map):

if circ_num > 0:
# Draw a middle box such as Else and Case
flow_layer = self.draw_flow_box(node, flow_wire_map, CF_MID, circ_num - 1)
flow_layer = self.draw_flow_box(
node, flow_wire_map, CF_MID, circ_num - 1, conditional=False
)
layers.append(flow_layer.full_layer)

_, _, nodes = _get_layered_instructions(circuit, wire_map=flow_wire_map)
Expand Down Expand Up @@ -1380,14 +1401,13 @@ def add_control_flow(self, node, layers, wire_map):
layers.append(flow_layer2.full_layer)

# Draw the right box for End
flow_layer = self.draw_flow_box(node, flow_wire_map, CF_RIGHT)
flow_layer = self.draw_flow_box(node, flow_wire_map, CF_RIGHT, conditional=False)
layers.append(flow_layer.full_layer)

def draw_flow_box(self, node, flow_wire_map, section, circ_num=0):
def draw_flow_box(self, node, flow_wire_map, section, circ_num=0, conditional=False):
"""Draw the left, middle, or right of a control flow box"""

op = node.op
conditional = section == CF_LEFT and not isinstance(op, ForLoopOp)
depth = str(self._nest_depth)
if section == CF_LEFT:
etext = ""
Expand Down
Loading

0 comments on commit a78c941

Please sign in to comment.