Skip to content

Commit

Permalink
Add Expr support to QuantumCircuit.compose
Browse files Browse the repository at this point in the history
This relatively straightforwardly generalises the mapping of variables
that already exists in `QuantumCircuit.compose` to be able to handle
arbitrary `Expr` nodes as well.  We must take care not to accidentally
mutate the `Expr` nodes in the input circuit in the name of efficiency.
  • Loading branch information
jakelishman committed Jul 3, 2023
1 parent 0abfb1f commit 3d77997
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 32 deletions.
113 changes: 81 additions & 32 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,43 +959,16 @@ def compose(
)
edge_map.update(zip(other.clbits, dest.cbit_argument_conversion(clbits)))

# Cache for `map_register_to_dest`.
_map_register_cache = {}

def map_register_to_dest(theirs):
"""Map the target's registers to suitable equivalents in the destination, adding an
extra one if there's no exact match."""
if theirs.name in _map_register_cache:
return _map_register_cache[theirs.name]
mapped_bits = [edge_map[bit] for bit in theirs]
for ours in dest.cregs:
if mapped_bits == list(ours):
mapped_theirs = ours
break
else:
mapped_theirs = ClassicalRegister(bits=mapped_bits)
dest.add_register(mapped_theirs)
_map_register_cache[theirs.name] = mapped_theirs
return mapped_theirs

variable_mapper = _ComposeVariableMapper(dest, edge_map)
mapped_instrs: list[CircuitInstruction] = []
for instr in other.data:
n_qargs: list[Qubit] = [edge_map[qarg] for qarg in instr.qubits]
n_cargs: list[Clbit] = [edge_map[carg] for carg in instr.clbits]
n_op = instr.operation.copy()

if getattr(n_op, "condition", None) is not None:
target, value = n_op.condition
if isinstance(target, Clbit):
n_op.condition = (edge_map[target], value)
else:
n_op.condition = (map_register_to_dest(target), value)
elif isinstance(n_op, SwitchCaseOp):
if isinstance(n_op.target, Clbit):
n_op.target = edge_map[n_op.target]
else:
n_op.target = map_register_to_dest(n_op.target)

if (condition := getattr(n_op, "condition", None)) is not None:
n_op.condition = variable_mapper.map_condition(condition)
if isinstance(n_op, SwitchCaseOp):
n_op.target = variable_mapper.map_target(n_op.target)
mapped_instrs.append(CircuitInstruction(n_op, n_qargs, n_cargs))

if front:
Expand Down Expand Up @@ -5223,3 +5196,79 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_):
else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'"
)
raise CircuitError(message)


class _ComposeVariableMapper(expr.ExprVisitor[expr.Expr]):
"""Stateful helper class that manages the mapping of variables in conditions and expressions to
items in the destination ``circuit``.
This mutates ``circuit`` by adding registers as required."""

__slots__ = ("circuit", "register_map", "bit_map")

def __init__(self, circuit, bit_map):
self.circuit = circuit
self.register_map = {}
self.bit_map = bit_map

def _map_register(self, theirs):
"""Map the target's registers to suitable equivalents in the destination, adding an
extra one if there's no exact match."""
if (mapped_theirs := self.register_map.get(theirs.name)) is not None:
return mapped_theirs
mapped_bits = [self.bit_map[bit] for bit in theirs]
for ours in self.circuit.cregs:
if mapped_bits == list(ours):
mapped_theirs = ours
break
else:
mapped_theirs = ClassicalRegister(bits=mapped_bits)
self.circuit.add_register(mapped_theirs)
self.register_map[theirs.name] = mapped_theirs
return mapped_theirs

def map_condition(self, condition, /):
"""Map the given ``condition`` so that it only references variables in the destination
circuit (as given to this class on initialisation)."""
if condition is None:
return None
if isinstance(condition, expr.Expr):
return self.map_expr(condition)
target, value = condition
if isinstance(target, Clbit):
return (self.bit_map[target], value)
return (self._map_register(target), value)

def map_target(self, target, /):
"""Map the runtime variables in a ``target`` of a :class:`.SwitchCaseOp` to the new circuit,
as defined in the ``circuit`` argument of the initialiser of this class."""
if isinstance(target, Clbit):
return self.bit_map[target]
if isinstance(target, ClassicalRegister):
return self._map_register(target)
return self.map_expr(target)

def map_expr(self, node: expr.Expr, /) -> expr.Expr:
"""Map the variables in an :class:`~.expr.Expr` node to the new circuit."""
return node.accept(self)

def visit_var(self, node, /):
if isinstance(node.var, Clbit):
return expr.Var(self.bit_map[node.var], node.type)
if isinstance(node.var, ClassicalRegister):
return expr.Var(self._map_register(node.var), node.type)
# Defensive against the expansion of the variable system; we don't want to silently do the
# wrong thing (which would be `return node` without mapping, right now).
raise CircuitError(f"unhandled variable in 'compose': {node}") # pragma: no cover

def visit_value(self, node, /):
return expr.Value(node.value, node.type)

def visit_unary(self, node, /):
return expr.Unary(node.op, node.operand.accept(self), node.type)

def visit_binary(self, node, /):
return expr.Binary(node.op, node.left.accept(self), node.right.accept(self), node.type)

def visit_cast(self, node, /):
return expr.Cast(node.operand.accept(self), node.type, implicit=node.implicit)
91 changes: 91 additions & 0 deletions test/python/circuit/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SwitchCaseOp,
)
from qiskit.circuit.library import HGate, RZGate, CXGate, CCXGate, TwoLocal
from qiskit.circuit.classical import expr
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -789,6 +790,96 @@ def test_compose_noclbits_registerless(self):
self.assertEqual(outer.clbits, inner.clbits)
self.assertEqual(outer.cregs, [])

def test_expr_condition_is_mapped(self):
"""Test that an expression in a condition involving several registers is mapped correctly to
the destination circuit."""
inner = QuantumCircuit(1)
inner.x(0)
a_src = ClassicalRegister(2, "a_src")
b_src = ClassicalRegister(2, "b_src")
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)

test_1 = lambda: expr.lift(a_src[0])
test_2 = lambda: expr.logic_not(b_src[1])
test_3 = lambda: expr.logic_and(expr.bit_and(b_src, 2), expr.less(c_src, 7))
source.if_test(test_1(), inner.copy(), [0], [])
source.if_else(test_2(), inner.copy(), inner.copy(), [0], [])
source.while_loop(test_3(), inner.copy(), [0], [])

a_dest = ClassicalRegister(2, "a_dest")
b_dest = ClassicalRegister(2, "b_dest")
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)

# Check that the input conditions weren't mutated.
for in_condition, instruction in zip((test_1, test_2, test_3), source.data):
self.assertEqual(in_condition(), instruction.operation.condition)

# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
self.assertEqual(len(dest.cregs), 3)
mapped_reg = dest.cregs[-1]

expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
expected.if_test(expr.lift(a_dest[0]), inner.copy(), [0], [])
expected.if_else(expr.logic_not(b_dest[1]), inner.copy(), inner.copy(), [0], [])
expected.while_loop(
expr.logic_and(expr.bit_and(b_dest, 2), expr.less(mapped_reg, 7)), inner.copy(), [0], []
)
self.assertEqual(dest, expected)

def test_expr_target_is_mapped(self):
"""Test that an expression in a switch statement's target is mapping correctly to the
destination circuit."""
inner1 = QuantumCircuit(1)
inner1.x(0)
inner2 = QuantumCircuit(1)
inner2.z(0)

a_src = ClassicalRegister(2, "a_src")
b_src = ClassicalRegister(2, "b_src")
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)

test_1 = lambda: expr.lift(a_src[0])
test_2 = lambda: expr.logic_not(b_src[1])
test_3 = lambda: expr.lift(b_src)
test_4 = lambda: expr.bit_and(c_src, 7)
source.switch(test_1(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
source.switch(test_2(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
source.switch(test_3(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])
source.switch(test_4(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])

a_dest = ClassicalRegister(2, "a_dest")
b_dest = ClassicalRegister(2, "b_dest")
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)

# Check that the input expressions weren't mutated.
for in_target, instruction in zip((test_1, test_2, test_3, test_4), source.data):
self.assertEqual(in_target(), instruction.operation.target)

# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
self.assertEqual(len(dest.cregs), 3)
mapped_reg = dest.cregs[-1]

expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
expected.switch(
expr.lift(a_dest[0]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
)
expected.switch(
expr.logic_not(b_dest[1]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
)
expected.switch(
expr.lift(b_dest), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], []
)
expected.switch(
expr.bit_and(mapped_reg, 7),
[(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())],
[0],
[],
)

self.assertEqual(dest, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3d77997

Please sign in to comment.