Skip to content

Commit

Permalink
[Bug, Feature] Fix permute basis and add permutation based apply oper…
Browse files Browse the repository at this point in the history
…ator (#262)

A few things that came up while working on the improvement to the noise.

- The `permute_basis` was not doing the correct permutation, but it was
going unnoticed because the `expand_operator` was compensating for it.
- I created an alternative `apply_operator_permute` which is just as
fast as the `apply_operator` function, but can serve as a basis for
future changes. For now we can simply save it. It could be useful if we
wish to make PyQ follow the more standard state shape of `[batch_size,
2**n_qubits]` instead of the `[2] * n_qubits + [batch_size]`. The reason
I wrote it is because the logic I am working on for the
`apply_density_mat` follows something similar, but doing it first for
the normal `apply_operator` was easier.
  • Loading branch information
jpmoutinho authored Aug 14, 2024
1 parent 26eeeaa commit c6bbce3
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 30 deletions.
1 change: 0 additions & 1 deletion pyqtorch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def sampled_expectation(
state,
eigvecs.T.conj(),
tuple(range(n_qubits)),
n_qubits=circuit.n_qubits,
)
eigvec_state_prod = torch.flatten(eigvec_state_prod, start_dim=0, end_dim=-2).t()
probs = torch.pow(torch.abs(eigvec_state_prod), 2)
Expand Down
59 changes: 44 additions & 15 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
from torch import Tensor, einsum

from pyqtorch.matrices import _dagger
from pyqtorch.utils import DensityMatrix
from pyqtorch.utils import DensityMatrix, permute_state

ABC_ARRAY: NDArray = array(list(ABC))


def apply_operator(
state: Tensor,
operator: Tensor,
qubits: tuple[int, ...] | list[int],
n_qubits: int | None = None,
batch_size: int | None = None,
qubit_support: tuple[int, ...] | list[int],
) -> Tensor:
"""Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state
of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits.
Expand All @@ -32,33 +30,64 @@ def apply_operator(
Arguments:
state: State to operate on.
operator: Tensor to contract over 'state'.
qubits: Tuple of qubits on which to apply the 'operator' to.
n_qubits: The number of qubits of the full system.
batch_size: Batch size of either state and or operators.
qubit_support: Tuple of qubits on which to apply the 'operator' to.
Returns:
State after applying 'operator'.
"""
qubits = list(qubits)
if n_qubits is None:
n_qubits = len(state.size()) - 1
if batch_size is None:
batch_size = state.size(-1)
n_support = len(qubits)
qubit_support = list(qubit_support)
n_qubits = len(state.size()) - 1
n_support = len(qubit_support)
n_state_dims = n_qubits + 1
operator = operator.view([2] * n_support * 2 + [operator.size(-1)])
in_state_dims = ABC_ARRAY[0:n_state_dims].copy()
operator_dims = ABC_ARRAY[n_state_dims : n_state_dims + 2 * n_support + 1].copy()
operator_dims[n_support : 2 * n_support] = in_state_dims[qubits]
operator_dims[n_support : 2 * n_support] = in_state_dims[qubit_support]
operator_dims[-1] = in_state_dims[-1]
out_state_dims = in_state_dims.copy()
out_state_dims[qubits] = operator_dims[0:n_support]
out_state_dims[qubit_support] = operator_dims[0:n_support]
operator_dims, in_state_dims, out_state_dims = list(
map(lambda e: "".join(list(e)), [operator_dims, in_state_dims, out_state_dims])
)
return einsum(f"{operator_dims},{in_state_dims}->{out_state_dims}", operator, state)


def apply_operator_permute(
state: Tensor,
operator: Tensor,
qubit_support: tuple[int, ...] | list[int],
) -> Tensor:
"""NOTE: Currently not being used.
Alternative apply operator function with a logic based on state permutations.
Seems to be as fast as the current `apply_operator`. To be saved for now, we
may want to switch to this one in the future if we wish to remove the state
[2] * n_qubits shape and make the batch dimension the first one as the typical
torch convention.
Arguments:
state: State to operate on.
operator: Tensor to contract over 'state'.
qubit_support: Tuple of qubits on which to apply the 'operator' to.
Returns:
State after applying 'operator'.
"""
n_qubits = len(state.size()) - 1
n_support = len(qubit_support)
batch_size = max(state.size(-1), operator.size(-1))
full_support = tuple(range(n_qubits))
support_perm = list(sorted(qubit_support)) + list(
set(full_support) - set(qubit_support)
)
state = permute_state(state, support_perm)
state = state.reshape([2**n_support, 2 ** (n_qubits - n_support), state.size(-1)])
result = einsum("ijb,jkb->ikb", operator, state).reshape(
[2] * n_qubits + [batch_size]
)
return permute_state(result, support_perm, inv=True)


def apply_density_mat(op: Tensor, density_matrix: DensityMatrix) -> DensityMatrix:
"""
Apply an operator to a density matrix, i.e., compute:
Expand Down
2 changes: 1 addition & 1 deletion pyqtorch/composite/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, operations: list[Module]):

@property
def qubit_support(self) -> tuple:
return self._qubit_support
return tuple(sorted(self._qubit_support))

def __iter__(self) -> Iterator:
return iter(self.operations)
Expand Down
4 changes: 1 addition & 3 deletions pyqtorch/hamiltonians/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,7 @@ def forward(
return apply_operator(
state=state,
operator=evolved_op,
qubits=self.qubit_support,
n_qubits=len(state.size()) - 1,
batch_size=evolved_op.shape[BATCH_DIM],
qubit_support=self.qubit_support,
)

def tensor(
Expand Down
1 change: 0 additions & 1 deletion pyqtorch/quantum_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def _forward(
state,
self.tensor(values, embedding),
self.qubit_support,
len(state.size()) - 1,
)

def _noise_forward(
Expand Down
43 changes: 39 additions & 4 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def expand_operator(
by explicitly filling in identity matrices on all remaining qubits.
"""
full_support = tuple(sorted(full_support))
qubit_support = tuple(sorted(qubit_support))
if not set(qubit_support).issubset(set(full_support)):
raise ValueError(
"Expanding tensor operation requires a `full_support` argument "
Expand All @@ -405,7 +406,7 @@ def expand_operator(
qubit_support += (i,)
other = IMAT.clone().to(device=device, dtype=dtype).unsqueeze(2)
operator = torch.kron(operator.contiguous(), other)
operator = permute_basis(operator, qubit_support)
operator = permute_basis(operator, qubit_support, inv=True)
return operator


Expand Down Expand Up @@ -445,27 +446,61 @@ def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor:
return operator


def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor:
def permute_state(
state: Tensor, qubit_support: tuple | list, inv: bool = False
) -> Tensor:
"""Takes a state tensor and permutes the qubit amplitudes
according to the order of the qubit support.
Args:
state (Tensor): State to permute over.
qubit_support (tuple): Qubit support.
inv (bool): Applies the inverse permutation instead.
Returns:
Tensor: Permuted state.
"""
if tuple(qubit_support) == tuple(sorted(qubit_support)):
return state

ordered_support = argsort(qubit_support)
ranked_support = argsort(ordered_support)

perm = list(ranked_support) + [len(qubit_support)]

if inv:
perm = np.argsort(perm).tolist()

return state.permute(perm)


def permute_basis(operator: Tensor, qubit_support: tuple, inv: bool = False) -> Tensor:
"""Takes an operator tensor and permutes the rows and
columns according to the order of the qubit support.
Args:
operator (Tensor): Operator to permute over.
qubit_support (tuple): Qubit support.
inv (bool): Applies the inverse permutation instead.
Returns:
Tensor: Permuted operator.
"""
ordered_support = argsort(qubit_support)
ranked_support = argsort(ordered_support)
n_qubits = len(qubit_support)
if all(a == b for a, b in zip(ordered_support, list(range(n_qubits)))):
if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))):
return operator
batch_size = operator.size(-1)
operator = operator.view([2] * 2 * n_qubits + [batch_size])

perm = list(
tuple(ordered_support) + tuple(ordered_support + n_qubits) + (2 * n_qubits,)
tuple(ranked_support) + tuple(ranked_support + n_qubits) + (2 * n_qubits,)
)

if inv:
perm = np.argsort(perm).tolist()

return operator.permute(perm).reshape([2**n_qubits, 2**n_qubits, batch_size])


Expand Down
8 changes: 5 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from pyqtorch.apply import apply_operator
from pyqtorch.apply import apply_operator, apply_operator_permute
from pyqtorch.composite import Add, Scale, Sequence
from pyqtorch.primitives import (
OPS_1Q,
Expand All @@ -24,6 +24,7 @@ def calc_mat_vec_wavefunction(
init_state: torch.Tensor,
values: dict = dict(),
full_support: tuple | None = None,
use_permute: bool = False,
) -> torch.Tensor:
"""Get the result of applying the matrix representation of a block to an initial state.
Expand All @@ -38,10 +39,11 @@ def calc_mat_vec_wavefunction(
"""
mat = block.tensor(values=values, full_support=full_support)
qubit_support = block.qubit_support if full_support is None else full_support
return apply_operator(
apply_func = apply_operator_permute if use_permute else apply_operator
return apply_func(
init_state,
mat,
qubits=qubit_support,
qubit_support,
)


Expand Down
11 changes: 9 additions & 2 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ def test_digital_tensor(n_qubits: int, batch_size: int, use_full_support: bool)
assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL)


@pytest.mark.parametrize("use_permute", [True, False])
@pytest.mark.parametrize("use_full_support", [True, False])
@pytest.mark.parametrize("n_qubits", [4, 5])
@pytest.mark.parametrize("batch_size", [1, 5])
def test_param_tensor(n_qubits: int, batch_size: int, use_full_support: bool) -> None:
def test_param_tensor(
n_qubits: int, batch_size: int, use_full_support: bool, use_permute: bool
) -> None:
"""
Goes through all parametric gates and tests their application to a random state
in comparison with the `tensor` method, either using just the qubit support of the gate
Expand All @@ -68,7 +71,11 @@ def test_param_tensor(n_qubits: int, batch_size: int, use_full_support: bool) ->
psi_star = op_concrete(psi_init, values)
full_support = tuple(range(n_qubits)) if use_full_support else None
psi_expected = calc_mat_vec_wavefunction(
op_concrete, psi_init, values=values, full_support=full_support
op_concrete,
psi_init,
values=values,
full_support=full_support,
use_permute=use_permute,
)
assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL)

Expand Down

0 comments on commit c6bbce3

Please sign in to comment.