From 77ab09c9b470e6a7994f70ac75106d1d1431ccd7 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Thu, 7 Apr 2022 14:30:56 -0700 Subject: [PATCH 1/2] Allow specifying initial state/matrix tensor in to_valid_density_matrix This changes cirq.to_valid_density_matrix to allow converting a state or density matrix in tensor form (shape == (2,) * num_qubits or shape == (2,) * (2 * num_qubits), respectively). Also fixes the ActOnDensityMatrixArgs construction logic to always call into cirq.to_valid_density_matrix so that these forms are also accepted as the initial_state when doing a density matrix simulation. Fixes #3958 --- cirq-core/cirq/qis/states.py | 11 +++-- cirq-core/cirq/qis/states_test.py | 27 ++++++++++- .../cirq/sim/act_on_density_matrix_args.py | 6 ++- .../sim/act_on_density_matrix_args_test.py | 47 +++++++++++++++++++ 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/qis/states.py b/cirq-core/cirq/qis/states.py index 8120fc0c8e5..38b7aab179d 100644 --- a/cirq-core/cirq/qis/states.py +++ b/cirq-core/cirq/qis/states.py @@ -26,6 +26,7 @@ Union, ) import itertools +import math import numpy as np @@ -968,9 +969,13 @@ def to_valid_density_matrix( ValueError if the density_matrix_rep is not valid. """ qid_shape = _qid_shape_from_args(num_qubits, qid_shape) - if isinstance(density_matrix_rep, np.ndarray) and density_matrix_rep.ndim == 2: - validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol) - return density_matrix_rep + if isinstance(density_matrix_rep, np.ndarray): + N = math.prod(qid_shape) + if len(qid_shape) > 1 and density_matrix_rep.shape == qid_shape * 2: + density_matrix_rep = density_matrix_rep.reshape((N, N)) + if density_matrix_rep.shape == (N, N): + validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol) + return density_matrix_rep state_vector = to_valid_state_vector( density_matrix_rep, len(qid_shape), qid_shape=qid_shape, dtype=dtype diff --git a/cirq-core/cirq/qis/states_test.py b/cirq-core/cirq/qis/states_test.py index 518fd51ec48..f4073b48715 100644 --- a/cirq-core/cirq/qis/states_test.py +++ b/cirq-core/cirq/qis/states_test.py @@ -607,6 +607,21 @@ def test_to_valid_density_matrix_from_density_matrix(): assert_valid_density_matrix(np.diag([0.2, 0.8, 0, 0]), qid_shape=(4,)) +def test_to_valid_density_matrix_from_density_matrix_tensor(): + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + cirq.one_hot(shape=(2, 2, 2, 2, 2, 2), dtype=np.complex64), num_qubits=3 + ), + cirq.one_hot(shape=(8, 8), dtype=np.complex64), + ) + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + cirq.one_hot(shape=(2, 3, 4, 2, 3, 4), dtype=np.complex64), qid_shape=(2, 3, 4) + ), + cirq.one_hot(shape=(24, 24), dtype=np.complex64), + ) + + def test_to_valid_density_matrix_not_square(): with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.array([[1], [0]]), num_qubits=1) @@ -614,7 +629,7 @@ def test_to_valid_density_matrix_not_square(): def test_to_valid_density_matrix_size_mismatch_num_qubits(): with pytest.raises(ValueError, match='shape'): - cirq.to_valid_density_matrix(np.array([[1, 0], [0, 0]]), num_qubits=2) + cirq.to_valid_density_matrix(np.array([[[1, 0], [0, 0]], [[0, 0], [0, 0]]]), num_qubits=2) with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.eye(4) / 4.0, num_qubits=1) @@ -690,6 +705,16 @@ def test_to_valid_density_matrix_from_state_vector(): ) +def test_to_valid_density_matrix_from_state_vector_tensor(): + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + density_matrix_rep=np.array(np.full((2, 2), 0.5), dtype=np.complex64), + num_qubits=2, + ), + 0.25 * np.ones((4, 4)), + ) + + def test_to_valid_density_matrix_from_state_invalid_state(): with pytest.raises(ValueError, match="Invalid quantum state"): cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2) diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index aa509b192cf..873549a39fd 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -79,7 +79,11 @@ def create( ).reshape(qid_shape * 2) else: if qid_shape is not None: - density_matrix = initial_state.reshape(qid_shape * 2) + if dtype and initial_state.dtype != dtype: + initial_state = initial_state.astype(dtype) + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) else: density_matrix = initial_state if np.may_share_memory(density_matrix, initial_state): diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py index 950c58cabff..32c73b92683 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py @@ -105,3 +105,50 @@ def test_with_qubits(): def test_qid_shape_error(): with pytest.raises(ValueError, match="qid_shape must be provided"): cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0) + + +def test_initial_state_vector(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_matrix(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_bad_shape(): + qubits = cirq.LineQubit.range(3) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64 + ) + + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64 + ) From 2ca9bbca20856bc1c461ffae48fab173b66a3ffa Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Thu, 7 Apr 2022 22:53:43 -0700 Subject: [PATCH 2/2] Use np.prod instead of math.prod --- cirq-core/cirq/qis/states.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/qis/states.py b/cirq-core/cirq/qis/states.py index 38b7aab179d..65f13562f20 100644 --- a/cirq-core/cirq/qis/states.py +++ b/cirq-core/cirq/qis/states.py @@ -26,7 +26,6 @@ Union, ) import itertools -import math import numpy as np @@ -970,7 +969,7 @@ def to_valid_density_matrix( """ qid_shape = _qid_shape_from_args(num_qubits, qid_shape) if isinstance(density_matrix_rep, np.ndarray): - N = math.prod(qid_shape) + N = np.prod(qid_shape, dtype=np.int64) if len(qid_shape) > 1 and density_matrix_rep.shape == qid_shape * 2: density_matrix_rep = density_matrix_rep.reshape((N, N)) if density_matrix_rep.shape == (N, N):