Skip to content

Commit

Permalink
Allow repetitions to be parameterized (quantumlib#5043)
Browse files Browse the repository at this point in the history
Since recursive parameter resolution is now working (quantumlib#5033) we can do this now.

The biggest caveat with this code is that params are floats, and repetitions must be an integer. I added a new type IntParam for the `repetitions` field itself, but it's still possible for the resolver to put a float value there. I added a runtime check for that. It may make sense to allow floats if they're really close to an integer, but I didn't do that here yet.

Closes quantumlib#3266
  • Loading branch information
daxfohl authored Mar 15, 2022
1 parent e8b5f49 commit 16ad898
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 62 deletions.
148 changes: 91 additions & 57 deletions cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
applied as part of a larger circuit, a CircuitOperation will execute all
component operations in order, including any nested CircuitOperations.
"""
import dataclasses
import math
from typing import (
AbstractSet,
Callable,
Expand All @@ -31,8 +33,8 @@
Union,
)

import dataclasses
import numpy as np
import sympy

from cirq import circuits, ops, protocols, value, study
from cirq._compat import proper_repr
Expand All @@ -41,12 +43,14 @@
import cirq


INT_CLASSES = (int, np.integer)
INT_TYPE = Union[int, np.integer]
IntParam = Union[INT_TYPE, sympy.Basic]
REPETITION_ID_SEPARATOR = '-'


def default_repetition_ids(repetitions: int) -> Optional[List[str]]:
if abs(repetitions) != 1:
def default_repetition_ids(repetitions: IntParam) -> Optional[List[str]]:
if isinstance(repetitions, INT_CLASSES) and abs(repetitions) != 1:
return [str(i) for i in range(abs(repetitions))]
return None

Expand All @@ -73,7 +77,10 @@ class CircuitOperation(ops.Operation):
Args:
circuit: The FrozenCircuit wrapped by this operation.
repetitions: How many times the circuit should be repeated.
repetitions: How many times the circuit should be repeated. This can be
integer, or a sympy expression. If sympy, the expression must
resolve to an integer, or float within 0.001 of integer, at
runtime.
qubit_map: Remappings for qubits in the circuit.
measurement_key_map: Remappings for measurement keys in the circuit.
The keys and values should be unindexed (i.e. without repetition_ids).
Expand Down Expand Up @@ -115,7 +122,7 @@ class CircuitOperation(ops.Operation):
)

circuit: 'cirq.FrozenCircuit'
repetitions: int = 1
repetitions: IntParam = 1
qubit_map: Dict['cirq.Qid', 'cirq.Qid'] = dataclasses.field(default_factory=dict)
measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict)
param_resolver: study.ParamResolver = study.ParamResolver()
Expand All @@ -130,20 +137,32 @@ def __post_init__(self):
raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self.circuit)!r}')

# Ensure that the circuit is invertible if the repetitions are negative.
if self.repetitions < 0:
try:
protocols.inverse(self.circuit.unfreeze())
except TypeError:
raise ValueError('repetitions are negative but the circuit is not invertible')

# Initialize repetition_ids to default, if unspecified. Else, validate their length.
loop_size = abs(self.repetitions)
if not self.repetition_ids:
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
elif len(self.repetition_ids) != loop_size:
raise ValueError(
f'Expected repetition_ids to be a list of length {loop_size}, '
f'got: {self.repetition_ids}'
if isinstance(self.repetitions, float):
if math.isclose(self.repetitions, round(self.repetitions)):
object.__setattr__(self, 'repetitions', round(self.repetitions))
if isinstance(self.repetitions, INT_CLASSES):
if self.repetitions < 0:
try:
protocols.inverse(self.circuit.unfreeze())
except TypeError:
raise ValueError('repetitions are negative but the circuit is not invertible')

# Initialize repetition_ids to default, if unspecified. Else, validate their length.
loop_size = abs(self.repetitions)
if not self.repetition_ids:
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
elif len(self.repetition_ids) != loop_size:
raise ValueError(
f'Expected repetition_ids to be a list of length {loop_size}, '
f'got: {self.repetition_ids}'
)
elif isinstance(self.repetitions, sympy.Basic):
if self.repetition_ids is not None:
raise ValueError('Cannot use repetition ids with parameterized repetitions')
else:
raise TypeError(
f'Only integer or sympy repetitions are allowed.\n'
f'User provided: {self.repetitions}'
)

# Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR`
Expand Down Expand Up @@ -213,15 +232,28 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _is_measurement_(self) -> bool:
return self.circuit._is_measurement_()

def _has_unitary_(self) -> bool:
# Return false if parameterized for early exit of has_unitary protocol.
# Otherwise return NotImplemented instructing the protocol to try alternate strategies
if self._is_parameterized_() or self.repeat_until:
return False
return NotImplemented

def _ensure_deterministic_loop_count(self):
if self.repeat_until or isinstance(self.repetitions, sympy.Basic):
raise ValueError('Cannot unroll circuit due to nondeterministic repetitions')

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
if self._cached_measurement_key_objs is None:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if self.repetition_ids is not None and self.use_repetition_ids:
circuit_keys = {
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
}
if circuit_keys and self.use_repetition_ids:
self._ensure_deterministic_loop_count()
if self.repetition_ids is not None:
circuit_keys = {
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
}
circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys}
object.__setattr__(
self,
Expand All @@ -241,28 +273,33 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
keys = (
frozenset()
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self.mapped_circuit())
else protocols.control_keys(self._mapped_single_loop())
)
if self.repeat_until is not None:
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
object.__setattr__(self, '_cached_control_keys', keys)
return self._cached_control_keys # type: ignore

def _is_parameterized_(self) -> bool:
return any(self._parameter_names_generator())

def _parameter_names_(self) -> AbstractSet[str]:
return {
name
for symbol in protocols.parameter_symbols(self.circuit)
return frozenset(self._parameter_names_generator())

def _parameter_names_generator(self) -> Iterator[str]:
yield from protocols.parameter_names(self.repetitions)
for symbol in protocols.parameter_symbols(self.circuit):
for name in protocols.parameter_names(
protocols.resolve_parameters(symbol, self.param_resolver, recursive=False)
)
}
):
yield name

def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit':
if self._cached_mapped_single_loop is None:
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
if isinstance(self.repetitions, INT_CLASSES) and self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
Expand Down Expand Up @@ -290,6 +327,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
qubit mapping, parameterization, etc.) applied to it. This behaves
like `cirq.decompose(self)`, but preserving moment structure.
"""
self._ensure_deterministic_loop_count()
if self.repetitions == 0:
return circuits.Circuit()
circuit = (
Expand Down Expand Up @@ -449,7 +487,7 @@ def _from_json_dict_(

def repeat(
self,
repetitions: Optional[INT_TYPE] = None,
repetitions: Optional[IntParam] = None,
repetition_ids: Optional[List[str]] = None,
) -> 'CircuitOperation':
"""Returns a copy of this operation repeated 'repetitions' times.
Expand Down Expand Up @@ -480,33 +518,29 @@ def repeat(
raise ValueError('At least one of repetitions and repetition_ids must be set')
repetitions = len(repetition_ids)

if not isinstance(repetitions, (int, np.integer)):
raise TypeError('Only integer repetitions are allowed.')
if isinstance(repetitions, INT_CLASSES):
if repetitions == 1 and repetition_ids is None:
# As CircuitOperation is immutable, this can safely return the original.
return self

repetitions = int(repetitions)

if repetitions == 1 and repetition_ids is None:
# As CircuitOperation is immutable, this can safely return the original.
return self

expected_repetition_id_length = abs(repetitions)
# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = self.repetitions * repetitions
expected_repetition_id_length = abs(repetitions)

if repetition_ids is None:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
f'Expected repetition_ids={repetition_ids} length to be '
f'{expected_repetition_id_length}'
)
if repetition_ids is None:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
f'Expected repetition_ids={repetition_ids} length to be '
f'{expected_repetition_id_length}'
)

# If `self.repetition_ids` is None, this will just return `repetition_ids`.
# If either self.repetition_ids or repetitions is None, it returns the other unchanged.
repetition_ids = _full_join_string_lists(repetition_ids, self.repetition_ids)

# The eventual number of repetitions of the returned CircuitOperation.
final_repetitions = protocols.mul(self.repetitions, repetitions)
return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: int) -> 'cirq.CircuitOperation':
def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation':
return self.repeat(power)

def _with_key_path_(self, path: Tuple[str, ...]):
Expand Down Expand Up @@ -547,8 +581,6 @@ def with_qubit_mapping(
Args:
qubit_map: A mapping of old qubits to new qubits. This map will be
composed with any existing qubit mapping.
transform: A function mapping old qubits to new qubits. This
function will be composed with any existing qubit mapping.
Returns:
A copy of this operation targeting qubits as indicated by qubit_map.
Expand Down Expand Up @@ -647,7 +679,8 @@ def with_params(
ParamResolver.
Note that any resulting parameter mappings with no corresponding
parameter in the base circuit will be omitted.
parameter in the base circuit will be omitted. These parameters do not
apply to the `repetitions` field if that is parameterized.
Args:
param_values: A map or ParamResolver able to convert old param
Expand All @@ -674,4 +707,5 @@ def with_params(
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.CircuitOperation':
return self.with_params(resolver.param_dict, recursive)
resolved = self.with_params(resolver.param_dict, recursive)
return resolved.replace(repetitions=resolver.value_of(self.repetitions, recursive))
Loading

0 comments on commit 16ad898

Please sign in to comment.