From b903e1bebcdafa714414162b0ed85a430a7adfa9 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 15 Mar 2022 13:48:10 -0700 Subject: [PATCH] Allow partial state vector function to handle qudits (#5077) * Allow partial state vector function to handle qudits * nits * check norm * nit * simplify kron * lint * nits Co-authored-by: Cirq Bot --- cirq-core/cirq/linalg/transformations.py | 55 ++++++++++--------- cirq-core/cirq/linalg/transformations_test.py | 53 ++++++++++++++---- cirq-core/cirq/sim/state_vector_test.py | 4 +- 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/cirq-core/cirq/linalg/transformations.py b/cirq-core/cirq/linalg/transformations.py index 516ba3f2134..614a34940f1 100644 --- a/cirq-core/cirq/linalg/transformations.py +++ b/cirq-core/cirq/linalg/transformations.py @@ -340,10 +340,10 @@ def partial_trace_of_state_vector_as_mixture( ) -> Tuple[Tuple[float, np.ndarray], ...]: """Returns a mixture representing a state vector with only some qubits kept. - The input state vector must have shape `(2,) * n` or `(2 ** n)` where - `state_vector` is expressed over n qubits. States in the output mixture will - retain the same type of shape as the input state vector, either `(2 ** k)` - or `(2,) * k` where k is the number of qubits kept. + The input state vector can have any shape, but if it is one-dimensional it + will be interpreted as qubits, since that is the most common case, and fail + if the dimension is not size `2 ** n`. States in the output mixture will + retain the same type of shape as the input state vector. If the state vector cannot be factored into a pure state over `keep_indices` then eigendecomposition is used and the output mixture will not be unique. @@ -361,31 +361,30 @@ def partial_trace_of_state_vector_as_mixture( partial trace. Raises: - ValueError: if the input `state_vector` is not an array of length - `(2 ** n)` or a tensor with a shape of `(2,) * n` + ValueError: If the input `state_vector` is one dimension, but that + dimension size is not a power of two. + IndexError: If any indexes are out of range. """ + if state_vector.ndim == 1: + dims = int(np.log2(state_vector.size)) + if 2 ** dims != state_vector.size: + raise ValueError(f'Cannot infer underlying shape of {state_vector.shape}.') + state_vector = state_vector.reshape((2,) * dims) + ret_shape: Tuple[int, ...] = (2 ** len(keep_indices),) + else: + ret_shape = tuple(state_vector.shape[i] for i in keep_indices) + # Attempt to do efficient state factoring. try: - state = sub_state_vector( - state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol - ) - return ((1.0, state),) + state, _ = factor_state_vector(state_vector, keep_indices, atol=atol) + return ((1.0, state.reshape(ret_shape)),) except EntangledStateError: pass # Fall back to a (non-unique) mixture representation. - keep_dims = 1 << len(keep_indices) - ret_shape: Union[Tuple[int], Tuple[int, ...]] - if state_vector.shape == (state_vector.size,): - ret_shape = (keep_dims,) - elif all(e == 2 for e in state_vector.shape): - ret_shape = tuple(2 for _ in range(len(keep_indices))) - - rho = np.kron(np.conj(state_vector.reshape(-1, 1)).T, state_vector.reshape(-1, 1)).reshape( - (2, 2) * int(np.log2(state_vector.size)) - ) - keep_rho = partial_trace(rho, keep_indices).reshape((keep_dims,) * 2) + rho = np.outer(state_vector, np.conj(state_vector)).reshape(state_vector.shape * 2) + keep_rho = partial_trace(rho, keep_indices).reshape((np.prod(ret_shape),) * 2) eigvals, eigvecs = np.linalg.eigh(keep_rho) mixture = tuple(zip(eigvals, [vec.reshape(ret_shape) for vec in eigvecs.T])) return tuple([(float(p[0]), p[1]) for p in mixture if not protocols.approx_eq(p[0], 0.0)]) @@ -436,8 +435,10 @@ def sub_state_vector( The state vector expressed over the desired subset of qubits. Raises: - ValueError: if the `state_vector` is not of the correct shape or the - indices are not a valid subset of the input `state_vector`'s indices + ValueError: If the `state_vector` is not of the correct shape or the + indices are not a valid subset of the input `state_vector`'s + indices. + IndexError: If any indexes are out of range. EntangledStateError: If the result of factoring is not a pure state and `default` is not provided. @@ -581,7 +582,9 @@ def factor_state_vector( same order as the original state vector. Raises: - ValueError: If the tensor cannot be factored along an axes. + EntangledStateError: If the tensor is already in entangled state, and + the validate flag is set. + ValueError: If the tensor factorization fails for any other reason. """ n_axes = len(axes) t1 = np.moveaxis(t, axes, range(n_axes)) @@ -595,7 +598,9 @@ def factor_state_vector( if validate: t2 = state_vector_kronecker_product(extracted, remainder) if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol): - raise ValueError('The tensor cannot be factored by the requested axes') + if not np.isclose(np.linalg.norm(t1), 1): + raise ValueError('Input state must be normalized.') + raise EntangledStateError('The tensor cannot be factored by the requested axes') return extracted, remainder diff --git a/cirq-core/cirq/linalg/transformations_test.py b/cirq-core/cirq/linalg/transformations_test.py index d01e34cb898..3303aec9935 100644 --- a/cirq-core/cirq/linalg/transformations_test.py +++ b/cirq-core/cirq/linalg/transformations_test.py @@ -492,24 +492,17 @@ def test_partial_trace_of_state_vector_as_mixture_invalid_input(): with pytest.raises(ValueError, match='7'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(7), [1, 2], atol=1e-8) - bad_shape = np.arange(16).reshape((2, 4, 2)) - with pytest.raises(ValueError, match='shaped'): - cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8) - bad_shape = np.arange(16).reshape((16, 1)) - with pytest.raises(ValueError, match='shaped'): - cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8) - with pytest.raises(ValueError, match='normalized'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(8), [1], atol=1e-8) state = np.arange(8) / np.linalg.norm(np.arange(8)) - with pytest.raises(ValueError, match='2, 2'): + with pytest.raises(ValueError, match='repeated axis'): cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2, 2], atol=1e-8) state = np.array([1, 0, 0, 0]).reshape((2, 2)) - with pytest.raises(ValueError, match='invalid'): + with pytest.raises(IndexError, match='out of range'): cirq.partial_trace_of_state_vector_as_mixture(state, [5], atol=1e-8) - with pytest.raises(ValueError, match='invalid'): + with pytest.raises(IndexError, match='out of range'): cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 2], atol=1e-8) @@ -576,6 +569,38 @@ def test_partial_trace_of_state_vector_as_mixture_pure_result(): ) +def test_partial_trace_of_state_vector_as_mixture_pure_result_qudits(): + a = cirq.testing.random_superposition(2) + b = cirq.testing.random_superposition(3) + c = cirq.testing.random_superposition(4) + state = np.kron(np.kron(a, b), c).reshape((2, 3, 4)) + + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [0], atol=1e-8), + ((1.0, a),), + ) + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [1], atol=1e-8), + ((1.0, b),), + ) + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [2], atol=1e-8), + ((1.0, c),), + ) + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1], atol=1e-8), + ((1.0, np.kron(a, b).reshape((2, 3))),), + ) + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [0, 2], atol=1e-8), + ((1.0, np.kron(a, c).reshape((2, 4))),), + ) + assert mixtures_equal( + cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2], atol=1e-8), + ((1.0, np.kron(b, c).reshape((3, 4))),), + ) + + def test_partial_trace_of_state_vector_as_mixture_mixed_result(): state = np.array([1, 0, 0, 1]) / np.sqrt(2) truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1]))) @@ -604,6 +629,14 @@ def test_partial_trace_of_state_vector_as_mixture_mixed_result(): assert mixtures_equal(mixture, truth) +def test_partial_trace_of_state_vector_as_mixture_mixed_result_qudits(): + state = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) / np.sqrt(2) + truth = ((0.5, np.array([1, 0, 0])), (0.5, np.array([0, 0, 1]))) + for q1 in [0, 1]: + mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8) + assert mixtures_equal(mixture, truth) + + def test_to_special(): u = cirq.testing.random_unitary(4) su = cirq.to_special(u) diff --git a/cirq-core/cirq/sim/state_vector_test.py b/cirq-core/cirq/sim/state_vector_test.py index 40f4e57c570..b0af39559cc 100644 --- a/cirq-core/cirq/sim/state_vector_test.py +++ b/cirq-core/cirq/sim/state_vector_test.py @@ -385,7 +385,7 @@ def test_factor_validation(): cirq.linalg.transformations.factor_state_vector(t, [1]) args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1))) t = args.create_merged_state().target_tensor - with pytest.raises(ValueError, match='factor'): + with pytest.raises(cirq.linalg.transformations.EntangledStateError): cirq.linalg.transformations.factor_state_vector(t, [0]) - with pytest.raises(ValueError, match='factor'): + with pytest.raises(cirq.linalg.transformations.EntangledStateError): cirq.linalg.transformations.factor_state_vector(t, [1])