Skip to content

Commit

Permalink
rm tensor inputs in apply_counts
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Oct 21, 2024
1 parent 089c309 commit 2f5735b
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions pyqtorch/noise/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor
from torch.distributions import normal, poisson, uniform

from pyqtorch.utils import OrderedCounter, counts_to_orderedcounter
from pyqtorch.utils import OrderedCounter


class WhiteNoise(Enum):
Expand Down Expand Up @@ -271,36 +271,22 @@ def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor:
return output_probs

def apply_on_counts(
self, counters: list[Counter | OrderedCounter] | Tensor, n_shots: int = 1000
) -> list[Counter] | Tensor:
"""Apply readout on counters represented as Counters or Tensors.
self, counters: list[Counter | OrderedCounter], n_shots: int = 1000
) -> list[Counter]:
"""Apply readout on counters represented as Counters.
Args:
counters (list[Counter | OrderedCounter] | Tensor): Samples of bit string as Counters.
counters (list[Counter | OrderedCounter]): Samples of bit string as Counters.
n_shots (int, optional): Number of shots to sample. Defaults to 1000.
Returns:
list[Counter] | Tensor: Samples of corrupted bit strings
list[Counter]: Samples of corrupted bit strings
"""
noise_matrix = self.create_noise_matrix(n_shots, False)
err_idx = torch.as_tensor(noise_matrix < self.error_probability) # type: ignore[operator]

corrupted_bitstrings = []
if isinstance(counters, Tensor):
for bincount in counters:
counter = counts_to_orderedcounter(bincount, self.n_qubits)
sample = sample_to_matrix(counter)
corrupted_counter = bs_corruption(err_idx=err_idx, sample=sample)
corrupted_bincount = torch.zeros_like(bincount)
for bitstring, count in corrupted_counter.items():
idx = int(bitstring, 2)
corrupted_bincount[idx] = count
corrupted_bitstrings.append(corrupted_bincount)
corrupted_bitstrings = torch.stack(corrupted_bitstrings)
else:
for counter in counters: # type: ignore[assignment]
sample = sample_to_matrix(counter)
corrupted_bitstrings.append(
bs_corruption(err_idx=err_idx, sample=sample)
)
for counter in counters: # type: ignore[assignment]
sample = sample_to_matrix(counter)
corrupted_bitstrings.append(bs_corruption(err_idx=err_idx, sample=sample))
return corrupted_bitstrings

0 comments on commit 2f5735b

Please sign in to comment.