From 16ad8981562b9a6245a3ada21bb6bbe3ad36de09 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 15 Mar 2022 14:02:08 -0700 Subject: [PATCH] Allow repetitions to be parameterized (#5043) Since recursive parameter resolution is now working (#5033) we can do this now. The biggest caveat with this code is that params are floats, and repetitions must be an integer. I added a new type IntParam for the `repetitions` field itself, but it's still possible for the resolver to put a float value there. I added a runtime check for that. It may make sense to allow floats if they're really close to an integer, but I didn't do that here yet. Closes #3266 --- cirq/circuits/circuit_operation.py | 148 ++++++++++------- cirq/circuits/circuit_operation_test.py | 154 +++++++++++++++++- cirq/sim/density_matrix_simulator_test.py | 4 +- .../measurement_transformers_test.py | 2 +- 4 files changed, 246 insertions(+), 62 deletions(-) diff --git a/cirq/circuits/circuit_operation.py b/cirq/circuits/circuit_operation.py index c5dc1ba8b04..89b260318e7 100644 --- a/cirq/circuits/circuit_operation.py +++ b/cirq/circuits/circuit_operation.py @@ -17,6 +17,8 @@ applied as part of a larger circuit, a CircuitOperation will execute all component operations in order, including any nested CircuitOperations. """ +import dataclasses +import math from typing import ( AbstractSet, Callable, @@ -31,8 +33,8 @@ Union, ) -import dataclasses import numpy as np +import sympy from cirq import circuits, ops, protocols, value, study from cirq._compat import proper_repr @@ -41,12 +43,14 @@ import cirq +INT_CLASSES = (int, np.integer) INT_TYPE = Union[int, np.integer] +IntParam = Union[INT_TYPE, sympy.Basic] REPETITION_ID_SEPARATOR = '-' -def default_repetition_ids(repetitions: int) -> Optional[List[str]]: - if abs(repetitions) != 1: +def default_repetition_ids(repetitions: IntParam) -> Optional[List[str]]: + if isinstance(repetitions, INT_CLASSES) and abs(repetitions) != 1: return [str(i) for i in range(abs(repetitions))] return None @@ -73,7 +77,10 @@ class CircuitOperation(ops.Operation): Args: circuit: The FrozenCircuit wrapped by this operation. - repetitions: How many times the circuit should be repeated. + repetitions: How many times the circuit should be repeated. This can be + integer, or a sympy expression. If sympy, the expression must + resolve to an integer, or float within 0.001 of integer, at + runtime. qubit_map: Remappings for qubits in the circuit. measurement_key_map: Remappings for measurement keys in the circuit. The keys and values should be unindexed (i.e. without repetition_ids). @@ -115,7 +122,7 @@ class CircuitOperation(ops.Operation): ) circuit: 'cirq.FrozenCircuit' - repetitions: int = 1 + repetitions: IntParam = 1 qubit_map: Dict['cirq.Qid', 'cirq.Qid'] = dataclasses.field(default_factory=dict) measurement_key_map: Dict[str, str] = dataclasses.field(default_factory=dict) param_resolver: study.ParamResolver = study.ParamResolver() @@ -130,20 +137,32 @@ def __post_init__(self): raise TypeError(f'Expected circuit of type FrozenCircuit, got: {type(self.circuit)!r}') # Ensure that the circuit is invertible if the repetitions are negative. - if self.repetitions < 0: - try: - protocols.inverse(self.circuit.unfreeze()) - except TypeError: - raise ValueError('repetitions are negative but the circuit is not invertible') - - # Initialize repetition_ids to default, if unspecified. Else, validate their length. - loop_size = abs(self.repetitions) - if not self.repetition_ids: - object.__setattr__(self, 'repetition_ids', self._default_repetition_ids()) - elif len(self.repetition_ids) != loop_size: - raise ValueError( - f'Expected repetition_ids to be a list of length {loop_size}, ' - f'got: {self.repetition_ids}' + if isinstance(self.repetitions, float): + if math.isclose(self.repetitions, round(self.repetitions)): + object.__setattr__(self, 'repetitions', round(self.repetitions)) + if isinstance(self.repetitions, INT_CLASSES): + if self.repetitions < 0: + try: + protocols.inverse(self.circuit.unfreeze()) + except TypeError: + raise ValueError('repetitions are negative but the circuit is not invertible') + + # Initialize repetition_ids to default, if unspecified. Else, validate their length. + loop_size = abs(self.repetitions) + if not self.repetition_ids: + object.__setattr__(self, 'repetition_ids', self._default_repetition_ids()) + elif len(self.repetition_ids) != loop_size: + raise ValueError( + f'Expected repetition_ids to be a list of length {loop_size}, ' + f'got: {self.repetition_ids}' + ) + elif isinstance(self.repetitions, sympy.Basic): + if self.repetition_ids is not None: + raise ValueError('Cannot use repetition ids with parameterized repetitions') + else: + raise TypeError( + f'Only integer or sympy repetitions are allowed.\n' + f'User provided: {self.repetitions}' ) # Disallow mapping to keys containing the `MEASUREMENT_KEY_SEPARATOR` @@ -213,15 +232,28 @@ def _qid_shape_(self) -> Tuple[int, ...]: def _is_measurement_(self) -> bool: return self.circuit._is_measurement_() + def _has_unitary_(self) -> bool: + # Return false if parameterized for early exit of has_unitary protocol. + # Otherwise return NotImplemented instructing the protocol to try alternate strategies + if self._is_parameterized_() or self.repeat_until: + return False + return NotImplemented + + def _ensure_deterministic_loop_count(self): + if self.repeat_until or isinstance(self.repetitions, sympy.Basic): + raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None and self.use_repetition_ids: - circuit_keys = { - key.with_key_path_prefix(repetition_id) - for repetition_id in self.repetition_ids - for key in circuit_keys - } + if circuit_keys and self.use_repetition_ids: + self._ensure_deterministic_loop_count() + if self.repetition_ids is not None: + circuit_keys = { + key.with_key_path_prefix(repetition_id) + for repetition_id in self.repetition_ids + for key in circuit_keys + } circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} object.__setattr__( self, @@ -241,28 +273,33 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: keys = ( frozenset() if not protocols.control_keys(self.circuit) - else protocols.control_keys(self.mapped_circuit()) + else protocols.control_keys(self._mapped_single_loop()) ) if self.repeat_until is not None: keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore + def _is_parameterized_(self) -> bool: + return any(self._parameter_names_generator()) + def _parameter_names_(self) -> AbstractSet[str]: - return { - name - for symbol in protocols.parameter_symbols(self.circuit) + return frozenset(self._parameter_names_generator()) + + def _parameter_names_generator(self) -> Iterator[str]: + yield from protocols.parameter_names(self.repetitions) + for symbol in protocols.parameter_symbols(self.circuit): for name in protocols.parameter_names( protocols.resolve_parameters(symbol, self.param_resolver, recursive=False) - ) - } + ): + yield name def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': if self._cached_mapped_single_loop is None: circuit = self.circuit.unfreeze() if self.qubit_map: circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) - if self.repetitions < 0: + if isinstance(self.repetitions, INT_CLASSES) and self.repetitions < 0: circuit = circuit ** -1 if self.measurement_key_map: circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) @@ -290,6 +327,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': qubit mapping, parameterization, etc.) applied to it. This behaves like `cirq.decompose(self)`, but preserving moment structure. """ + self._ensure_deterministic_loop_count() if self.repetitions == 0: return circuits.Circuit() circuit = ( @@ -449,7 +487,7 @@ def _from_json_dict_( def repeat( self, - repetitions: Optional[INT_TYPE] = None, + repetitions: Optional[IntParam] = None, repetition_ids: Optional[List[str]] = None, ) -> 'CircuitOperation': """Returns a copy of this operation repeated 'repetitions' times. @@ -480,33 +518,29 @@ def repeat( raise ValueError('At least one of repetitions and repetition_ids must be set') repetitions = len(repetition_ids) - if not isinstance(repetitions, (int, np.integer)): - raise TypeError('Only integer repetitions are allowed.') + if isinstance(repetitions, INT_CLASSES): + if repetitions == 1 and repetition_ids is None: + # As CircuitOperation is immutable, this can safely return the original. + return self - repetitions = int(repetitions) - - if repetitions == 1 and repetition_ids is None: - # As CircuitOperation is immutable, this can safely return the original. - return self - - expected_repetition_id_length = abs(repetitions) - # The eventual number of repetitions of the returned CircuitOperation. - final_repetitions = self.repetitions * repetitions + expected_repetition_id_length = abs(repetitions) - if repetition_ids is None: - repetition_ids = default_repetition_ids(expected_repetition_id_length) - elif len(repetition_ids) != expected_repetition_id_length: - raise ValueError( - f'Expected repetition_ids={repetition_ids} length to be ' - f'{expected_repetition_id_length}' - ) + if repetition_ids is None: + repetition_ids = default_repetition_ids(expected_repetition_id_length) + elif len(repetition_ids) != expected_repetition_id_length: + raise ValueError( + f'Expected repetition_ids={repetition_ids} length to be ' + f'{expected_repetition_id_length}' + ) - # If `self.repetition_ids` is None, this will just return `repetition_ids`. + # If either self.repetition_ids or repetitions is None, it returns the other unchanged. repetition_ids = _full_join_string_lists(repetition_ids, self.repetition_ids) + # The eventual number of repetitions of the returned CircuitOperation. + final_repetitions = protocols.mul(self.repetitions, repetitions) return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids) - def __pow__(self, power: int) -> 'cirq.CircuitOperation': + def __pow__(self, power: IntParam) -> 'cirq.CircuitOperation': return self.repeat(power) def _with_key_path_(self, path: Tuple[str, ...]): @@ -547,8 +581,6 @@ def with_qubit_mapping( Args: qubit_map: A mapping of old qubits to new qubits. This map will be composed with any existing qubit mapping. - transform: A function mapping old qubits to new qubits. This - function will be composed with any existing qubit mapping. Returns: A copy of this operation targeting qubits as indicated by qubit_map. @@ -647,7 +679,8 @@ def with_params( ParamResolver. Note that any resulting parameter mappings with no corresponding - parameter in the base circuit will be omitted. + parameter in the base circuit will be omitted. These parameters do not + apply to the `repetitions` field if that is parameterized. Args: param_values: A map or ParamResolver able to convert old param @@ -674,4 +707,5 @@ def with_params( def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'cirq.CircuitOperation': - return self.with_params(resolver.param_dict, recursive) + resolved = self.with_params(resolver.param_dict, recursive) + return resolved.replace(repetitions=resolver.value_of(self.repetitions, recursive)) diff --git a/cirq/circuits/circuit_operation_test.py b/cirq/circuits/circuit_operation_test.py index 6f94e706bee..a92c7abbd1b 100644 --- a/cirq/circuits/circuit_operation_test.py +++ b/cirq/circuits/circuit_operation_test.py @@ -337,8 +337,10 @@ def test_repeat(add_measurements, use_default_ids_for_initial_rep): ): _ = op_base.repeat() - with pytest.raises(TypeError, match='Only integer repetitions are allowed'): + with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'): _ = op_base.repeat(1.3) + assert op_base.repeat(3.00000000001).repetitions == 3 + assert op_base.repeat(2.99999999999).repetitions == 3 @pytest.mark.parametrize('add_measurements', [True, False]) @@ -359,6 +361,156 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps): assert np.allclose(result.state_vector(), [1, 0]) +def test_parameterized_repeat(): + q = cirq.LineQubit(0) + op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) ** sympy.Symbol('a') + assert cirq.parameter_names(op) == {'a'} + assert not cirq.has_unitary(op) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1}) + assert np.allclose(result.state_vector(), [0, 1]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1}) + assert np.allclose(result.state_vector(), [0, 1]) + with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'): + cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5}) + with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'): + cirq.Simulator().simulate(cirq.Circuit(op)) + op = op ** -1 + assert cirq.parameter_names(op) == {'a'} + assert not cirq.has_unitary(op) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 0}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1}) + assert np.allclose(result.state_vector(), [0, 1]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': -1}) + assert np.allclose(result.state_vector(), [0, 1]) + with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'): + cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5}) + with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'): + cirq.Simulator().simulate(cirq.Circuit(op)) + op = op ** sympy.Symbol('b') + assert cirq.parameter_names(op) == {'a', 'b'} + assert not cirq.has_unitary(op) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1}) + assert np.allclose(result.state_vector(), [0, 1]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 2, 'b': 1}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 2}) + assert np.allclose(result.state_vector(), [1, 0]) + with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'): + cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1}) + with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'): + cirq.Simulator().simulate(cirq.Circuit(op)) + op = op ** 2.0 + assert cirq.parameter_names(op) == {'a', 'b'} + assert not cirq.has_unitary(op) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1}) + assert np.allclose(result.state_vector(), [1, 0]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1}) + assert np.allclose(result.state_vector(), [0, 1]) + result = cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1, 'b': 1.5}) + assert np.allclose(result.state_vector(), [0, 1]) + with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'): + cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5, 'b': 1.5}) + with pytest.raises(ValueError, match='Circuit contains ops whose symbols were not specified'): + cirq.Simulator().simulate(cirq.Circuit(op)) + + +def test_parameterized_repeat_side_effects(): + q = cirq.LineQubit(0) + op = cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')), + repetitions=sympy.Symbol('a'), + ) + + # Control keys can be calculated because they only "lift" if there's a matching + # measurement, in which case they're not returned here. + assert cirq.control_keys(op) == {cirq.MeasurementKey('c')} + + # "local" params do not bind to the repetition param. + assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'} + + # Check errors that require unrolling the circuit. + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + cirq.measurement_key_objs(op) + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + cirq.measurement_key_names(op) + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + op.mapped_circuit() + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + cirq.decompose(op) + + # Not compatible with repetition ids + with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): + op.with_repetition_ids(['x', 'y']) + with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): + op.repeat(repetition_ids=['x', 'y']) + + # TODO(daxfohl): This should work, but likely requires a new protocol that returns *just* the + # name of the measurement keys. (measurement_key_names returns the full serialized string). + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + cirq.with_measurement_key_mapping(op, {'m': 'm2'}) + + # Everything should work once resolved + op = cirq.resolve_parameters(op, {'a': 2}) + assert set(map(str, cirq.measurement_key_objs(op))) == {'0:m', '1:m'} + assert op.mapped_circuit() == cirq.Circuit( + cirq.X(q).with_classical_controls('c'), + cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')), + cirq.X(q).with_classical_controls('c'), + cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')), + ) + assert cirq.decompose(op) == cirq.decompose( + cirq.Circuit( + cirq.X(q).with_classical_controls('c'), + cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')), + cirq.X(q).with_classical_controls('c'), + cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')), + ) + ) + + +def test_parameterized_repeat_side_effects_when_not_using_rep_ids(): + q = cirq.LineQubit(0) + op = cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')), + repetitions=sympy.Symbol('a'), + use_repetition_ids=False, + ) + assert cirq.control_keys(op) == {cirq.MeasurementKey('c')} + assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'} + assert set(map(str, cirq.measurement_key_objs(op))) == {'m'} + assert cirq.measurement_key_names(op) == {'m'} + assert cirq.measurement_key_names(cirq.with_measurement_key_mapping(op, {'m': 'm2'})) == {'m2'} + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + op.mapped_circuit() + with pytest.raises( + ValueError, match='Cannot unroll circuit due to nondeterministic repetitions' + ): + cirq.decompose(op) + with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): + op.with_repetition_ids(['x', 'y']) + with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): + op.repeat(repetition_ids=['x', 'y']) + + def test_qid_shape(): circuit = cirq.FrozenCircuit( cirq.IdentityGate(qid_shape=(q.dimension,)).on(q) diff --git a/cirq/sim/density_matrix_simulator_test.py b/cirq/sim/density_matrix_simulator_test.py index 85b4a7352ae..4ad776e0823 100644 --- a/cirq/sim/density_matrix_simulator_test.py +++ b/cirq/sim/density_matrix_simulator_test.py @@ -533,9 +533,7 @@ def test_simulate_ignore_measurements(split: bool): @pytest.mark.parametrize('split', [True, False]) def test_simulate_ignore_measurements_subcircuits(split: bool): q0 = cirq.LineQubit(0) - with cirq.testing.assert_deprecated( - 'ignore_measurement_results', deadline='v0.15', count=6 if split else 4 - ): + with cirq.testing.assert_deprecated('ignore_measurement_results', deadline='v0.15', count=None): simulator = cirq.DensityMatrixSimulator( split_untangled_states=split, ignore_measurement_results=True ) diff --git a/cirq/transformers/measurement_transformers_test.py b/cirq/transformers/measurement_transformers_test.py index b64d25391c2..cb46c0ff357 100644 --- a/cirq/transformers/measurement_transformers_test.py +++ b/cirq/transformers/measurement_transformers_test.py @@ -39,7 +39,7 @@ def assert_equivalent_to_deferred(circuit: cirq.Circuit): def assert_equivalent_to_dephased(circuit: cirq.Circuit): qubits = list(circuit.all_qubits()) - with cirq.testing.assert_deprecated('ignore_measurement_results', deadline='v0.15', count=14): + with cirq.testing.assert_deprecated('ignore_measurement_results', deadline='v0.15', count=None): sim = cirq.DensityMatrixSimulator(ignore_measurement_results=True) num_qubits = len(qubits) backwards = list(circuit.all_operations())[::-1]