Skip to content

Commit

Permalink
fix_permute_basis
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmoutinho committed Aug 13, 2024
1 parent bc1b9a3 commit 6f97ef4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyqtorch/composite/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, operations: list[Module]):

@property
def qubit_support(self) -> tuple:
return self._qubit_support
return tuple(sorted(self._qubit_support))

def __iter__(self) -> Iterator:
return iter(self.operations)
Expand Down
14 changes: 10 additions & 4 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def expand_operator(
by explicitly filling in identity matrices on all remaining qubits.
"""
full_support = tuple(sorted(full_support))
qubit_support = tuple(sorted(qubit_support))
if not set(qubit_support).issubset(set(full_support)):
raise ValueError(
"Expanding tensor operation requires a `full_support` argument "
Expand All @@ -405,7 +406,7 @@ def expand_operator(
qubit_support += (i,)
other = IMAT.clone().to(device=device, dtype=dtype).unsqueeze(2)
operator = torch.kron(operator.contiguous(), other)
operator = permute_basis(operator, qubit_support)
operator = permute_basis(operator, qubit_support, inv=True)
return operator


Expand Down Expand Up @@ -445,7 +446,7 @@ def promote_operator(operator: Tensor, target: int, n_qubits: int) -> Tensor:
return operator


def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor:
def permute_basis(operator: Tensor, qubit_support: tuple, inv=False) -> Tensor:
"""Takes an operator tensor and permutes the rows and
columns according to the order of the qubit support.
Expand All @@ -457,15 +458,20 @@ def permute_basis(operator: Tensor, qubit_support: tuple) -> Tensor:
Tensor: Permuted operator.
"""
ordered_support = argsort(qubit_support)
ranked_support = argsort(ordered_support)
n_qubits = len(qubit_support)
if all(a == b for a, b in zip(ordered_support, list(range(n_qubits)))):
if all(a == b for a, b in zip(ranked_support, list(range(n_qubits)))):
return operator
batch_size = operator.size(-1)
operator = operator.view([2] * 2 * n_qubits + [batch_size])

perm = list(
tuple(ordered_support) + tuple(ordered_support + n_qubits) + (2 * n_qubits,)
tuple(ranked_support) + tuple(ranked_support + n_qubits) + (2 * n_qubits,)
)

if inv:
perm = np.argsort(perm).tolist()

return operator.permute(perm).reshape([2**n_qubits, 2**n_qubits, batch_size])


Expand Down

0 comments on commit 6f97ef4

Please sign in to comment.