diff --git a/bqskit/ir/gates/__init__.py b/bqskit/ir/gates/__init__.py index b6a2c7956..56c963f5a 100644 --- a/bqskit/ir/gates/__init__.py +++ b/bqskit/ir/gates/__init__.py @@ -112,6 +112,7 @@ CircuitGate MeasurementPlaceholder + Reset BarrierPlaceholder .. rubric:: Gate Base Classes @@ -141,6 +142,7 @@ from bqskit.ir.gates.constantgate import ConstantGate from bqskit.ir.gates.generalgate import GeneralGate from bqskit.ir.gates.measure import MeasurementPlaceholder +from bqskit.ir.gates.reset import Reset from bqskit.ir.gates.parameterized import * # noqa from bqskit.ir.gates.parameterized import __all__ as parameterized_all from bqskit.ir.gates.qubitgate import QubitGate @@ -150,7 +152,7 @@ __all__ = composed_all + constant_all + parameterized_all __all__ += ['ComposedGate', 'ConstantGate'] __all__ += ['QubitGate', 'QutritGate', 'QuditGate'] -__all__ += ['CircuitGate', 'MeasurementPlaceholder', 'BarrierPlaceholder'] +__all__ += ['CircuitGate', 'MeasurementPlaceholder', 'Reset', 'BarrierPlaceholder'] __all__ += ['GeneralGate'] # TODO: Implement the rest of the gates in: diff --git a/bqskit/ir/gates/reset.py b/bqskit/ir/gates/reset.py new file mode 100644 index 000000000..e72049003 --- /dev/null +++ b/bqskit/ir/gates/reset.py @@ -0,0 +1,28 @@ +"""This module implements the Reset class.""" +from __future__ import annotations + +from bqskit.ir.gates.constantgate import ConstantGate +from bqskit.qis.unitary.unitary import RealVector +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix + +class Reset(ConstantGate): + """Pseudogate to initialize the qudit to |0>.""" + + def __init__(self, radix: int = 2) -> None: + """ + Construct a Reset. + + Args: + radix (int): the dimension of the qudit. (Default: 2) + """ + self._num_qudits = 1 + self._qasm_name = 'reset' + self._radixes = tuple([radix]) + self._num_params = 0 + + def get_unitary(self, params: RealVector = []) -> UnitaryMatrix: + raise RuntimeError( + 'Cannot compute unitary for a reset.', + ) + + diff --git a/bqskit/ir/lang/qasm2/visitor.py b/bqskit/ir/lang/qasm2/visitor.py index 4f7c289ba..23a51c083 100644 --- a/bqskit/ir/lang/qasm2/visitor.py +++ b/bqskit/ir/lang/qasm2/visitor.py @@ -50,6 +50,7 @@ from bqskit.ir.gates.constant.z import ZGate from bqskit.ir.gates.constant.zz import ZZGate from bqskit.ir.gates.measure import MeasurementPlaceholder +from bqskit.ir.gates.reset import Reset from bqskit.ir.gates.parameterized.ccp import CCPGate from bqskit.ir.gates.parameterized.cp import CPGate from bqskit.ir.gates.parameterized.crx import CRXGate @@ -180,12 +181,6 @@ def get_circuit(self) -> Circuit: circuit = Circuit(num_qubits) circuit.extend(self.op_list) - # Add measurements - if len(self.measurements) > 0: - cregs = cast(List[Tuple[str, int]], self.classical_regs) - mph = MeasurementPlaceholder(cregs, self.measurements) - circuit.append_gate(mph, list(self.measurements.keys())) - return circuit def fill_gate_defs(self) -> None: @@ -261,6 +256,9 @@ def fill_gate_defs(self) -> None: self.gate_defs['rccx'] = GateDef('rccx', 0, 3, RCCXGate()) self.gate_defs['rc3x'] = GateDef('rc3x', 0, 4, RC3XGate()) + # reset + self.gate_defs['reset'] = GateDef('reset', 0, 1, Reset()) + def qreg(self, tree: lark.Tree) -> None: """Qubit register node visitor.""" reg_name = tree.children[0] @@ -297,13 +295,6 @@ def gate(self, tree: lark.Tree) -> None: qlist = tree.children[-1] location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) - if any(q in self.measurements for q in location): - raise LangException( - 'BQSKit currently does not support mid-circuit measurements.' - ' Unable to apply a gate on the same qubit where a measurement' - ' has been previously made.', - ) - # Parse gate object gate_name = str(tree.children[0]) if gate_name in self.gate_defs: @@ -595,6 +586,7 @@ def measure(self, tree: lark.Tree) -> None: class_childs = tree.children[1].children qubit_reg_name = str(qubit_childs[0]) class_reg_name = str(class_childs[0]) + cregs = cast(List[Tuple[str, int]], self.classical_regs) if not any(r.name == qubit_reg_name for r in self.qubit_regs): raise LangException( f'Measuring undefined qubit register: {qubit_reg_name}', @@ -605,7 +597,7 @@ def measure(self, tree: lark.Tree) -> None: f'Measuring undefined classical register: {class_reg_name}', ) - if len(qubit_childs) == 1 and len(class_childs) == 1: + if len(qubit_childs) == 1 and len(class_childs) == 1: # for measure all for name, size in self.qubit_regs: if qubit_reg_name == name: qubit_size = size @@ -628,6 +620,9 @@ def measure(self, tree: lark.Tree) -> None: for i in range(qubit_size): self.measurements[outer_idx + i] = (class_reg_name, i) + mph = MeasurementPlaceholder(cregs, self.measurements) + self.gate_defs['measure'] = GateDef('measure', 0, qubit_size, mph) + elif len(qubit_childs) == 2 and len(class_childs) == 2: qubit_index = int(qubit_childs[1]) @@ -642,6 +637,9 @@ def measure(self, tree: lark.Tree) -> None: outer_idx += size self.measurements[qubit_index] = (class_reg_name, class_index) + mph = MeasurementPlaceholder(cregs, self.measurements) + self.gate_defs['measure'] = GateDef('measure', 0, 1, mph) + else: raise LangException( @@ -650,9 +648,75 @@ def measure(self, tree: lark.Tree) -> None: 'measured to a single classical bit.', ) + params: list[float] = [] + qlist = tree.children[0] + location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) + + # Parse gate object + gate_name = tree.data + if gate_name in self.gate_defs: + gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name] + elif gate_name in self.custom_gate_defs: + gate_def = self.custom_gate_defs[gate_name] + else: + raise LangException('Unrecognized gate: %s.' % gate_name) + + if len(params) != gate_def.num_params: + raise LangException( + 'Expected %d params got %d params for gate %s.' + % (gate_def.num_params, len(params), gate_name), + ) + + if len(location) != gate_def.num_vars: + raise LangException( + 'Gate acts on %d qubits, got %d qubit variables.' + % (gate_def.num_vars, len(location)), + ) + + # Build operation and add to circuit + self.op_list.append(gate_def.build_op(location, params)) + def reset(self, tree: lark.Tree) -> None: - """Reset statement node visitor.""" - raise LangException('BQSKit currently does not support resets.') + """reset node visitor.""" + params: list[float] = [] + qlist = tree.children[-1] + gate_name = tree.data + if gate_name in self.gate_defs: + gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name] + elif gate_name in self.custom_gate_defs: + gate_def = self.custom_gate_defs[gate_name] + else: + raise LangException('Unrecognized gate: %s.' % gate_name) + + if len(params) != gate_def.num_params: + raise LangException( + 'Expected %d params got %d params for gate %s.' + % (gate_def.num_params, len(params), gate_name), + ) + + if len(qlist.children) == 2: + location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) + # Parse gate object + if len(location) != gate_def.num_vars: + raise LangException( + 'Gate acts on %d qubits, got %d qubit variables.' + % (gate_def.num_vars, len(location)), + ) + + # Build operation and add to circuit + self.op_list.append(gate_def.build_op(location, params)) + else: + locations = [CircuitLocation(i) for i in range(self.qubit_regs[0][1])] + for location in locations: + if len(location) != gate_def.num_vars: + raise LangException( + 'Gate acts on %d qubits, got %d qubit variables.' + % (gate_def.num_vars, len(location)), + ) + + # Build operation and add to circuit + self.op_list.append(gate_def.build_op(location, params)) + def convert_qubit_ids_to_indices(self, qlist: lark.Tree) -> list[int]: if qlist.data == 'anylist': diff --git a/bqskit/passes/partitioning/quick.py b/bqskit/passes/partitioning/quick.py index 4c9d2e0e4..d010c6def 100644 --- a/bqskit/passes/partitioning/quick.py +++ b/bqskit/passes/partitioning/quick.py @@ -377,4 +377,4 @@ def __init__( # Close the bin for q in location: - self.active_qudits.remove(q) + self.active_qudits.remove(q) \ No newline at end of file diff --git a/bqskit/passes/partitioning/single.py b/bqskit/passes/partitioning/single.py index ce9a431e2..814df3dc4 100644 --- a/bqskit/passes/partitioning/single.py +++ b/bqskit/passes/partitioning/single.py @@ -6,7 +6,8 @@ from bqskit.ir.circuit import Circuit from bqskit.ir.gates.barrier import BarrierPlaceholder from bqskit.ir.region import CircuitRegion - +from bqskit.ir.gates import MeasurementPlaceholder +from bqskit.ir.gates import Reset class GroupSingleQuditGatePass(BasePass): """ @@ -31,7 +32,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None: op = circuit[c, q] if ( op.num_qudits == 1 - and not isinstance(op.gate, BarrierPlaceholder) + and not isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset)) ): if region_start is None: region_start = c diff --git a/tests/ir/lang/test_qasm_decode.py b/tests/ir/lang/test_qasm_decode.py index c912705b6..7227ea1bc 100644 --- a/tests/ir/lang/test_qasm_decode.py +++ b/tests/ir/lang/test_qasm_decode.py @@ -11,6 +11,7 @@ from bqskit.ir.gates.circuitgate import CircuitGate from bqskit.ir.gates.constant.cx import CNOTGate from bqskit.ir.gates.measure import MeasurementPlaceholder +from bqskit.ir.gates.reset import Reset from bqskit.ir.gates.parameterized.u1 import U1Gate from bqskit.ir.gates.parameterized.u1q import U1qGate from bqskit.ir.gates.parameterized.u2 import U2Gate @@ -296,6 +297,29 @@ def test_include_simple(self) -> None: assert circuit.get_unitary().get_distance_from(gate_unitary) < 1e-7 +class TestReset: + def test_reset_single_qubit(self) -> None: + input = """ + OPENQASM 2.0; + qreg q[1]; + reset q[0]; + """ + circuit = OPENQASM2Language().decode(input) + expected = Reset() + assert circuit[0, 0].gate == expected + + def test_reset_register(self) -> None: + input = """ + OPENQASM 2.0; + qreg q[2]; + reset q; + """ + circuit = OPENQASM2Language().decode(input) + expected = Reset() + assert circuit[0, 0].gate == expected + assert circuit[0, 1].gate == expected + + class TestMeasure: def test_measure_single_bit(self) -> None: input = """ diff --git a/tests/ir/lang/test_qasm_encode.py b/tests/ir/lang/test_qasm_encode.py index 182540f39..3ba60775d 100644 --- a/tests/ir/lang/test_qasm_encode.py +++ b/tests/ir/lang/test_qasm_encode.py @@ -2,6 +2,7 @@ from bqskit.ir.circuit import Circuit from bqskit.ir.gates import CNOTGate +from bqskit.ir.gates import Reset from bqskit.ir.gates import U3Gate from bqskit.ir.lang.qasm2 import OPENQASM2Language @@ -41,3 +42,17 @@ def test_nested_circuitgate(self) -> None: qasm = OPENQASM2Language().encode(circuit) parsed_circuit = OPENQASM2Language().decode(qasm) assert parsed_circuit.get_unitary().get_distance_from(in_utry) < 1e-7 + + def test_reset(self) -> None: + circuit = Circuit(1) + circuit.append_gate(Reset(), 0) + + qasm = OPENQASM2Language().encode(circuit) + expected = ( + 'OPENQASM 2.0;\n' + 'include "qelib1.inc";\n' + 'qreg q[1];\n' + 'reset q[0];\n' + ) + assert qasm == expected +