Skip to content

Commit

Permalink
Scoping for control keys (#4736)
Browse files Browse the repository at this point in the history
Part 12 of https://tinyurl.com/cirq-feedforward

Allow for classically controlled subcircuits to be controlled by measurements done prior to that subcircuit.

Such behavior "already worked" for subcircuits without repetitions. But for subcircuits with repetitions, all measurement/control keys were "lifted" to the repetition id. i.e. in a subcircuit containing control key "A", that would be lifted to "0:A" and "1:A" to distinguish them. However if the measurement 'A' is outside of the subcircuit, then that lifting would cause the control keys to no longer match the measurement key.

This PR fixes the above problem by passing context data through subcircuit decomposition. The behavior is that if a control key matches a measurement key from the same subcircuit, then both are lifted to the corresponding repetition. Otherwise the control key matches to the nearest-scoped measurement key outside of the subcircuit.
  • Loading branch information
daxfohl authored Dec 20, 2021
1 parent e7892c3 commit 4ebfb1c
Show file tree
Hide file tree
Showing 19 changed files with 559 additions and 45 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@
with_key_path,
with_key_path_prefix,
with_measurement_key_mapping,
with_rescoped_keys,
)

from cirq.ion import (
Expand Down
19 changes: 18 additions & 1 deletion cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,18 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
[protocols.with_key_path_prefix(moment, prefix) for moment in self.moments]
)

def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
moments = []
for moment in self.moments:
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
moments.append(new_moment)
bindable_keys |= protocols.measurement_key_objs(new_moment)
return self._with_sliced_moments(moments)

def _qid_shape_(self) -> Tuple[int, ...]:
return self.qid_shape()

Expand Down Expand Up @@ -1171,7 +1183,8 @@ def to_text_diagram_drawer(
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
cbits = tuple(
sorted(
(key for op in self.all_operations() for key in protocols.control_keys(op)), key=str
set(key for op in self.all_operations() for key in protocols.control_keys(op)),
key=str,
)
)
labels = qubits + cbits
Expand Down Expand Up @@ -1524,6 +1537,10 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors
)

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
return controls - protocols.measurement_key_objs(self)


def _overlap_collision_time(
c1: Sequence['cirq.Moment'], c2: Sequence['cirq.Moment'], align: 'cirq.Alignment'
Expand Down
69 changes: 50 additions & 19 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
AbstractSet,
Callable,
Dict,
FrozenSet,
Iterator,
List,
Optional,
Tuple,
Union,
Iterator,
)

import dataclasses
Expand Down Expand Up @@ -77,11 +78,18 @@ class CircuitOperation(ops.Operation):
The keys and values should be unindexed (i.e. without repetition_ids).
The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`.
param_resolver: Resolved values for parameters in the circuit.
parent_path: A tuple of identifiers for any parent CircuitOperations containing this one.
repetition_ids: List of identifiers for each repetition of the
CircuitOperation. If populated, the length should be equal to the
repetitions. If not populated and abs(`repetitions`) > 1, it is
initialized to strings for numbers in `range(repetitions)`.
parent_path: A tuple of identifiers for any parent CircuitOperations
containing this one.
extern_keys: The set of measurement keys defined at extern scope. The
values here are used by decomposition and simulation routines to
cache which external measurement keys exist as possible binding
targets for unbound `ClassicallyControlledOperation` keys. This
field is not intended to be set or changed manually, and should be
empty in circuits that aren't in the middle of decomposition.
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
Expand All @@ -96,6 +104,7 @@ class CircuitOperation(ops.Operation):
param_resolver: study.ParamResolver = study.ParamResolver()
repetition_ids: Optional[List[str]] = dataclasses.field(default=None)
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
Expand Down Expand Up @@ -184,9 +193,7 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
for repetition_id in self.repetition_ids
for key in circuit_keys
}
circuit_keys = {
protocols.with_key_path_prefix(key, self.parent_path) for key in circuit_keys
}
circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys}
object.__setattr__(
self,
'_cached_measurement_key_objs',
Expand All @@ -200,6 +207,11 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}

def _control_keys_(self) -> AbstractSet[value.MeasurementKey]:
if not protocols.control_keys(self.circuit):
return frozenset()
return protocols.control_keys(self.mapped_circuit())

def _parameter_names_(self) -> AbstractSet[str]:
return {
name
Expand All @@ -222,26 +234,28 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
like `cirq.decompose(self)`, but preserving moment structure.
"""
circuit = self.circuit.unfreeze()
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
has_measurements = protocols.is_measurement(circuit)
if has_measurements:
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if deep:
circuit = circuit.map_operations(
lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op
)
if self.param_resolver:
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if self.repetition_ids:
if not has_measurements:
if not protocols.is_measurement(circuit):
circuit = circuit * abs(self.repetitions)
else:
circuit = circuits.Circuit(
protocols.with_key_path_prefix(circuit, (rep,)) for rep in self.repetition_ids
protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids
)
if self.parent_path:
circuit = protocols.with_key_path_prefix(circuit, self.parent_path)
circuit = protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
)
if deep:
circuit = circuit.map_operations(
lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op
)
return circuit

def mapped_op(self, deep: bool = False) -> 'cirq.CircuitOperation':
Expand Down Expand Up @@ -430,6 +444,21 @@ def _with_key_path_(self, path: Tuple[str, ...]):
def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return dataclasses.replace(self, parent_path=prefix + self.parent_path)

def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
# The following line prevents binding to measurement keys in previous repeated subcircuits
# "just because their repetition ids matched". If we eventually decide to change that
# requirement and allow binding across subcircuits (possibly conditionally upon the key or
# the subcircuit having some 'allow_cross_circuit_binding' field set), this is the line to
# change or remove.
bindable_keys = frozenset(k for k in bindable_keys if len(k.path) <= len(path))
bindable_keys |= {k.with_key_path_prefix(*path) for k in self.extern_keys}
path += self.parent_path
return dataclasses.replace(self, parent_path=path, extern_keys=bindable_keys)

def with_key_path(self, path: Tuple[str, ...]):
return self._with_key_path_(path)

Expand Down Expand Up @@ -518,14 +547,16 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
keys than this operation.
"""
new_map = {}
for k_obj in self.circuit.all_measurement_key_objs():
for k_obj in protocols.measurement_keys_touched(self.circuit):
k = k_obj.name
k_new = self.measurement_key_map.get(k, k)
k_new = key_map.get(k_new, k_new)
if k_new != k:
new_map[k] = k_new
new_op = self.replace(measurement_key_map=new_map)
if len(new_op._measurement_key_objs_()) != len(self._measurement_key_objs_()):
if len(protocols.measurement_keys_touched(new_op)) != len(
protocols.measurement_keys_touched(self)
):
raise ValueError(
f'Collision in measurement key map composition. Original map:\n'
f'{self.measurement_key_map}\nApplied changes: {key_map}'
Expand Down
60 changes: 48 additions & 12 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,17 +647,33 @@ def test_decompose_nested():
op2 = cirq.CircuitOperation(circuit2)
circuit3 = cirq.FrozenCircuit(
op2.with_params({exp1: exp_half}),
op2.with_params({exp1: exp_one}),
op2.with_params({exp1: exp_two}),
op2.with_params({exp1: exp_one})
.with_measurement_key_mapping({'ma': 'ma1'})
.with_measurement_key_mapping({'mb': 'mb1'})
.with_measurement_key_mapping({'mc': 'mc1'})
.with_measurement_key_mapping({'md': 'md1'}),
op2.with_params({exp1: exp_two})
.with_measurement_key_mapping({'ma': 'ma2'})
.with_measurement_key_mapping({'mb': 'mb2'})
.with_measurement_key_mapping({'mc': 'mc2'})
.with_measurement_key_mapping({'md': 'md2'}),
)
op3 = cirq.CircuitOperation(circuit3)

final_op = op3.with_params({exp_half: 0.5, exp_one: 1.0, exp_two: 2.0})

expected_circuit1 = cirq.Circuit(
op2.with_params({exp1: 0.5, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0})
.with_measurement_key_mapping({'ma': 'ma1'})
.with_measurement_key_mapping({'mb': 'mb1'})
.with_measurement_key_mapping({'mc': 'mc1'})
.with_measurement_key_mapping({'md': 'md1'}),
op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0})
.with_measurement_key_mapping({'ma': 'ma2'})
.with_measurement_key_mapping({'mb': 'mb2'})
.with_measurement_key_mapping({'mc': 'mc2'})
.with_measurement_key_mapping({'md': 'md2'}),
)

result_ops1 = cirq.decompose_once(final_op)
Expand All @@ -673,21 +689,21 @@ def test_decompose_nested():
cirq.X(d) ** 0.5,
cirq.measure(d, key='md'),
cirq.X(a) ** 1.0,
cirq.measure(a, key='ma'),
cirq.measure(a, key='ma1'),
cirq.X(b) ** 1.0,
cirq.measure(b, key='mb'),
cirq.measure(b, key='mb1'),
cirq.X(c) ** 1.0,
cirq.measure(c, key='mc'),
cirq.measure(c, key='mc1'),
cirq.X(d) ** 1.0,
cirq.measure(d, key='md'),
cirq.measure(d, key='md1'),
cirq.X(a) ** 2.0,
cirq.measure(a, key='ma'),
cirq.measure(a, key='ma2'),
cirq.X(b) ** 2.0,
cirq.measure(b, key='mb'),
cirq.measure(b, key='mb2'),
cirq.X(c) ** 2.0,
cirq.measure(c, key='mc'),
cirq.measure(c, key='mc2'),
cirq.X(d) ** 2.0,
cirq.measure(d, key='md'),
cirq.measure(d, key='md2'),
)
assert cirq.Circuit(cirq.decompose(final_op)) == expected_circuit
# Verify that mapped_circuit gives the same operations.
Expand Down Expand Up @@ -816,4 +832,24 @@ def test_mapped_circuit_keeps_keys_under_parent_path():
assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'}


def test_keys_conflict_no_repetitions():
q = cirq.LineQubit(0)
op1 = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.measure(q, key='A'),
)
)
op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1))
with pytest.raises(ValueError, match='Conflicting measurement keys found: A'):
_ = op2.mapped_circuit(deep=True)


def test_keys_conflict_locally():
q = cirq.LineQubit(0)
op1 = cirq.measure(q, key='A')
op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1))
with pytest.raises(ValueError, match='Conflicting measurement keys found: A'):
_ = op2.mapped_circuit()


# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
66 changes: 66 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,72 @@ def test_append_multiple():
)


def test_append_control_key_subcircuit():
q0, q1 = cirq.LineQubit.range(2)

c = cirq.Circuit()
c.append(cirq.measure(q0, key='a'))
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'a'))
)
)
assert len(c) == 2

c = cirq.Circuit()
c.append(cirq.measure(q0, key='a'))
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b'))
)
)
assert len(c) == 1

c = cirq.Circuit()
c.append(cirq.measure(q0, key='a'))
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b'))
).with_measurement_key_mapping({'b': 'a'})
)
assert len(c) == 2

c = cirq.Circuit()
c.append(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0, key='a'))))
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b'))
).with_measurement_key_mapping({'b': 'a'})
)
assert len(c) == 2

c = cirq.Circuit()
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q0, key='a'))
).with_measurement_key_mapping({'a': 'c'})
)
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b'))
).with_measurement_key_mapping({'b': 'c'})
)
assert len(c) == 2

c = cirq.Circuit()
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q0, key='a'))
).with_measurement_key_mapping({'a': 'b'})
)
c.append(
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b'))
).with_measurement_key_mapping({'b': 'a'})
)
assert len(c) == 1


def test_append_moments():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet[value.MeasurementKey]] = None

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down Expand Up @@ -133,6 +134,11 @@ def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
return self.all_measurement_key_objs()

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
if self._control_keys is None:
self._control_keys = super()._control_keys_()
return self._control_keys

def are_all_measurements_terminal(self) -> bool:
if self._are_all_measurements_terminal is None:
self._are_all_measurements_terminal = super().are_all_measurements_terminal()
Expand Down
Loading

0 comments on commit 4ebfb1c

Please sign in to comment.