From 6f97ef42b8fc3d3b57ffb72740bb958bc32ca6ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20P=2E=20Moutinho?= <56390829+jpmoutinho@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:01:01 +0200 Subject: [PATCH] fix_permute_basis --- pyqtorch/composite/sequence.py | 2 +- pyqtorch/utils.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pyqtorch/composite/sequence.py b/pyqtorch/composite/sequence.py index 97174300..02bda51e 100644 --- a/pyqtorch/composite/sequence.py +++ b/pyqtorch/composite/sequence.py @@ -85,7 +85,7 @@ def __init__(self, operations: list[Module]): @property def qubit_support(self) -> tuple: - return self._qubit_support + return tuple(sorted(self._qubit_support)) def __iter__(self) -> Iterator: return iter(self.operations) diff --git a/pyqtorch/utils.py b/pyqtorch/utils.py index 1e178337..11b141a5 100644 --- a/pyqtorch/utils.py +++ b/pyqtorch/utils.py @@ -395,6 +395,7 @@ def expand_operator( by explicitly filling in identity matrices on all remaining qubits. """ full_support = tuple(sorted(full_support)) + qubit_support = tuple(sorted(qubit_support)) if not set(qubit_support).issubset(set(full_support)): raise ValueError( "Expanding tensor operation requires a `full_support` argument " @@ -405,7 +406,7 @@ def expand_operator( qubit_support += (i,) other = IMAT.clone().to(device=device, dtype=dtype).unsqueeze(2) operator = torch.kron(operator.contiguous(), other) - operator = permute_basis(operator, qubit_support) + operator = permute_basis(operator, qubit_support, inv=True) return operator @@ -445,7 +446,7 @@ def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor: return operator -def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor: +def permute_basis(operator: Tensor, qubit_support: tuple, inv=False) -> Tensor: """Takes an operator tensor and permutes the rows and columns according to the order of the qubit support. @@ -457,15 +458,20 @@ def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor: Tensor: Permuted operator. """ ordered_support = argsort(qubit_support) + ranked_support = argsort(ordered_support) n_qubits = len(qubit_support) - if all(a == b for a, b in zip(ordered_support, list(range(n_qubits)))): + if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))): return operator batch_size = operator.size(-1) operator = operator.view([2] * 2 * n_qubits + [batch_size]) perm = list( - tuple(ordered_support) + tuple(ordered_support + n_qubits) + (2 * n_qubits,) + tuple(ranked_support) + tuple(ranked_support + n_qubits) + (2 * n_qubits,) ) + + if inv: + perm = np.argsort(perm).tolist() + return operator.permute(perm).reshape([2**n_qubits, 2**n_qubits, batch_size])