diff --git a/cirq/__init__.py b/cirq/__init__.py index 0e1c8bf5474..2c2f5d0f571 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -246,6 +246,7 @@ chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, Duration, + LinearDict, PeriodicValue, Timestamp, validate_probability, diff --git a/cirq/value/__init__.py b/cirq/value/__init__.py index 0511520f49b..c6c6643a19c 100644 --- a/cirq/value/__init__.py +++ b/cirq/value/__init__.py @@ -20,6 +20,9 @@ from cirq.value.duration import ( Duration, ) +from cirq.value.linear_dict import ( + LinearDict, +) from cirq.value.probability import ( validate_probability, ) diff --git a/cirq/value/linear_dict.py b/cirq/value/linear_dict.py new file mode 100644 index 00000000000..30498da2f9e --- /dev/null +++ b/cirq/value/linear_dict.py @@ -0,0 +1,266 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Linear combination represented as mapping of things to coefficients.""" + +from typing import (Any, Callable, Dict, ItemsView, Iterable, Iterator, + KeysView, Mapping, overload, Tuple, TypeVar, Union, + ValuesView) + +Scalar = Union[complex, float] +TVector = TypeVar('TVector') + +_TDefault = TypeVar('_TDefault') + + +class LinearDict(Dict[TVector, Scalar]): + """Represents linear combination of things. + + LinearDict implements the basic linear algebraic operations of vector + addition and scalar multiplication for linear combinations of abstract + vectors. Keys represent the vectors, values represent their coefficients. + The only requirement on the keys is that they be hashable (i.e. are + immutable and implement __hash__ and __eq__ with equal objects hashing + to equal values). + + A consequence of treating keys as opaque is that all relationships between + the keys other than equality are ignored. In particular, keys are allowed + to be linearly dependent. + """ + def __init__(self, terms: Mapping[TVector, Scalar]) -> None: + """Initializes linear combination from a collection of terms. + + Args: + terms: Mapping of abstract vectors to coefficients in the linear + combination being initialized. + """ + super().__init__() + self.update(terms) + + @classmethod + def fromkeys(cls, vectors, coefficient=0): + return LinearDict(dict.fromkeys(vectors, complex(coefficient))) + + def clean(self, *, atol: float=1e-9) -> 'LinearDict': + """Remove terms with coefficients of absolute value atol or less.""" + negligible = [v for v, c in super().items() if abs(c) <= atol] + for v in negligible: + del self[v] + return self + + def copy(self) -> 'LinearDict': + return LinearDict(super().copy()) + + def keys(self) -> KeysView[TVector]: + snapshot = self.copy().clean(atol=0) + return super(LinearDict, snapshot).keys() + + def values(self) -> ValuesView[Scalar]: + snapshot = self.copy().clean(atol=0) + return super(LinearDict, snapshot).values() + + def items(self) -> ItemsView[TVector, Scalar]: + snapshot = self.copy().clean(atol=0) + return super(LinearDict, snapshot).items() + + # pylint: disable=function-redefined + @overload + def update(self, other: Mapping[TVector, Scalar], **kwargs: Scalar) -> None: + pass + + @overload + def update(self, + other: Iterable[Tuple[TVector, Scalar]], + **kwargs: Scalar) -> None: + pass + + @overload + def update(self, *args: Any, **kwargs: Scalar) -> None: + pass + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + self.clean(atol=0) + + @overload + def get(self, vector: TVector) -> Scalar: + pass + + @overload + def get(self, vector: TVector, default: _TDefault + ) -> Union[Scalar, _TDefault]: + pass + + def get(self, vector, default=0): + if super().get(vector, 0) == 0: + return default + return super().get(vector) + # pylint: enable=function-redefined + + def __contains__(self, vector: Any) -> bool: + return super().__contains__(vector) and super().__getitem__(vector) != 0 + + def __getitem__(self, vector: TVector) -> Scalar: + return super().get(vector, 0) + + def __setitem__(self, vector: TVector, coefficient: Scalar) -> None: + if coefficient != 0: + super().__setitem__(vector, coefficient) + return + if super().__contains__(vector): + super().__delitem__(vector) + + def __iter__(self) -> Iterator[TVector]: + snapshot = self.copy().clean(atol=0) + return super(LinearDict, snapshot).__iter__() + + def __len__(self) -> int: + return len([v for v, c in self.items() if c != 0]) + + def __iadd__(self, other: 'LinearDict') -> 'LinearDict': + for vector, other_coefficient in other.items(): + old_coefficient = super().get(vector, 0) + new_coefficient = old_coefficient + other_coefficient + super().__setitem__(vector, new_coefficient) + self.clean(atol=0) + return self + + def __add__(self, other: 'LinearDict') -> 'LinearDict': + result = self.copy() + result += other + return result + + def __isub__(self, other: 'LinearDict') -> 'LinearDict': + for vector, other_coefficient in other.items(): + old_coefficient = super().get(vector, 0) + new_coefficient = old_coefficient - other_coefficient + super().__setitem__(vector, new_coefficient) + self.clean(atol=0) + return self + + def __sub__(self, other: 'LinearDict') -> 'LinearDict': + result = self.copy() + result -= other + return result + + def __neg__(self) -> 'LinearDict': + return LinearDict({v: -c for v, c in self.items()}) + + def __imul__(self, a: Scalar) -> 'LinearDict': + for vector in self: + self[vector] *= a + self.clean(atol=0) + return self + + def __mul__(self, a: Scalar) -> 'LinearDict': + result = self.copy() + result *= a + return result + + def __rmul__(self, a: Scalar) -> 'LinearDict': + return self.__mul__(a) + + def __truediv__(self, a: Scalar) -> 'LinearDict': + return self.__mul__(1 / a) + + def __bool__(self) -> bool: + return not all(c == 0 for c in self.values()) + + def __eq__(self, other: Any) -> bool: + """Checks whether two linear combinations are exactly equal. + + Presence or absence of terms with coefficients exactly equal to + zero does not affect outcome. + + Not appropriate for most practical purposes due to sensitivity to + numerical error in floating point coefficients. Use cirq.approx_eq() + instead. + """ + if not isinstance(other, LinearDict): + return NotImplemented + + all_vs = set(self.keys()) | set(other.keys()) + return all(self[v] == other[v] for v in all_vs) + + def __ne__(self, other: Any) -> bool: + """Checks whether two linear combinations are not exactly equal. + + See __eq__(). + """ + if not isinstance(other, LinearDict): + return NotImplemented + + return not self == other + + def _approx_eq_(self, other: Any, atol: float) -> bool: + """Checks whether two linear combinations are approximately equal.""" + if not isinstance(other, LinearDict): + return NotImplemented + + all_vs = set(self.keys()) | set(other.keys()) + return all(abs(self[v] - other[v]) < atol for v in all_vs) + + @staticmethod + def _format_coefficient(format_spec: str, coefficient: Scalar) -> str: + coefficient = complex(coefficient) + real_str = '{:{fmt}}'.format(coefficient.real, fmt=format_spec) + imag_str = '{:{fmt}}'.format(coefficient.imag, fmt=format_spec) + if float(real_str) == 0 and float(imag_str) == 0: + return '' + if float(imag_str) == 0: + return real_str + if float(real_str) == 0: + return imag_str + 'j' + if real_str[0] == '-' and imag_str[0] == '-': + return '-({}+{}j)'.format(real_str[1:], imag_str[1:]) + if imag_str[0] in ['+', '-']: + return '({}{}j)'.format(real_str, imag_str) + return '({}+{}j)'.format(real_str, imag_str) + + @staticmethod + def _format_term(format_spec: str, + vector: TVector, + coefficient: Scalar) -> str: + coefficient_str = LinearDict._format_coefficient( + format_spec, coefficient) + if not coefficient_str: + return coefficient_str + result = '{}*{!s}'.format(coefficient_str, vector) + if result[0] in ['+', '-']: + return result + return '+' + result + + def __format__(self, format_spec: str) -> str: + formatted_terms = [self._format_term(format_spec, v, self[v]) + for v in sorted(self.keys())] + s = ''.join(formatted_terms) + if not s: + return '{:{fmt}}'.format(0, fmt=format_spec) + if s[0] == '+': + return s[1:] + return s + + def __repr__(self) -> str: + coefficients = dict(self) + return 'cirq.LinearDict({!r})'.format(coefficients) + + def __str__(self): + return self.__format__('.3f') + + def _repr_pretty_(self, p: Any, cycle: bool) -> None: + if cycle: + p.text('LinearDict(...)') + else: + p.text(str(self)) + diff --git a/cirq/value/linear_dict_test.py b/cirq/value/linear_dict_test.py new file mode 100644 index 00000000000..fb0737093a7 --- /dev/null +++ b/cirq/value/linear_dict_test.py @@ -0,0 +1,436 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import cirq + + +@pytest.mark.parametrize('keys, coefficient, terms_expected', ( + ((), 10, {}), + (('X',), 2, {'X': 2}), + (('a', 'b', 'c', 'd'), 0.5, {'a': 0.5, 'b': 0.5, 'c': 0.5, 'd': 0.5}), + (('b', 'c', 'd', 'e'), -2j, {'b': -2j, 'c': -2j, 'd': -2j, 'e': -2j}), +)) +def test_fromkeys(keys, coefficient, terms_expected): + actual = cirq.LinearDict.fromkeys(keys, coefficient) + expected = cirq.LinearDict(terms_expected) + assert actual == expected + assert expected == actual + + +@pytest.mark.parametrize('terms, atol, terms_expected', ( + ({'X': 1, 'Y': 2, 'Z': 3}, 2, {'Z': 3}), + ({'X': 0.1, 'Y': 1, 'Z': 10}, 1e-3, {'X': 0.1, 'Y': 1, 'Z': 10}), + ({'X': 1e-10, 'H': 1e-11}, 1e-9, {}), + ({}, 1, {}), +)) +def test_clean(terms, atol, terms_expected): + linear_dict = cirq.LinearDict(terms) + linear_dict.clean(atol=atol) + expected = cirq.LinearDict(terms_expected) + assert linear_dict == expected + assert expected == linear_dict + + +@pytest.mark.parametrize('terms', ( + {'X': 1j/2}, {'X': 1, 'Y': 2, 'Z': 3}, {}, +)) +def test_copy(terms): + linear_dict = cirq.LinearDict(terms) + assert type(linear_dict.copy()) == cirq.LinearDict + assert linear_dict.copy() == linear_dict + assert linear_dict == linear_dict.copy() + + +@pytest.mark.parametrize('terms, expected_keys', ( + ({}, ()), + ({'X': 0}, ()), + ({'X': 0.1}, ('X',)), + ({'X': -1, 'Y': 0, 'Z': 1}, ('X', 'Z')), +)) +def test_keys(terms, expected_keys): + linear_dict = cirq.LinearDict(terms) + assert tuple(sorted(linear_dict.keys())) == expected_keys + + +@pytest.mark.parametrize('terms, expected_values', ( + ({}, ()), + ({'X': 0}, ()), + ({'X': 0.1}, (0.1,)), + ({'X': -1, 'Y': 0, 'Z': 1}, (-1, 1)), +)) +def test_values(terms, expected_values): + linear_dict = cirq.LinearDict(terms) + assert tuple(sorted(linear_dict.values())) == expected_values + + +@pytest.mark.parametrize('terms, expected_items', ( + ({}, ()), + ({'X': 0}, ()), + ({'X': 0.1}, (('X', 0.1),)), + ({'X': -1, 'Y': 0, 'Z': 1}, (('X', -1), ('Z', 1))), +)) +def test_items(terms, expected_items): + linear_dict = cirq.LinearDict(terms) + assert tuple(sorted(linear_dict.items())) == expected_items + + +@pytest.mark.parametrize('terms_1, terms_2, terms_expected', ( + ({}, {}, {}), + ({}, {'X': 0.1}, {'X': 0.1}), + ({'X': 1}, {'Y': 2}, {'X': 1, 'Y': 2}), + ({'X': 1}, {'X': 4}, {'X': 4}), + ({'X': 1, 'Y': 2}, {'Y': -2}, {'X': 1, 'Y': -2}), +)) +def test_update(terms_1, terms_2, terms_expected): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + linear_dict_1.update(linear_dict_2) + expected = cirq.LinearDict(terms_expected) + assert linear_dict_1 == expected + assert expected == linear_dict_1 + + +@pytest.mark.parametrize('terms, vector, expected_coefficient', ( + ({}, '', 0), + ({}, 'X', 0), + ({'X': 0}, 'X', 0), + ({'X': -1j}, 'X', -1j), + ({'X': 1j}, 'Y', 0), +)) +def test_get(terms, vector, expected_coefficient): + linear_dict = cirq.LinearDict(terms) + actual_coefficient = linear_dict.get(vector) + assert actual_coefficient == expected_coefficient + + +@pytest.mark.parametrize('terms, vector, expected', ( + ({}, 'X', False), + ({'X': 0}, 'X', False), + ({'X': 0.1}, 'X', True), + ({'X': 1, 'Y': -1}, 'Y', True), +)) +def test_contains(terms, vector, expected): + linear_dict = cirq.LinearDict(terms) + actual = vector in linear_dict + assert actual == expected + + +@pytest.mark.parametrize('terms, vector, expected_coefficient', ( + ({}, 'X', 0), + ({'X': 1}, 'X', 1), + ({'Y': 1}, 'X', 0), + ({'X': 2, 'Y': 3}, 'X', 2), + ({'X': 1, 'Y': 2}, 'Z', 0), +)) +def test_getitem(terms, vector, expected_coefficient): + linear_dict = cirq.LinearDict(terms) + actual_coefficient = linear_dict[vector] + assert actual_coefficient == expected_coefficient + + +@pytest.mark.parametrize('terms, vector, coefficient, terms_expected', ( + ({}, 'X', 0, {}), + ({}, 'X', 1, {'X': 1}), + ({'X': 1}, 'X', 2, {'X': 2}), + ({'X': 1, 'Y': 3}, 'X', 2, {'X': 2, 'Y': 3}), + ({'X': 1, 'Y': 2}, 'X', 0, {'Y': 2}), +)) +def test_setitem(terms, vector, coefficient, terms_expected): + linear_dict = cirq.LinearDict(terms) + linear_dict[vector] = coefficient + expected = cirq.LinearDict(terms_expected) + assert linear_dict == expected + assert expected == linear_dict + + +def test_addition_in_iteration(): + linear_dict = cirq.LinearDict({'a': 2, 'b': 1, 'c': 0, 'd': -1, 'e': -2}) + for v in linear_dict: + linear_dict[v] += 1 + assert linear_dict == cirq.LinearDict( + {'a': 3, 'b': 2, 'c': 0, 'd': 0, 'e': -1}) + assert linear_dict == cirq.LinearDict({'a': 3, 'b': 2, 'e': -1}) + + +def test_multiplication_in_iteration(): + linear_dict = cirq.LinearDict({'u': 2, 'v': 1, 'w': -1}) + for v, c in linear_dict.items(): + if c > 0: + linear_dict[v] *= 0 + assert linear_dict == cirq.LinearDict({'u': 0, 'v': 0, 'w': -1}) + assert linear_dict == cirq.LinearDict({'w': -1}) + + +@pytest.mark.parametrize('terms, expected_length', ( + ({}, 0), + ({'X': 0}, 0), + ({'X': 0.1}, 1), + ({'X': 1, 'Y': -2j}, 2), + ({'X': 0, 'Y': 1}, 1) +)) +def test_len(terms, expected_length): + linear_dict = cirq.LinearDict(terms) + assert len(linear_dict) == expected_length + + +@pytest.mark.parametrize('terms_1, terms_2, terms_expected', ( + ({}, {}, {}), + ({}, {'X': 0.1}, {'X': 0.1}), + ({'X': 1}, {'Y': 2}, {'X': 1, 'Y': 2}), + ({'X': 1}, {'X': 1}, {'X': 2}), + ({'X': 1, 'Y': 2}, {'Y': -2}, {'X': 1}), +)) +def test_vector_addition(terms_1, terms_2, terms_expected): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + actual_1 = linear_dict_1 + linear_dict_2 + actual_2 = linear_dict_1 + actual_2 += linear_dict_2 + expected = cirq.LinearDict(terms_expected) + assert actual_1 == expected + assert actual_2 == expected + assert actual_1 == actual_2 + + +@pytest.mark.parametrize('terms_1, terms_2, terms_expected', ( + ({}, {}, {}), + ({'a': 2}, {'a': 2}, {}), + ({'a': 3}, {'a': 2}, {'a': 1}), + ({'X': 1}, {'Y': 2}, {'X': 1, 'Y': -2}), + ({'X': 1}, {'X': 1}, {}), + ({'X': 1, 'Y': 2}, {'Y': 2}, {'X': 1}), + ({'X': 1, 'Y': 2}, {'Y': 3}, {'X': 1, 'Y': -1}), +)) +def test_vector_subtraction(terms_1, terms_2, terms_expected): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + actual_1 = linear_dict_1 - linear_dict_2 + actual_2 = linear_dict_1 + actual_2 -= linear_dict_2 + expected = cirq.LinearDict(terms_expected) + assert actual_1 == expected + assert actual_2 == expected + assert actual_1 == actual_2 + + +@pytest.mark.parametrize('terms, terms_expected', ( + ({}, {}), + ({'key': 1}, {'key': -1}), + ({'1': 10, '2': -20}, {'1': -10, '2': 20}), +)) +def test_vector_negation(terms, terms_expected): + linear_dict = cirq.LinearDict(terms) + actual = -linear_dict + expected = cirq.LinearDict(terms_expected) + assert actual == expected + assert expected == actual + + +@pytest.mark.parametrize('scalar, terms, terms_expected', ( + (2, {}, {}), + (2, {'X': 1, 'Y': -2}, {'X': 2, 'Y': -4}), + (0, {'abc': 10, 'def': 20}, {}), + (1j, {'X': 4j}, {'X': -4}), + (-1, {'a': 10, 'b': -20}, {'a': -10, 'b': 20}), +)) +def test_scalar_multiplication(scalar, terms, terms_expected): + linear_dict = cirq.LinearDict(terms) + actual_1 = scalar * linear_dict + actual_2 = linear_dict * scalar + expected = cirq.LinearDict(terms_expected) + assert actual_1 == expected + assert actual_2 == expected + assert actual_1 == actual_2 + + +@pytest.mark.parametrize('scalar, terms, terms_expected', ( + (2, {}, {}), + (2, {'X': 6, 'Y': -2}, {'X': 3, 'Y': -1}), + (1j, {'X': 1, 'Y': 1j}, {'X': -1j, 'Y': 1}), + (-1, {'a': 10, 'b': -20}, {'a': -10, 'b': 20}), +)) +def test_scalar_division(scalar, terms, terms_expected): + linear_dict = cirq.LinearDict(terms) + actual = linear_dict / scalar + expected = cirq.LinearDict(terms_expected) + assert actual == expected + assert expected == actual + + +@pytest.mark.parametrize('expression, expected', ( + ((cirq.LinearDict({'X': 10}) + cirq.LinearDict({'X': 10, 'Y': -40})) / 20, + cirq.LinearDict({'X': 1, 'Y': -2})), + (cirq.LinearDict({'a': -2}) + 2 * cirq.LinearDict({'a': 1}), + cirq.LinearDict({})), + (cirq.LinearDict({'b': 2}) - 2 * cirq.LinearDict({'b': 1}), + cirq.LinearDict({})), +)) +def test_expressions(expression, expected): + assert expression == expected + assert not expression != expected + assert cirq.approx_eq(expression, expected) + + +@pytest.mark.parametrize('terms, bool_value', ( + ({}, False), + ({'X': 0}, False), + ({'Z': 1e-12}, True), + ({'Y': 1}, True), +)) +def test_bool(terms, bool_value): + linear_dict = cirq.LinearDict(terms) + assert bool(linear_dict) == bool_value + + +@pytest.mark.parametrize('terms_1, terms_2', ( + ({}, {}), + ({}, {'X': 0}), + ({'X': 0.0}, {'Y': 0.0}), + ({'a': 1}, {'a': 1, 'b': 0}), +)) +def test_equal(terms_1, terms_2): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + assert linear_dict_1 == linear_dict_2 + assert linear_dict_2 == linear_dict_1 + assert not linear_dict_1 != linear_dict_2 + assert not linear_dict_2 != linear_dict_1 + + +@pytest.mark.parametrize('terms_1, terms_2', ( + ({}, {'a': 1}), + ({'X': 1e-12}, {'X': 0}), + ({'X': 0.0}, {'Y': 0.1}), + ({'X': 1}, {'X': 1, 'Z': 1e-12}), +)) +def test_unequal(terms_1, terms_2): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + assert linear_dict_1 != linear_dict_2 + assert linear_dict_2 != linear_dict_1 + assert not linear_dict_1 == linear_dict_2 + assert not linear_dict_2 == linear_dict_1 + + +@pytest.mark.parametrize('terms_1, terms_2', ( + ({}, {'X': 1e-9}), + ({'X': 1e-12}, {'X': 0}), + ({'X': 5e-10}, {'Y': 2e-11}), + ({'X': 1.000000001}, {'X': 1, 'Z': 0}), +)) +def test_approximately_equal(terms_1, terms_2): + linear_dict_1 = cirq.LinearDict(terms_1) + linear_dict_2 = cirq.LinearDict(terms_2) + assert cirq.approx_eq(linear_dict_1, linear_dict_2) + assert cirq.approx_eq(linear_dict_2, linear_dict_1) + + +@pytest.mark.parametrize('a, b', ( + (cirq.LinearDict({}), None), + (cirq.LinearDict({'X': 0}), 0), + (cirq.LinearDict({'Y': 1}), 1), + (cirq.LinearDict({'Z': 1}), 1j), + (cirq.LinearDict({'I': 1}), 'I'), +)) +def test_incomparable(a, b): + assert a.__eq__(b) is NotImplemented + assert a.__ne__(b) is NotImplemented + assert a._approx_eq_(b, atol=1e-9) is NotImplemented + + +@pytest.mark.parametrize('terms, fmt, expected_string', ( + ({}, '{}', '0'), + ({}, '{:.2f}', '0.00'), + ({}, '{:.2e}', '0.00e+00'), + ({'X': 2**-10}, '{:.2f}', '0.00'), + ({'X': 1/100}, '{:.2e}', '1.00e-02*X'), + ({'X': 1j*2**-10}, '{:.2f}', '0.00'), + ({'X': 1j*2**-10}, '{:.3f}', '0.001j*X'), + ({'X': 2j, 'Y': -3}, '{:.2f}', '2.00j*X-3.00*Y'), + ({'X': -2j, 'Y': 3}, '{:.2f}', '-2.00j*X+3.00*Y'), + ({'X': np.sqrt(1j)}, '{:.3f}', '(0.707+0.707j)*X'), + ({'X': np.sqrt(-1j)}, '{:.3f}', '(0.707-0.707j)*X'), + ({'X': -np.sqrt(-1j)}, '{:.3f}', '(-0.707+0.707j)*X'), + ({'X': -np.sqrt(1j)}, '{:.3f}', '-(0.707+0.707j)*X'), + ({'X': 1, 'Y': -1, 'Z': 1j}, '{:.5f}', '1.00000*X-1.00000*Y+1.00000j*Z'), + ({'X': 2, 'Y': -0.0001}, '{:.4f}', '2.0000*X-0.0001*Y'), + ({'X': 2, 'Y': -0.0001}, '{:.3f}', '2.000*X'), + ({'X': 2, 'Y': -0.0001}, '{:.1e}', '2.0e+00*X-1.0e-04*Y'), +)) +def test_format(terms, fmt, expected_string): + linear_dict = cirq.LinearDict(terms) + actual_string = fmt.format(linear_dict) + assert actual_string.replace(' ', '') == expected_string.replace(' ', '') + + +@pytest.mark.parametrize('terms', ( + ({}, {'X': 1}, {'X': 2, 'Y': 3}, {'X': 1.23456789e-12}) +)) +def test_repr(terms): + original = cirq.LinearDict(terms) + print(repr(original)) + recovered = eval(repr(original)) + assert original == recovered + assert recovered == original + + +@pytest.mark.parametrize('terms, string', ( + ({}, '0.000'), + ({'X': 1.5, 'Y': 1e-5}, '1.500*X'), + ({'Y': 2}, '2.000*Y'), + ({'X': 1, 'Y': -1j}, '1.000*X-1.000j*Y'), + ({'X': np.sqrt(3)/3, 'Y': np.sqrt(3)/3, 'Z': np.sqrt(3)/3}, + '0.577*X+0.577*Y+0.577*Z'), + ({'I': np.sqrt(1j)}, '(0.707+0.707j)*I'), + ({'X': np.sqrt(-1j)}, '(0.707-0.707j)*X'), + ({'X': -np.sqrt(-1j)}, '(-0.707+0.707j)*X'), + ({'X': -np.sqrt(1j)}, '-(0.707+0.707j)*X'), + ({'X': -2, 'Y': -3}, '-2.000*X-3.000*Y'), + ({'X': -2j, 'Y': -3}, '-2.000j*X-3.000*Y'), + ({'X': -2j, 'Y': -3j}, '-2.000j*X-3.000j*Y'), +)) +def test_str(terms, string): + linear_dict = cirq.LinearDict(terms) + assert str(linear_dict).replace(' ', '') == string.replace(' ', '') + + +class FakePrinter: + def __init__(self): + self.buffer = '' + + def text(self, s: str) -> None: + self.buffer += s + + def reset(self) -> None: + self.buffer = '' + + +@pytest.mark.parametrize('terms', ( + {}, {'Y': 2}, {'X': 1, 'Y': -1j}, + {'X': np.sqrt(3)/3, 'Y': np.sqrt(3)/3, 'Z': np.sqrt(3)/3}, + {'I': np.sqrt(1j)}, {'X': np.sqrt(-1j)}, +)) +def test_repr_pretty(terms): + printer = FakePrinter() + linear_dict = cirq.LinearDict(terms) + + linear_dict._repr_pretty_(printer, False) + assert printer.buffer.replace(' ', '') == str(linear_dict).replace(' ', '') + + printer.reset() + linear_dict._repr_pretty_(printer, True) + assert printer.buffer == 'LinearDict(...)'