diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index aa509b192cfd..687c491eea12 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -13,6 +13,7 @@ # limitations under the License. """Objects and methods for acting efficiently on a density matrix.""" +import math from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Type, Union import numpy as np @@ -79,7 +80,26 @@ def create( ).reshape(qid_shape * 2) else: if qid_shape is not None: - density_matrix = initial_state.reshape(qid_shape * 2) + qid_size = math.prod(qid_shape) + shape = initial_state.shape + if shape == qid_shape or shape == (qid_size,): + if len(shape) != 1: + initial_state = initial_state.reshape((qid_size,)) + elif shape == qid_shape * 2 or shape == (qid_size, qid_size): + if len(shape) != 2: + initial_state = initial_state.reshape((qid_size, qid_size)) + if dtype and initial_state.dtype != dtype: + # Convert type because to_valid_density_matrix does not convert dtype. + initial_state = initial_state.astype(dtype) + else: + raise ValueError( + f'Invalid initial state. Expected state vector of shape {(qid_size,)} ' + f'or density matrix of shape {(qid_size, qid_size)}; ' + f'got {initial_state.shape}.' + ) + density_matrix = qis.to_valid_density_matrix( + initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype + ).reshape(qid_shape * 2) else: density_matrix = initial_state if np.may_share_memory(density_matrix, initial_state): diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py index 950c58cabff6..4450b77d51b1 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args_test.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args_test.py @@ -105,3 +105,50 @@ def test_with_qubits(): def test_qid_shape_error(): with pytest.raises(ValueError, match="qid_shape must be provided"): cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0) + + +def test_initial_state_vector(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_matrix(): + qubits = cirq.LineQubit.range(3) + args = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64 + ) + assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + args2 = cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64 + ) + assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2) + + +def test_initial_state_bad_shape(): + qubits = cirq.LineQubit.range(3) + with pytest.raises(ValueError, match="Invalid initial state."): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid initial state."): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64 + ) + + with pytest.raises(ValueError, match="Invalid initial state."): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64 + ) + with pytest.raises(ValueError, match="Invalid initial state."): + cirq.ActOnDensityMatrixArgs( + qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64 + )