diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index e07d0eccd047..b6333071511a 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -959,43 +959,16 @@ def compose( ) edge_map.update(zip(other.clbits, dest.cbit_argument_conversion(clbits))) - # Cache for `map_register_to_dest`. - _map_register_cache = {} - - def map_register_to_dest(theirs): - """Map the target's registers to suitable equivalents in the destination, adding an - extra one if there's no exact match.""" - if theirs.name in _map_register_cache: - return _map_register_cache[theirs.name] - mapped_bits = [edge_map[bit] for bit in theirs] - for ours in dest.cregs: - if mapped_bits == list(ours): - mapped_theirs = ours - break - else: - mapped_theirs = ClassicalRegister(bits=mapped_bits) - dest.add_register(mapped_theirs) - _map_register_cache[theirs.name] = mapped_theirs - return mapped_theirs - + variable_mapper = _ComposeVariableMapper(dest, edge_map) mapped_instrs: list[CircuitInstruction] = [] for instr in other.data: n_qargs: list[Qubit] = [edge_map[qarg] for qarg in instr.qubits] n_cargs: list[Clbit] = [edge_map[carg] for carg in instr.clbits] n_op = instr.operation.copy() - - if getattr(n_op, "condition", None) is not None: - target, value = n_op.condition - if isinstance(target, Clbit): - n_op.condition = (edge_map[target], value) - else: - n_op.condition = (map_register_to_dest(target), value) - elif isinstance(n_op, SwitchCaseOp): - if isinstance(n_op.target, Clbit): - n_op.target = edge_map[n_op.target] - else: - n_op.target = map_register_to_dest(n_op.target) - + if (condition := getattr(n_op, "condition", None)) is not None: + n_op.condition = variable_mapper.map_condition(condition) + if isinstance(n_op, SwitchCaseOp): + n_op.target = variable_mapper.map_target(n_op.target) mapped_instrs.append(CircuitInstruction(n_op, n_qargs, n_cargs)) if front: @@ -5223,3 +5196,79 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_): else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'" ) raise CircuitError(message) + + +class _ComposeVariableMapper(expr.ExprVisitor[expr.Expr]): + """Stateful helper class that manages the mapping of variables in conditions and expressions to + items in the destination ``circuit``. + + This mutates ``circuit`` by adding registers as required.""" + + __slots__ = ("circuit", "register_map", "bit_map") + + def __init__(self, circuit, bit_map): + self.circuit = circuit + self.register_map = {} + self.bit_map = bit_map + + def _map_register(self, theirs): + """Map the target's registers to suitable equivalents in the destination, adding an + extra one if there's no exact match.""" + if (mapped_theirs := self.register_map.get(theirs.name)) is not None: + return mapped_theirs + mapped_bits = [self.bit_map[bit] for bit in theirs] + for ours in self.circuit.cregs: + if mapped_bits == list(ours): + mapped_theirs = ours + break + else: + mapped_theirs = ClassicalRegister(bits=mapped_bits) + self.circuit.add_register(mapped_theirs) + self.register_map[theirs.name] = mapped_theirs + return mapped_theirs + + def map_condition(self, condition, /): + """Map the given ``condition`` so that it only references variables in the destination + circuit (as given to this class on initialisation).""" + if condition is None: + return None + if isinstance(condition, expr.Expr): + return self.map_expr(condition) + target, value = condition + if isinstance(target, Clbit): + return (self.bit_map[target], value) + return (self._map_register(target), value) + + def map_target(self, target, /): + """Map the runtime variables in a ``target`` of a :class:`.SwitchCaseOp` to the new circuit, + as defined in the ``circuit`` argument of the initialiser of this class.""" + if isinstance(target, Clbit): + return self.bit_map[target] + if isinstance(target, ClassicalRegister): + return self._map_register(target) + return self.map_expr(target) + + def map_expr(self, node: expr.Expr, /) -> expr.Expr: + """Map the variables in an :class:`~.expr.Expr` node to the new circuit.""" + return node.accept(self) + + def visit_var(self, node, /): + if isinstance(node.var, Clbit): + return expr.Var(self.bit_map[node.var], node.type) + if isinstance(node.var, ClassicalRegister): + return expr.Var(self._map_register(node.var), node.type) + # Defensive against the expansion of the variable system; we don't want to silently do the + # wrong thing (which would be `return node` without mapping, right now). + raise CircuitError(f"unhandled variable in 'compose': {node}") # pragma: no cover + + def visit_value(self, node, /): + return expr.Value(node.value, node.type) + + def visit_unary(self, node, /): + return expr.Unary(node.op, node.operand.accept(self), node.type) + + def visit_binary(self, node, /): + return expr.Binary(node.op, node.left.accept(self), node.right.accept(self), node.type) + + def visit_cast(self, node, /): + return expr.Cast(node.operand.accept(self), node.type, implicit=node.implicit) diff --git a/test/python/circuit/test_compose.py b/test/python/circuit/test_compose.py index 0f77493361fb..70eaffb0f748 100644 --- a/test/python/circuit/test_compose.py +++ b/test/python/circuit/test_compose.py @@ -33,6 +33,7 @@ SwitchCaseOp, ) from qiskit.circuit.library import HGate, RZGate, CXGate, CCXGate, TwoLocal +from qiskit.circuit.classical import expr from qiskit.test import QiskitTestCase @@ -789,6 +790,96 @@ def test_compose_noclbits_registerless(self): self.assertEqual(outer.clbits, inner.clbits) self.assertEqual(outer.cregs, []) + def test_expr_condition_is_mapped(self): + """Test that an expression in a condition involving several registers is mapped correctly to + the destination circuit.""" + inner = QuantumCircuit(1) + inner.x(0) + a_src = ClassicalRegister(2, "a_src") + b_src = ClassicalRegister(2, "b_src") + c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src)) + source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src) + + test_1 = lambda: expr.lift(a_src[0]) + test_2 = lambda: expr.logic_not(b_src[1]) + test_3 = lambda: expr.logic_and(expr.bit_and(b_src, 2), expr.less(c_src, 7)) + source.if_test(test_1(), inner.copy(), [0], []) + source.if_else(test_2(), inner.copy(), inner.copy(), [0], []) + source.while_loop(test_3(), inner.copy(), [0], []) + + a_dest = ClassicalRegister(2, "a_dest") + b_dest = ClassicalRegister(2, "b_dest") + dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source) + + # Check that the input conditions weren't mutated. + for in_condition, instruction in zip((test_1, test_2, test_3), source.data): + self.assertEqual(in_condition(), instruction.operation.condition) + + # Should be `a_dest`, `b_dest` and an added one to account for `c_src`. + self.assertEqual(len(dest.cregs), 3) + mapped_reg = dest.cregs[-1] + + expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg) + expected.if_test(expr.lift(a_dest[0]), inner.copy(), [0], []) + expected.if_else(expr.logic_not(b_dest[1]), inner.copy(), inner.copy(), [0], []) + expected.while_loop( + expr.logic_and(expr.bit_and(b_dest, 2), expr.less(mapped_reg, 7)), inner.copy(), [0], [] + ) + self.assertEqual(dest, expected) + + def test_expr_target_is_mapped(self): + """Test that an expression in a switch statement's target is mapping correctly to the + destination circuit.""" + inner1 = QuantumCircuit(1) + inner1.x(0) + inner2 = QuantumCircuit(1) + inner2.z(0) + + a_src = ClassicalRegister(2, "a_src") + b_src = ClassicalRegister(2, "b_src") + c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src)) + source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src) + + test_1 = lambda: expr.lift(a_src[0]) + test_2 = lambda: expr.logic_not(b_src[1]) + test_3 = lambda: expr.lift(b_src) + test_4 = lambda: expr.bit_and(c_src, 7) + source.switch(test_1(), [(False, inner1.copy()), (True, inner2.copy())], [0], []) + source.switch(test_2(), [(False, inner1.copy()), (True, inner2.copy())], [0], []) + source.switch(test_3(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], []) + source.switch(test_4(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], []) + + a_dest = ClassicalRegister(2, "a_dest") + b_dest = ClassicalRegister(2, "b_dest") + dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source) + + # Check that the input expressions weren't mutated. + for in_target, instruction in zip((test_1, test_2, test_3, test_4), source.data): + self.assertEqual(in_target(), instruction.operation.target) + + # Should be `a_dest`, `b_dest` and an added one to account for `c_src`. + self.assertEqual(len(dest.cregs), 3) + mapped_reg = dest.cregs[-1] + + expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg) + expected.switch( + expr.lift(a_dest[0]), [(False, inner1.copy()), (True, inner2.copy())], [0], [] + ) + expected.switch( + expr.logic_not(b_dest[1]), [(False, inner1.copy()), (True, inner2.copy())], [0], [] + ) + expected.switch( + expr.lift(b_dest), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [] + ) + expected.switch( + expr.bit_and(mapped_reg, 7), + [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], + [0], + [], + ) + + self.assertEqual(dest, expected) + if __name__ == "__main__": unittest.main()