Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize type annotations in cirq.circuits #4776

Merged
merged 6 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def moments(self) -> Sequence['cirq.Moment']:

@property
@abc.abstractmethod
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
pass

def freeze(self) -> 'cirq.FrozenCircuit':
Expand Down Expand Up @@ -589,7 +589,7 @@ def findall_operations_until_blocked(
self,
start_frontier: Dict['cirq.Qid', int],
is_blocker: Callable[['cirq.Operation'], bool] = lambda op: False,
) -> List[Tuple[int, ops.Operation]]:
) -> List[Tuple[int, 'cirq.Operation']]:
"""Finds all operations until a blocking operation is hit.

An operation is considered blocking if
Expand Down Expand Up @@ -740,7 +740,7 @@ def findall_operations(

def findall_operations_with_gate_type(
self, gate_type: Type[T_DESIRED_GATE_TYPE]
) -> Iterable[Tuple[int, ops.GateOperation, T_DESIRED_GATE_TYPE]]:
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
"""Find the locations of all gate operations of a given type.

Args:
Expand Down Expand Up @@ -852,7 +852,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
"""Returns the qubits acted upon by Operations in this circuit."""
return frozenset(q for m in self.moments for q in m.qubits)

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
"""Iterates over the operations applied by this circuit.

Operations from earlier moments will be iterated over first. Operations
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def to_text_diagram_drawer(
get_circuit_diagram_info: Optional[
Callable[['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo']
] = None,
) -> TextDiagramDrawer:
) -> 'cirq.TextDiagramDrawer':
"""Returns a TextDiagramDrawer with the circuit drawn into it.

Args:
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def _to_qasm_output(
header: Optional[str] = None,
precision: int = 10,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
) -> QasmOutput:
) -> 'cirq.QasmOutput':
"""Returns a QASM object equivalent to the circuit.

Args:
Expand All @@ -1273,7 +1273,7 @@ def _to_qasm_output(

def _to_quil_output(
self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
) -> QuilOutput:
) -> 'cirq.QuilOutput':
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return QuilOutput(operations=self.all_operations(), qubits=qubits)

Expand Down Expand Up @@ -1697,23 +1697,23 @@ def __init__(
self.append(contents, strategy=strategy)

@property
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
return self._device

@device.setter
def device(self, new_device: 'cirq.Device') -> None:
new_device.validate_circuit(self)
self._device = new_device

def __copy__(self) -> 'Circuit':
def __copy__(self) -> 'cirq.Circuit':
return self.copy()

def copy(self) -> 'Circuit':
def copy(self) -> 'cirq.Circuit':
copied_circuit = Circuit(device=self._device)
copied_circuit._moments = self._moments[:]
return copied_circuit

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit':
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.Circuit':
new_circuit = Circuit(device=self.device)
new_circuit._moments = list(moments)
return new_circuit
Expand Down Expand Up @@ -1793,7 +1793,7 @@ def __rmul__(self, repetitions: INT_TYPE):
return NotImplemented
return self * int(repetitions)

def __pow__(self, exponent: int) -> 'Circuit':
def __pow__(self, exponent: int) -> 'cirq.Circuit':
"""A circuit raised to a power, only valid for exponent -1, the inverse.

This will fail if anything other than -1 is passed to the Circuit by
Expand All @@ -1819,7 +1819,7 @@ def with_device(
self,
new_device: 'cirq.Device',
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
) -> 'Circuit':
) -> 'cirq.Circuit':
"""Maps the current circuit onto a new device, and validates.

Args:
Expand Down Expand Up @@ -2296,7 +2296,9 @@ def clear_operations_touching(
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Circuit':
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Circuit':
resolved_moments = []
for moment in self:
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
Expand Down Expand Up @@ -2391,7 +2393,7 @@ def _draw_moment_annotations(
col: int,
use_unicode_characters: bool,
label_map: Dict['cirq.LabelEntity', int],
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
precision: Optional[int],
get_circuit_diagram_info: Callable[
['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo'
Expand Down Expand Up @@ -2421,7 +2423,7 @@ def _draw_moment_in_diagram(
moment: 'cirq.Moment',
use_unicode_characters: bool,
label_map: Dict['cirq.LabelEntity', int],
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
precision: Optional[int],
moment_groups: List[Tuple[int, int]],
get_circuit_diagram_info: Optional[
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def _formatted_phase(coefficient: complex, unicode: bool, precision: Optional[in
def _draw_moment_groups_in_diagram(
moment_groups: List[Tuple[int, int]],
use_unicode_characters: bool,
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
):
out_diagram.insert_empty_rows(0)
h = out_diagram.height()
Expand Down Expand Up @@ -2572,7 +2574,7 @@ def _draw_moment_groups_in_diagram(


def _apply_unitary_circuit(
circuit: AbstractCircuit,
circuit: 'cirq.AbstractCircuit',
state: np.ndarray,
qubits: Tuple['cirq.Qid', ...],
dtype: Type[np.number],
Expand Down
22 changes: 11 additions & 11 deletions cirq-core/cirq/circuits/circuit_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self,
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
incoming_graph_data: Any = None,
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
) -> None:
"""Initializes a CircuitDag.

Expand All @@ -100,7 +100,7 @@ def make_node(op: 'cirq.Operation') -> Unique:

@staticmethod
def from_circuit(
circuit: circuit.Circuit,
circuit: 'cirq.Circuit',
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
) -> 'CircuitDag':
return CircuitDag.from_ops(
Expand All @@ -111,7 +111,7 @@ def from_circuit(
def from_ops(
*operations: 'cirq.OP_TREE',
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
) -> 'CircuitDag':
dag = CircuitDag(can_reorder=can_reorder, device=device)
for op in ops.flatten_op_tree(operations):
Expand Down Expand Up @@ -147,21 +147,21 @@ def __ne__(self, other):

__hash__ = None # type: ignore

def ordered_nodes(self) -> Iterator[Unique[ops.Operation]]:
def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]:
if not self.nodes():
return
g = self.copy()

def get_root_node(some_node: Unique[ops.Operation]) -> Unique[ops.Operation]:
def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']:
pred = g.pred
while pred[some_node]:
some_node = next(iter(pred[some_node]))
return some_node

def get_first_node() -> Unique[ops.Operation]:
def get_first_node() -> Unique['cirq.Operation']:
return get_root_node(next(iter(g.nodes())))

def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Operation]:
def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']:
if succ:
return get_root_node(next(iter(succ)))

Expand All @@ -178,20 +178,20 @@ def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Oper

node = get_next_node(succ)

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
return (node.val for node in self.ordered_nodes())

def all_qubits(self):
return frozenset(q for node in self.nodes for q in node.val.qubits)

def to_circuit(self) -> circuit.Circuit:
def to_circuit(self) -> 'cirq.Circuit':
return circuit.Circuit(
self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST, device=self.device
)

def findall_nodes_until_blocked(
self, is_blocker: Callable[[ops.Operation], bool]
) -> Iterator[Unique[ops.Operation]]:
self, is_blocker: Callable[['cirq.Operation'], bool]
) -> Iterator[Unique['cirq.Operation']]:
"""Finds all nodes before blocking ones.

Args:
Expand Down
22 changes: 12 additions & 10 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def __post_init__(self):
# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

def base_operation(self) -> 'CircuitOperation':
def base_operation(self) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with only the wrapped circuit.

Key and qubit mappings, parameter values, and repetitions are not copied.
"""
return CircuitOperation(self.circuit)

def replace(self, **changes) -> 'CircuitOperation':
def replace(self, **changes) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with the specified changes."""
return dataclasses.replace(self, **changes)

Expand Down Expand Up @@ -435,7 +435,7 @@ def repeat(

return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: int) -> 'CircuitOperation':
def __pow__(self, power: int) -> 'cirq.CircuitOperation':
return self.repeat(power)

def _with_key_path_(self, path: Tuple[str, ...]):
Expand All @@ -462,13 +462,13 @@ def _with_rescoped_keys_(
def with_key_path(self, path: Tuple[str, ...]):
return self._with_key_path_(path)

def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
def with_repetition_ids(self, repetition_ids: List[str]) -> 'cirq.CircuitOperation':
return self.replace(repetition_ids=repetition_ids)

def with_qubit_mapping(
self,
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
) -> 'CircuitOperation':
) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated qubit mapping.

Users should pass either 'qubit_map' or 'transform' to this method.
Expand Down Expand Up @@ -509,7 +509,7 @@ def with_qubit_mapping(
)
return new_op

def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated qubit mapping.

Args:
Expand All @@ -529,7 +529,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
raise ValueError(f'Expected {expected} qubits, got {len(new_qubits)}.')
return self.with_qubit_mapping(dict(zip(self.qubits, new_qubits)))

def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOperation':
def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated key mapping.

Args:
Expand Down Expand Up @@ -563,10 +563,12 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
)
return new_op

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'CircuitOperation':
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
return self.with_measurement_key_mapping(key_map)

def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'CircuitOperation':
def with_params(
self, param_values: 'cirq.ParamResolverOrSimilarType'
) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated ParamResolver.

Note that any resulting parameter mappings with no corresponding
Expand All @@ -592,7 +594,7 @@ def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'Circui
# TODO: handle recursive parameter resolution gracefully
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'CircuitOperation':
) -> 'cirq.CircuitOperation':
if recursive:
raise ValueError(
'Recursive resolution of CircuitOperation parameters is prohibited. '
Expand Down
20 changes: 10 additions & 10 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def moments(self) -> Sequence['cirq.Moment']:
return self._moments

@property
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
return self._device

def __hash__(self):
Expand Down Expand Up @@ -116,7 +116,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
self._all_qubits = super().all_qubits()
return self._all_qubits

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
if self._all_operations is None:
self._all_operations = tuple(super().all_operations())
return iter(self._all_operations)
Expand Down Expand Up @@ -152,29 +152,29 @@ def all_measurement_key_names(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> AbstractSet[str]:
return self.all_measurement_key_names()

def __add__(self, other) -> 'FrozenCircuit':
def __add__(self, other) -> 'cirq.FrozenCircuit':
return (self.unfreeze() + other).freeze()

def __radd__(self, other) -> 'FrozenCircuit':
def __radd__(self, other) -> 'cirq.FrozenCircuit':
return (other + self.unfreeze()).freeze()

# Needed for numpy to handle multiplication by np.int64 correctly.
__array_priority__ = 10000

# TODO: handle multiplication / powers differently?
def __mul__(self, other) -> 'FrozenCircuit':
def __mul__(self, other) -> 'cirq.FrozenCircuit':
return (self.unfreeze() * other).freeze()

def __rmul__(self, other) -> 'FrozenCircuit':
def __rmul__(self, other) -> 'cirq.FrozenCircuit':
return (other * self.unfreeze()).freeze()

def __pow__(self, other) -> 'FrozenCircuit':
def __pow__(self, other) -> 'cirq.FrozenCircuit':
try:
return (self.unfreeze() ** other).freeze()
except:
return NotImplemented

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.FrozenCircuit':
new_circuit = FrozenCircuit(device=self.device)
new_circuit._moments = tuple(moments)
return new_circuit
Expand All @@ -183,12 +183,12 @@ def with_device(
self,
new_device: 'cirq.Device',
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
) -> 'FrozenCircuit':
) -> 'cirq.FrozenCircuit':
return self.unfreeze().with_device(new_device, qubit_mapping).freeze()

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'FrozenCircuit':
) -> 'cirq.FrozenCircuit':
return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze()

def tetris_concat(
Expand Down
Loading