Skip to content

Commit

Permalink
Fix more mypy --next type errors (quantumlib#5392)
Browse files Browse the repository at this point in the history
  • Loading branch information
dabacon authored May 24, 2022
1 parent 5f1f238 commit 4a369e2
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 27 deletions.
6 changes: 4 additions & 2 deletions cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ def value_of(
if isinstance(value, sympy.Pow) and len(value.args) == 2:
base = self.value_of(value.args[0], recursive)
exponent = self.value_of(value.args[1], recursive)
# Casts because numpy can handle expressions (by delegating to __pow__), but does
# not have signature that will support this.
if isinstance(base, numbers.Number):
return np.float_power(base, exponent)
return np.power(base, exponent)
return np.float_power(cast(complex, base), cast(complex, exponent))
return np.power(cast(complex, base), cast(complex, exponent))

if not isinstance(value, sympy.Basic):
# No known way to resolve this variable, return unchanged.
Expand Down
9 changes: 6 additions & 3 deletions cirq/transformers/eject_phased_paulis.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _try_get_known_phased_pauli(
elif (
isinstance(gate, ops.PhasedXZGate)
and not protocols.is_parameterized(gate.z_exponent)
and np.isclose(gate.z_exponent, 0)
and np.isclose(float(gate.z_exponent), 0)
):
e = gate.x_exponent
p = gate.axis_phase_exponent
Expand All @@ -336,9 +336,12 @@ def _try_get_known_z_half_turns(
g = op.gate
if (
isinstance(g, ops.PhasedXZGate)
and np.isclose(g.x_exponent, 0)
and np.isclose(g.axis_phase_exponent, 0)
and not protocols.is_parameterized(g.x_exponent)
and not protocols.is_parameterized(g.axis_phase_exponent)
and np.isclose(float(g.x_exponent), 0)
and np.isclose(float(g.axis_phase_exponent), 0)
):

h = g.z_exponent
elif isinstance(g, ops.ZPowGate):
h = g.exponent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def random_qubit_unitary(
rng: Random number generator to be used in sampling. Default is
numpy.random.
"""
rng = np.random if rng is None else rng
real_rng: np.random.RandomState = np.random if rng is None else rng

theta = np.arcsin(np.sqrt(rng.rand(*shape)))
phi_d = rng.rand(*shape) * np.pi * 2
phi_o = rng.rand(*shape) * np.pi * 2
theta = np.arcsin(np.sqrt(real_rng.rand(*shape)))
phi_d = real_rng.rand(*shape) * np.pi * 2
phi_o = real_rng.rand(*shape) * np.pi * 2

out = _single_qubit_unitary(theta, phi_d, phi_o)

if randomize_global_phase:
out = np.moveaxis(out, (-2, -1), (0, 1))
out *= np.exp(1j * np.pi * 2 * rng.rand(*shape))
out *= np.exp(1j * np.pi * 2 * real_rng.rand(*shape))
out = np.moveaxis(out, (0, 1), (-2, -1))
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Attempt to tabulate single qubit gates required to generate a target 2Q gate
with a product A k A."""
from functools import reduce
from typing import Tuple, Sequence, List, NamedTuple
from typing import List, NamedTuple, Sequence, Tuple

from dataclasses import dataclass
import numpy as np
Expand Down Expand Up @@ -100,7 +100,7 @@ def compile_two_qubit_gate(self, unitary: np.ndarray) -> TwoQubitGateTabulationR
unitary = np.asarray(unitary)
kak_vec = cirq.kak_vector(unitary, check_preconditions=False)
infidelities = kak_vector_infidelity(kak_vec, self.kak_vecs, ignore_equivalent_vectors=True)
nearest_ind = infidelities.argmin()
nearest_ind = int(infidelities.argmin())

success = infidelities[nearest_ind] < self.max_expected_infidelity

Expand Down Expand Up @@ -483,13 +483,13 @@ def two_qubit_gate_product_tabulation(
else:
missed_points.append(missing_vec)

kak_vecs = np.array(kak_vecs)
kak_vecs_arr = np.array(kak_vecs)
summary += (
f'\nFraction of Weyl chamber reached with 2 gates and 3 gates '
f'(after patchup)'
f': {(len(kak_vecs) - 1) / num_mesh_points :.3f}'
f': {(len(kak_vecs_arr) - 1) / num_mesh_points :.3f}'
)

return TwoQubitGateTabulation(
base_gate, kak_vecs, sq_cycles, max_infidelity, summary, tuple(missed_points)
base_gate, kak_vecs_arr, sq_cycles, max_infidelity, summary, tuple(missed_points)
)
7 changes: 2 additions & 5 deletions cirq/value/duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,8 @@ def __init__(
else:
raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')

self._picos: Union[float, int, sympy.Expr] = (
picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
)
if isinstance(self._picos, np.number):
self._picos = float(self._picos)
val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._picos)
Expand Down
17 changes: 11 additions & 6 deletions cirq/vis/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ def integrated_histogram(
if isinstance(data, Mapping):
data = list(data.values())

data = [d for d in data if not np.isnan(d)]
n = len(data)
float_data = [float(d) for d in data if not np.isnan(float(d))]

n = len(float_data)

if not show_zero:
bin_values = np.linspace(0, 1, n + 1)
parameter_values = sorted(np.concatenate(([0], data)))
parameter_values = sorted(np.concatenate(([0], float_data)))
else:
bin_values = np.linspace(0, 1, n)
parameter_values = sorted(data)
parameter_values = sorted(float_data)
plot_options = {"where": 'post', "color": 'b', "linestyle": '-', "lw": 1.0, "ms": 0.0}
plot_options.update(kwargs)

Expand Down Expand Up @@ -127,15 +128,19 @@ def integrated_histogram(

if median_line:
set_line(
np.median(data),
np.median(float_data),
linestyle='--',
color=plot_options['color'],
alpha=0.5,
label=median_label,
)
if mean_line:
set_line(
np.mean(data), linestyle='-.', color=plot_options['color'], alpha=0.5, label=mean_label
np.mean(float_data),
linestyle='-.',
color=plot_options['color'],
alpha=0.5,
label=mean_label,
)
if show_plot:
fig.show()
Expand Down
2 changes: 1 addition & 1 deletion cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def plot_state_histogram(
elif isinstance(data, collections.Counter):
tick_label, values = zip(*sorted(data.items()))
else:
values = data
values = np.array(data)
if not tick_label:
tick_label = np.arange(len(values))
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
Expand Down

0 comments on commit 4a369e2

Please sign in to comment.