From c6bbce355b485af9ca7d33797ba3f89fe1543aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20P=2E=20Moutinho?= Date: Wed, 14 Aug 2024 08:21:38 +0200 Subject: [PATCH] [Bug, Feature] Fix permute basis and add permutation based apply operator (#262) A few things that came up while working on the improvement to the noise. - The `permute_basis` was not doing the correct permutation, but it was going unnoticed because the `expand_operator` was compensating for it. - I created an alternative `apply_operator_permute` which is just as fast as the `apply_operator` function, but can serve as a basis for future changes. For now we can simply save it. It could be useful if we wish to make PyQ follow the more standard state shape of `[batch_size, 2**n_qubits]` instead of the `[2] * n_qubits + [batch_size]`. The reason I wrote it is because the logic I am working on for the `apply_density_mat` follows something similar, but doing it first for the normal `apply_operator` was easier. --- pyqtorch/api.py | 1 - pyqtorch/apply.py | 59 ++++++++++++++++++++++-------- pyqtorch/composite/sequence.py | 2 +- pyqtorch/hamiltonians/evolution.py | 4 +- pyqtorch/quantum_operation.py | 1 - pyqtorch/utils.py | 43 ++++++++++++++++++++-- tests/helpers.py | 8 ++-- tests/test_tensor.py | 11 +++++- 8 files changed, 99 insertions(+), 30 deletions(-) diff --git a/pyqtorch/api.py b/pyqtorch/api.py index 124e7158..d2da8bf8 100644 --- a/pyqtorch/api.py +++ b/pyqtorch/api.py @@ -168,7 +168,6 @@ def sampled_expectation( state, eigvecs.T.conj(), tuple(range(n_qubits)), - n_qubits=circuit.n_qubits, ) eigvec_state_prod = torch.flatten(eigvec_state_prod, start_dim=0, end_dim=-2).t() probs = torch.pow(torch.abs(eigvec_state_prod), 2) diff --git a/pyqtorch/apply.py b/pyqtorch/apply.py index b1539c52..aed8bc0b 100644 --- a/pyqtorch/apply.py +++ b/pyqtorch/apply.py @@ -7,7 +7,7 @@ from torch import Tensor, einsum from pyqtorch.matrices import _dagger -from pyqtorch.utils import DensityMatrix +from pyqtorch.utils import DensityMatrix, permute_state ABC_ARRAY: NDArray = array(list(ABC)) @@ -15,9 +15,7 @@ def apply_operator( state: Tensor, operator: Tensor, - qubits: tuple[int, ...] | list[int], - n_qubits: int | None = None, - batch_size: int | None = None, + qubit_support: tuple[int, ...] | list[int], ) -> Tensor: """Applies an operator, i.e. a single tensor of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of (target and control) qubits. @@ -32,33 +30,64 @@ def apply_operator( Arguments: state: State to operate on. operator: Tensor to contract over 'state'. - qubits: Tuple of qubits on which to apply the 'operator' to. - n_qubits: The number of qubits of the full system. - batch_size: Batch size of either state and or operators. + qubit_support: Tuple of qubits on which to apply the 'operator' to. Returns: State after applying 'operator'. """ - qubits = list(qubits) - if n_qubits is None: - n_qubits = len(state.size()) - 1 - if batch_size is None: - batch_size = state.size(-1) - n_support = len(qubits) + qubit_support = list(qubit_support) + n_qubits = len(state.size()) - 1 + n_support = len(qubit_support) n_state_dims = n_qubits + 1 operator = operator.view([2] * n_support * 2 + [operator.size(-1)]) in_state_dims = ABC_ARRAY[0:n_state_dims].copy() operator_dims = ABC_ARRAY[n_state_dims : n_state_dims + 2 * n_support + 1].copy() - operator_dims[n_support : 2 * n_support] = in_state_dims[qubits] + operator_dims[n_support : 2 * n_support] = in_state_dims[qubit_support] operator_dims[-1] = in_state_dims[-1] out_state_dims = in_state_dims.copy() - out_state_dims[qubits] = operator_dims[0:n_support] + out_state_dims[qubit_support] = operator_dims[0:n_support] operator_dims, in_state_dims, out_state_dims = list( map(lambda e: "".join(list(e)), [operator_dims, in_state_dims, out_state_dims]) ) return einsum(f"{operator_dims},{in_state_dims}->{out_state_dims}", operator, state) +def apply_operator_permute( + state: Tensor, + operator: Tensor, + qubit_support: tuple[int, ...] | list[int], +) -> Tensor: + """NOTE: Currently not being used. + + Alternative apply operator function with a logic based on state permutations. + Seems to be as fast as the current `apply_operator`. To be saved for now, we + may want to switch to this one in the future if we wish to remove the state + [2] * n_qubits shape and make the batch dimension the first one as the typical + torch convention. + + Arguments: + state: State to operate on. + operator: Tensor to contract over 'state'. + qubit_support: Tuple of qubits on which to apply the 'operator' to. + + Returns: + State after applying 'operator'. + """ + n_qubits = len(state.size()) - 1 + n_support = len(qubit_support) + batch_size = max(state.size(-1), operator.size(-1)) + full_support = tuple(range(n_qubits)) + support_perm = list(sorted(qubit_support)) + list( + set(full_support) - set(qubit_support) + ) + state = permute_state(state, support_perm) + state = state.reshape([2**n_support, 2 ** (n_qubits - n_support), state.size(-1)]) + result = einsum("ijb,jkb->ikb", operator, state).reshape( + [2] * n_qubits + [batch_size] + ) + return permute_state(result, support_perm, inv=True) + + def apply_density_mat(op: Tensor, density_matrix: DensityMatrix) -> DensityMatrix: """ Apply an operator to a density matrix, i.e., compute: diff --git a/pyqtorch/composite/sequence.py b/pyqtorch/composite/sequence.py index 97174300..02bda51e 100644 --- a/pyqtorch/composite/sequence.py +++ b/pyqtorch/composite/sequence.py @@ -85,7 +85,7 @@ def __init__(self, operations: list[Module]): @property def qubit_support(self) -> tuple: - return self._qubit_support + return tuple(sorted(self._qubit_support)) def __iter__(self) -> Iterator: return iter(self.operations) diff --git a/pyqtorch/hamiltonians/evolution.py b/pyqtorch/hamiltonians/evolution.py index 5fe1a70a..ad25e27c 100644 --- a/pyqtorch/hamiltonians/evolution.py +++ b/pyqtorch/hamiltonians/evolution.py @@ -330,9 +330,7 @@ def forward( return apply_operator( state=state, operator=evolved_op, - qubits=self.qubit_support, - n_qubits=len(state.size()) - 1, - batch_size=evolved_op.shape[BATCH_DIM], + qubit_support=self.qubit_support, ) def tensor( diff --git a/pyqtorch/quantum_operation.py b/pyqtorch/quantum_operation.py index 7ae6aced..b95ffb7b 100644 --- a/pyqtorch/quantum_operation.py +++ b/pyqtorch/quantum_operation.py @@ -354,7 +354,6 @@ def _forward( state, self.tensor(values, embedding), self.qubit_support, - len(state.size()) - 1, ) def _noise_forward( diff --git a/pyqtorch/utils.py b/pyqtorch/utils.py index 1e178337..f331ef80 100644 --- a/pyqtorch/utils.py +++ b/pyqtorch/utils.py @@ -395,6 +395,7 @@ def expand_operator( by explicitly filling in identity matrices on all remaining qubits. """ full_support = tuple(sorted(full_support)) + qubit_support = tuple(sorted(qubit_support)) if not set(qubit_support).issubset(set(full_support)): raise ValueError( "Expanding tensor operation requires a `full_support` argument " @@ -405,7 +406,7 @@ def expand_operator( qubit_support += (i,) other = IMAT.clone().to(device=device, dtype=dtype).unsqueeze(2) operator = torch.kron(operator.contiguous(), other) - operator = permute_basis(operator, qubit_support) + operator = permute_basis(operator, qubit_support, inv=True) return operator @@ -445,27 +446,61 @@ def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor: return operator -def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor: +def permute_state( + state: Tensor, qubit_support: tuple | list, inv: bool = False +) -> Tensor: + """Takes a state tensor and permutes the qubit amplitudes + according to the order of the qubit support. + + Args: + state (Tensor): State to permute over. + qubit_support (tuple): Qubit support. + inv (bool): Applies the inverse permutation instead. + + Returns: + Tensor: Permuted state. + """ + if tuple(qubit_support) == tuple(sorted(qubit_support)): + return state + + ordered_support = argsort(qubit_support) + ranked_support = argsort(ordered_support) + + perm = list(ranked_support) + [len(qubit_support)] + + if inv: + perm = np.argsort(perm).tolist() + + return state.permute(perm) + + +def permute_basis(operator: Tensor, qubit_support: tuple, inv: bool = False) -> Tensor: """Takes an operator tensor and permutes the rows and columns according to the order of the qubit support. Args: operator (Tensor): Operator to permute over. qubit_support (tuple): Qubit support. + inv (bool): Applies the inverse permutation instead. Returns: Tensor: Permuted operator. """ ordered_support = argsort(qubit_support) + ranked_support = argsort(ordered_support) n_qubits = len(qubit_support) - if all(a == b for a, b in zip(ordered_support, list(range(n_qubits)))): + if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))): return operator batch_size = operator.size(-1) operator = operator.view([2] * 2 * n_qubits + [batch_size]) perm = list( - tuple(ordered_support) + tuple(ordered_support + n_qubits) + (2 * n_qubits,) + tuple(ranked_support) + tuple(ranked_support + n_qubits) + (2 * n_qubits,) ) + + if inv: + perm = np.argsort(perm).tolist() + return operator.permute(perm).reshape([2**n_qubits, 2**n_qubits, batch_size]) diff --git a/tests/helpers.py b/tests/helpers.py index 3a1df9f2..cdb30fe9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,7 +4,7 @@ import torch -from pyqtorch.apply import apply_operator +from pyqtorch.apply import apply_operator, apply_operator_permute from pyqtorch.composite import Add, Scale, Sequence from pyqtorch.primitives import ( OPS_1Q, @@ -24,6 +24,7 @@ def calc_mat_vec_wavefunction( init_state: torch.Tensor, values: dict = dict(), full_support: tuple | None = None, + use_permute: bool = False, ) -> torch.Tensor: """Get the result of applying the matrix representation of a block to an initial state. @@ -38,10 +39,11 @@ def calc_mat_vec_wavefunction( """ mat = block.tensor(values=values, full_support=full_support) qubit_support = block.qubit_support if full_support is None else full_support - return apply_operator( + apply_func = apply_operator_permute if use_permute else apply_operator + return apply_func( init_state, mat, - qubits=qubit_support, + qubit_support, ) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 503920d2..1dbd2b80 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -49,10 +49,13 @@ def test_digital_tensor(n_qubits: int, batch_size: int, use_full_support: bool) assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL) +@pytest.mark.parametrize("use_permute", [True, False]) @pytest.mark.parametrize("use_full_support", [True, False]) @pytest.mark.parametrize("n_qubits", [4, 5]) @pytest.mark.parametrize("batch_size", [1, 5]) -def test_param_tensor(n_qubits: int, batch_size: int, use_full_support: bool) -> None: +def test_param_tensor( + n_qubits: int, batch_size: int, use_full_support: bool, use_permute: bool +) -> None: """ Goes through all parametric gates and tests their application to a random state in comparison with the `tensor` method, either using just the qubit support of the gate @@ -68,7 +71,11 @@ def test_param_tensor(n_qubits: int, batch_size: int, use_full_support: bool) -> psi_star = op_concrete(psi_init, values) full_support = tuple(range(n_qubits)) if use_full_support else None psi_expected = calc_mat_vec_wavefunction( - op_concrete, psi_init, values=values, full_support=full_support + op_concrete, + psi_init, + values=values, + full_support=full_support, + use_permute=use_permute, ) assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL)