Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update files to conform to new mypy standard #6662

Merged
merged 7 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def test_repeat(add_measurements: bool, use_default_ids_for_initial_rep: bool) -
_ = op_base.repeat()

with pytest.raises(TypeError, match='Only integer or sympy repetitions are allowed'):
_ = op_base.repeat(1.3) # type: ignore[arg-type]
assert op_base.repeat(3.00000000001).repetitions == 3 # type: ignore[arg-type]
assert op_base.repeat(2.99999999999).repetitions == 3 # type: ignore[arg-type]
_ = op_base.repeat(1.3)
assert op_base.repeat(3.00000000001).repetitions == 3
assert op_base.repeat(2.99999999999).repetitions == 3


@pytest.mark.parametrize('add_measurements', [True, False])
Expand Down
7 changes: 4 additions & 3 deletions cirq-core/cirq/experiments/qubit_characterizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import functools

from typing import (
Any,
cast,
Any,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -107,7 +107,6 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes:
show_plot = not ax
if not ax:
fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover
ax = cast(plt.Axes, ax) # pragma: no cover
ax.set_ylim((0.0, 1.0)) # pragma: no cover
ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro', label='data', **plot_kwargs)
x = np.linspace(self._num_cfds_seq[0], self._num_cfds_seq[-1], 100)
Expand Down Expand Up @@ -304,7 +303,9 @@ def plot(self, axes: Optional[List[plt.Axes]] = None, **plot_kwargs: Any) -> Lis
"""
show_plot = axes is None
if axes is None:
fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.0), subplot_kw={'projection': '3d'})
fig, axes_v = plt.subplots(1, 2, figsize=(12.0, 5.0), subplot_kw={'projection': '3d'})
axes_v = cast(np.ndarray, axes_v)
axes = list(axes_v)
elif len(axes) != 2:
raise ValueError('A TomographyResult needs 2 axes to plot.')
mat = self._density_matrix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Single qubit readout experiments using parallel or isolated statistics."""
import dataclasses
import time
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
from typing import cast, Any, Dict, Iterable, List, Optional, TYPE_CHECKING

import sympy
import numpy as np
Expand Down Expand Up @@ -77,8 +77,9 @@ def plot_heatmap(
"""

if axs is None:
_, axs = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4))

_, axs_v = plt.subplots(1, 2, dpi=200, facecolor='white', figsize=(12, 4))
axs_v = cast(np.ndarray, axs_v)
axs = cast(tuple[plt.Axes, plt.Axes], (axs_v[0], axs_v[1]))
else:
if (
not isinstance(axs, (tuple, list, np.ndarray))
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/interop/quirk/cells/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def apply(op: Union[str, _HangingNode]) -> None:
a = vals.pop()
# Note: vals seems to be _HangingToken
# func operates on _ResolvedTokens. Ignoring type issues for now.
vals.append(op.func(a, b)) # type: ignore[arg-type]
vals.append(op.func(a, b))

def close_paren() -> None:
while True:
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def test_cphase_unitary(angle_rads, expected_unitary):
np.testing.assert_allclose(cirq.unitary(cirq.cphase(angle_rads)), expected_unitary)


# TODO(#6663): fix this use case.
@pytest.mark.xfail
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this stopped working? If just the latter assertion is broken can we break it out into it's own test and mark it as failed / track a fix for it, WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like either comparison between sympy numbers and floats changed or something changed that changes the functions called during comparison ... either way I opened #6663 to fix it.

def test_parameterized_cphase():
assert cirq.cphase(sympy.pi) == cirq.CZ
assert cirq.cphase(sympy.pi / 2) == cirq.CZ**0.5
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class GlobalPhaseGate(raw_types.Gate):
def __init__(self, coefficient: 'cirq.TParamValComplex', atol: float = 1e-8) -> None:
if not isinstance(coefficient, sympy.Basic):
if abs(1 - abs(coefficient)) > atol: # type: ignore[operator]
if abs(1 - abs(coefficient)) > atol:
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
self._coefficient = coefficient

Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,10 @@ def test_run_param_resolver(dtype: Type[np.complexfloating], split: bool):
cirq.measure(q1),
)
param_resolver = {'b0': b0, 'b1': b1}
result = simulator.run(circuit, param_resolver=param_resolver) # type: ignore
result = simulator.run(circuit, param_resolver=param_resolver)
np.testing.assert_equal(result.measurements, {'q(0)': [[b0]], 'q(1)': [[b1]]})
# pylint: disable=line-too-long
np.testing.assert_equal(result.params, cirq.ParamResolver(param_resolver)) # type: ignore
np.testing.assert_equal(result.params, cirq.ParamResolver(param_resolver))


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,11 +498,11 @@ def test_simulate_param_resolver(dtype: Type[np.complexfloating], split: bool):
(cirq.X ** sympy.Symbol('b0'))(q0), (cirq.X ** sympy.Symbol('b1'))(q1)
)
resolver = {'b0': b0, 'b1': b1}
result = simulator.simulate(circuit, param_resolver=resolver) # type: ignore
result = simulator.simulate(circuit, param_resolver=resolver)
expected_state = np.zeros(shape=(2, 2))
expected_state[b0][b1] = 1.0
np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
assert result.params == cirq.ParamResolver(resolver) # type: ignore
assert result.params == cirq.ParamResolver(resolver)
assert len(result.measurements) == 0


Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/study/flatten_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def __init__(
params = param_dict if param_dict else {}
# TODO: Support complex values for typing below.
symbol_params: resolver.ParamDictType = {
_ensure_not_str(param): _ensure_not_str(val) # type: ignore[misc]
for param, val in params.items()
_ensure_not_str(param): _ensure_not_str(val) for param, val in params.items()
}
super().__init__(symbol_params)
if get_param_name is None:
Expand Down
12 changes: 5 additions & 7 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def value_of(
if isinstance(param_value, str):
param_value = sympy.Symbol(param_value)
elif not isinstance(param_value, sympy.Basic):
return value # type: ignore[return-value]
return value
if recursive:
param_value = self._value_of_recursive(value)
return param_value # type: ignore[return-value]
return param_value

if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
Expand Down Expand Up @@ -207,7 +207,7 @@ def _value_of_recursive(self, value: 'cirq.TParamKey') -> 'cirq.TParamValComplex

# There isn't a full evaluation for 'value' yet. Until it's ready,
# map value to None to identify loops in component evaluation.
self._deep_eval_map[value] = _RECURSION_FLAG # type: ignore
self._deep_eval_map[value] = _RECURSION_FLAG

v = self.value_of(value, recursive=False)
if v == value:
Expand All @@ -220,10 +220,8 @@ def _resolve_parameters_(self, resolver: 'ParamResolver', recursive: bool) -> 'P
new_dict: Dict['cirq.TParamKey', Union[float, str, sympy.Symbol, sympy.Expr]] = {
k: k for k in resolver
}
new_dict.update({k: self.value_of(k, recursive) for k in self}) # type: ignore[misc]
new_dict.update(
{k: resolver.value_of(v, recursive) for k, v in new_dict.items()} # type: ignore[misc]
)
new_dict.update({k: self.value_of(k, recursive) for k in self})
new_dict.update({k: resolver.value_of(v, recursive) for k, v in new_dict.items()})
if recursive and self._param_dict:
new_resolver = ParamResolver(cast(ParamDictType, new_dict))
# Resolve down to single-step mappings.
Expand Down
6 changes: 1 addition & 5 deletions cirq-core/cirq/study/sweepable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ def to_sweeps(sweepable: Sweepable, metadata: Optional[dict] = None) -> List[Swe
product_sweep = dict_to_product_sweep(sweepable)
return [_resolver_to_sweep(resolver, metadata) for resolver in product_sweep]
if isinstance(sweepable, Iterable) and not isinstance(sweepable, str):
return [
sweep
for item in sweepable
for sweep in to_sweeps(item, metadata) # type: ignore[arg-type]
]
return [sweep for item in sweepable for sweep in to_sweeps(item, metadata)]
raise TypeError(f'Unrecognized sweepable type: {type(sweepable)}.\nsweepable: {sweepable}')


Expand Down
5 changes: 1 addition & 4 deletions cirq-core/cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,7 @@ def dict_to_product_sweep(factor_dict: ProductOrZipSweepLike) -> Product:
Cartesian product of the sweeps.
"""
return Product(
*(
Points(k, v if isinstance(v, Sequence) else [v]) # type: ignore
for k, v in factor_dict.items()
)
*(Points(k, v if isinstance(v, Sequence) else [v]) for k, v in factor_dict.items())
)


Expand Down
2 changes: 0 additions & 2 deletions cirq-core/cirq/vis/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def plot(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(figsize=(8, 8))
ax = cast(plt.Axes, ax)
original_config = copy.deepcopy(self._config)
self.update_config(**kwargs)

Expand Down Expand Up @@ -416,7 +415,6 @@ def plot(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(figsize=(8, 8))
ax = cast(plt.Axes, ax)
original_config = copy.deepcopy(self._config)
self.update_config(**kwargs)
qubits = set([q for qubits in self._value_map.keys() for q in qubits])
Expand Down
3 changes: 1 addition & 2 deletions cirq-core/cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Tool to visualize the results of a study."""

from typing import cast, Optional, Sequence, SupportsFloat, Union
from typing import Optional, Sequence, SupportsFloat, Union
import collections
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -87,7 +87,6 @@ def plot_state_histogram(
show_fig = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
ax = cast(plt.Axes, ax)
if isinstance(data, result.Result):
values = get_state_histogram(data)
elif isinstance(data, collections.Counter):
Expand Down
1 change: 0 additions & 1 deletion cirq-google/cirq_google/engine/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def plot_histograms(
show_plot = not ax
if not ax:
fig, ax = plt.subplots(1, 1)
ax = cast(plt.Axes, ax)

if isinstance(keys, str):
keys = [keys]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def circuit_proto(json: Dict, qubits: List[str]):
op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}),
),
(
cirq.XPowGate(exponent=np.double(0.125))(Q1), # type: ignore
cirq.XPowGate(exponent=np.double(0.125))(Q1),
op_proto({'xpowgate': {'exponent': {'float_value': 0.125}}, 'qubit_constant_index': [0]}),
),
(
cirq.XPowGate(exponent=np.short(1))(Q1), # type: ignore
cirq.XPowGate(exponent=np.short(1))(Q1),
op_proto({'xpowgate': {'exponent': {'float_value': 1.0}}, 'qubit_constant_index': [0]}),
),
(
Expand Down
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/serialization/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,5 @@ def from_proto(
)

return cirq.CircuitOperation(
circuit, repetitions, qubit_map, measurement_key_map, arg_map, rep_ids # type: ignore
circuit, repetitions, qubit_map, measurement_key_map, arg_map, rep_ids
)
4 changes: 2 additions & 2 deletions dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
exclude = dev_tools/modules_test_data/.*/setup\.py
exclude = dev_tools/modules_test_data/.*/setup\.py|cirq-rigetti/(?!__init__.)*\.py # Temporarily exclude cirq-rigetti (see #6661)
show_error_codes = true
plugins = duet.typing
warn_unused_ignores = true
Expand All @@ -12,7 +12,7 @@ ignore_missing_imports = true
# 3rd-party libs for which we don't have stubs

# Google
[mypy-google.api_core.*,google.auth.*,google.colab.*,google.cloud.*]
[mypy-google.api_core.*,google.auth.*,google.colab.*,google.cloud.*,google.oauth2.*]
follow_imports = silent
ignore_missing_imports = true

Expand Down