diff --git a/cirq/circuits/circuit_test.py b/cirq/circuits/circuit_test.py index 2b420247ef8..95c1b3941cc 100644 --- a/cirq/circuits/circuit_test.py +++ b/cirq/circuits/circuit_test.py @@ -14,7 +14,7 @@ import os from collections import defaultdict from random import randint, random, sample, randrange -from typing import Tuple +from typing import Optional, Tuple, TYPE_CHECKING import numpy as np import pytest @@ -65,6 +65,10 @@ def can_add_operation_into_moment( ) +if TYPE_CHECKING: + import cirq + + class _MomentAndOpTypeValidatingDeviceType(cirq.Device): def validate_operation(self, operation): if not isinstance(operation, cirq.Operation): @@ -958,7 +962,9 @@ def __init__(self, replacer=(lambda x: x)): super().__init__() self.replacer = replacer - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: new_ops = self.replacer(op) return cirq.PointOptimizationSummary( clear_span=1, clear_qubits=op.qubits, new_operations=new_ops diff --git a/cirq/circuits/optimization_pass_test.py b/cirq/circuits/optimization_pass_test.py index 1fa33dcc4cd..15184a0f149 100644 --- a/cirq/circuits/optimization_pass_test.py +++ b/cirq/circuits/optimization_pass_test.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, TYPE_CHECKING, Set, List import pytest import cirq -from cirq import PointOptimizer, PointOptimizationSummary +from cirq import PointOptimizer, PointOptimizationSummary, Operation from cirq.testing import EqualsTester +if TYPE_CHECKING: + import cirq + def test_equality(): a = cirq.NamedQubit('a') @@ -57,7 +61,9 @@ class ReplaceWithXGates(PointOptimizer): operation's qubits. """ - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: end = index + 1 new_ops = [cirq.X(q) for q in op.qubits] done = False @@ -65,10 +71,10 @@ def optimization_at(self, circuit, index, op): n = circuit.next_moment_operating_on(op.qubits, end) if n is None: break - next_ops = {circuit.operation_at(q, n) for q in op.qubits} - next_ops = [e for e in next_ops if e] - next_ops = sorted(next_ops, key=lambda e: str(e.qubits)) - for next_op in next_ops: + next_ops: Set[Optional[Operation]] = {circuit.operation_at(q, n) for q in op.qubits} + next_ops_list: List[Operation] = [e for e in next_ops if e] + next_ops_sorted = sorted(next_ops_list, key=lambda e: str(e.qubits)) + for next_op in next_ops_sorted: if next_op: if set(next_op.qubits).issubset(op.qubits): end = n + 1 @@ -149,14 +155,19 @@ def test_point_optimizer_raises_on_gates_changing_qubits(): class EverythingIs42(cirq.PointOptimizer): """Changes all single qubit operations to act on LineQubit(42)""" - def optimization_at(self, circuit, index, op): - if len(op.qubits) == 1: + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: + new_op = op + if len(op.qubits) == 1 and isinstance(op, cirq.GateOperation): new_op = op.gate(cirq.LineQubit(42)) - return cirq.PointOptimizationSummary( - clear_span=1, clear_qubits=op.qubits, new_operations=new_op - ) + + return cirq.PointOptimizationSummary( + clear_span=1, clear_qubits=op.qubits, new_operations=new_op + ) c = cirq.Circuit(cirq.X(cirq.LineQubit(0)), cirq.X(cirq.LineQubit(1))) + with pytest.raises(ValueError, match='new qubits'): EverythingIs42().optimize_circuit(c) diff --git a/cirq/contrib/acquaintance/executor.py b/cirq/contrib/acquaintance/executor.py index b37be6dc63b..299ee6c347c 100644 --- a/cirq/contrib/acquaintance/executor.py +++ b/cirq/contrib/acquaintance/executor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING +from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING, Optional import abc from collections import defaultdict @@ -81,7 +81,9 @@ def __call__(self, strategy: 'cirq.Circuit'): super().optimize_circuit(strategy) return self.mapping.copy() - def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: if isinstance(op.gate, AcquaintanceOpportunityGate): logical_indices = tuple(self.mapping[q] for q in op.qubits) logical_operations = self.execution_strategy.get_operations(logical_indices, op.qubits) @@ -93,7 +95,7 @@ def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operati if isinstance(op, ops.GateOperation) and isinstance(op.gate, PermutationGate): op.gate.update_mapping(self.mapping, op.qubits) - return + return None raise TypeError( 'Can only execute a strategy consisting of gates that ' diff --git a/cirq/google/optimizers/convert_to_sqrt_iswap.py b/cirq/google/optimizers/convert_to_sqrt_iswap.py index 4f7d41d99ef..cf228fd3995 100644 --- a/cirq/google/optimizers/convert_to_sqrt_iswap.py +++ b/cirq/google/optimizers/convert_to_sqrt_iswap.py @@ -18,6 +18,7 @@ from cirq import ops, circuits, protocols + if TYPE_CHECKING: import cirq @@ -104,7 +105,9 @@ def convert(self, op: 'cirq.Operation') -> List['cirq.Operation']: ) return a - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: if isinstance(op.gate, ops.MatrixGate) and len(op.qubits) == 1: return None diff --git a/cirq/google/optimizers/convert_to_sycamore_gates.py b/cirq/google/optimizers/convert_to_sycamore_gates.py index accc92e1758..225028d80de 100644 --- a/cirq/google/optimizers/convert_to_sycamore_gates.py +++ b/cirq/google/optimizers/convert_to_sycamore_gates.py @@ -17,6 +17,7 @@ import numpy as np import scipy.linalg from cirq import circuits, google, linalg, ops, optimizers, protocols + from cirq.google.ops import SycamoreGate from cirq.google.optimizers.two_qubit_gates.gate_compilation import GateTabulation @@ -129,7 +130,9 @@ def on_stuck_raise(bad): on_stuck_raise=None if self.ignore_failures else on_stuck_raise, ) - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: if not isinstance(op, ops.GateOperation): return None @@ -144,6 +147,8 @@ def optimization_at(self, circuit, index, op): ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits}) if len(ops_in_front) == 1 and isinstance(ops_in_front[0], ops.GateOperation): gate2 = ops_in_front[0].gate + else: + next_index = 0 if isinstance(gate, ops.SwapPowGate) and isinstance(gate2, ops.ZZPowGate): rads = gate2.exponent * np.pi / 2 diff --git a/cirq/google/optimizers/convert_to_xmon_gates.py b/cirq/google/optimizers/convert_to_xmon_gates.py index 85852e64391..699c91b4039 100644 --- a/cirq/google/optimizers/convert_to_xmon_gates.py +++ b/cirq/google/optimizers/convert_to_xmon_gates.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Optional from cirq import ops, protocols from cirq.circuits.optimization_pass import ( @@ -90,7 +90,9 @@ def on_stuck_raise(bad): on_stuck_raise=None if self.ignore_failures else on_stuck_raise, ) - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: converted = self.convert(op) if len(converted) == 1 and converted[0] is op: return None diff --git a/cirq/neutral_atoms/convert_to_neutral_atom_gates.py b/cirq/neutral_atoms/convert_to_neutral_atom_gates.py index b21800888bb..cc7236fa19c 100644 --- a/cirq/neutral_atoms/convert_to_neutral_atom_gates.py +++ b/cirq/neutral_atoms/convert_to_neutral_atom_gates.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional, TYPE_CHECKING from cirq import ops, protocols from cirq.circuits.optimization_pass import ( @@ -20,6 +20,9 @@ ) from cirq import optimizers +if TYPE_CHECKING: + import cirq + class ConvertToNeutralAtomGates(PointOptimizer): """Attempts to convert gates into native Atom gates. @@ -76,7 +79,9 @@ def on_stuck_raise(bad): on_stuck_raise=None if self.ignore_failures else on_stuck_raise, ) - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: converted = self.convert(op) if len(converted) == 1 and converted[0] is op: return None diff --git a/cirq/optimizers/expand_composite.py b/cirq/optimizers/expand_composite.py index 35288f32951..ba9a9ebd494 100644 --- a/cirq/optimizers/expand_composite.py +++ b/cirq/optimizers/expand_composite.py @@ -14,7 +14,7 @@ """An optimizer that expands composite operations via `cirq.decompose`.""" -from typing import Callable +from typing import Callable, Optional, TYPE_CHECKING from cirq import ops, protocols from cirq.circuits.optimization_pass import ( @@ -22,6 +22,9 @@ PointOptimizationSummary, ) +if TYPE_CHECKING: + import cirq + class ExpandComposite(PointOptimizer): """An optimizer that expands composite operations via `cirq.decompose`. @@ -41,7 +44,9 @@ def __init__(self, no_decomp: Callable[[ops.Operation], bool] = (lambda _: False super().__init__() self.no_decomp = no_decomp - def optimization_at(self, circuit, index, op): + def optimization_at( + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: decomposition = protocols.decompose(op, keep=self.no_decomp, on_stuck_raise=None) if decomposition == [op]: return None