From 729c09d9559e3117d3f9291c36247228811ced66 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 8 Apr 2022 08:54:33 -0700 Subject: [PATCH] Allow specifying initial state vector in DensityMatrixSimulator (#5223) This changes how ActOnDensityMatrixArgs is constructed to allow specifying the initial state as a state vector or state tensor, or as a density matrix or density tensor. Some of this could perhaps be moved into `cirq.to_valid_density_matrix` if people think that is a better place. Currently `to_valid_density_matrix` only handles 1D state vectors or 2D density matrices, not 2x2x..2 tensors in either case, but if we have the qid_shape we can tell handle these unambiguously. Fixes #3958 --- cirq-core/cirq/qis/states.py | 10 ++-- cirq-core/cirq/qis/states_test.py | 27 ++++++++++- .../cirq/sim/act_on_density_matrix_args.py | 6 ++- .../sim/act_on_density_matrix_args_test.py | 47 +++++++++++++++++++ 4 files changed, 85 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/qis/states.py b/cirq-core/cirq/qis/states.py index 8120fc0c8e5e..65f13562f20c 100644 --- a/cirq-core/cirq/qis/states.py +++ b/cirq-core/cirq/qis/states.py @@ -968,9 +968,13 @@ def to_valid_density_matrix( ValueError if the density_matrix_rep is not valid. """ qid_shape = _qid_shape_from_args(num_qubits, qid_shape) - if isinstance(density_matrix_rep, np.ndarray) and density_matrix_rep.ndim == 2: - validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol) - return density_matrix_rep + if isinstance(density_matrix_rep, np.ndarray): + N = np.prod(qid_shape, dtype=np.int64) + if len(qid_shape) > 1 and density_matrix_rep.shape == qid_shape * 2: + density_matrix_rep = density_matrix_rep.reshape((N, N)) + if density_matrix_rep.shape == (N, N): + validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol) + return density_matrix_rep state_vector = to_valid_state_vector( density_matrix_rep, len(qid_shape), qid_shape=qid_shape, dtype=dtype diff --git a/cirq-core/cirq/qis/states_test.py b/cirq-core/cirq/qis/states_test.py index 518fd51ec481..f4073b48715d 100644 --- a/cirq-core/cirq/qis/states_test.py +++ b/cirq-core/cirq/qis/states_test.py @@ -607,6 +607,21 @@ def test_to_valid_density_matrix_from_density_matrix(): assert_valid_density_matrix(np.diag([0.2, 0.8, 0, 0]), qid_shape=(4,)) +def test_to_valid_density_matrix_from_density_matrix_tensor(): + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + cirq.one_hot(shape=(2, 2, 2, 2, 2, 2), dtype=np.complex64), num_qubits=3 + ), + cirq.one_hot(shape=(8, 8), dtype=np.complex64), + ) + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + cirq.one_hot(shape=(2, 3, 4, 2, 3, 4), dtype=np.complex64), qid_shape=(2, 3, 4) + ), + cirq.one_hot(shape=(24, 24), dtype=np.complex64), + ) + + def test_to_valid_density_matrix_not_square(): with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.array([[1], [0]]), num_qubits=1) @@ -614,7 +629,7 @@ def test_to_valid_density_matrix_not_square(): def test_to_valid_density_matrix_size_mismatch_num_qubits(): with pytest.raises(ValueError, match='shape'): - cirq.to_valid_density_matrix(np.array([[1, 0], [0, 0]]), num_qubits=2) + cirq.to_valid_density_matrix(np.array([[[1, 0], [0, 0]], [[0, 0], [0, 0]]]), num_qubits=2) with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.eye(4) / 4.0, num_qubits=1) @@ -690,6 +705,16 @@ def test_to_valid_density_matrix_from_state_vector(): ) +def test_to_valid_density_matrix_from_state_vector_tensor(): + np.testing.assert_almost_equal( + cirq.to_valid_density_matrix( + density_matrix_rep=np.array(np.full((2, 2), 0.5), dtype=np.complex64), + num_qubits=2, + ), + 0.25 * np.ones((4, 4)), + ) + + def test_to_valid_density_matrix_from_state_invalid_state(): with pytest.raises(ValueError, match="Invalid quantum state"): cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2) diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 0eebf2716a48..a2f591090dad 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -79,7 +79,11 @@ def create( ).reshape(qid_shape * 2) else: if qid_shape is not None: - density_matrix = initial_state.reshape(qid_shape * 2) + if dtype and initial_state.dtype != dtype: + initial_state = initial_state.astype(dtype) + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) else: density_matrix = initial_state if np.may_share_memory(density_matrix, initial_state): diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py index fd62c6934535..e82685bd452e 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py @@ -98,3 +98,50 @@ def test_with_qubits(): def test_qid_shape_error(): with pytest.raises(ValueError, match="qid_shape must be provided"): cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0) + + +def test_initial_state_vector(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_matrix(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_bad_shape(): + qubits = cirq.LineQubit.range(3) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64 + ) + + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid quantum state"): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64 + )