Skip to content

Commit

Permalink
runtime deferred measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl committed Mar 29, 2023
1 parent acf2c66 commit f7de78d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ def _json_dict_(self) -> Dict[str, Any]:
return {'conditions': self._conditions, 'sub_operation': self._sub_operation}

def _act_on_(self, sim_state: 'cirq.SimulationStateBase') -> bool:
if all(c.resolve(sim_state.classical_data) for c in self._conditions):
protocols.act_on(self._sub_operation, sim_state)
from cirq.sim import SimulationState

if not isinstance(sim_state, SimulationState):
return NotImplemented
sim_state.controlled_act(self._conditions, self._sub_operation)
return True

def _with_measurement_key_mapping_(
Expand Down
34 changes: 34 additions & 0 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
prng: Optional[np.random.RandomState] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
deferred_mode: bool = False
):
"""Inits SimulationState.
Expand All @@ -72,6 +73,7 @@ def __init__(
prng = cast(np.random.RandomState, np.random)
self._prng = prng
self._state = state
self._deferred_mode = deferred_mode

@property
def prng(self) -> np.random.RandomState:
Expand Down Expand Up @@ -99,13 +101,45 @@ def measure(
Raises:
ValueError: If a measurement key has already been logged to a key.
"""
if self._deferred_mode:
from cirq.transformers.measurement_transformers import _MeasurementQid
from cirq import ops
targets = [_MeasurementQid(key, q, 0) for q in qubits]
self.add_qubits(targets)
for q, target in zip(qubits, targets):
protocols.act_on(ops.CX(q, target), self)
return
bits = self._perform_measurement(qubits)
confused = self._confuse_result(bits, qubits, confusion_map)
corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(confused, invert_mask)]
self._classical_data.record_measurement(
value.MeasurementKey.parse_serialized(key), corrected, qubits
)

def controlled_act(self, conditions: Tuple['cirq.Condition', ...], sub_operation: 'cirq.Operation'):
from cirq.transformers.measurement_transformers import _MeasurementQid
if self._deferred_mode:
controls = []
for c in conditions:
if isinstance(c, value.KeyCondition):
qubits = [q for q in self.qubits if isinstance(q, _MeasurementQid) and q._key == c.key]
if len(qubits) != 1:
# TODO: Multi-qubit conditions require
# https://github.com/quantumlib/Cirq/issues/4512
# Remember to update docstring above once this works.
raise ValueError('Only single qubit conditions are allowed.')
controls.extend(qubits)
else:
raise ValueError('Only KeyConditions are allowed.')
op = sub_operation.without_classical_controls().controlled_by(
*controls, control_values=[tuple(range(1, q.dimension)) for q in controls]
)
protocols.act_on(op, self)
return

if all(c.resolve(self.classical_data) for c in conditions):
protocols.act_on(sub_operation, self)

def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]

Expand Down
34 changes: 34 additions & 0 deletions cirq-core/cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,37 @@ def _decompose_(self, qubits):
control_circuit = cirq.Circuit(cirq.XPowGate(exponent=exp, dimension=dim).on(q))

assert np.allclose(resolve(test_circuit), resolve(control_circuit))

@pytest.mark.parametrize('state_type', [cirq.StateVectorSimulationState, cirq.DensityMatrixSimulationState])
def test_basic(state_type):
from cirq.transformers.measurement_transformers import _MeasurementQid
q0, q1 = cirq.LineQubit.range(2)
state = state_type(qubits=[q0, q1])
state._deferred_mode = True
circuit = [
cirq.X(q0),
cirq.measure(q0, key='a'),
cirq.X(q1).with_classical_controls('a'),
]
print()
for op in circuit:
cirq.act_on(op, state)
print(state)
state._deferred_mode = False

cirq.act_on(cirq.measure(q1, key='b'), state)
print(state)

q_ma = _MeasurementQid('a', q0)
control = [
cirq.X(q0),
cirq.CX(q0, q_ma),
cirq.CX(q_ma, q1),
cirq.measure(q_ma, key='a'),
cirq.measure(q1, key='b'),
]
print()
state = state_type(qubits=[q0, q1, q_ma])
for op in control:
cirq.act_on(op, state)
print(state)

0 comments on commit f7de78d

Please sign in to comment.