Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMR partition #210

Closed
wants to merge 12 commits into from
4 changes: 3 additions & 1 deletion bqskit/ir/gates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@

CircuitGate
MeasurementPlaceholder
Reset
BarrierPlaceholder

.. rubric:: Gate Base Classes
Expand Down Expand Up @@ -141,6 +142,7 @@
from bqskit.ir.gates.constantgate import ConstantGate
from bqskit.ir.gates.generalgate import GeneralGate
from bqskit.ir.gates.measure import MeasurementPlaceholder
from bqskit.ir.gates.reset import Reset
from bqskit.ir.gates.parameterized import * # noqa
from bqskit.ir.gates.parameterized import __all__ as parameterized_all
from bqskit.ir.gates.qubitgate import QubitGate
Expand All @@ -150,7 +152,7 @@
__all__ = composed_all + constant_all + parameterized_all
__all__ += ['ComposedGate', 'ConstantGate']
__all__ += ['QubitGate', 'QutritGate', 'QuditGate']
__all__ += ['CircuitGate', 'MeasurementPlaceholder', 'BarrierPlaceholder']
__all__ += ['CircuitGate', 'MeasurementPlaceholder', 'Reset', 'BarrierPlaceholder']
__all__ += ['GeneralGate']

# TODO: Implement the rest of the gates in:
Expand Down
28 changes: 28 additions & 0 deletions bqskit/ir/gates/reset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""This module implements the Reset class."""
from __future__ import annotations

from bqskit.ir.gates.constantgate import ConstantGate
from bqskit.qis.unitary.unitary import RealVector
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix

class Reset(ConstantGate):
"""Pseudogate to initialize the qudit to |0>."""

def __init__(self, radix: int = 2) -> None:
"""
Construct a Reset.

Args:
radix (int): the dimension of the qudit. (Default: 2)
"""
self._num_qudits = 1
self._qasm_name = 'reset'
self._radixes = tuple([radix])
self._num_params = 0

def get_unitary(self, params: RealVector = []) -> UnitaryMatrix:
raise RuntimeError(
'Cannot compute unitary for a reset.',
)


96 changes: 80 additions & 16 deletions bqskit/ir/lang/qasm2/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from bqskit.ir.gates.constant.z import ZGate
from bqskit.ir.gates.constant.zz import ZZGate
from bqskit.ir.gates.measure import MeasurementPlaceholder
from bqskit.ir.gates.reset import Reset
from bqskit.ir.gates.parameterized.ccp import CCPGate
from bqskit.ir.gates.parameterized.cp import CPGate
from bqskit.ir.gates.parameterized.crx import CRXGate
Expand Down Expand Up @@ -180,12 +181,6 @@ def get_circuit(self) -> Circuit:
circuit = Circuit(num_qubits)
circuit.extend(self.op_list)

# Add measurements
if len(self.measurements) > 0:
cregs = cast(List[Tuple[str, int]], self.classical_regs)
mph = MeasurementPlaceholder(cregs, self.measurements)
circuit.append_gate(mph, list(self.measurements.keys()))

return circuit

def fill_gate_defs(self) -> None:
Expand Down Expand Up @@ -261,6 +256,9 @@ def fill_gate_defs(self) -> None:
self.gate_defs['rccx'] = GateDef('rccx', 0, 3, RCCXGate())
self.gate_defs['rc3x'] = GateDef('rc3x', 0, 4, RC3XGate())

# reset
self.gate_defs['reset'] = GateDef('reset', 0, 1, Reset())

def qreg(self, tree: lark.Tree) -> None:
"""Qubit register node visitor."""
reg_name = tree.children[0]
Expand Down Expand Up @@ -297,13 +295,6 @@ def gate(self, tree: lark.Tree) -> None:
qlist = tree.children[-1]
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))

if any(q in self.measurements for q in location):
raise LangException(
'BQSKit currently does not support mid-circuit measurements.'
' Unable to apply a gate on the same qubit where a measurement'
' has been previously made.',
)

# Parse gate object
gate_name = str(tree.children[0])
if gate_name in self.gate_defs:
Expand Down Expand Up @@ -595,6 +586,7 @@ def measure(self, tree: lark.Tree) -> None:
class_childs = tree.children[1].children
qubit_reg_name = str(qubit_childs[0])
class_reg_name = str(class_childs[0])
cregs = cast(List[Tuple[str, int]], self.classical_regs)
if not any(r.name == qubit_reg_name for r in self.qubit_regs):
raise LangException(
f'Measuring undefined qubit register: {qubit_reg_name}',
Expand All @@ -605,7 +597,7 @@ def measure(self, tree: lark.Tree) -> None:
f'Measuring undefined classical register: {class_reg_name}',
)

if len(qubit_childs) == 1 and len(class_childs) == 1:
if len(qubit_childs) == 1 and len(class_childs) == 1: # for measure all
for name, size in self.qubit_regs:
if qubit_reg_name == name:
qubit_size = size
Expand All @@ -628,6 +620,9 @@ def measure(self, tree: lark.Tree) -> None:

for i in range(qubit_size):
self.measurements[outer_idx + i] = (class_reg_name, i)
mph = MeasurementPlaceholder(cregs, self.measurements)
self.gate_defs['measure'] = GateDef('measure', 0, qubit_size, mph)


elif len(qubit_childs) == 2 and len(class_childs) == 2:
qubit_index = int(qubit_childs[1])
Expand All @@ -642,6 +637,9 @@ def measure(self, tree: lark.Tree) -> None:
outer_idx += size

self.measurements[qubit_index] = (class_reg_name, class_index)
mph = MeasurementPlaceholder(cregs, self.measurements)
self.gate_defs['measure'] = GateDef('measure', 0, 1, mph)


else:
raise LangException(
Expand All @@ -650,9 +648,75 @@ def measure(self, tree: lark.Tree) -> None:
'measured to a single classical bit.',
)

params: list[float] = []
qlist = tree.children[0]
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))

# Parse gate object
gate_name = tree.data
if gate_name in self.gate_defs:
gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name]
elif gate_name in self.custom_gate_defs:
gate_def = self.custom_gate_defs[gate_name]
else:
raise LangException('Unrecognized gate: %s.' % gate_name)

if len(params) != gate_def.num_params:
raise LangException(
'Expected %d params got %d params for gate %s.'
% (gate_def.num_params, len(params), gate_name),
)

if len(location) != gate_def.num_vars:
raise LangException(
'Gate acts on %d qubits, got %d qubit variables.'
% (gate_def.num_vars, len(location)),
)

# Build operation and add to circuit
self.op_list.append(gate_def.build_op(location, params))

def reset(self, tree: lark.Tree) -> None:
"""Reset statement node visitor."""
raise LangException('BQSKit currently does not support resets.')
"""reset node visitor."""
params: list[float] = []
qlist = tree.children[-1]
gate_name = tree.data
if gate_name in self.gate_defs:
gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name]
elif gate_name in self.custom_gate_defs:
gate_def = self.custom_gate_defs[gate_name]
else:
raise LangException('Unrecognized gate: %s.' % gate_name)

if len(params) != gate_def.num_params:
raise LangException(
'Expected %d params got %d params for gate %s.'
% (gate_def.num_params, len(params), gate_name),
)

if len(qlist.children) == 2:
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))
# Parse gate object
if len(location) != gate_def.num_vars:
raise LangException(
'Gate acts on %d qubits, got %d qubit variables.'
% (gate_def.num_vars, len(location)),
)

# Build operation and add to circuit
self.op_list.append(gate_def.build_op(location, params))
else:
locations = [CircuitLocation(i) for i in range(self.qubit_regs[0][1])]
for location in locations:
if len(location) != gate_def.num_vars:
raise LangException(
'Gate acts on %d qubits, got %d qubit variables.'
% (gate_def.num_vars, len(location)),
)

# Build operation and add to circuit
self.op_list.append(gate_def.build_op(location, params))


def convert_qubit_ids_to_indices(self, qlist: lark.Tree) -> list[int]:
if qlist.data == 'anylist':
Expand Down
2 changes: 1 addition & 1 deletion bqskit/passes/partitioning/quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,4 +377,4 @@ def __init__(

# Close the bin
for q in location:
self.active_qudits.remove(q)
self.active_qudits.remove(q)
5 changes: 3 additions & 2 deletions bqskit/passes/partitioning/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.barrier import BarrierPlaceholder
from bqskit.ir.region import CircuitRegion

from bqskit.ir.gates import MeasurementPlaceholder
from bqskit.ir.gates import Reset

class GroupSingleQuditGatePass(BasePass):
"""
Expand All @@ -31,7 +32,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
op = circuit[c, q]
if (
op.num_qudits == 1
and not isinstance(op.gate, BarrierPlaceholder)
and not isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset))
):
if region_start is None:
region_start = c
Expand Down
24 changes: 24 additions & 0 deletions tests/ir/lang/test_qasm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from bqskit.ir.gates.circuitgate import CircuitGate
from bqskit.ir.gates.constant.cx import CNOTGate
from bqskit.ir.gates.measure import MeasurementPlaceholder
from bqskit.ir.gates.reset import Reset
from bqskit.ir.gates.parameterized.u1 import U1Gate
from bqskit.ir.gates.parameterized.u1q import U1qGate
from bqskit.ir.gates.parameterized.u2 import U2Gate
Expand Down Expand Up @@ -296,6 +297,29 @@ def test_include_simple(self) -> None:
assert circuit.get_unitary().get_distance_from(gate_unitary) < 1e-7


class TestReset:
def test_reset_single_qubit(self) -> None:
input = """
OPENQASM 2.0;
qreg q[1];
reset q[0];
"""
circuit = OPENQASM2Language().decode(input)
expected = Reset()
assert circuit[0, 0].gate == expected

def test_reset_register(self) -> None:
input = """
OPENQASM 2.0;
qreg q[2];
reset q;
"""
circuit = OPENQASM2Language().decode(input)
expected = Reset()
assert circuit[0, 0].gate == expected
assert circuit[0, 1].gate == expected


class TestMeasure:
def test_measure_single_bit(self) -> None:
input = """
Expand Down
15 changes: 15 additions & 0 deletions tests/ir/lang/test_qasm_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from bqskit.ir.circuit import Circuit
from bqskit.ir.gates import CNOTGate
from bqskit.ir.gates import Reset
from bqskit.ir.gates import U3Gate
from bqskit.ir.lang.qasm2 import OPENQASM2Language

Expand Down Expand Up @@ -41,3 +42,17 @@ def test_nested_circuitgate(self) -> None:
qasm = OPENQASM2Language().encode(circuit)
parsed_circuit = OPENQASM2Language().decode(qasm)
assert parsed_circuit.get_unitary().get_distance_from(in_utry) < 1e-7

def test_reset(self) -> None:
circuit = Circuit(1)
circuit.append_gate(Reset(), 0)

qasm = OPENQASM2Language().encode(circuit)
expected = (
'OPENQASM 2.0;\n'
'include "qelib1.inc";\n'
'qreg q[1];\n'
'reset q[0];\n'
)
assert qasm == expected

Loading