From 8708b9a01c3d504ce8879b0e0d1cc0d8c1fd1717 Mon Sep 17 00:00:00 2001 From: "Antoine (Tony) Bruguier" Date: Tue, 25 Jan 2022 22:57:51 -0800 Subject: [PATCH] Increase atol for sub_state_vector() (#4877) Attempts to fix #4786 --- cirq-core/cirq/linalg/transformations.py | 2 +- cirq-core/cirq/linalg/transformations_test.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 12322fe60cf..44b7977de01 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -396,7 +396,7 @@ def sub_state_vector( keep_indices: List[int], *, default: np.ndarray = RaiseValueErrorIfNotProvided, - atol: Union[int, float] = 1e-8, + atol: Union[int, float] = 1e-6, ) -> np.ndarray: r"""Attempts to factor a state vector into two parts and return one of them. diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index db13af73c12..d01e34cb898 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -609,3 +609,22 @@ def test_to_special(): su = cirq.to_special(u) assert not cirq.is_special_unitary(u) assert cirq.is_special_unitary(su) + + +def test_default_tolerance(): + a, b = cirq.LineQubit.range(2) + final_state_vector = ( + cirq.Simulator() + .simulate( + cirq.Circuit( + cirq.H(a), + cirq.H(b), + cirq.CZ(a, b), + cirq.measure(a), + ) + ) + .final_state_vector.reshape((2, 2)) + ) + # Here, we do NOT specify the default tolerance. It is merely to check that the default value + # is reasonable. + cirq.sub_state_vector(final_state_vector, [0])