diff --git a/cirq-core/cirq/contrib/acquaintance/__init__.py b/cirq-core/cirq/contrib/acquaintance/__init__.py index 56093a761f0..9ca4d0b67a4 100644 --- a/cirq-core/cirq/contrib/acquaintance/__init__.py +++ b/cirq-core/cirq/contrib/acquaintance/__init__.py @@ -22,6 +22,7 @@ AcquaintanceOperation, GreedyExecutionStrategy, StrategyExecutor, + StrategyExecutorTransformer, ) from cirq.contrib.acquaintance.gates import acquaint, AcquaintanceOpportunityGate, SwapNetworkGate diff --git a/cirq-core/cirq/contrib/acquaintance/executor.py b/cirq-core/cirq/contrib/acquaintance/executor.py index 9939b36628a..b8437d909e8 100644 --- a/cirq-core/cirq/contrib/acquaintance/executor.py +++ b/cirq-core/cirq/contrib/acquaintance/executor.py @@ -17,7 +17,7 @@ import abc from collections import defaultdict -from cirq import circuits, devices, ops, protocols +from cirq import circuits, devices, ops, protocols, transformers, _compat from cirq.contrib.acquaintance.gates import AcquaintanceOpportunityGate from cirq.contrib.acquaintance.permutation import ( @@ -61,9 +61,26 @@ def get_operations( """Gets the logical operations to apply to qubits.""" def __call__(self, *args, **kwargs): - return StrategyExecutor(self)(*args, **kwargs) + """Returns the final mapping of logical indices to qubits after + executing an acquaintance strategy. + """ + if len(args) < 1 or not isinstance(args[0], circuits.AbstractCircuit): + raise ValueError( + ( + "To call ExecutionStrategy, an argument of type " + "circuits.AbstractCircuit must be passed in as the first non-keyword argument" + ) + ) + input_circuit = args[0] + strategy = StrategyExecutorTransformer(self) + final_circuit = strategy(input_circuit, **kwargs) + input_circuit._moments = final_circuit._moments + return strategy.mapping +@_compat.deprecated_class( + deadline='v1.0', fix='Use cirq.contrib.acquaintance.StrategyExecutorTransformer' +) class StrategyExecutor(circuits.PointOptimizer): """Executes an acquaintance strategy.""" @@ -100,6 +117,74 @@ def optimization_at( ) +@transformers.transformer +class StrategyExecutorTransformer: + """Executes an acquaintance strategy.""" + + def __init__(self, execution_strategy: ExecutionStrategy) -> None: + """Initializes transformer. + + Args: + execution_strategy: The `ExecutionStrategy` to execute. + + Raises: + ValueError: if execution_strategy is None. + """ + + if execution_strategy is None: + raise ValueError('execution_strategy cannot be None') + self.execution_strategy = execution_strategy + self._mapping = execution_strategy.initial_mapping.copy() + + def __call__( + self, circuit: circuits.AbstractCircuit, context: Optional['cirq.TransformerContext'] = None + ) -> circuits.Circuit: + """Executes an acquaintance strategy using cirq.map_operations_and_unroll and + mutates initial mapping. + + Args: + circuit: 'cirq.Circuit' input circuit to transform. + context: `cirq.TransformerContext` storing common configurable + options for transformers. + + Returns: + A copy of the modified circuit after executing an acquaintance + strategy on all instances of AcquaintanceOpportunityGate + """ + + circuit = transformers.expand_composite( + circuit, no_decomp=expose_acquaintance_gates.no_decomp + ) + return transformers.map_operations_and_unroll( + circuit=circuit, + map_func=self._map_func, + deep=context.deep if context else False, + tags_to_ignore=context.tags_to_ignore if context else (), + ).unfreeze(copy=False) + + @property + def mapping(self) -> LogicalMapping: + return self._mapping + + def _map_func(self, op: 'cirq.Operation', index) -> 'cirq.OP_TREE': + if isinstance(op.gate, AcquaintanceOpportunityGate): + logical_indices = tuple(self._mapping[q] for q in op.qubits) + logical_operations = self.execution_strategy.get_operations(logical_indices, op.qubits) + clear_span = int(not self.execution_strategy.keep_acquaintance) + + return logical_operations if clear_span else [op, logical_operations] + + if isinstance(op.gate, PermutationGate): + op.gate.update_mapping(self._mapping, op.qubits) + return op + + raise TypeError( + 'Can only execute a strategy consisting of gates that ' + 'are instances of AcquaintanceOpportunityGate or ' + 'PermutationGate.' + ) + + class AcquaintanceOperation(ops.GateOperation): """Represents an a acquaintance opportunity between a particular set of logical indices on a particular set of physical qubits. diff --git a/cirq-core/cirq/contrib/acquaintance/executor_test.py b/cirq-core/cirq/contrib/acquaintance/executor_test.py index 231e7ab9d5c..6e8fb83fcaf 100644 --- a/cirq-core/cirq/contrib/acquaintance/executor_test.py +++ b/cirq-core/cirq/contrib/acquaintance/executor_test.py @@ -36,7 +36,11 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs): return self._wire_symbols -def test_executor_explicit(): +@pytest.mark.parametrize( + 'StrategyType, is_deprecated', + [[cca.StrategyExecutor, True], [cca.StrategyExecutorTransformer, False]], +) +def test_executor_explicit(StrategyType, is_deprecated): num_qubits = 8 qubits = cirq.LineQubit.range(num_qubits) circuit = cca.complete_acquaintance_strategy(qubits, 2) @@ -48,7 +52,20 @@ def test_executor_explicit(): } initial_mapping = {q: i for i, q in enumerate(sorted(qubits))} execution_strategy = cca.GreedyExecutionStrategy(gates, initial_mapping) - executor = cca.StrategyExecutor(execution_strategy) + + if is_deprecated: + with cirq.testing.assert_deprecated( + "Use cirq.contrib.acquaintance.StrategyExecutorTransformer", deadline='v1.0' + ): + executor = StrategyType(execution_strategy) + with pytest.raises(TypeError): + op = cirq.X(qubits[0]) + bad_strategy = cirq.Circuit(op) + executor.optimization_at(bad_strategy, 0, op) + else: + with pytest.raises(ValueError): + executor = StrategyType(None) + executor = StrategyType(execution_strategy) with pytest.raises(NotImplementedError): bad_gates = {(0,): ExampleGate(['0']), (0, 1): ExampleGate(['0', '1'])} @@ -58,12 +75,10 @@ def test_executor_explicit(): bad_strategy = cirq.Circuit(cirq.X(qubits[0])) executor(bad_strategy) - with pytest.raises(TypeError): - op = cirq.X(qubits[0]) - bad_strategy = cirq.Circuit(op) - executor.optimization_at(bad_strategy, 0, op) - - executor(circuit) + if is_deprecated: + executor(circuit) + else: + circuit = executor(circuit) expected_text_diagram = """ 0: ───0───1───╲0╱─────────────────1───3───╲0╱─────────────────3───5───╲0╱─────────────────5───7───╲0╱───────────────── │ │ │ │ │ │ │ │ │ │ │ │ @@ -112,8 +127,11 @@ def test_executor_random( logical_circuit = cirq.Circuit([g(*Q) for Q, g in gates.items()]) expected_unitary = logical_circuit.unitary() - initial_mapping = {q: q for q in qubits} + + with pytest.raises(ValueError): + cca.GreedyExecutionStrategy(gates, initial_mapping)() + final_mapping = cca.GreedyExecutionStrategy(gates, initial_mapping)(circuit) permutation = {q.x: qq.x for q, qq in final_mapping.items()} circuit.append(cca.LinearPermutationGate(num_qubits, permutation)(*qubits))