diff --git a/cirq/circuits/circuit_operation.py b/cirq/circuits/circuit_operation.py index 8e98d843497..724e579d4d8 100644 --- a/cirq/circuits/circuit_operation.py +++ b/cirq/circuits/circuit_operation.py @@ -592,10 +592,14 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.Circu return self.with_measurement_key_mapping(key_map) def with_params( - self, param_values: 'cirq.ParamResolverOrSimilarType' + self, param_values: 'cirq.ParamResolverOrSimilarType', recursive: bool = False ) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated ParamResolver. + Any existing parameter mappings will have their values updated given + the provided mapping, and any new parameters will be added to the + ParamResolver. + Note that any resulting parameter mappings with no corresponding parameter in the base circuit will be omitted. @@ -603,6 +607,11 @@ def with_params( param_values: A map or ParamResolver able to convert old param values to new param values. This map will be composed with any existing ParamResolver via single-step resolution. + recursive: If True, resolves parameter values recursively over the + resolver; otherwise performs a single resolution step. This + behavior applies only to the passed-in mapping, for the current + application. Existing parameters are never resolved recursively + because a->b and b->a needs to be a valid mapping. Returns: A copy of this operation with its ParamResolver updated as specified @@ -611,18 +620,12 @@ def with_params( new_params = {} for k in protocols.parameter_symbols(self.circuit): v = self.param_resolver.value_of(k, recursive=False) - v = protocols.resolve_parameters(v, param_values, recursive=False) + v = protocols.resolve_parameters(v, param_values, recursive=recursive) if v != k: new_params[k] = v return self.replace(param_resolver=new_params) - # TODO: handle recursive parameter resolution gracefully def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'cirq.CircuitOperation': - if recursive: - raise ValueError( - 'Recursive resolution of CircuitOperation parameters is prohibited. ' - 'Use "recursive=False" to prevent this error.' - ) - return self.with_params(resolver.param_dict) + return self.with_params(resolver.param_dict, recursive) diff --git a/cirq/circuits/circuit_operation_test.py b/cirq/circuits/circuit_operation_test.py index b2846eae82a..6999bfbd8d3 100644 --- a/cirq/circuits/circuit_operation_test.py +++ b/cirq/circuits/circuit_operation_test.py @@ -20,7 +20,6 @@ import cirq from cirq.circuits.circuit_operation import _full_join_string_lists - ALL_SIMULATORS = ( cirq.Simulator(), cirq.DensityMatrixSimulator(), @@ -248,9 +247,40 @@ def test_with_params(): == op_with_params ) - # Recursive parameter resolution is rejected. - with pytest.raises(ValueError, match='Use "recursive=False"'): - _ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict)) + +def test_recursive_params(): + q = cirq.LineQubit(0) + a, a2, b, b2 = sympy.symbols('a a2 b b2') + circuitop = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q) ** a, + cirq.Z(q) ** b, + ), + # Not recursive, a and b are swapped. + param_resolver=cirq.ParamResolver({a: b, b: a}), + ) + # Recursive, so a->a2->0 and b->b2->1. + outer_params = {a: a2, a2: 0, b: b2, b2: 1} + resolved = cirq.resolve_parameters(circuitop, outer_params) + # Combined, a->b->b2->1, and b->a->a2->0. + assert resolved.param_resolver.param_dict == {a: 1, b: 0} + + # Non-recursive, so a->a2 and b->b2. + resolved = cirq.resolve_parameters(circuitop, outer_params, recursive=False) + # Combined, a->b->b2, and b->a->a2. + assert resolved.param_resolver.param_dict == {a: b2, b: a2} + + with pytest.raises(RecursionError): + cirq.resolve_parameters(circuitop, {a: a2, a2: a}) + + # Non-recursive, so a->b and b->a. + resolved = cirq.resolve_parameters(circuitop, {a: b, b: a}, recursive=False) + # Combined, a->b->a, and b->a->b. + assert resolved.param_resolver.param_dict == {} + + # First example should behave like an X when simulated + result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params) + assert np.allclose(result.state_vector(), [0, 1]) @pytest.mark.parametrize('add_measurements', [True, False])