Skip to content

Commit

Permalink
[Performance, Refactor] Operator multiplication without explicit iden…
Browse files Browse the repository at this point in the history
…tities (#268)

Refactors the `operator_product` function that does not require explicit
identity padding. Also refactors the `apply_density_mat` into a new
`apply_operator_dm` without explicit identities. The logic of both is
similar, and the logic of `apply_operator_dm` is very similar to
`apply_operator_permute`, but keeping everything separate for now.

The reason not to use `operator_product` inside `apply_operator_dm` is
to avoid permuting the qubits twice.

The new functions are similar in performance for small number of qubits.
For 5 to 9 qubits they become a few times faster. For 12 qubits the
speedup is already ~10x.
  • Loading branch information
jpmoutinho authored Aug 26, 2024
1 parent 59ba67c commit bde458d
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 91 deletions.
135 changes: 103 additions & 32 deletions pyqtorch/apply.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from math import log2
from string import ascii_letters as ABC

from numpy import array
from numpy.typing import NDArray
from torch import Tensor, einsum

from pyqtorch.matrices import _dagger
from pyqtorch.utils import DensityMatrix, permute_state
from pyqtorch.utils import DensityMatrix, permute_basis, permute_state

ABC_ARRAY: NDArray = array(list(ABC))

Expand Down Expand Up @@ -88,45 +89,115 @@ def apply_operator_permute(
return permute_state(result, support_perm, inv=True)


def apply_density_mat(op: Tensor, density_matrix: DensityMatrix) -> DensityMatrix:
def apply_operator_dm(
state: DensityMatrix,
operator: Tensor,
qubit_support: tuple[int, ...] | list[int],
) -> Tensor:
"""
Apply an operator to a density matrix, i.e., compute:
op1 * density_matrix * op1_dagger.
Apply an operator to a density matrix on a given qubit suport, i.e., compute:
OP.DM.OP.dagger()
Args:
op (Tensor): The operator to apply.
density_matrix (DensityMatrix): The density matrix.
state: State to operate on.
operator: Tensor to contract over 'state'.
qubit_support: Tuple of qubits on which to apply the 'operator' to.
Returns:
DensityMatrix: The resulting density matrix after applying the operator and its dagger.
DensityMatrix: The resulting density matrix after applying the operator.
"""
batch_size_op = op.size(-1)
batch_size_dm = density_matrix.size(-1)
if batch_size_dm > batch_size_op:
# The other condition is impossible because
# operators are always initialized with batch_size = 1.
op = op.repeat(1, 1, batch_size_dm)
return einsum("ijb,jkb,klb->ilb", op, density_matrix, _dagger(op))

if not isinstance(state, DensityMatrix):
raise TypeError("Function apply_operator_dm requires a density matrix state.")

n_qubits = int(log2(state.size()[0]))
n_support = len(qubit_support)
batch_size = max(state.size(-1), operator.size(-1))
full_support = tuple(range(n_qubits))
support_perm = tuple(sorted(qubit_support)) + tuple(
set(full_support) - set(qubit_support)
)
state = permute_basis(state, support_perm)

def operator_product(op1: Tensor, op2: Tensor) -> Tensor:
# There is probably a smart way to represent the lines below in a single einsum...
state = state.reshape(
[2**n_support, (2 ** (2 * n_qubits - n_support)), state.size(-1)]
)
state = einsum("ijb,jkb->ikb", operator, state).reshape(
[2**n_qubits, 2**n_qubits, batch_size]
)
state = _dagger(state).reshape(
[2**n_support, (2 ** (2 * n_qubits - n_support)), state.size(-1)]
)
state = _dagger(
einsum("ijb,jkb->ikb", operator, state).reshape(
[2**n_qubits, 2**n_qubits, batch_size]
)
)
return permute_basis(state, support_perm, inv=True)


def operator_product(
op1: Tensor,
supp1: tuple[int, ...],
op2: Tensor,
supp2: tuple[int, ...],
) -> Tensor:
"""
Compute the product of two operators.
`torch.bmm` is not suitable for our purposes because,
in our convention, the batch_size is in the last dimension.
Operator product based on block matrix multiplication.
Args:
op1 (Tensor): The first operator.
op2 (Tensor): The second operator.
Returns:
Tensor: The product of the two operators.
E.g., for some small operator S and a big operator with 4 partitions A, B, C, D:
|S 0|.|A B| = |S.A S.B|
|0 S| |C D| |S.C S.D|
or
|A B|.|S 0| = |A.S B.S|
|C D|.|0 S| |C.S D.S|
The same generalizes for different sizes of the big operator. Note that the block
diagonal matrix is never computed. Instead, the big operator is permuted and
reshaped into a wide matrix:
W = [A B C D]
And then the result is computed as S.W, reshaped back into a square matrix, and
permuted back into the original ordering.
"""
# ? Should we continue to adjust the batch here?
# ? as now all gates are init with batch_size=1.
batch_size_1 = op1.size(-1)
batch_size_2 = op2.size(-1)
if batch_size_1 > batch_size_2:
op2 = op2.repeat(1, 1, batch_size_1)[:, :, :batch_size_1]
elif batch_size_2 > batch_size_1:
op1 = op1.repeat(1, 1, batch_size_2)[:, :, :batch_size_2]
return einsum("ijb,jkb->ikb", op1, op2)

if supp1 == supp2:
return einsum("ijb,jkb->ikb", op1, op2)

if len(supp1) < len(supp2):
small_op, small_supp = op1, supp1
big_op, big_supp = op2, supp2
transpose = False
else:
small_op, small_supp = _dagger(op2), supp2
big_op, big_supp = _dagger(op1), supp1
transpose = True

if not set(small_supp).issubset(set(big_supp)):
raise ValueError("Operator product requires overlapping qubit supports.")

n_big, n_small = len(big_supp), len(small_supp)
batch_big, batch_small = big_op.size(-1), small_op.size(-1)
batch_size = max(batch_big, batch_small)

# Permute the large operator and reshape into a wide matrix
support_perm = tuple(sorted(small_supp)) + tuple(set(big_supp) - set(small_supp))
big_op = permute_basis(big_op, support_perm)
big_op = big_op.reshape([2**n_small, (2 ** (2 * n_big - n_small)), batch_big])

# Compute S.W and reshape back to square
result = einsum("ijb,jkb->ikb", small_op, big_op).reshape(
[2**n_big, 2**n_big, batch_size]
)

# Apply the inverse qubit permutation
if transpose:
return _dagger(permute_basis(result, support_perm, inv=True))
else:
return permute_basis(result, support_perm, inv=True)
41 changes: 6 additions & 35 deletions pyqtorch/noise/gates.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from math import log2, sqrt
from math import sqrt
from typing import Any

import torch
from torch import Tensor

from pyqtorch.apply import apply_density_mat
from pyqtorch.apply import apply_operator_dm
from pyqtorch.embed import Embedding
from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, XMAT, YMAT, ZMAT
from pyqtorch.utils import DensityMatrix, density_mat, qubit_support_as_tuple
Expand All @@ -29,16 +29,14 @@ def __init__(
self.error_probabilities: tuple[float, ...] | float = error_probabilities

def extra_repr(self) -> str:
return f"target: {self.qubit_support}, prob: {self.probabilities}"
return f"target: {self.qubit_support}, prob: {self.error_probabilities}"

@property
def kraus_operators(self) -> list[Tensor]:
return [getattr(self, f"kraus_{i}") for i in range(len(self._buffers))]

def _tensor(
def tensor(
self,
values: dict[str, Tensor] | Tensor = dict(),
embedding: Embedding | None = None,
) -> list[Tensor]:
# Since PyQ expects tensor.Size = [2**n_qubits, 2**n_qubits,batch_size].
return [kraus_op.unsqueeze(2) for kraus_op in self.kraus_operators]
Expand Down Expand Up @@ -69,10 +67,9 @@ def forward(
"""
if not isinstance(state, DensityMatrix):
state = density_mat(state)
n_qubits = int(log2(state.size(1)))
rho_evols: list[Tensor] = []
for kraus in self.tensor(values, n_qubits):
rho_evol: Tensor = apply_density_mat(kraus, state)
for kraus in self.tensor():
rho_evol: Tensor = apply_operator_dm(state, kraus, self.qubit_support)
rho_evols.append(rho_evol)
rho_final: Tensor = torch.stack(rho_evols, dim=0)
rho_final = torch.sum(rho_final, dim=0)
Expand All @@ -92,32 +89,6 @@ def to(self, *args: Any, **kwargs: Any) -> Noise:
self._dtype = self.kraus_0.dtype
return self

def tensor(
self, values: dict[str, Tensor] = dict(), n_qubits: int = 1
) -> list[Tensor]:
block_mats = self._tensor(values)
mats = []
for blockmat in block_mats:
if n_qubits == 1:
mats.append(blockmat)
else:
full_sup = tuple(i for i in range(n_qubits))
support = tuple(sorted(self.qubit_support))
mat = (
IMAT.clone().to(self.device, self.dtype).unsqueeze(2)
if support[0] != full_sup[0]
else blockmat
)
for i in full_sup[1:]:
if i == support[0]:
other = blockmat
mat = torch.kron(mat.contiguous(), other.contiguous())
elif i not in support:
other = IMAT.clone().to(self.device, self.dtype).unsqueeze(2)
mat = torch.kron(mat.contiguous(), other.contiguous())
mats.append(mat)
return mats


class BitFlip(Noise):
"""
Expand Down
18 changes: 9 additions & 9 deletions pyqtorch/quantum_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch import Tensor

from pyqtorch.apply import apply_density_mat, apply_operator
from pyqtorch.apply import apply_operator, apply_operator_dm
from pyqtorch.embed import Embedding
from pyqtorch.matrices import _dagger
from pyqtorch.noise import NoiseProtocol, _repr_noise
Expand Down Expand Up @@ -344,10 +344,8 @@ def _forward(
embedding: Embedding | None = None,
) -> Tensor:
if isinstance(state, DensityMatrix):
n_qubits = int(log2(state.size(1)))
full_support = tuple(range(n_qubits))
return apply_density_mat(
self.tensor(values, embedding, full_support=full_support), state
return apply_operator_dm(
state, self.tensor(values, embedding), self.qubit_support
)
else:
return apply_operator(
Expand All @@ -362,13 +360,14 @@ def _noise_forward(
values: dict[str, Tensor] | Tensor = dict(),
embedding: Embedding | None = None,
) -> Tensor:

if not isinstance(state, DensityMatrix):
state = density_mat(state)
n_qubits = int(log2(state.size(1)))
full_support = tuple(range(n_qubits))
state = apply_density_mat(
self.tensor(values, embedding, full_support=full_support), state

state = apply_operator_dm(
state, self.tensor(values, embedding), self.qubit_support
)

if isinstance(self.noise, dict):
for noise_instance in self.noise.values():
protocol = noise_instance.protocol_to_gate()
Expand All @@ -382,6 +381,7 @@ def _noise_forward(
)
state = noise_gate(state, values)
return state

elif isinstance(self.noise, NoiseProtocol):
protocol = self.noise.protocol_to_gate()
noise_gate = protocol(
Expand Down
27 changes: 16 additions & 11 deletions tests/test_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import Tensor

from pyqtorch.apply import apply_density_mat, operator_product
from pyqtorch.apply import apply_operator_dm, operator_product
from pyqtorch.circuit import QuantumCircuit
from pyqtorch.matrices import (
HMAT,
Expand Down Expand Up @@ -72,20 +72,22 @@ def test_operator_product(
random_unitary_gate: Primitive | Parametric,
n_qubits: int,
) -> None:
batch_size_1 = torch.randint(low=1, high=5, size=(1,)).item()
batch_size_2 = torch.randint(low=1, high=5, size=(1,)).item()
max_batch = max(batch_size_2, batch_size_1)
batch_size = torch.randint(low=2, high=5, size=(1,)).item()
values = {"theta": torch.rand(1)}
op = random_unitary_gate.tensor(values=values, full_support=tuple(range(n_qubits)))
full_support = tuple(range(n_qubits))
op = random_unitary_gate.tensor(values=values, full_support=full_support)
op_mul = operator_product(
op.repeat(1, 1, batch_size_1), _dagger(op.repeat(1, 1, batch_size_2))
op1=op.repeat(1, 1, batch_size),
supp1=full_support,
op2=_dagger(op),
supp2=full_support,
)
assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, max_batch])
assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
assert torch.allclose(
op_mul,
torch.eye(2**n_qubits, dtype=torch.cdouble)
.unsqueeze(2)
.repeat(1, 1, max_batch),
.repeat(1, 1, batch_size),
)


Expand All @@ -97,11 +99,14 @@ def test_apply_density_mat(
random_input_dm: DensityMatrix,
) -> None:
values = {"theta": torch.rand(1)}
op = random_unitary_gate.tensor(values=values, full_support=tuple(range(n_qubits)))
full_support = tuple(range(n_qubits))
op = random_unitary_gate
rho = random_input_dm
rho_evol = apply_density_mat(op, rho)
op_mat = op.tensor(values=values)
rho_evol = apply_operator_dm(rho, op_mat, op.qubit_support)
assert rho_evol.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size])
rho_expected = operator_product(op, operator_product(rho, _dagger(op)))
mul1 = operator_product(rho, full_support, _dagger(op_mat), op.qubit_support)
rho_expected = operator_product(op_mat, op.qubit_support, mul1, full_support)
assert torch.allclose(rho_evol, rho_expected)


Expand Down
Loading

0 comments on commit bde458d

Please sign in to comment.