From bde458d33da82886a95950cd6bb7f009c759e6ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20P=2E=20Moutinho?= Date: Mon, 26 Aug 2024 10:18:49 +0200 Subject: [PATCH] [Performance, Refactor] Operator multiplication without explicit identities (#268) Refactors the `operator_product` function that does not require explicit identity padding. Also refactors the `apply_density_mat` into a new `apply_operator_dm` without explicit identities. The logic of both is similar, and the logic of `apply_operator_dm` is very similar to `apply_operator_permute`, but keeping everything separate for now. The reason not to use `operator_product` inside `apply_operator_dm` is to avoid permuting the qubits twice. The new functions are similar in performance for small number of qubits. For 5 to 9 qubits they become a few times faster. For 12 qubits the speedup is already ~10x. --- pyqtorch/apply.py | 135 ++++++++++++++++++++++++++-------- pyqtorch/noise/gates.py | 41 ++--------- pyqtorch/quantum_operation.py | 18 ++--- tests/test_noise.py | 27 ++++--- tests/test_tensor.py | 29 +++++++- 5 files changed, 159 insertions(+), 91 deletions(-) diff --git a/pyqtorch/apply.py b/pyqtorch/apply.py index aed8bc0b..29bae46e 100644 --- a/pyqtorch/apply.py +++ b/pyqtorch/apply.py @@ -1,5 +1,6 @@ from __future__ import annotations +from math import log2 from string import ascii_letters as ABC from numpy import array @@ -7,7 +8,7 @@ from torch import Tensor, einsum from pyqtorch.matrices import _dagger -from pyqtorch.utils import DensityMatrix, permute_state +from pyqtorch.utils import DensityMatrix, permute_basis, permute_state ABC_ARRAY: NDArray = array(list(ABC)) @@ -88,45 +89,115 @@ def apply_operator_permute( return permute_state(result, support_perm, inv=True) -def apply_density_mat(op: Tensor, density_matrix: DensityMatrix) -> DensityMatrix: +def apply_operator_dm( + state: DensityMatrix, + operator: Tensor, + qubit_support: tuple[int, ...] | list[int], +) -> Tensor: """ - Apply an operator to a density matrix, i.e., compute: - op1 * density_matrix * op1_dagger. + Apply an operator to a density matrix on a given qubit suport, i.e., compute: + + OP.DM.OP.dagger() Args: - op (Tensor): The operator to apply. - density_matrix (DensityMatrix): The density matrix. + state: State to operate on. + operator: Tensor to contract over 'state'. + qubit_support: Tuple of qubits on which to apply the 'operator' to. Returns: - DensityMatrix: The resulting density matrix after applying the operator and its dagger. + DensityMatrix: The resulting density matrix after applying the operator. """ - batch_size_op = op.size(-1) - batch_size_dm = density_matrix.size(-1) - if batch_size_dm > batch_size_op: - # The other condition is impossible because - # operators are always initialized with batch_size = 1. - op = op.repeat(1, 1, batch_size_dm) - return einsum("ijb,jkb,klb->ilb", op, density_matrix, _dagger(op)) + if not isinstance(state, DensityMatrix): + raise TypeError("Function apply_operator_dm requires a density matrix state.") + + n_qubits = int(log2(state.size()[0])) + n_support = len(qubit_support) + batch_size = max(state.size(-1), operator.size(-1)) + full_support = tuple(range(n_qubits)) + support_perm = tuple(sorted(qubit_support)) + tuple( + set(full_support) - set(qubit_support) + ) + state = permute_basis(state, support_perm) -def operator_product(op1: Tensor, op2: Tensor) -> Tensor: + # There is probably a smart way to represent the lines below in a single einsum... + state = state.reshape( + [2**n_support, (2 ** (2 * n_qubits - n_support)), state.size(-1)] + ) + state = einsum("ijb,jkb->ikb", operator, state).reshape( + [2**n_qubits, 2**n_qubits, batch_size] + ) + state = _dagger(state).reshape( + [2**n_support, (2 ** (2 * n_qubits - n_support)), state.size(-1)] + ) + state = _dagger( + einsum("ijb,jkb->ikb", operator, state).reshape( + [2**n_qubits, 2**n_qubits, batch_size] + ) + ) + return permute_basis(state, support_perm, inv=True) + + +def operator_product( + op1: Tensor, + supp1: tuple[int, ...], + op2: Tensor, + supp2: tuple[int, ...], +) -> Tensor: """ - Compute the product of two operators. - `torch.bmm` is not suitable for our purposes because, - in our convention, the batch_size is in the last dimension. + Operator product based on block matrix multiplication. - Args: - op1 (Tensor): The first operator. - op2 (Tensor): The second operator. - Returns: - Tensor: The product of the two operators. + E.g., for some small operator S and a big operator with 4 partitions A, B, C, D: + + |S 0|.|A B| = |S.A S.B| + |0 S| |C D| |S.C S.D| + + or + + |A B|.|S 0| = |A.S B.S| + |C D|.|0 S| |C.S D.S| + + The same generalizes for different sizes of the big operator. Note that the block + diagonal matrix is never computed. Instead, the big operator is permuted and + reshaped into a wide matrix: + + W = [A B C D] + + And then the result is computed as S.W, reshaped back into a square matrix, and + permuted back into the original ordering. """ - # ? Should we continue to adjust the batch here? - # ? as now all gates are init with batch_size=1. - batch_size_1 = op1.size(-1) - batch_size_2 = op2.size(-1) - if batch_size_1 > batch_size_2: - op2 = op2.repeat(1, 1, batch_size_1)[:, :, :batch_size_1] - elif batch_size_2 > batch_size_1: - op1 = op1.repeat(1, 1, batch_size_2)[:, :, :batch_size_2] - return einsum("ijb,jkb->ikb", op1, op2) + + if supp1 == supp2: + return einsum("ijb,jkb->ikb", op1, op2) + + if len(supp1) < len(supp2): + small_op, small_supp = op1, supp1 + big_op, big_supp = op2, supp2 + transpose = False + else: + small_op, small_supp = _dagger(op2), supp2 + big_op, big_supp = _dagger(op1), supp1 + transpose = True + + if not set(small_supp).issubset(set(big_supp)): + raise ValueError("Operator product requires overlapping qubit supports.") + + n_big, n_small = len(big_supp), len(small_supp) + batch_big, batch_small = big_op.size(-1), small_op.size(-1) + batch_size = max(batch_big, batch_small) + + # Permute the large operator and reshape into a wide matrix + support_perm = tuple(sorted(small_supp)) + tuple(set(big_supp) - set(small_supp)) + big_op = permute_basis(big_op, support_perm) + big_op = big_op.reshape([2**n_small, (2 ** (2 * n_big - n_small)), batch_big]) + + # Compute S.W and reshape back to square + result = einsum("ijb,jkb->ikb", small_op, big_op).reshape( + [2**n_big, 2**n_big, batch_size] + ) + + # Apply the inverse qubit permutation + if transpose: + return _dagger(permute_basis(result, support_perm, inv=True)) + else: + return permute_basis(result, support_perm, inv=True) diff --git a/pyqtorch/noise/gates.py b/pyqtorch/noise/gates.py index 4cb87a57..08098d2e 100644 --- a/pyqtorch/noise/gates.py +++ b/pyqtorch/noise/gates.py @@ -1,12 +1,12 @@ from __future__ import annotations -from math import log2, sqrt +from math import sqrt from typing import Any import torch from torch import Tensor -from pyqtorch.apply import apply_density_mat +from pyqtorch.apply import apply_operator_dm from pyqtorch.embed import Embedding from pyqtorch.matrices import DEFAULT_MATRIX_DTYPE, IMAT, XMAT, YMAT, ZMAT from pyqtorch.utils import DensityMatrix, density_mat, qubit_support_as_tuple @@ -29,16 +29,14 @@ def __init__( self.error_probabilities: tuple[float, ...] | float = error_probabilities def extra_repr(self) -> str: - return f"target: {self.qubit_support}, prob: {self.probabilities}" + return f"target: {self.qubit_support}, prob: {self.error_probabilities}" @property def kraus_operators(self) -> list[Tensor]: return [getattr(self, f"kraus_{i}") for i in range(len(self._buffers))] - def _tensor( + def tensor( self, - values: dict[str, Tensor] | Tensor = dict(), - embedding: Embedding | None = None, ) -> list[Tensor]: # Since PyQ expects tensor.Size = [2**n_qubits, 2**n_qubits,batch_size]. return [kraus_op.unsqueeze(2) for kraus_op in self.kraus_operators] @@ -69,10 +67,9 @@ def forward( """ if not isinstance(state, DensityMatrix): state = density_mat(state) - n_qubits = int(log2(state.size(1))) rho_evols: list[Tensor] = [] - for kraus in self.tensor(values, n_qubits): - rho_evol: Tensor = apply_density_mat(kraus, state) + for kraus in self.tensor(): + rho_evol: Tensor = apply_operator_dm(state, kraus, self.qubit_support) rho_evols.append(rho_evol) rho_final: Tensor = torch.stack(rho_evols, dim=0) rho_final = torch.sum(rho_final, dim=0) @@ -92,32 +89,6 @@ def to(self, *args: Any, **kwargs: Any) -> Noise: self._dtype = self.kraus_0.dtype return self - def tensor( - self, values: dict[str, Tensor] = dict(), n_qubits: int = 1 - ) -> list[Tensor]: - block_mats = self._tensor(values) - mats = [] - for blockmat in block_mats: - if n_qubits == 1: - mats.append(blockmat) - else: - full_sup = tuple(i for i in range(n_qubits)) - support = tuple(sorted(self.qubit_support)) - mat = ( - IMAT.clone().to(self.device, self.dtype).unsqueeze(2) - if support[0] != full_sup[0] - else blockmat - ) - for i in full_sup[1:]: - if i == support[0]: - other = blockmat - mat = torch.kron(mat.contiguous(), other.contiguous()) - elif i not in support: - other = IMAT.clone().to(self.device, self.dtype).unsqueeze(2) - mat = torch.kron(mat.contiguous(), other.contiguous()) - mats.append(mat) - return mats - class BitFlip(Noise): """ diff --git a/pyqtorch/quantum_operation.py b/pyqtorch/quantum_operation.py index 99fe50c8..f9c8660a 100644 --- a/pyqtorch/quantum_operation.py +++ b/pyqtorch/quantum_operation.py @@ -9,7 +9,7 @@ import torch from torch import Tensor -from pyqtorch.apply import apply_density_mat, apply_operator +from pyqtorch.apply import apply_operator, apply_operator_dm from pyqtorch.embed import Embedding from pyqtorch.matrices import _dagger from pyqtorch.noise import NoiseProtocol, _repr_noise @@ -344,10 +344,8 @@ def _forward( embedding: Embedding | None = None, ) -> Tensor: if isinstance(state, DensityMatrix): - n_qubits = int(log2(state.size(1))) - full_support = tuple(range(n_qubits)) - return apply_density_mat( - self.tensor(values, embedding, full_support=full_support), state + return apply_operator_dm( + state, self.tensor(values, embedding), self.qubit_support ) else: return apply_operator( @@ -362,13 +360,14 @@ def _noise_forward( values: dict[str, Tensor] | Tensor = dict(), embedding: Embedding | None = None, ) -> Tensor: + if not isinstance(state, DensityMatrix): state = density_mat(state) - n_qubits = int(log2(state.size(1))) - full_support = tuple(range(n_qubits)) - state = apply_density_mat( - self.tensor(values, embedding, full_support=full_support), state + + state = apply_operator_dm( + state, self.tensor(values, embedding), self.qubit_support ) + if isinstance(self.noise, dict): for noise_instance in self.noise.values(): protocol = noise_instance.protocol_to_gate() @@ -382,6 +381,7 @@ def _noise_forward( ) state = noise_gate(state, values) return state + elif isinstance(self.noise, NoiseProtocol): protocol = self.noise.protocol_to_gate() noise_gate = protocol( diff --git a/tests/test_noise.py b/tests/test_noise.py index 67c4b1b8..5a0cbb64 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from pyqtorch.apply import apply_density_mat, operator_product +from pyqtorch.apply import apply_operator_dm, operator_product from pyqtorch.circuit import QuantumCircuit from pyqtorch.matrices import ( HMAT, @@ -72,20 +72,22 @@ def test_operator_product( random_unitary_gate: Primitive | Parametric, n_qubits: int, ) -> None: - batch_size_1 = torch.randint(low=1, high=5, size=(1,)).item() - batch_size_2 = torch.randint(low=1, high=5, size=(1,)).item() - max_batch = max(batch_size_2, batch_size_1) + batch_size = torch.randint(low=2, high=5, size=(1,)).item() values = {"theta": torch.rand(1)} - op = random_unitary_gate.tensor(values=values, full_support=tuple(range(n_qubits))) + full_support = tuple(range(n_qubits)) + op = random_unitary_gate.tensor(values=values, full_support=full_support) op_mul = operator_product( - op.repeat(1, 1, batch_size_1), _dagger(op.repeat(1, 1, batch_size_2)) + op1=op.repeat(1, 1, batch_size), + supp1=full_support, + op2=_dagger(op), + supp2=full_support, ) - assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, max_batch]) + assert op_mul.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size]) assert torch.allclose( op_mul, torch.eye(2**n_qubits, dtype=torch.cdouble) .unsqueeze(2) - .repeat(1, 1, max_batch), + .repeat(1, 1, batch_size), ) @@ -97,11 +99,14 @@ def test_apply_density_mat( random_input_dm: DensityMatrix, ) -> None: values = {"theta": torch.rand(1)} - op = random_unitary_gate.tensor(values=values, full_support=tuple(range(n_qubits))) + full_support = tuple(range(n_qubits)) + op = random_unitary_gate rho = random_input_dm - rho_evol = apply_density_mat(op, rho) + op_mat = op.tensor(values=values) + rho_evol = apply_operator_dm(rho, op_mat, op.qubit_support) assert rho_evol.size() == torch.Size([2**n_qubits, 2**n_qubits, batch_size]) - rho_expected = operator_product(op, operator_product(rho, _dagger(op))) + mul1 = operator_product(rho, full_support, _dagger(op_mat), op.qubit_support) + rho_expected = operator_product(op_mat, op.qubit_support, mul1, full_support) assert torch.allclose(rho_evol, rho_expected) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index f98ca8ae..fee44fbb 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -32,6 +32,7 @@ from pyqtorch.utils import ( ATOL, RTOL, + density_mat, permute_basis, random_state, ) @@ -39,12 +40,17 @@ pi = torch.tensor(torch.pi) +@pytest.mark.parametrize("use_dm", [True, False]) @pytest.mark.parametrize("use_permute", [True, False]) @pytest.mark.parametrize("use_full_support", [True, False]) @pytest.mark.parametrize("n_qubits", [4, 5]) @pytest.mark.parametrize("batch_size", [1, 5]) def test_digital_tensor( - n_qubits: int, batch_size: int, use_full_support: bool, use_permute: bool + n_qubits: int, + batch_size: int, + use_full_support: bool, + use_permute: bool, + use_dm: bool, ) -> None: """ Goes through all non-parametric gates and tests their application to a random state @@ -56,20 +62,30 @@ def test_digital_tensor( supp = get_op_support(op, n_qubits) op_concrete = op(*supp) psi_init = random_state(n_qubits, batch_size) - psi_star = op_concrete(psi_init) + if use_dm: + psi_star = op_concrete(density_mat(psi_init)) + else: + psi_star = op_concrete(psi_init) full_support = tuple(range(n_qubits)) if use_full_support else None psi_expected = calc_mat_vec_wavefunction( op_concrete, psi_init, full_support=full_support, use_permute=use_permute ) + if use_dm: + psi_expected = density_mat(psi_expected) assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL) +@pytest.mark.parametrize("use_dm", [True, False]) @pytest.mark.parametrize("use_permute", [True, False]) @pytest.mark.parametrize("use_full_support", [True, False]) @pytest.mark.parametrize("n_qubits", [4, 5]) @pytest.mark.parametrize("batch_size", [1, 5]) def test_param_tensor( - n_qubits: int, batch_size: int, use_full_support: bool, use_permute: bool + n_qubits: int, + batch_size: int, + use_full_support: bool, + use_permute: bool, + use_dm: bool, ) -> None: """ Goes through all parametric gates and tests their application to a random state @@ -83,7 +99,10 @@ def test_param_tensor( op_concrete = op(*supp, *params) psi_init = random_state(n_qubits) values = {param: torch.rand(batch_size) for param in params} - psi_star = op_concrete(psi_init, values) + if use_dm: + psi_star = op_concrete(density_mat(psi_init), values) + else: + psi_star = op_concrete(psi_init, values) full_support = tuple(range(n_qubits)) if use_full_support else None psi_expected = calc_mat_vec_wavefunction( op_concrete, @@ -92,6 +111,8 @@ def test_param_tensor( full_support=full_support, use_permute=use_permute, ) + if use_dm: + psi_expected = density_mat(psi_expected) assert torch.allclose(psi_star, psi_expected, rtol=RTOL, atol=ATOL)