From 2d4053b2f2222e9cb14ea801d628c2ad25339cea Mon Sep 17 00:00:00 2001 From: peachnuts Date: Mon, 4 Dec 2023 11:41:47 -0800 Subject: [PATCH 1/2] add support for mid-circuit measurement --- bqskit/ir/lang/qasm2/visitor.py | 48 ++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/bqskit/ir/lang/qasm2/visitor.py b/bqskit/ir/lang/qasm2/visitor.py index 8ff6fecc7..05e0441cf 100644 --- a/bqskit/ir/lang/qasm2/visitor.py +++ b/bqskit/ir/lang/qasm2/visitor.py @@ -181,12 +181,6 @@ def get_circuit(self) -> Circuit: circuit = Circuit(num_qubits) circuit.extend(self.op_list) - # Add measurements - if len(self.measurements) > 0: - cregs = cast(List[Tuple[str, int]], self.classical_regs) - mph = MeasurementPlaceholder(cregs, self.measurements) - circuit.append_gate(mph, list(self.measurements.keys())) - return circuit def fill_gate_defs(self) -> None: @@ -301,13 +295,6 @@ def gate(self, tree: lark.Tree) -> None: qlist = tree.children[-1] location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) - if any(q in self.measurements for q in location): - raise LangException( - 'BQSKit currently does not support mid-circuit measurements.' - ' Unable to apply a gate on the same qubit where a measurement' - ' has been previously made.', - ) - # Parse gate object gate_name = str(tree.children[0]) if gate_name in self.gate_defs: @@ -654,11 +641,46 @@ def measure(self, tree: lark.Tree) -> None: 'measured to a single classical bit.', ) + + for key, item in enumerate(self.measurements.items()): + cregs = cast(List[Tuple[str, int]], self.classical_regs) + mph = MeasurementPlaceholder(cregs, {key: item}) + self.gate_defs['measure'] = GateDef('measure', 0, 1, mph) + + params = [] + qlist = tree.children[0] + location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) + + # Parse gate object + gate_name = tree.data + if gate_name in self.gate_defs: + gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name] + elif gate_name in self.custom_gate_defs: + gate_def = self.custom_gate_defs[gate_name] + else: + raise LangException('Unrecognized gate: %s.' % gate_name) + + if len(params) != gate_def.num_params: + raise LangException( + 'Expected %d params got %d params for gate %s.' + % (gate_def.num_params, len(params), gate_name), + ) + + if len(location) != gate_def.num_vars: + raise LangException( + 'Gate acts on %d qubits, got %d qubit variables.' + % (gate_def.num_vars, len(location)), + ) + + # Build operation and add to circuit + self.op_list.append(gate_def.build_op(location, params)) + def reset(self, tree: lark.Tree) -> None: """reset node visitor.""" params = [] qlist = tree.children[-1] location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist)) + # Parse gate object gate_name = tree.data if gate_name in self.gate_defs: From dcc2bbbc9fef1cc43b6991b585bf0a1efc30132e Mon Sep 17 00:00:00 2001 From: peachnuts Date: Mon, 4 Dec 2023 14:39:14 -0800 Subject: [PATCH 2/2] add MMR partition support --- bqskit/passes/partitioning/quick.py | 101 ++++++++++++++++++++++++++- bqskit/passes/partitioning/single.py | 5 +- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/bqskit/passes/partitioning/quick.py b/bqskit/passes/partitioning/quick.py index 4c9d2e0e4..b5322c1ec 100644 --- a/bqskit/passes/partitioning/quick.py +++ b/bqskit/passes/partitioning/quick.py @@ -13,6 +13,8 @@ from bqskit.ir.location import CircuitLocation from bqskit.ir.point import CircuitPoint from bqskit.utils.typing import is_integer +from bqskit.ir.gates import MeasurementPlaceholder +from bqskit.ir.gates import Reset _logger = logging.getLogger(__name__) @@ -112,12 +114,12 @@ def process_pending_bins() -> None: loc = list(sorted(bin.qudits)) # Merge previously placed blocks if possible - merging = not isinstance(bin, BarrierBin) + merging = not isinstance(bin, (BarrierBin, MeasurementBin, ResetBin)) while merging: merging = False for p in partitioned_circuit.rear: op = partitioned_circuit[p] - if isinstance(op.gate, BarrierPlaceholder): + if isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset)): # Don't merge through barriers continue qudits = list(op.location) @@ -152,7 +154,7 @@ def process_pending_bins() -> None: partitioned_circuit.append_circuit( subc, loc, - not isinstance(bin, BarrierBin), + not isinstance(bin, (BarrierBin, MeasurementBin, ResetBin)), True, ) for qudit in bin.qudits: @@ -191,6 +193,32 @@ def process_pending_bins() -> None: pending_bins.append(BarrierBin(point, location, circuit)) continue + # Measurement close all overlapping bins + if isinstance(op.gate, MeasurementPlaceholder): + for bin in overlapping_bins: + if close_bin_qudits(bin, location, cycle): + num_closed += 1 + else: + extended = [q for q in location if q not in bin.qudits] + bin.blocked_qudits.update(extended) + + # Track the measurement to restore it in partitioned circuit + pending_bins.append(MeasurementBin(point, location, circuit)) + continue + + # Reset close all overlapping bins + if isinstance(op.gate, Reset): + for bin in overlapping_bins: + if close_bin_qudits(bin, location, cycle): + num_closed += 1 + else: + extended = [q for q in location if q not in bin.qudits] + bin.blocked_qudits.update(extended) + + # Track the reset to restore it in partitioned circuit + pending_bins.append(ResetBin(point, location, circuit)) + continue + # Get all the currently active bins that can have op added to them admissible_bins = [ bin for bin in overlapping_bins @@ -378,3 +406,70 @@ def __init__( # Close the bin for q in location: self.active_qudits.remove(q) + + +class MeasurementBin(Bin): + """A special bin made to mark and preserve measurement location.""" + + def __init__( + self, + point: CircuitPoint, + location: CircuitLocation, + circuit: Circuit, + ) -> None: + """Initialize a MeasurementBin with the point and location of a measurement.""" + super().__init__() + + # Add the measurement + self.add_op(point, location) + + # Barriar bins fill the volume to the next gates + + nexts = circuit.next(point) + ends: dict[int, int | None] = {q: None for q in location} + for p in nexts: + loc = circuit[p].location + for q in loc: + if q in ends and ( + ends[q] is None or ends[q] >= p.cycle): # type: ignore # noqa # short-circuit safety for >= + ends[q] = p.cycle - 1 + + self.ends = ends + + # Close the bin + for q in location: + self.active_qudits.remove(q) + + +class ResetBin(Bin): + """A special bin made to mark and preserve reset location.""" + + def __init__( + self, + point: CircuitPoint, + location: CircuitLocation, + circuit: Circuit, + ) -> None: + """Initialize a reset with the point and location of a reset.""" + super().__init__() + + # Add the reset + self.add_op(point, location) + + # Barriar bins fill the volume to the next gates + + nexts = circuit.next(point) + ends: dict[int, int | None] = {q: None for q in location} + for p in nexts: + loc = circuit[p].location + for q in loc: + if q in ends and ( + ends[q] is None or ends[q] >= p.cycle): # type: ignore # noqa # short-circuit safety for >= + ends[q] = p.cycle - 1 + + self.ends = ends + + # Close the bin + for q in location: + self.active_qudits.remove(q) + diff --git a/bqskit/passes/partitioning/single.py b/bqskit/passes/partitioning/single.py index ce9a431e2..814df3dc4 100644 --- a/bqskit/passes/partitioning/single.py +++ b/bqskit/passes/partitioning/single.py @@ -6,7 +6,8 @@ from bqskit.ir.circuit import Circuit from bqskit.ir.gates.barrier import BarrierPlaceholder from bqskit.ir.region import CircuitRegion - +from bqskit.ir.gates import MeasurementPlaceholder +from bqskit.ir.gates import Reset class GroupSingleQuditGatePass(BasePass): """ @@ -31,7 +32,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None: op = circuit[c, q] if ( op.num_qudits == 1 - and not isinstance(op.gate, BarrierPlaceholder) + and not isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset)) ): if region_start is None: region_start = c