Skip to content

Commit

Permalink
Add repeat-until functionality to subcircuits (quantumlib#5018)
Browse files Browse the repository at this point in the history
* Allow flattening of subcircuits

* format

* Add serialization logic and tests

* Change flatten_repetitions (default False) to use_repetition_ids (default True)

* Add shape tests for simulation results from flattened subcircuits

* docs

* add repeat_until

* repr/json/etc

* format

* chagne do_while to repeat_until

* merge fix

* make mapped_single_loop private

* Address code review comments.

* Fix test

* simplify branch

* simplify branch

* simplify branch

* add unbound controls in repeat_until to control_keys
  • Loading branch information
daxfohl authored Mar 1, 2022
1 parent cc04cfe commit 2dbba7a
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 23 deletions.
90 changes: 68 additions & 22 deletions cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
component operations in order, including any nested CircuitOperations.
"""
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
cast,
Dict,
FrozenSet,
Iterator,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -94,6 +95,12 @@ class CircuitOperation(ops.Operation):
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated.
repeat_until: A condition that will be tested after each iteration of
the subcircuit. The subcircuit will repeat until condition returns
True, but will always run at least once, and the measurement key
need not be defined prior to the subcircuit (but must be defined in
a measurement within the subcircuit). This field is incompatible
with repetitions or repetition_ids.
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
Expand All @@ -103,6 +110,9 @@ class CircuitOperation(ops.Operation):
_cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
default=None, init=False
)
_cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field(
default=None, init=False
)

circuit: 'cirq.FrozenCircuit'
repetitions: int = 1
Expand All @@ -113,6 +123,7 @@ class CircuitOperation(ops.Operation):
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset)
use_repetition_ids: bool = True
repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
Expand Down Expand Up @@ -148,6 +159,14 @@ def __post_init__(self):
if q_new.dimension != q.dimension:
raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}')

if self.repeat_until:
if self.use_repetition_ids or self.repetitions != 1:
raise ValueError('Cannot use repetitions with repeat_until')
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
self.repeat_until.keys
):
raise ValueError('Infinite loop: condition is not modified in subcircuit.')

# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

Expand All @@ -174,6 +193,7 @@ def __eq__(self, other) -> bool:
and self.repetition_ids == other.repetition_ids
and self.parent_path == other.parent_path
and self.use_repetition_ids == other.use_repetition_ids
and self.repeat_until == other.repeat_until
)

# Methods for getting post-mapping properties of the contained circuit.
Expand Down Expand Up @@ -223,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self.mapped_circuit())
)
if self.repeat_until is not None:
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
object.__setattr__(self, '_cached_control_keys', keys)
return self._cached_control_keys # type: ignore

Expand All @@ -235,6 +257,27 @@ def _parameter_names_(self) -> AbstractSet[str]:
)
}

def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit':
if self._cached_mapped_single_loop is None:
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(
circuit, self.param_resolver, recursive=False
)
object.__setattr__(self, '_cached_mapped_single_loop', circuit)
circuit = cast(circuits.Circuit, self._cached_mapped_single_loop)
if repetition_id:
circuit = protocols.with_rescoped_keys(circuit, (repetition_id,))
return protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
)

def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
"""Applies all maps to the contained circuit and returns the result.
Expand All @@ -249,24 +292,12 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
"""
if self.repetitions == 0:
return circuits.Circuit()
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if self.repetition_ids is not None:
if not self.use_repetition_ids or not protocols.is_measurement(circuit):
circuit = circuit * abs(self.repetitions)
else:
circuit = circuits.Circuit(
protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids
)
circuit = protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
circuit = (
circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids)
if self.repetition_ids is not None
and self.use_repetition_ids
and protocols.is_measurement(self.circuit)
else self._mapped_single_loop() * abs(self.repetitions)
)
if deep:
circuit = circuit.map_operations(
Expand All @@ -282,8 +313,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
return self.mapped_circuit(deep=False).all_operations()

def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
for op in self._decompose_():
protocols.act_on(op, args)
if self.repeat_until:
circuit = self._mapped_single_loop()
while True:
for op in circuit.all_operations():
protocols.act_on(op, args)
if self.repeat_until.resolve(args.classical_data):
break
else:
for op in self._decompose_():
protocols.act_on(op, args)
return True

# Methods for string representation of the operation.
Expand All @@ -305,6 +344,8 @@ def __repr__(self):
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
return f'cirq.CircuitOperation({indented_args[:-4]})'

Expand Down Expand Up @@ -337,6 +378,8 @@ def dict_str(d: Dict) -> str:
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
return circuit_msg
return f'{circuit_msg}({", ".join(args)})'
Expand Down Expand Up @@ -375,6 +418,8 @@ def _json_dict_(self):
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp

@classmethod
Expand All @@ -388,10 +433,11 @@ def _from_json_dict_(
repetition_ids,
parent_path=(),
use_repetition_ids=True,
repeat_until=None,
**kwargs,
):
return (
cls(circuit, use_repetition_ids=use_repetition_ids)
cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until)
.with_qubit_mapping(dict(qubit_map))
.with_measurement_key_mapping(measurement_key_map)
.with_params(param_resolver)
Expand Down
121 changes: 121 additions & 0 deletions cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,24 @@ def test_string_format():
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(x, key='a')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
assert (
repr(op7)
== """\
cirq.CircuitOperation(
circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)


def test_json_dict():
Expand Down Expand Up @@ -977,4 +995,107 @@ def test_simulate_no_repetition_ids_inner(sim):
assert result.records['1:a'].shape == (1, 2, 1)


@pytest.mark.parametrize('sim', ALL_SIMULATORS)
def test_repeat_until(sim):
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.X(q),
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q),
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
measurements = sim.run(c).records['m'][0]
assert len(measurements) == 2
assert measurements[0] == (0,)
assert measurements[1] == (1,)


@pytest.mark.parametrize('sim', ALL_SIMULATORS)
def test_repeat_until_sympy(sim):
q1, q2 = cirq.LineQubit.range(2)
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q2),
cirq.measure(q2, key='b'),
),
use_repetition_ids=False,
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
)
c = cirq.Circuit(
cirq.measure(q1, key='a'),
circuitop,
)
# Validate commutation
assert len(c) == 2
assert cirq.control_keys(circuitop) == {cirq.MeasurementKey('a')}
measurements = sim.run(c).records['b'][0]
assert len(measurements) == 2
assert measurements[0] == (1,)
assert measurements[1] == (0,)


@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()])
def test_post_selection(sim):
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
result = sim.run(c)
assert result.records['m'][0][-1] == (1,)
for i in range(len(result.records['m'][0]) - 1):
assert result.records['m'][0][i] == (0,)


def test_repeat_until_diagram():
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
""",
use_unicode_characters=True,
)


def test_repeat_until_error():
q = cirq.LineQubit(0)
with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'):
cirq.CircuitOperation(
cirq.FrozenCircuit(),
use_repetition_ids=True,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)


# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
26 changes: 26 additions & 0 deletions cirq/protocols/json_test_data/CircuitOperation.json
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,32 @@
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false
},
{
"cirq_type": "CircuitOperation",
"circuit": {
"cirq_type": "_SerializedKey",
"key": 1
},
"repetitions": 1,
"qubit_map": [],
"measurement_key_map": {},
"param_resolver": {
"cirq_type": "ParamResolver",
"param_dict": []
},
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false,
"repeat_until": {
"cirq_type": "KeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "0,1,2,3,4",
"path": []
},
"index": -1
}
}
]
]
Expand Down
17 changes: 16 additions & 1 deletion cirq/protocols/json_test_data/CircuitOperation.repr
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
),
]),
param_resolver={sympy.Symbol('theta'): 1.5},
use_repetition_ids=False)]
use_repetition_ids=False),
cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
),
cirq.Moment(
cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')),
)]

0 comments on commit 2dbba7a

Please sign in to comment.