Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations for optimization_pass #3962

Merged
merged 6 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from collections import defaultdict
from random import randint, random, sample, randrange
from typing import Tuple
from typing import Optional, Tuple, TYPE_CHECKING

import numpy as np
import pytest
Expand Down Expand Up @@ -65,6 +65,10 @@ def can_add_operation_into_moment(
)


if TYPE_CHECKING:
import cirq


class _MomentAndOpTypeValidatingDeviceType(cirq.Device):
def validate_operation(self, operation):
if not isinstance(operation, cirq.Operation):
Expand Down Expand Up @@ -958,7 +962,9 @@ def __init__(self, replacer=(lambda x: x)):
super().__init__()
self.replacer = replacer

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
new_ops = self.replacer(op)
return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_ops
Expand Down
33 changes: 22 additions & 11 deletions cirq/circuits/optimization_pass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, TYPE_CHECKING, Set, List

import pytest
import cirq
from cirq import PointOptimizer, PointOptimizationSummary
from cirq import PointOptimizer, PointOptimizationSummary, Operation
from cirq.testing import EqualsTester

if TYPE_CHECKING:
import cirq


def test_equality():
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -57,18 +61,20 @@ class ReplaceWithXGates(PointOptimizer):
operation's qubits.
"""

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
end = index + 1
new_ops = [cirq.X(q) for q in op.qubits]
done = False
while not done:
n = circuit.next_moment_operating_on(op.qubits, end)
if n is None:
break
next_ops = {circuit.operation_at(q, n) for q in op.qubits}
next_ops = [e for e in next_ops if e]
next_ops = sorted(next_ops, key=lambda e: str(e.qubits))
for next_op in next_ops:
next_ops: Set[Optional[Operation]] = {circuit.operation_at(q, n) for q in op.qubits}
next_ops_list: List[Operation] = [e for e in next_ops if e]
next_ops_sorted = sorted(next_ops_list, key=lambda e: str(e.qubits))
for next_op in next_ops_sorted:
if next_op:
if set(next_op.qubits).issubset(op.qubits):
end = n + 1
Expand Down Expand Up @@ -149,14 +155,19 @@ def test_point_optimizer_raises_on_gates_changing_qubits():
class EverythingIs42(cirq.PointOptimizer):
"""Changes all single qubit operations to act on LineQubit(42)"""

def optimization_at(self, circuit, index, op):
if len(op.qubits) == 1:
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
new_op = op
if len(op.qubits) == 1 and isinstance(op, cirq.GateOperation):
new_op = op.gate(cirq.LineQubit(42))
return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_op
)

return cirq.PointOptimizationSummary(
clear_span=1, clear_qubits=op.qubits, new_operations=new_op
)

c = cirq.Circuit(cirq.X(cirq.LineQubit(0)), cirq.X(cirq.LineQubit(1)))

with pytest.raises(ValueError, match='new qubits'):
EverythingIs42().optimize_circuit(c)

Expand Down
8 changes: 5 additions & 3 deletions cirq/contrib/acquaintance/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING
from typing import DefaultDict, Dict, Sequence, TYPE_CHECKING, Optional

import abc
from collections import defaultdict
Expand Down Expand Up @@ -81,7 +81,9 @@ def __call__(self, strategy: 'cirq.Circuit'):
super().optimize_circuit(strategy)
return self.mapping.copy()

def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if isinstance(op.gate, AcquaintanceOpportunityGate):
logical_indices = tuple(self.mapping[q] for q in op.qubits)
logical_operations = self.execution_strategy.get_operations(logical_indices, op.qubits)
Expand All @@ -93,7 +95,7 @@ def optimization_at(self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operati

if isinstance(op, ops.GateOperation) and isinstance(op.gate, PermutationGate):
op.gate.update_mapping(self.mapping, op.qubits)
return
return None

raise TypeError(
'Can only execute a strategy consisting of gates that '
Expand Down
5 changes: 4 additions & 1 deletion cirq/google/optimizers/convert_to_sqrt_iswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from cirq import ops, circuits, protocols


if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -104,7 +105,9 @@ def convert(self, op: 'cirq.Operation') -> List['cirq.Operation']:
)
return a

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if isinstance(op.gate, ops.MatrixGate) and len(op.qubits) == 1:
return None

Expand Down
7 changes: 6 additions & 1 deletion cirq/google/optimizers/convert_to_sycamore_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import scipy.linalg
from cirq import circuits, google, linalg, ops, optimizers, protocols

from cirq.google.ops import SycamoreGate
from cirq.google.optimizers.two_qubit_gates.gate_compilation import GateTabulation

Expand Down Expand Up @@ -129,7 +130,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
if not isinstance(op, ops.GateOperation):
return None

Expand All @@ -144,6 +147,8 @@ def optimization_at(self, circuit, index, op):
ops_in_front = list({circuit.operation_at(q, next_index) for q in op.qubits})
if len(ops_in_front) == 1 and isinstance(ops_in_front[0], ops.GateOperation):
gate2 = ops_in_front[0].gate
else:
next_index = 0

if isinstance(gate, ops.SwapPowGate) and isinstance(gate2, ops.ZZPowGate):
rads = gate2.exponent * np.pi / 2
Expand Down
6 changes: 4 additions & 2 deletions cirq/google/optimizers/convert_to_xmon_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, TYPE_CHECKING
from typing import List, TYPE_CHECKING, Optional

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
Expand Down Expand Up @@ -90,7 +90,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
9 changes: 7 additions & 2 deletions cirq/neutral_atoms/convert_to_neutral_atom_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Optional, TYPE_CHECKING

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
Expand All @@ -20,6 +20,9 @@
)
from cirq import optimizers

if TYPE_CHECKING:
import cirq


class ConvertToNeutralAtomGates(PointOptimizer):
"""Attempts to convert gates into native Atom gates.
Expand Down Expand Up @@ -76,7 +79,9 @@ def on_stuck_raise(bad):
on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
)

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
converted = self.convert(op)
if len(converted) == 1 and converted[0] is op:
return None
Expand Down
9 changes: 7 additions & 2 deletions cirq/optimizers/expand_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

"""An optimizer that expands composite operations via `cirq.decompose`."""

from typing import Callable
from typing import Callable, Optional, TYPE_CHECKING

from cirq import ops, protocols
from cirq.circuits.optimization_pass import (
PointOptimizer,
PointOptimizationSummary,
)

if TYPE_CHECKING:
import cirq


class ExpandComposite(PointOptimizer):
"""An optimizer that expands composite operations via `cirq.decompose`.
Expand All @@ -41,7 +44,9 @@ def __init__(self, no_decomp: Callable[[ops.Operation], bool] = (lambda _: False
super().__init__()
self.no_decomp = no_decomp

def optimization_at(self, circuit, index, op):
def optimization_at(
self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation'
) -> Optional['cirq.PointOptimizationSummary']:
decomposition = protocols.decompose(op, keep=self.no_decomp, on_stuck_raise=None)
if decomposition == [op]:
return None
Expand Down