Skip to content

Commit

Permalink
Recursive subop parameter resolution (#5033)
Browse files Browse the repository at this point in the history
Preserves existing behavior in circuitoperations, where `with_params({a: b, b: a})` just swaps the parameter names and preserves that behavior for subsequent application (we don't change like 613), but we allow optional recursive application for each individual resolution applied (line 614). @95-martin-orion 

Fixes #5016
Closes #3619
  • Loading branch information
daxfohl authored Mar 1, 2022
1 parent 5ff22ce commit d5be95f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
21 changes: 12 additions & 9 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,26 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.Circu
return self.with_measurement_key_mapping(key_map)

def with_params(
self, param_values: 'cirq.ParamResolverOrSimilarType'
self, param_values: 'cirq.ParamResolverOrSimilarType', recursive: bool = False
) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated ParamResolver.
Any existing parameter mappings will have their values updated given
the provided mapping, and any new parameters will be added to the
ParamResolver.
Note that any resulting parameter mappings with no corresponding
parameter in the base circuit will be omitted.
Args:
param_values: A map or ParamResolver able to convert old param
values to new param values. This map will be composed with any
existing ParamResolver via single-step resolution.
recursive: If True, resolves parameter values recursively over the
resolver; otherwise performs a single resolution step. This
behavior applies only to the passed-in mapping, for the current
application. Existing parameters are never resolved recursively
because a->b and b->a needs to be a valid mapping.
Returns:
A copy of this operation with its ParamResolver updated as specified
Expand All @@ -611,18 +620,12 @@ def with_params(
new_params = {}
for k in protocols.parameter_symbols(self.circuit):
v = self.param_resolver.value_of(k, recursive=False)
v = protocols.resolve_parameters(v, param_values, recursive=False)
v = protocols.resolve_parameters(v, param_values, recursive=recursive)
if v != k:
new_params[k] = v
return self.replace(param_resolver=new_params)

# TODO: handle recursive parameter resolution gracefully
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.CircuitOperation':
if recursive:
raise ValueError(
'Recursive resolution of CircuitOperation parameters is prohibited. '
'Use "recursive=False" to prevent this error.'
)
return self.with_params(resolver.param_dict)
return self.with_params(resolver.param_dict, recursive)
38 changes: 34 additions & 4 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import cirq
from cirq.circuits.circuit_operation import _full_join_string_lists


ALL_SIMULATORS = (
cirq.Simulator(),
cirq.DensityMatrixSimulator(),
Expand Down Expand Up @@ -248,9 +247,40 @@ def test_with_params():
== op_with_params
)

# Recursive parameter resolution is rejected.
with pytest.raises(ValueError, match='Use "recursive=False"'):
_ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict))

def test_recursive_params():
q = cirq.LineQubit(0)
a, a2, b, b2 = sympy.symbols('a a2 b b2')
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** a,
cirq.Z(q) ** b,
),
# Not recursive, a and b are swapped.
param_resolver=cirq.ParamResolver({a: b, b: a}),
)
# Recursive, so a->a2->0 and b->b2->1.
outer_params = {a: a2, a2: 0, b: b2, b2: 1}
resolved = cirq.resolve_parameters(circuitop, outer_params)
# Combined, a->b->b2->1, and b->a->a2->0.
assert resolved.param_resolver.param_dict == {a: 1, b: 0}

# Non-recursive, so a->a2 and b->b2.
resolved = cirq.resolve_parameters(circuitop, outer_params, recursive=False)
# Combined, a->b->b2, and b->a->a2.
assert resolved.param_resolver.param_dict == {a: b2, b: a2}

with pytest.raises(RecursionError):
cirq.resolve_parameters(circuitop, {a: a2, a2: a})

# Non-recursive, so a->b and b->a.
resolved = cirq.resolve_parameters(circuitop, {a: b, b: a}, recursive=False)
# Combined, a->b->a, and b->a->b.
assert resolved.param_resolver.param_dict == {}

# First example should behave like an X when simulated
result = cirq.Simulator().simulate(cirq.Circuit(circuitop), param_resolver=outer_params)
assert np.allclose(result.state_vector(), [0, 1])


@pytest.mark.parametrize('add_measurements', [True, False])
Expand Down

0 comments on commit d5be95f

Please sign in to comment.