From 2f5735b6c871fb2f4f9159fb433a62930632c707 Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Mon, 21 Oct 2024 11:55:40 +0200 Subject: [PATCH] rm tensor inputs in apply_counts --- pyqtorch/noise/readout.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/pyqtorch/noise/readout.py b/pyqtorch/noise/readout.py index 5c1108c1..aa415198 100644 --- a/pyqtorch/noise/readout.py +++ b/pyqtorch/noise/readout.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.distributions import normal, poisson, uniform -from pyqtorch.utils import OrderedCounter, counts_to_orderedcounter +from pyqtorch.utils import OrderedCounter class WhiteNoise(Enum): @@ -271,36 +271,22 @@ def apply_on_probas(self, batch_probs: Tensor, n_shots: int = 1000) -> Tensor: return output_probs def apply_on_counts( - self, counters: list[Counter | OrderedCounter] | Tensor, n_shots: int = 1000 - ) -> list[Counter] | Tensor: - """Apply readout on counters represented as Counters or Tensors. + self, counters: list[Counter | OrderedCounter], n_shots: int = 1000 + ) -> list[Counter]: + """Apply readout on counters represented as Counters. Args: - counters (list[Counter | OrderedCounter] | Tensor): Samples of bit string as Counters. + counters (list[Counter | OrderedCounter]): Samples of bit string as Counters. n_shots (int, optional): Number of shots to sample. Defaults to 1000. Returns: - list[Counter] | Tensor: Samples of corrupted bit strings + list[Counter]: Samples of corrupted bit strings """ noise_matrix = self.create_noise_matrix(n_shots, False) err_idx = torch.as_tensor(noise_matrix < self.error_probability) # type: ignore[operator] corrupted_bitstrings = [] - if isinstance(counters, Tensor): - for bincount in counters: - counter = counts_to_orderedcounter(bincount, self.n_qubits) - sample = sample_to_matrix(counter) - corrupted_counter = bs_corruption(err_idx=err_idx, sample=sample) - corrupted_bincount = torch.zeros_like(bincount) - for bitstring, count in corrupted_counter.items(): - idx = int(bitstring, 2) - corrupted_bincount[idx] = count - corrupted_bitstrings.append(corrupted_bincount) - corrupted_bitstrings = torch.stack(corrupted_bitstrings) - else: - for counter in counters: # type: ignore[assignment] - sample = sample_to_matrix(counter) - corrupted_bitstrings.append( - bs_corruption(err_idx=err_idx, sample=sample) - ) + for counter in counters: # type: ignore[assignment] + sample = sample_to_matrix(counter) + corrupted_bitstrings.append(bs_corruption(err_idx=err_idx, sample=sample)) return corrupted_bitstrings