Skip to content

Commit

Permalink
Sanitize type annotations in cirq.circuits (quantumlib#4776)
Browse files Browse the repository at this point in the history
Follow-up to the cirq.sim one. Note this reduces things that cause recursive dependencies as well, so watch out for this in code reviews.
  • Loading branch information
daxfohl authored Dec 23, 2021
1 parent 99a8076 commit c8ee329
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 65 deletions.
38 changes: 20 additions & 18 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def moments(self) -> Sequence['cirq.Moment']:

@property
@abc.abstractmethod
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
pass

def freeze(self) -> 'cirq.FrozenCircuit':
Expand Down Expand Up @@ -589,7 +589,7 @@ def findall_operations_until_blocked(
self,
start_frontier: Dict['cirq.Qid', int],
is_blocker: Callable[['cirq.Operation'], bool] = lambda op: False,
) -> List[Tuple[int, ops.Operation]]:
) -> List[Tuple[int, 'cirq.Operation']]:
"""Finds all operations until a blocking operation is hit.
An operation is considered blocking if
Expand Down Expand Up @@ -740,7 +740,7 @@ def findall_operations(

def findall_operations_with_gate_type(
self, gate_type: Type[T_DESIRED_GATE_TYPE]
) -> Iterable[Tuple[int, ops.GateOperation, T_DESIRED_GATE_TYPE]]:
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
"""Find the locations of all gate operations of a given type.
Args:
Expand Down Expand Up @@ -852,7 +852,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
"""Returns the qubits acted upon by Operations in this circuit."""
return frozenset(q for m in self.moments for q in m.qubits)

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
"""Iterates over the operations applied by this circuit.
Operations from earlier moments will be iterated over first. Operations
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def to_text_diagram_drawer(
get_circuit_diagram_info: Optional[
Callable[['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo']
] = None,
) -> TextDiagramDrawer:
) -> 'cirq.TextDiagramDrawer':
"""Returns a TextDiagramDrawer with the circuit drawn into it.
Args:
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def _to_qasm_output(
header: Optional[str] = None,
precision: int = 10,
qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT,
) -> QasmOutput:
) -> 'cirq.QasmOutput':
"""Returns a QASM object equivalent to the circuit.
Args:
Expand All @@ -1273,7 +1273,7 @@ def _to_qasm_output(

def _to_quil_output(
self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
) -> QuilOutput:
) -> 'cirq.QuilOutput':
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return QuilOutput(operations=self.all_operations(), qubits=qubits)

Expand Down Expand Up @@ -1697,23 +1697,23 @@ def __init__(
self.append(contents, strategy=strategy)

@property
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
return self._device

@device.setter
def device(self, new_device: 'cirq.Device') -> None:
new_device.validate_circuit(self)
self._device = new_device

def __copy__(self) -> 'Circuit':
def __copy__(self) -> 'cirq.Circuit':
return self.copy()

def copy(self) -> 'Circuit':
def copy(self) -> 'cirq.Circuit':
copied_circuit = Circuit(device=self._device)
copied_circuit._moments = self._moments[:]
return copied_circuit

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit':
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.Circuit':
new_circuit = Circuit(device=self.device)
new_circuit._moments = list(moments)
return new_circuit
Expand Down Expand Up @@ -1793,7 +1793,7 @@ def __rmul__(self, repetitions: INT_TYPE):
return NotImplemented
return self * int(repetitions)

def __pow__(self, exponent: int) -> 'Circuit':
def __pow__(self, exponent: int) -> 'cirq.Circuit':
"""A circuit raised to a power, only valid for exponent -1, the inverse.
This will fail if anything other than -1 is passed to the Circuit by
Expand All @@ -1819,7 +1819,7 @@ def with_device(
self,
new_device: 'cirq.Device',
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
) -> 'Circuit':
) -> 'cirq.Circuit':
"""Maps the current circuit onto a new device, and validates.
Args:
Expand Down Expand Up @@ -2296,7 +2296,9 @@ def clear_operations_touching(
if 0 <= k < len(self._moments):
self._moments[k] = self._moments[k].without_operations_touching(qubits)

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Circuit':
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'cirq.Circuit':
resolved_moments = []
for moment in self:
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
Expand Down Expand Up @@ -2391,7 +2393,7 @@ def _draw_moment_annotations(
col: int,
use_unicode_characters: bool,
label_map: Dict['cirq.LabelEntity', int],
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
precision: Optional[int],
get_circuit_diagram_info: Callable[
['cirq.Operation', 'cirq.CircuitDiagramInfoArgs'], 'cirq.CircuitDiagramInfo'
Expand Down Expand Up @@ -2421,7 +2423,7 @@ def _draw_moment_in_diagram(
moment: 'cirq.Moment',
use_unicode_characters: bool,
label_map: Dict['cirq.LabelEntity', int],
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
precision: Optional[int],
moment_groups: List[Tuple[int, int]],
get_circuit_diagram_info: Optional[
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def _formatted_phase(coefficient: complex, unicode: bool, precision: Optional[in
def _draw_moment_groups_in_diagram(
moment_groups: List[Tuple[int, int]],
use_unicode_characters: bool,
out_diagram: TextDiagramDrawer,
out_diagram: 'cirq.TextDiagramDrawer',
):
out_diagram.insert_empty_rows(0)
h = out_diagram.height()
Expand Down Expand Up @@ -2572,7 +2574,7 @@ def _draw_moment_groups_in_diagram(


def _apply_unitary_circuit(
circuit: AbstractCircuit,
circuit: 'cirq.AbstractCircuit',
state: np.ndarray,
qubits: Tuple['cirq.Qid', ...],
dtype: Type[np.number],
Expand Down
22 changes: 11 additions & 11 deletions cirq/circuits/circuit_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self,
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
incoming_graph_data: Any = None,
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
) -> None:
"""Initializes a CircuitDag.
Expand All @@ -100,7 +100,7 @@ def make_node(op: 'cirq.Operation') -> Unique:

@staticmethod
def from_circuit(
circuit: circuit.Circuit,
circuit: 'cirq.Circuit',
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
) -> 'CircuitDag':
return CircuitDag.from_ops(
Expand All @@ -111,7 +111,7 @@ def from_circuit(
def from_ops(
*operations: 'cirq.OP_TREE',
can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits,
device: devices.Device = devices.UNCONSTRAINED_DEVICE,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
) -> 'CircuitDag':
dag = CircuitDag(can_reorder=can_reorder, device=device)
for op in ops.flatten_op_tree(operations):
Expand Down Expand Up @@ -147,21 +147,21 @@ def __ne__(self, other):

__hash__ = None # type: ignore

def ordered_nodes(self) -> Iterator[Unique[ops.Operation]]:
def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]:
if not self.nodes():
return
g = self.copy()

def get_root_node(some_node: Unique[ops.Operation]) -> Unique[ops.Operation]:
def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']:
pred = g.pred
while pred[some_node]:
some_node = next(iter(pred[some_node]))
return some_node

def get_first_node() -> Unique[ops.Operation]:
def get_first_node() -> Unique['cirq.Operation']:
return get_root_node(next(iter(g.nodes())))

def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Operation]:
def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']:
if succ:
return get_root_node(next(iter(succ)))

Expand All @@ -178,20 +178,20 @@ def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique[ops.Oper

node = get_next_node(succ)

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
return (node.val for node in self.ordered_nodes())

def all_qubits(self):
return frozenset(q for node in self.nodes for q in node.val.qubits)

def to_circuit(self) -> circuit.Circuit:
def to_circuit(self) -> 'cirq.Circuit':
return circuit.Circuit(
self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST, device=self.device
)

def findall_nodes_until_blocked(
self, is_blocker: Callable[[ops.Operation], bool]
) -> Iterator[Unique[ops.Operation]]:
self, is_blocker: Callable[['cirq.Operation'], bool]
) -> Iterator[Unique['cirq.Operation']]:
"""Finds all nodes before blocking ones.
Args:
Expand Down
22 changes: 12 additions & 10 deletions cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def __post_init__(self):
# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

def base_operation(self) -> 'CircuitOperation':
def base_operation(self) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with only the wrapped circuit.
Key and qubit mappings, parameter values, and repetitions are not copied.
"""
return CircuitOperation(self.circuit)

def replace(self, **changes) -> 'CircuitOperation':
def replace(self, **changes) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with the specified changes."""
return dataclasses.replace(self, **changes)

Expand Down Expand Up @@ -435,7 +435,7 @@ def repeat(

return self.replace(repetitions=final_repetitions, repetition_ids=repetition_ids)

def __pow__(self, power: int) -> 'CircuitOperation':
def __pow__(self, power: int) -> 'cirq.CircuitOperation':
return self.repeat(power)

def _with_key_path_(self, path: Tuple[str, ...]):
Expand All @@ -462,13 +462,13 @@ def _with_rescoped_keys_(
def with_key_path(self, path: Tuple[str, ...]):
return self._with_key_path_(path)

def with_repetition_ids(self, repetition_ids: List[str]) -> 'CircuitOperation':
def with_repetition_ids(self, repetition_ids: List[str]) -> 'cirq.CircuitOperation':
return self.replace(repetition_ids=repetition_ids)

def with_qubit_mapping(
self,
qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
) -> 'CircuitOperation':
) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated qubit mapping.
Users should pass either 'qubit_map' or 'transform' to this method.
Expand Down Expand Up @@ -509,7 +509,7 @@ def with_qubit_mapping(
)
return new_op

def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated qubit mapping.
Args:
Expand All @@ -529,7 +529,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'CircuitOperation':
raise ValueError(f'Expected {expected} qubits, got {len(new_qubits)}.')
return self.with_qubit_mapping(dict(zip(self.qubits, new_qubits)))

def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOperation':
def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated key mapping.
Args:
Expand Down Expand Up @@ -563,10 +563,12 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
)
return new_op

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'CircuitOperation':
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.CircuitOperation':
return self.with_measurement_key_mapping(key_map)

def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'CircuitOperation':
def with_params(
self, param_values: 'cirq.ParamResolverOrSimilarType'
) -> 'cirq.CircuitOperation':
"""Returns a copy of this operation with an updated ParamResolver.
Note that any resulting parameter mappings with no corresponding
Expand All @@ -592,7 +594,7 @@ def with_params(self, param_values: study.ParamResolverOrSimilarType) -> 'Circui
# TODO: handle recursive parameter resolution gracefully
def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'CircuitOperation':
) -> 'cirq.CircuitOperation':
if recursive:
raise ValueError(
'Recursive resolution of CircuitOperation parameters is prohibited. '
Expand Down
20 changes: 10 additions & 10 deletions cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def moments(self) -> Sequence['cirq.Moment']:
return self._moments

@property
def device(self) -> devices.Device:
def device(self) -> 'cirq.Device':
return self._device

def __hash__(self):
Expand Down Expand Up @@ -116,7 +116,7 @@ def all_qubits(self) -> FrozenSet['cirq.Qid']:
self._all_qubits = super().all_qubits()
return self._all_qubits

def all_operations(self) -> Iterator[ops.Operation]:
def all_operations(self) -> Iterator['cirq.Operation']:
if self._all_operations is None:
self._all_operations = tuple(super().all_operations())
return iter(self._all_operations)
Expand Down Expand Up @@ -152,29 +152,29 @@ def all_measurement_key_names(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> AbstractSet[str]:
return self.all_measurement_key_names()

def __add__(self, other) -> 'FrozenCircuit':
def __add__(self, other) -> 'cirq.FrozenCircuit':
return (self.unfreeze() + other).freeze()

def __radd__(self, other) -> 'FrozenCircuit':
def __radd__(self, other) -> 'cirq.FrozenCircuit':
return (other + self.unfreeze()).freeze()

# Needed for numpy to handle multiplication by np.int64 correctly.
__array_priority__ = 10000

# TODO: handle multiplication / powers differently?
def __mul__(self, other) -> 'FrozenCircuit':
def __mul__(self, other) -> 'cirq.FrozenCircuit':
return (self.unfreeze() * other).freeze()

def __rmul__(self, other) -> 'FrozenCircuit':
def __rmul__(self, other) -> 'cirq.FrozenCircuit':
return (other * self.unfreeze()).freeze()

def __pow__(self, other) -> 'FrozenCircuit':
def __pow__(self, other) -> 'cirq.FrozenCircuit':
try:
return (self.unfreeze() ** other).freeze()
except:
return NotImplemented

def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit':
def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'cirq.FrozenCircuit':
new_circuit = FrozenCircuit(device=self.device)
new_circuit._moments = tuple(moments)
return new_circuit
Expand All @@ -183,12 +183,12 @@ def with_device(
self,
new_device: 'cirq.Device',
qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e,
) -> 'FrozenCircuit':
) -> 'cirq.FrozenCircuit':
return self.unfreeze().with_device(new_device, qubit_mapping).freeze()

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'FrozenCircuit':
) -> 'cirq.FrozenCircuit':
return self.unfreeze()._resolve_parameters_(resolver, recursive).freeze()

def tetris_concat(
Expand Down
Loading

0 comments on commit c8ee329

Please sign in to comment.