Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
peachnuts committed Jan 18, 2024
2 parents d965dc7 + dcc2bbb commit c6c888a
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 18 deletions.
47 changes: 34 additions & 13 deletions bqskit/ir/lang/qasm2/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,6 @@ def get_circuit(self) -> Circuit:
circuit = Circuit(num_qubits)
circuit.extend(self.op_list)

# Add measurements
if len(self.measurements) > 0:
cregs = cast(List[Tuple[str, int]], self.classical_regs)
mph = MeasurementPlaceholder(cregs, self.measurements)
circuit.append_gate(mph, list(self.measurements.keys()))

return circuit

def fill_gate_defs(self) -> None:
Expand Down Expand Up @@ -301,13 +295,6 @@ def gate(self, tree: lark.Tree) -> None:
qlist = tree.children[-1]
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))

if any(q in self.measurements for q in location):
raise LangException(
'BQSKit currently does not support mid-circuit measurements.'
' Unable to apply a gate on the same qubit where a measurement'
' has been previously made.',
)

# Parse gate object
gate_name = str(tree.children[0])
if gate_name in self.gate_defs:
Expand Down Expand Up @@ -654,6 +641,40 @@ def measure(self, tree: lark.Tree) -> None:
'measured to a single classical bit.',
)


for key, item in enumerate(self.measurements.items()):
cregs = cast(List[Tuple[str, int]], self.classical_regs)
mph = MeasurementPlaceholder(cregs, {key: item})
self.gate_defs['measure'] = GateDef('measure', 0, 1, mph)

params = []
qlist = tree.children[0]
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))

# Parse gate object
gate_name = tree.data
if gate_name in self.gate_defs:
gate_def: GateDef | CustomGateDef = self.gate_defs[gate_name]
elif gate_name in self.custom_gate_defs:
gate_def = self.custom_gate_defs[gate_name]
else:
raise LangException('Unrecognized gate: %s.' % gate_name)

if len(params) != gate_def.num_params:
raise LangException(
'Expected %d params got %d params for gate %s.'
% (gate_def.num_params, len(params), gate_name),
)

if len(location) != gate_def.num_vars:
raise LangException(
'Gate acts on %d qubits, got %d qubit variables.'
% (gate_def.num_vars, len(location)),
)

# Build operation and add to circuit
self.op_list.append(gate_def.build_op(location, params))

def reset(self, tree: lark.Tree) -> None:
"""reset node visitor."""
params: list[float] = []
Expand Down
101 changes: 98 additions & 3 deletions bqskit/passes/partitioning/quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from bqskit.ir.location import CircuitLocation
from bqskit.ir.point import CircuitPoint
from bqskit.utils.typing import is_integer
from bqskit.ir.gates import MeasurementPlaceholder
from bqskit.ir.gates import Reset

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,12 +114,12 @@ def process_pending_bins() -> None:
loc = list(sorted(bin.qudits))

# Merge previously placed blocks if possible
merging = not isinstance(bin, BarrierBin)
merging = not isinstance(bin, (BarrierBin, MeasurementBin, ResetBin))
while merging:
merging = False
for p in partitioned_circuit.rear:
op = partitioned_circuit[p]
if isinstance(op.gate, BarrierPlaceholder):
if isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset)):
# Don't merge through barriers
continue
qudits = list(op.location)
Expand Down Expand Up @@ -152,7 +154,7 @@ def process_pending_bins() -> None:
partitioned_circuit.append_circuit(
subc,
loc,
not isinstance(bin, BarrierBin),
not isinstance(bin, (BarrierBin, MeasurementBin, ResetBin)),
True,
)
for qudit in bin.qudits:
Expand Down Expand Up @@ -191,6 +193,32 @@ def process_pending_bins() -> None:
pending_bins.append(BarrierBin(point, location, circuit))
continue

# Measurement close all overlapping bins
if isinstance(op.gate, MeasurementPlaceholder):
for bin in overlapping_bins:
if close_bin_qudits(bin, location, cycle):
num_closed += 1
else:
extended = [q for q in location if q not in bin.qudits]
bin.blocked_qudits.update(extended)

# Track the measurement to restore it in partitioned circuit
pending_bins.append(MeasurementBin(point, location, circuit))
continue

# Reset close all overlapping bins
if isinstance(op.gate, Reset):
for bin in overlapping_bins:
if close_bin_qudits(bin, location, cycle):
num_closed += 1
else:
extended = [q for q in location if q not in bin.qudits]
bin.blocked_qudits.update(extended)

# Track the reset to restore it in partitioned circuit
pending_bins.append(ResetBin(point, location, circuit))
continue

# Get all the currently active bins that can have op added to them
admissible_bins = [
bin for bin in overlapping_bins
Expand Down Expand Up @@ -378,3 +406,70 @@ def __init__(
# Close the bin
for q in location:
self.active_qudits.remove(q)


class MeasurementBin(Bin):
"""A special bin made to mark and preserve measurement location."""

def __init__(
self,
point: CircuitPoint,
location: CircuitLocation,
circuit: Circuit,
) -> None:
"""Initialize a MeasurementBin with the point and location of a measurement."""
super().__init__()

# Add the measurement
self.add_op(point, location)

# Barriar bins fill the volume to the next gates

nexts = circuit.next(point)
ends: dict[int, int | None] = {q: None for q in location}
for p in nexts:
loc = circuit[p].location
for q in loc:
if q in ends and (
ends[q] is None or ends[q] >= p.cycle): # type: ignore # noqa # short-circuit safety for >=
ends[q] = p.cycle - 1

self.ends = ends

# Close the bin
for q in location:
self.active_qudits.remove(q)


class ResetBin(Bin):
"""A special bin made to mark and preserve reset location."""

def __init__(
self,
point: CircuitPoint,
location: CircuitLocation,
circuit: Circuit,
) -> None:
"""Initialize a reset with the point and location of a reset."""
super().__init__()

# Add the reset
self.add_op(point, location)

# Barriar bins fill the volume to the next gates

nexts = circuit.next(point)
ends: dict[int, int | None] = {q: None for q in location}
for p in nexts:
loc = circuit[p].location
for q in loc:
if q in ends and (
ends[q] is None or ends[q] >= p.cycle): # type: ignore # noqa # short-circuit safety for >=
ends[q] = p.cycle - 1

self.ends = ends

# Close the bin
for q in location:
self.active_qudits.remove(q)

5 changes: 3 additions & 2 deletions bqskit/passes/partitioning/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.barrier import BarrierPlaceholder
from bqskit.ir.region import CircuitRegion

from bqskit.ir.gates import MeasurementPlaceholder
from bqskit.ir.gates import Reset

class GroupSingleQuditGatePass(BasePass):
"""
Expand All @@ -31,7 +32,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
op = circuit[c, q]
if (
op.num_qudits == 1
and not isinstance(op.gate, BarrierPlaceholder)
and not isinstance(op.gate, (BarrierPlaceholder, MeasurementPlaceholder, Reset))
):
if region_start is None:
region_start = c
Expand Down

0 comments on commit c6c888a

Please sign in to comment.