Skip to content

Commit

Permalink
More numpy types (quantumlib#5683)
Browse files Browse the repository at this point in the history
12 errors left

Part of quantumlib#3767
  • Loading branch information
vtomole authored and rht committed May 1, 2023
1 parent 51ebbba commit 05790ed
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 33 deletions.
11 changes: 8 additions & 3 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Any,
Callable,
Mapping,
MutableSequence,
cast,
Dict,
FrozenSet,
Expand Down Expand Up @@ -1462,7 +1463,7 @@ def concat_ragged(

# Allocate a buffer large enough to append and prepend all the circuits.
pad_len = sum(len(c) for c in circuits) - n_acc
buffer = np.zeros(shape=pad_len * 2 + n_acc, dtype=object)
buffer: MutableSequence['cirq.Moment'] = [cirq.Moment()] * (pad_len * 2 + n_acc)

# Put the initial circuit in the center of the buffer.
offset = pad_len
Expand Down Expand Up @@ -1601,7 +1602,11 @@ def _overlap_collision_time(


def _concat_ragged_helper(
c1_offset: int, n1: int, buf: np.ndarray, c2: Sequence['cirq.Moment'], align: 'cirq.Alignment'
c1_offset: int,
n1: int,
buf: MutableSequence['cirq.Moment'],
c2: Sequence['cirq.Moment'],
align: 'cirq.Alignment',
) -> Tuple[int, int]:
n2 = len(c2)
shift = _overlap_collision_time(buf[c1_offset : c1_offset + n1], c2, align)
Expand Down Expand Up @@ -2369,7 +2374,7 @@ def _resolve_parameters_(
return Circuit(resolved_moments)

@property
def moments(self):
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def with_noise(self, noise: 'cirq.NOISE_MODEL_LIKE') -> 'cirq.Circuit':
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def repeat(
# As CircuitOperation is immutable, this can safely return the original.
return self

expected_repetition_id_length = abs(repetitions)
expected_repetition_id_length: int = np.abs(repetitions)

if repetition_ids is None:
if self.use_repetition_ids:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/quimb/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def circuit_to_tensors(

for moment in circuit.moments:
for op in moment.operations:
assert op.gate._has_unitary_()
assert cirq.has_unitary(op.gate)
start_inds = [f'i{qubit_frontier[q]}_q{q}' for q in op.qubits]
for q in op.qubits:
qubit_frontier[q] += 1
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ def _expectation_from_density_matrix_no_validation(

while any(result.shape):
result = np.trace(result, axis1=0, axis2=len(result.shape) // 2)
return result * self.coefficient

return float(result * self.coefficient)

def zip_items(
self, other: 'cirq.PauliString[TKey]'
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/circuit_diagram_info_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def format_radians(self, radians: Union[sympy.Basic, int, float]) -> str:
return '0'
if radians == -np.pi:
return '-' + unit
if self.precision is not None:
if self.precision is not None and not isinstance(radians, sympy.Basic):
quantity = self.format_real(radians / np.pi)
return quantity + unit
return repr(radians)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/protocols/trace_distance_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, TypeVar, Optional, Sequence
from typing import Any, TypeVar, Optional, Sequence, Union

import numpy as np
from typing_extensions import Protocol
Expand Down Expand Up @@ -109,7 +109,7 @@ def _strat_distance_from_unitary(val: Any) -> Optional[float]:
return trace_distance_from_angle_list(np.angle(np.linalg.eigvals(u)))


def trace_distance_from_angle_list(angle_list: Sequence[float]) -> float:
def trace_distance_from_angle_list(angle_list: Union[Sequence[float], np.ndarray]) -> float:
"""Given a list of arguments of the eigenvalues of a unitary matrix,
calculates the trace distance bound of the unitary effect.
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/qis/clifford_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cirq import protocols
from cirq._compat import proper_repr
from cirq.qis import quantum_state_representation
from cirq.value import big_endian_int_to_digits, linear_dict
from cirq.value import big_endian_int_to_digits, linear_dict, random_state

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -509,7 +509,7 @@ def destabilizers(self) -> List['cirq.DensePauliString']:
generators above generate the full Pauli group on n qubits."""
return [self._row_to_dense_pauli(i) for i in range(self.n)]

def _measure(self, q, prng: np.random.RandomState = np.random) -> int:
def _measure(self, q, prng: np.random.RandomState) -> int:
"""Performs a projective measurement on the q'th qubit.
Returns: the result (0 or 1) of the measurement.
Expand Down Expand Up @@ -651,6 +651,6 @@ def apply_global_phase(self, coefficient: linear_dict.Scalar):
pass

def measure(
self, axes: Sequence[int], seed: Optional['cirq.RANDOM_STATE_OR_SEED_LIKE'] = None
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
return [self._measure(axis, seed) for axis in axes]
return [self._measure(axis, random_state.parse_random_state(seed)) for axis in axes]
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import cirq
from cirq import protocols, qis, value
from cirq.value import big_endian_int_to_digits
from cirq.value import big_endian_int_to_digits, random_state


@value.value_equality
Expand Down Expand Up @@ -388,7 +388,7 @@ def apply_global_phase(self, coefficient: value.Scalar):
def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
return [self._measure(axis, seed) for axis in axes]
return [self._measure(axis, random_state.parse_random_state(seed)) for axis in axes]


def _phase(exponent, global_shift):
Expand Down
17 changes: 8 additions & 9 deletions cirq-core/cirq/sim/density_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def sample_density_matrix(
qid_shape = (2,) * num_qubits
else:
_validate_density_matrix_qid_shape(density_matrix, qid_shape)
num_qubits = len(qid_shape)
meas_shape = _indices_shape(qid_shape, indices)

if repetitions == 0 or len(indices) == 0:
Expand Down Expand Up @@ -139,16 +138,16 @@ def measure_density_matrix(
qid_shape = (2,) * num_qubits
else:
_validate_density_matrix_qid_shape(density_matrix, qid_shape)
num_qubits = len(qid_shape)
meas_shape = _indices_shape(qid_shape, indices)

arrout: np.ndarray = (
np.copy(density_matrix)
if out is None
else density_matrix
if out is density_matrix
else (np.copyto(dst=out, src=density_matrix), out)[-1]
)
arrout: np.ndarray
if out is None:
arrout = np.copy(density_matrix)
elif out is density_matrix:
arrout = density_matrix
else:
np.copyto(dst=out, src=density_matrix)
arrout = out

if len(indices) == 0:
return ([], arrout)
Expand Down
8 changes: 3 additions & 5 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects and methods for acting efficiently on a state vector."""

from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union

import numpy as np
Expand Down Expand Up @@ -225,13 +224,12 @@ def prepare_into_buffer(k: int):
e.reshape(shape * 2).astype(self._state_vector.dtype) for e in kraus_operators
]
p = prng.random()
weight = None
fallback_weight = 0
fallback_weight = 0.0
fallback_weight_index = 0
index = None

for index in range(len(kraus_tensors)):
prepare_into_buffer(index)
weight = np.linalg.norm(self._buffer) ** 2
weight = float(np.linalg.norm(self._buffer) ** 2)

if weight > fallback_weight:
fallback_weight_index = index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union, Sequence, Optional

import numpy as np
from cirq.value import random_state

_RealArraylike = Union[np.ndarray, float]

Expand Down Expand Up @@ -58,7 +59,7 @@ def random_qubit_unitary(
rng: Random number generator to be used in sampling. Default is
numpy.random.
"""
real_rng: np.random.RandomState = np.random if rng is None else rng
real_rng = random_state.parse_random_state(rng)

theta = np.arcsin(np.sqrt(real_rng.rand(*shape)))
phi_d = real_rng.rand(*shape) * np.pi * 2
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/value/type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Union

import sympy

from cirq._doc import document
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/vis/state_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def plot_state_histogram(
tick_label, values = zip(*sorted(data.items()))
else:
values = np.array(data)
if not tick_label:
tick_label = np.arange(len(values))
if tick_label is None:
tick_label = [str(i) for i in range(len(values))]
ax.bar(np.arange(len(values)), values, tick_label=tick_label)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
Expand Down
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/engine/virtual_engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def create_default_noisy_quantum_virtual_machine(
try: # coverage: ignore
import qsimcirq # type: ignore

simulator_class = qsimcirq.Simulator # coverage: ignore
simulator_class = qsimcirq.QSimSimulator # coverage: ignore
except ImportError:
simulator_class = cirq.Simulator # coverage: ignore

Expand Down

0 comments on commit 05790ed

Please sign in to comment.