From 42ee0e7a76bec8e62750dab5fa1ee52ca1956774 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Wed, 20 Nov 2024 13:32:14 +0100 Subject: [PATCH] singladispatchmethod --- pyqtorch/api.py | 2 +- pyqtorch/circuit.py | 2 +- pyqtorch/noise/readout.py | 75 +++++++++++++++++++++++++-------------- tests/test_readout.py | 4 +-- 4 files changed, 52 insertions(+), 31 deletions(-) diff --git a/pyqtorch/api.py b/pyqtorch/api.py index 6b01b6bd..9e2b530b 100644 --- a/pyqtorch/api.py +++ b/pyqtorch/api.py @@ -174,7 +174,7 @@ def sampled_expectation( eigvec_state_prod = torch.flatten(eigvec_state_prod, start_dim=0, end_dim=-2).t() probs = torch.pow(torch.abs(eigvec_state_prod), 2) if circuit.readout_noise is not None: - batch_samples = circuit.readout_noise.apply_on_probas(probs, n_shots) + batch_samples = circuit.readout_noise.apply(probs, n_shots) batch_sample_multinomial = torch.func.vmap( lambda p: sample_multinomial( diff --git a/pyqtorch/circuit.py b/pyqtorch/circuit.py index 2007733d..f48d521a 100644 --- a/pyqtorch/circuit.py +++ b/pyqtorch/circuit.py @@ -109,7 +109,7 @@ def sample( if self.readout_noise is None: return counters - return self.readout_noise.apply_on_counts(counters, n_shots) + return self.readout_noise.apply(counters, n_shots) class DropoutQuantumCircuit(QuantumCircuit): diff --git a/pyqtorch/noise/readout.py b/pyqtorch/noise/readout.py index bc8a48a4..ccf894b3 100644 --- a/pyqtorch/noise/readout.py +++ b/pyqtorch/noise/readout.py @@ -3,14 +3,13 @@ from abc import ABC from collections import Counter from enum import Enum +from functools import singledispatchmethod from math import log import torch from torch import Tensor from torch.distributions import normal, poisson, uniform -from pyqtorch.utils import OrderedCounter - class WhiteNoise(Enum): """White noise distributions.""" @@ -159,12 +158,8 @@ def create_confusion_matrices(noise_matrix: Tensor, error_probability: float) -> class ReadoutInterface(ABC): - def apply_on_probas(self, batch_probs, n_shots: int) -> Tensor: - raise NotImplementedError - - def apply_on_counts( - self, counters: list[Counter | OrderedCounter], n_shots - ) -> list[Counter]: + @singledispatchmethod + def apply(self, input_to_corrupt, n_shots: int): raise NotImplementedError @@ -282,25 +277,30 @@ def create_noise_matrix(self, n_shots: int) -> Tensor: self.confusion_matrix = confusion_matrices return noise_matrix - def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor: + @singledispatchmethod + def apply(self, input_to_corrupt, n_shots: int): + raise NotImplementedError + + @apply.register + def _(self, input_to_corrupt: Tensor, n_shots: int) -> Tensor: """Apply confusion matrix on probabilities. Args: - batch_probs (Tensor): Batch of probability vectors. - n_shots (int, optional): Number of shots. Defaults to 1000. + input_to_corrupt (Tensor): Batch of probability vectors. + n_shots (int, optional): Number of shots. Returns: Tensor: Corrupted probabilities. """ # Create binary representations - n_states = batch_probs.shape[1] + n_states = input_to_corrupt.shape[1] # Create binary representation of all states - state_indices = torch.arange(n_states, device=batch_probs.device) + state_indices = torch.arange(n_states, device=input_to_corrupt.device) binary_repr = ( state_indices.unsqueeze(1) - >> torch.arange(self.n_qubits - 1, -1, -1, device=batch_probs.device) + >> torch.arange(self.n_qubits - 1, -1, -1, device=input_to_corrupt.device) ) & 1 # Get input and output bits for all qubits at once @@ -314,21 +314,20 @@ def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor: # Index into confusion matrix for all qubits at once # Shape: (n_states_out, n_states_in, n_qubits) qubit_transitions = confusion_matrices[ - torch.arange(self.n_qubits, device=batch_probs.device), + torch.arange(self.n_qubits, device=input_to_corrupt.device), output_bits, input_bits, ] transition_matrix = torch.prod(qubit_transitions, dim=-1) - output_probs = torch.matmul(batch_probs, transition_matrix.T) + output_probs = torch.matmul(input_to_corrupt, transition_matrix.T) return output_probs - def apply_on_counts( - self, counters: list[Counter | OrderedCounter], n_shots: int = 1000 - ) -> list[Counter]: + @apply.register + def _(self, input_to_corrupt: list, n_shots: int) -> list[Counter]: """Apply readout on counters represented as Counters. Args: - counters (list[Counter | OrderedCounter]): Samples of bit string as Counters. + input_to_corrupt (list[Counter | OrderedCounter]): Samples of bit string as Counters. n_shots (int, optional): Number of shots to sample. Defaults to 1000. Returns: @@ -338,7 +337,7 @@ def apply_on_counts( err_idx = torch.as_tensor(noise_matrix < self.error_probability) corrupted_bitstrings = [] - for counter in counters: + for counter in input_to_corrupt: sample = sample_to_matrix(counter) corrupted_bitstrings.append( bs_bitflip_corruption(err_idx=err_idx, sample=sample) @@ -365,17 +364,39 @@ def __init__( self.n_qubits = int(log(confusion_matrix.size(0), 2)) self.seed = seed - def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor: - output_probs = batch_probs @ self.confusion_matrix.T + @singledispatchmethod + def apply(self, input_to_corrupt, n_shots: int): + raise NotImplementedError + + @apply.register + def _(self, input_to_corrupt: Tensor, n_shots: int) -> Tensor: + """Apply confusion matrix on probabilities. + + Args: + input_to_corrupt (Tensor): Batch of probability vectors. + n_shots (int, optional): Number of shots. + + Returns: + Tensor: Corrupted probabilities. + """ + output_probs = input_to_corrupt @ self.confusion_matrix.T return output_probs - def apply_on_counts( - self, counters: list[Counter | OrderedCounter], n_shots: int = 1000 - ) -> list[Counter]: + @apply.register + def _(self, input_to_corrupt: list, n_shots: int) -> list[Counter]: + """Apply readout on counters represented as Counters. + + Args: + input_to_corrupt (list[Counter | OrderedCounter]): Samples of bit string as Counters. + n_shots (int, optional): Number of shots to sample. Defaults to 1000. + + Returns: + list[Counter]: Samples of corrupted bit strings + """ if self.seed is not None: torch.manual_seed(self.seed) corrupted_bitstrings = [] - for counter in counters: + for counter in input_to_corrupt: sample = sample_to_matrix(counter) corrupted_bitstrings.append( bs_confusion_corruption(self.confusion_matrix, sample) diff --git a/tests/test_readout.py b/tests/test_readout.py index 9b693e3d..3ecd0194 100644 --- a/tests/test_readout.py +++ b/tests/test_readout.py @@ -126,7 +126,7 @@ def test_correlated_readout() -> None: corr_readout = CorrelatedReadoutNoise(confusion_matrix, 0) probas = torch.tensor([[0.4, 0.3, 0.2, 0.1]], dtype=torch.double) - out_probas = corr_readout.apply_on_probas(probas) + out_probas = corr_readout.apply(probas, n_shots=1000) assert torch.allclose( out_probas, torch.tensor([[0.3830, 0.2900, 0.2060, 0.1210]], dtype=torch.float64), @@ -192,7 +192,7 @@ def test_readout_apply_probas() -> None: n_shots = 1000 probas = torch.tensor([[0.4, 0.3, 0.2, 0.1]], dtype=torch.double) readobj = ReadoutNoise(2, seed=0) - out_probas = readobj.apply_on_probas(probas) + out_probas = readobj.apply(probas, n_shots=1000) assert torch.allclose(torch.sum(out_probas), torch.ones(1, dtype=out_probas.dtype)) assert torch.allclose(