Skip to content

Commit

Permalink
Allow partial state vector function to handle qudits (quantumlib#5077)
Browse files Browse the repository at this point in the history
* Allow partial state vector function to handle qudits

* nits

* check norm

* nit

* simplify kron

* lint

* nits

Co-authored-by: Cirq Bot <[email protected]>
  • Loading branch information
2 people authored and rht committed May 1, 2023
1 parent 4db45dc commit b903e1b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 37 deletions.
55 changes: 30 additions & 25 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def partial_trace_of_state_vector_as_mixture(
) -> Tuple[Tuple[float, np.ndarray], ...]:
"""Returns a mixture representing a state vector with only some qubits kept.
The input state vector must have shape `(2,) * n` or `(2 ** n)` where
`state_vector` is expressed over n qubits. States in the output mixture will
retain the same type of shape as the input state vector, either `(2 ** k)`
or `(2,) * k` where k is the number of qubits kept.
The input state vector can have any shape, but if it is one-dimensional it
will be interpreted as qubits, since that is the most common case, and fail
if the dimension is not size `2 ** n`. States in the output mixture will
retain the same type of shape as the input state vector.
If the state vector cannot be factored into a pure state over `keep_indices`
then eigendecomposition is used and the output mixture will not be unique.
Expand All @@ -361,31 +361,30 @@ def partial_trace_of_state_vector_as_mixture(
partial trace.
Raises:
ValueError: if the input `state_vector` is not an array of length
`(2 ** n)` or a tensor with a shape of `(2,) * n`
ValueError: If the input `state_vector` is one dimension, but that
dimension size is not a power of two.
IndexError: If any indexes are out of range.
"""

if state_vector.ndim == 1:
dims = int(np.log2(state_vector.size))
if 2 ** dims != state_vector.size:
raise ValueError(f'Cannot infer underlying shape of {state_vector.shape}.')
state_vector = state_vector.reshape((2,) * dims)
ret_shape: Tuple[int, ...] = (2 ** len(keep_indices),)
else:
ret_shape = tuple(state_vector.shape[i] for i in keep_indices)

# Attempt to do efficient state factoring.
try:
state = sub_state_vector(
state_vector, keep_indices, default=RaiseValueErrorIfNotProvided, atol=atol
)
return ((1.0, state),)
state, _ = factor_state_vector(state_vector, keep_indices, atol=atol)
return ((1.0, state.reshape(ret_shape)),)
except EntangledStateError:
pass

# Fall back to a (non-unique) mixture representation.
keep_dims = 1 << len(keep_indices)
ret_shape: Union[Tuple[int], Tuple[int, ...]]
if state_vector.shape == (state_vector.size,):
ret_shape = (keep_dims,)
elif all(e == 2 for e in state_vector.shape):
ret_shape = tuple(2 for _ in range(len(keep_indices)))

rho = np.kron(np.conj(state_vector.reshape(-1, 1)).T, state_vector.reshape(-1, 1)).reshape(
(2, 2) * int(np.log2(state_vector.size))
)
keep_rho = partial_trace(rho, keep_indices).reshape((keep_dims,) * 2)
rho = np.outer(state_vector, np.conj(state_vector)).reshape(state_vector.shape * 2)
keep_rho = partial_trace(rho, keep_indices).reshape((np.prod(ret_shape),) * 2)
eigvals, eigvecs = np.linalg.eigh(keep_rho)
mixture = tuple(zip(eigvals, [vec.reshape(ret_shape) for vec in eigvecs.T]))
return tuple([(float(p[0]), p[1]) for p in mixture if not protocols.approx_eq(p[0], 0.0)])
Expand Down Expand Up @@ -436,8 +435,10 @@ def sub_state_vector(
The state vector expressed over the desired subset of qubits.
Raises:
ValueError: if the `state_vector` is not of the correct shape or the
indices are not a valid subset of the input `state_vector`'s indices
ValueError: If the `state_vector` is not of the correct shape or the
indices are not a valid subset of the input `state_vector`'s
indices.
IndexError: If any indexes are out of range.
EntangledStateError: If the result of factoring is not a pure state and
`default` is not provided.
Expand Down Expand Up @@ -581,7 +582,9 @@ def factor_state_vector(
same order as the original state vector.
Raises:
ValueError: If the tensor cannot be factored along an axes.
EntangledStateError: If the tensor is already in entangled state, and
the validate flag is set.
ValueError: If the tensor factorization fails for any other reason.
"""
n_axes = len(axes)
t1 = np.moveaxis(t, axes, range(n_axes))
Expand All @@ -595,7 +598,9 @@ def factor_state_vector(
if validate:
t2 = state_vector_kronecker_product(extracted, remainder)
if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol):
raise ValueError('The tensor cannot be factored by the requested axes')
if not np.isclose(np.linalg.norm(t1), 1):
raise ValueError('Input state must be normalized.')
raise EntangledStateError('The tensor cannot be factored by the requested axes')
return extracted, remainder


Expand Down
53 changes: 43 additions & 10 deletions cirq-core/cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,24 +492,17 @@ def test_partial_trace_of_state_vector_as_mixture_invalid_input():
with pytest.raises(ValueError, match='7'):
cirq.partial_trace_of_state_vector_as_mixture(np.arange(7), [1, 2], atol=1e-8)

bad_shape = np.arange(16).reshape((2, 4, 2))
with pytest.raises(ValueError, match='shaped'):
cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8)
bad_shape = np.arange(16).reshape((16, 1))
with pytest.raises(ValueError, match='shaped'):
cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8)

with pytest.raises(ValueError, match='normalized'):
cirq.partial_trace_of_state_vector_as_mixture(np.arange(8), [1], atol=1e-8)

state = np.arange(8) / np.linalg.norm(np.arange(8))
with pytest.raises(ValueError, match='2, 2'):
with pytest.raises(ValueError, match='repeated axis'):
cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2, 2], atol=1e-8)

state = np.array([1, 0, 0, 0]).reshape((2, 2))
with pytest.raises(ValueError, match='invalid'):
with pytest.raises(IndexError, match='out of range'):
cirq.partial_trace_of_state_vector_as_mixture(state, [5], atol=1e-8)
with pytest.raises(ValueError, match='invalid'):
with pytest.raises(IndexError, match='out of range'):
cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 2], atol=1e-8)


Expand Down Expand Up @@ -576,6 +569,38 @@ def test_partial_trace_of_state_vector_as_mixture_pure_result():
)


def test_partial_trace_of_state_vector_as_mixture_pure_result_qudits():
a = cirq.testing.random_superposition(2)
b = cirq.testing.random_superposition(3)
c = cirq.testing.random_superposition(4)
state = np.kron(np.kron(a, b), c).reshape((2, 3, 4))

assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [0], atol=1e-8),
((1.0, a),),
)
assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [1], atol=1e-8),
((1.0, b),),
)
assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [2], atol=1e-8),
((1.0, c),),
)
assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1], atol=1e-8),
((1.0, np.kron(a, b).reshape((2, 3))),),
)
assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [0, 2], atol=1e-8),
((1.0, np.kron(a, c).reshape((2, 4))),),
)
assert mixtures_equal(
cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2], atol=1e-8),
((1.0, np.kron(b, c).reshape((3, 4))),),
)


def test_partial_trace_of_state_vector_as_mixture_mixed_result():
state = np.array([1, 0, 0, 1]) / np.sqrt(2)
truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1])))
Expand Down Expand Up @@ -604,6 +629,14 @@ def test_partial_trace_of_state_vector_as_mixture_mixed_result():
assert mixtures_equal(mixture, truth)


def test_partial_trace_of_state_vector_as_mixture_mixed_result_qudits():
state = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) / np.sqrt(2)
truth = ((0.5, np.array([1, 0, 0])), (0.5, np.array([0, 0, 1])))
for q1 in [0, 1]:
mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8)
assert mixtures_equal(mixture, truth)


def test_to_special():
u = cirq.testing.random_unitary(4)
su = cirq.to_special(u)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/state_vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_factor_validation():
cirq.linalg.transformations.factor_state_vector(t, [1])
args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)))
t = args.create_merged_state().target_tensor
with pytest.raises(ValueError, match='factor'):
with pytest.raises(cirq.linalg.transformations.EntangledStateError):
cirq.linalg.transformations.factor_state_vector(t, [0])
with pytest.raises(ValueError, match='factor'):
with pytest.raises(cirq.linalg.transformations.EntangledStateError):
cirq.linalg.transformations.factor_state_vector(t, [1])

0 comments on commit b903e1b

Please sign in to comment.