Skip to content

Commit

Permalink
singladispatchmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Nov 20, 2024
1 parent c557ca2 commit 42ee0e7
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyqtorch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def sampled_expectation(
eigvec_state_prod = torch.flatten(eigvec_state_prod, start_dim=0, end_dim=-2).t()
probs = torch.pow(torch.abs(eigvec_state_prod), 2)
if circuit.readout_noise is not None:
batch_samples = circuit.readout_noise.apply_on_probas(probs, n_shots)
batch_samples = circuit.readout_noise.apply(probs, n_shots)

batch_sample_multinomial = torch.func.vmap(
lambda p: sample_multinomial(
Expand Down
2 changes: 1 addition & 1 deletion pyqtorch/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def sample(
if self.readout_noise is None:
return counters

return self.readout_noise.apply_on_counts(counters, n_shots)
return self.readout_noise.apply(counters, n_shots)


class DropoutQuantumCircuit(QuantumCircuit):
Expand Down
75 changes: 48 additions & 27 deletions pyqtorch/noise/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from abc import ABC
from collections import Counter
from enum import Enum
from functools import singledispatchmethod
from math import log

import torch
from torch import Tensor
from torch.distributions import normal, poisson, uniform

from pyqtorch.utils import OrderedCounter


class WhiteNoise(Enum):
"""White noise distributions."""
Expand Down Expand Up @@ -159,12 +158,8 @@ def create_confusion_matrices(noise_matrix: Tensor, error_probability: float) ->


class ReadoutInterface(ABC):
def apply_on_probas(self, batch_probs, n_shots: int) -> Tensor:
raise NotImplementedError

def apply_on_counts(
self, counters: list[Counter | OrderedCounter], n_shots
) -> list[Counter]:
@singledispatchmethod
def apply(self, input_to_corrupt, n_shots: int):
raise NotImplementedError


Expand Down Expand Up @@ -282,25 +277,30 @@ def create_noise_matrix(self, n_shots: int) -> Tensor:
self.confusion_matrix = confusion_matrices
return noise_matrix

def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor:
@singledispatchmethod
def apply(self, input_to_corrupt, n_shots: int):
raise NotImplementedError

@apply.register
def _(self, input_to_corrupt: Tensor, n_shots: int) -> Tensor:
"""Apply confusion matrix on probabilities.
Args:
batch_probs (Tensor): Batch of probability vectors.
n_shots (int, optional): Number of shots. Defaults to 1000.
input_to_corrupt (Tensor): Batch of probability vectors.
n_shots (int, optional): Number of shots.
Returns:
Tensor: Corrupted probabilities.
"""

# Create binary representations
n_states = batch_probs.shape[1]
n_states = input_to_corrupt.shape[1]

# Create binary representation of all states
state_indices = torch.arange(n_states, device=batch_probs.device)
state_indices = torch.arange(n_states, device=input_to_corrupt.device)
binary_repr = (
state_indices.unsqueeze(1)
>> torch.arange(self.n_qubits - 1, -1, -1, device=batch_probs.device)
>> torch.arange(self.n_qubits - 1, -1, -1, device=input_to_corrupt.device)
) & 1

# Get input and output bits for all qubits at once
Expand All @@ -314,21 +314,20 @@ def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor:
# Index into confusion matrix for all qubits at once
# Shape: (n_states_out, n_states_in, n_qubits)
qubit_transitions = confusion_matrices[
torch.arange(self.n_qubits, device=batch_probs.device),
torch.arange(self.n_qubits, device=input_to_corrupt.device),
output_bits,
input_bits,
]
transition_matrix = torch.prod(qubit_transitions, dim=-1)
output_probs = torch.matmul(batch_probs, transition_matrix.T)
output_probs = torch.matmul(input_to_corrupt, transition_matrix.T)
return output_probs

def apply_on_counts(
self, counters: list[Counter | OrderedCounter], n_shots: int = 1000
) -> list[Counter]:
@apply.register
def _(self, input_to_corrupt: list, n_shots: int) -> list[Counter]:
"""Apply readout on counters represented as Counters.
Args:
counters (list[Counter | OrderedCounter]): Samples of bit string as Counters.
input_to_corrupt (list[Counter | OrderedCounter]): Samples of bit string as Counters.
n_shots (int, optional): Number of shots to sample. Defaults to 1000.
Returns:
Expand All @@ -338,7 +337,7 @@ def apply_on_counts(
err_idx = torch.as_tensor(noise_matrix < self.error_probability)

corrupted_bitstrings = []
for counter in counters:
for counter in input_to_corrupt:
sample = sample_to_matrix(counter)
corrupted_bitstrings.append(
bs_bitflip_corruption(err_idx=err_idx, sample=sample)
Expand All @@ -365,17 +364,39 @@ def __init__(
self.n_qubits = int(log(confusion_matrix.size(0), 2))
self.seed = seed

def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor:
output_probs = batch_probs @ self.confusion_matrix.T
@singledispatchmethod
def apply(self, input_to_corrupt, n_shots: int):
raise NotImplementedError

@apply.register
def _(self, input_to_corrupt: Tensor, n_shots: int) -> Tensor:
"""Apply confusion matrix on probabilities.
Args:
input_to_corrupt (Tensor): Batch of probability vectors.
n_shots (int, optional): Number of shots.
Returns:
Tensor: Corrupted probabilities.
"""
output_probs = input_to_corrupt @ self.confusion_matrix.T
return output_probs

def apply_on_counts(
self, counters: list[Counter | OrderedCounter], n_shots: int = 1000
) -> list[Counter]:
@apply.register
def _(self, input_to_corrupt: list, n_shots: int) -> list[Counter]:
"""Apply readout on counters represented as Counters.
Args:
input_to_corrupt (list[Counter | OrderedCounter]): Samples of bit string as Counters.
n_shots (int, optional): Number of shots to sample. Defaults to 1000.
Returns:
list[Counter]: Samples of corrupted bit strings
"""
if self.seed is not None:
torch.manual_seed(self.seed)
corrupted_bitstrings = []
for counter in counters:
for counter in input_to_corrupt:
sample = sample_to_matrix(counter)
corrupted_bitstrings.append(
bs_confusion_corruption(self.confusion_matrix, sample)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_correlated_readout() -> None:

corr_readout = CorrelatedReadoutNoise(confusion_matrix, 0)
probas = torch.tensor([[0.4, 0.3, 0.2, 0.1]], dtype=torch.double)
out_probas = corr_readout.apply_on_probas(probas)
out_probas = corr_readout.apply(probas, n_shots=1000)
assert torch.allclose(
out_probas,
torch.tensor([[0.3830, 0.2900, 0.2060, 0.1210]], dtype=torch.float64),
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_readout_apply_probas() -> None:
n_shots = 1000
probas = torch.tensor([[0.4, 0.3, 0.2, 0.1]], dtype=torch.double)
readobj = ReadoutNoise(2, seed=0)
out_probas = readobj.apply_on_probas(probas)
out_probas = readobj.apply(probas, n_shots=1000)

assert torch.allclose(torch.sum(out_probas), torch.ones(1, dtype=out_probas.dtype))
assert torch.allclose(
Expand Down

0 comments on commit 42ee0e7

Please sign in to comment.