From 9aede582fd0b32f2e947d3175673852fcd067881 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 23 Dec 2021 13:29:32 -0800 Subject: [PATCH] Sanitize type annotations in cirq.circuits (#4776) Follow-up to the cirq.sim one. Note this reduces things that cause recursive dependencies as well, so watch out for this in code reviews. --- cirq-core/cirq/circuits/circuit.py | 38 ++++++++++--------- cirq-core/cirq/circuits/circuit_dag.py | 22 +++++------ cirq-core/cirq/circuits/circuit_operation.py | 22 ++++++----- cirq-core/cirq/circuits/frozen_circuit.py | 20 +++++----- cirq-core/cirq/circuits/optimization_pass.py | 14 +++---- cirq-core/cirq/circuits/qasm_output.py | 2 +- .../cirq/circuits/text_diagram_drawer.py | 18 +++++---- 7 files changed, 71 insertions(+), 65 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index abb23189fdb..53e042bdcfe 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -122,7 +122,7 @@ def moments(self) -> Sequence['cirq.Moment']: @property @abc.abstractmethod - def device(self) -> devices.Device: + def device(self) -> 'cirq.Device': pass def freeze(self) -> 'cirq.FrozenCircuit': @@ -589,7 +589,7 @@ def findall_operations_until_blocked( self, start_frontier: Dict['cirq.Qid', int], is_blocker: Callable[['cirq.Operation'], bool] = lambda op: False, - ) -> List[Tuple[int, ops.Operation]]: + ) -> List[Tuple[int, 'cirq.Operation']]: """Finds all operations until a blocking operation is hit. An operation is considered blocking if @@ -740,7 +740,7 @@ def findall_operations( def findall_operations_with_gate_type( self, gate_type: Type[T_DESIRED_GATE_TYPE] - ) -> Iterable[Tuple[int, ops.GateOperation, T_DESIRED_GATE_TYPE]]: + ) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]: """Find the locations of all gate operations of a given type. Args: @@ -852,7 +852,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']: """Returns the qubits acted upon by Operations in this circuit.""" return frozenset(q for m in self.moments for q in m.qubits) - def all_operations(self) -> Iterator[ops.Operation]: + def all_operations(self) -> Iterator['cirq.Operation']: """Iterates over the operations applied by this circuit. Operations from earlier moments will be iterated over first. Operations @@ -1162,7 +1162,7 @@ def to_text_diagram_drawer( get_circuit_diagram_info: Optional[ Callable[['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo'] ] = None, - ) -> TextDiagramDrawer: + ) -> 'cirq.TextDiagramDrawer': """Returns a TextDiagramDrawer with the circuit drawn into it. Args: @@ -1250,7 +1250,7 @@ def _to_qasm_output( header: Optional[str] = None, precision: int = 10, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, - ) -> QasmOutput: + ) -> 'cirq.QasmOutput': """Returns a QASM object equivalent to the circuit. Args: @@ -1273,7 +1273,7 @@ def _to_qasm_output( def _to_quil_output( self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT - ) -> QuilOutput: + ) -> 'cirq.QuilOutput': qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) return QuilOutput(operations=self.all_operations(), qubits=qubits) @@ -1697,7 +1697,7 @@ def __init__( self.append(contents, strategy=strategy) @property - def device(self) -> devices.Device: + def device(self) -> 'cirq.Device': return self._device @device.setter @@ -1705,15 +1705,15 @@ def device(self, new_device: 'cirq.Device') -> None: new_device.validate_circuit(self) self._device = new_device - def __copy__(self) -> 'Circuit': + def __copy__(self) -> 'cirq.Circuit': return self.copy() - def copy(self) -> 'Circuit': + def copy(self) -> 'cirq.Circuit': copied_circuit = Circuit(device=self._device) copied_circuit._moments = self._moments[:] return copied_circuit - def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit': + def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.Circuit': new_circuit = Circuit(device=self.device) new_circuit._moments = list(moments) return new_circuit @@ -1793,7 +1793,7 @@ def __rmul__(self, repetitions: INT_TYPE): return NotImplemented return self * int(repetitions) - def __pow__(self, exponent: int) -> 'Circuit': + def __pow__(self, exponent: int) -> 'cirq.Circuit': """A circuit raised to a power, only valid for exponent -1, the inverse. This will fail if anything other than -1 is passed to the Circuit by @@ -1819,7 +1819,7 @@ def with_device( self, new_device: 'cirq.Device', qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e, - ) -> 'Circuit': + ) -> 'cirq.Circuit': """Maps the current circuit onto a new device, and validates. Args: @@ -2296,7 +2296,9 @@ def clear_operations_touching( if 0 <= k < len(self._moments): self._moments[k] = self._moments[k].without_operations_touching(qubits) - def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Circuit': + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> 'cirq.Circuit': resolved_moments = [] for moment in self: resolved_operations = _resolve_operations(moment.operations, resolver, recursive) @@ -2391,7 +2393,7 @@ def _draw_moment_annotations( col: int, use_unicode_characters: bool, label_map: Dict['cirq.LabelEntity', int], - out_diagram: TextDiagramDrawer, + out_diagram: 'cirq.TextDiagramDrawer', precision: Optional[int], get_circuit_diagram_info: Callable[ ['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo' @@ -2421,7 +2423,7 @@ def _draw_moment_in_diagram( moment: 'cirq.Moment', use_unicode_characters: bool, label_map: Dict['cirq.LabelEntity', int], - out_diagram: TextDiagramDrawer, + out_diagram: 'cirq.TextDiagramDrawer', precision: Optional[int], moment_groups: List[Tuple[int, int]], get_circuit_diagram_info: Optional[ @@ -2542,7 +2544,7 @@ def _formatted_phase(coefficient: complex, unicode: bool, precision: Optional[in def _draw_moment_groups_in_diagram( moment_groups: List[Tuple[int, int]], use_unicode_characters: bool, - out_diagram: TextDiagramDrawer, + out_diagram: 'cirq.TextDiagramDrawer', ): out_diagram.insert_empty_rows(0) h = out_diagram.height() @@ -2572,7 +2574,7 @@ def _draw_moment_groups_in_diagram( def _apply_unitary_circuit( - circuit: AbstractCircuit, + circuit: 'cirq.AbstractCircuit', state: np.ndarray, qubits: Tuple['cirq.Qid', ...], dtype: Type[np.number], diff --git a/cirq-core/cirq/circuits/circuit_dag.py b/cirq-core/cirq/circuits/circuit_dag.py index 33250aaeec6..4728fbe8132 100644 --- a/cirq-core/cirq/circuits/circuit_dag.py +++ b/cirq-core/cirq/circuits/circuit_dag.py @@ -74,7 +74,7 @@ def __init__( self, can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, incoming_graph_data: Any = None, - device: devices.Device = devices.UNCONSTRAINED_DEVICE, + device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE, ) -> None: """Initializes a CircuitDag. @@ -100,7 +100,7 @@ def make_node(op: 'cirq.Operation') -> Unique: @staticmethod def from_circuit( - circuit: circuit.Circuit, + circuit: 'cirq.Circuit', can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, ) -> 'CircuitDag': return CircuitDag.from_ops( @@ -111,7 +111,7 @@ def from_circuit( def from_ops( *operations: 'cirq.OP_TREE', can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, - device: devices.Device = devices.UNCONSTRAINED_DEVICE, + device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE, ) -> 'CircuitDag': dag = CircuitDag(can_reorder=can_reorder, device=device) for op in ops.flatten_op_tree(operations): @@ -147,21 +147,21 @@ def __ne__(self, other): __hash__ = None # type: ignore - def ordered_nodes(self) -> Iterator[Unique[ops.Operation]]: + def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]: if not self.nodes(): return g = self.copy() - def get_root_node(some_node: Unique[ops.Operation]) -> Unique[ops.Operation]: + def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']: pred = g.pred while pred[some_node]: some_node = next(iter(pred[some_node])) return some_node - def get_first_node() -> Unique[ops.Operation]: + def get_first_node() -> Unique['cirq.Operation']: return get_root_node(next(iter(g.nodes()))) - def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Operation]: + def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']: if succ: return get_root_node(next(iter(succ))) @@ -178,20 +178,20 @@ def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Oper node = get_next_node(succ) - def all_operations(self) -> Iterator[ops.Operation]: + def all_operations(self) -> Iterator['cirq.Operation']: return (node.val for node in self.ordered_nodes()) def all_qubits(self): return frozenset(q for node in self.nodes for q in node.val.qubits) - def to_circuit(self) -> circuit.Circuit: + def to_circuit(self) -> 'cirq.Circuit': return circuit.Circuit( self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST, device=self.device ) def findall_nodes_until_blocked( - self, is_blocker: Callable[[ops.Operation], bool] - ) -> Iterator[Unique[ops.Operation]]: + self, is_blocker: Callable[['cirq.Operation'], bool] + ) -> Iterator[Unique['cirq.Operation']]: """Finds all nodes before blocking ones. Args: diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index d10857d1a5f..7cfd135edc7 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -143,14 +143,14 @@ def __post_init__(self): # Ensure that param_resolver is converted to an actual ParamResolver. object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) - def base_operation(self) -> 'CircuitOperation': + def base_operation(self) -> 'cirq.CircuitOperation': """Returns a copy of this operation with only the wrapped circuit. Key and qubit mappings, parameter values, and repetitions are not copied. """ return CircuitOperation(self.circuit) - def replace(self, **changes) -> 'CircuitOperation': + def replace(self, **changes) -> 'cirq.CircuitOperation': """Returns a copy of this operation with the specified changes.""" return dataclasses.replace(self, **changes) @@ -435,7 +435,7 @@ def repeat( return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids) - def __pow__(self, power: int) -> 'CircuitOperation': + def __pow__(self, power: int) -> 'cirq.CircuitOperation': return self.repeat(power) def _with_key_path_(self, path: Tuple[str, ...]): @@ -462,13 +462,13 @@ def _with_rescoped_keys_( def with_key_path(self, path: Tuple[str, ...]): return self._with_key_path_(path) - def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation': + def with_repetition_ids(self, repetition_ids: List[str]) -> 'cirq.CircuitOperation': return self.replace(repetition_ids=repetition_ids) def with_qubit_mapping( self, qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']], - ) -> 'CircuitOperation': + ) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated qubit mapping. Users should pass either 'qubit_map' or 'transform' to this method. @@ -509,7 +509,7 @@ def with_qubit_mapping( ) return new_op - def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation': + def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated qubit mapping. Args: @@ -529,7 +529,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation': raise ValueError(f'Expected {expected} qubits, got {len(new_qubits)}.') return self.with_qubit_mapping(dict(zip(self.qubits, new_qubits))) - def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOperation': + def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated key mapping. Args: @@ -563,10 +563,12 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera ) return new_op - def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'CircuitOperation': + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation': return self.with_measurement_key_mapping(key_map) - def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'CircuitOperation': + def with_params( + self, param_values: 'cirq.ParamResolverOrSimilarType' + ) -> 'cirq.CircuitOperation': """Returns a copy of this operation with an updated ParamResolver. Note that any resulting parameter mappings with no corresponding @@ -592,7 +594,7 @@ def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'Circui # TODO: handle recursive parameter resolution gracefully def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool - ) -> 'CircuitOperation': + ) -> 'cirq.CircuitOperation': if recursive: raise ValueError( 'Recursive resolution of CircuitOperation parameters is prohibited. ' diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index be21ae0f95f..f7633907cd2 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -83,7 +83,7 @@ def moments(self) -> Sequence['cirq.Moment']: return self._moments @property - def device(self) -> devices.Device: + def device(self) -> 'cirq.Device': return self._device def __hash__(self): @@ -116,7 +116,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']: self._all_qubits = super().all_qubits() return self._all_qubits - def all_operations(self) -> Iterator[ops.Operation]: + def all_operations(self) -> Iterator['cirq.Operation']: if self._all_operations is None: self._all_operations = tuple(super().all_operations()) return iter(self._all_operations) @@ -152,29 +152,29 @@ def all_measurement_key_names(self) -> AbstractSet[str]: def _measurement_key_names_(self) -> AbstractSet[str]: return self.all_measurement_key_names() - def __add__(self, other) -> 'FrozenCircuit': + def __add__(self, other) -> 'cirq.FrozenCircuit': return (self.unfreeze() + other).freeze() - def __radd__(self, other) -> 'FrozenCircuit': + def __radd__(self, other) -> 'cirq.FrozenCircuit': return (other + self.unfreeze()).freeze() # Needed for numpy to handle multiplication by np.int64 correctly. __array_priority__ = 10000 # TODO: handle multiplication / powers differently? - def __mul__(self, other) -> 'FrozenCircuit': + def __mul__(self, other) -> 'cirq.FrozenCircuit': return (self.unfreeze() * other).freeze() - def __rmul__(self, other) -> 'FrozenCircuit': + def __rmul__(self, other) -> 'cirq.FrozenCircuit': return (other * self.unfreeze()).freeze() - def __pow__(self, other) -> 'FrozenCircuit': + def __pow__(self, other) -> 'cirq.FrozenCircuit': try: return (self.unfreeze() ** other).freeze() except: return NotImplemented - def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit': + def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.FrozenCircuit': new_circuit = FrozenCircuit(device=self.device) new_circuit._moments = tuple(moments) return new_circuit @@ -183,12 +183,12 @@ def with_device( self, new_device: 'cirq.Device', qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e, - ) -> 'FrozenCircuit': + ) -> 'cirq.FrozenCircuit': return self.unfreeze().with_device(new_device, qubit_mapping).freeze() def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool - ) -> 'FrozenCircuit': + ) -> 'cirq.FrozenCircuit': return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze() def tetris_concat( diff --git a/cirq-core/cirq/circuits/optimization_pass.py b/cirq-core/cirq/circuits/optimization_pass.py index 41593e59b59..943823fdd6c 100644 --- a/cirq-core/cirq/circuits/optimization_pass.py +++ b/cirq-core/cirq/circuits/optimization_pass.py @@ -13,13 +13,11 @@ # limitations under the License. """Defines the OptimizationPass type.""" -from typing import Dict, Callable, Iterable, Optional, Sequence, TYPE_CHECKING, Tuple, cast - import abc from collections import defaultdict +from typing import Dict, Callable, Iterable, Optional, Sequence, TYPE_CHECKING, Tuple, cast from cirq import ops -from cirq.circuits.circuit import Circuit if TYPE_CHECKING: import cirq @@ -90,7 +88,7 @@ class PointOptimizer: def __init__( self, post_clean_up: Callable[ - [Sequence['cirq.Operation']], ops.OP_TREE + [Sequence['cirq.Operation']], 'cirq.OP_TREE' ] = lambda op_list: op_list, ) -> None: """Inits PointOptimizer. @@ -102,13 +100,13 @@ def __init__( """ self.post_clean_up = post_clean_up - def __call__(self, circuit: Circuit): + def __call__(self, circuit: 'cirq.Circuit'): return self.optimize_circuit(circuit) @abc.abstractmethod def optimization_at( - self, circuit: Circuit, index: int, op: 'cirq.Operation' - ) -> Optional[PointOptimizationSummary]: + self, circuit: 'cirq.Circuit', index: int, op: 'cirq.Operation' + ) -> Optional['cirq.PointOptimizationSummary']: """Describes how to change operations near the given location. For example, this method could realize that the given operation is an @@ -128,7 +126,7 @@ def optimization_at( change should be made. """ - def optimize_circuit(self, circuit: Circuit): + def optimize_circuit(self, circuit: 'cirq.Circuit'): frontier: Dict['Qid', int] = defaultdict(lambda: 0) i = 0 while i < len(circuit): # Note: circuit may mutate as we go. diff --git a/cirq-core/cirq/circuits/qasm_output.py b/cirq-core/cirq/circuits/qasm_output.py index 1024938086f..912c1c9d8d8 100644 --- a/cirq-core/cirq/circuits/qasm_output.py +++ b/cirq-core/cirq/circuits/qasm_output.py @@ -97,7 +97,7 @@ def _from_json_dict_(cls, theta: float, phi: float, lmda: float, **kwargs) -> 'Q @value.value_equality class QasmTwoQubitGate(ops.Gate): - def __init__(self, kak: linalg.KakDecomposition) -> None: + def __init__(self, kak: 'cirq.KakDecomposition') -> None: """A two qubit gate represented in QASM by the KAK decomposition. All angles are in half turns. Assumes a canonicalized KAK diff --git a/cirq-core/cirq/circuits/text_diagram_drawer.py b/cirq-core/cirq/circuits/text_diagram_drawer.py index fef54fb5e19..b25e262fdfb 100644 --- a/cirq-core/cirq/circuits/text_diagram_drawer.py +++ b/cirq-core/cirq/circuits/text_diagram_drawer.py @@ -24,6 +24,7 @@ Optional, Sequence, Tuple, + TYPE_CHECKING, Union, ) @@ -39,6 +40,9 @@ DOUBLED_BOX_CHARS, ) +if TYPE_CHECKING: + import cirq + _HorizontalLine = NamedTuple( 'HorizontalLine', [ @@ -196,7 +200,7 @@ def horizontal_line( x1, x2 = sorted([x1, x2]) self.horizontal_lines.append(_HorizontalLine(y, x1, x2, emphasize, doubled)) - def transpose(self) -> 'TextDiagramDrawer': + def transpose(self) -> 'cirq.TextDiagramDrawer': """Returns the same diagram, but mirrored across its diagonal.""" out = TextDiagramDrawer() out.entries = { @@ -367,14 +371,14 @@ def copy(self): horizontal_padding=self.horizontal_padding, ) - def shift(self, dx: int = 0, dy: int = 0) -> 'TextDiagramDrawer': + def shift(self, dx: int = 0, dy: int = 0) -> 'cirq.TextDiagramDrawer': self._transform_coordinates(lambda x, y: (x + dx, y + dy)) return self - def shifted(self, dx: int = 0, dy: int = 0) -> 'TextDiagramDrawer': + def shifted(self, dx: int = 0, dy: int = 0) -> 'cirq.TextDiagramDrawer': return self.copy().shift(dx, dy) - def superimpose(self, other: 'TextDiagramDrawer') -> 'TextDiagramDrawer': + def superimpose(self, other: 'cirq.TextDiagramDrawer') -> 'cirq.TextDiagramDrawer': self.entries.update(other.entries) self.horizontal_lines += other.horizontal_lines self.vertical_lines += other.vertical_lines @@ -382,13 +386,13 @@ def superimpose(self, other: 'TextDiagramDrawer') -> 'TextDiagramDrawer': self.vertical_padding.update(other.vertical_padding) return self - def superimposed(self, other: 'TextDiagramDrawer') -> 'TextDiagramDrawer': + def superimposed(self, other: 'cirq.TextDiagramDrawer') -> 'cirq.TextDiagramDrawer': return self.copy().superimpose(other) @classmethod def vstack( cls, - diagrams: Sequence['TextDiagramDrawer'], + diagrams: Sequence['cirq.TextDiagramDrawer'], padding_resolver: Optional[Callable[[Sequence[Optional[int]]], int]] = None, ): """Vertically stack text diagrams. @@ -429,7 +433,7 @@ def vstack( @classmethod def hstack( cls, - diagrams: Sequence['TextDiagramDrawer'], + diagrams: Sequence['cirq.TextDiagramDrawer'], padding_resolver: Optional[Callable[[Sequence[Optional[int]]], int]] = None, ): """Horizontally stack text diagrams.