Skip to content

Commit

Permalink
Revert to having State.sample() as the abstractmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
HGSilveri committed Dec 12, 2024
1 parent 495c325 commit ab1dfab
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 102 deletions.
126 changes: 27 additions & 99 deletions pulser-core/pulser/backend/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Mapping, Sequence
from itertools import product
from typing import Any, Generic, Literal, Protocol, Type, TypeVar, Union
from collections.abc import Sequence
from typing import Generic, Literal, Type, TypeVar, Union

import numpy as np

Expand All @@ -31,17 +30,6 @@
StateType = TypeVar("StateType", bound="State")


class _ProbabilityType(Protocol):
"""A protocol for probability types.
Defines only the methods needed to correctly type hint State.
"""

def __float__(self) -> float: ...

def __add__(self, other: Any) -> float: ...


class State(ABC, Generic[ArgScalarType, ReturnScalarType]):
"""Base class enforcing an API for quantum states.
Expand All @@ -67,21 +55,35 @@ def eigenstates(self) -> tuple[Eigenstate, ...]:
"""
return tuple(self._eigenstates)

@property
def basis_states(self) -> tuple[str, ...]:
"""The basis states combinations, in order."""
return tuple(
map(
"".join,
product("".join(self.eigenstates), repeat=self.n_qudits),
)
)

@property
def qudit_dim(self) -> int:
"""The dimensions (ie number of eigenstates) of a qudit."""
return len(self.eigenstates)

def get_basis_state_from_index(self, index: int) -> str:
"""Generates a basis state combination from its index in the state.
Assumes a state vector representation regardless of the actual support
of the state.
Args:
index: The position of the state in a state vector.
Returns:
The basis state combination for the desired index.
"""
if index < 0:
raise ValueError(
f"'index' must be a non-negative integer;"
f" got {index} instead."
)
return "".join(
self.eigenstates[int(dig)]
for dig in np.base_repr(index, base=self.qudit_dim).zfill(
self.n_qudits
)
)

@abstractmethod
def overlap(self: StateType, other: StateType, /) -> ReturnScalarType:
"""Compute the overlap between this state and another of the same type.
Expand All @@ -98,49 +100,6 @@ def overlap(self: StateType, other: StateType, /) -> ReturnScalarType:
pass

@abstractmethod
def probabilities(
self, *, cutoff: float = 1e-10
) -> Mapping[str, _ProbabilityType]:
"""Extracts the probabilties of measuring each basis state combination.
Args:
cutoff: The value below which a probability is considered to be
zero.
Returns:
A mapping between basis state combinations and their respective
probabilities.
"""
pass

def bitstring_probabilities(
self, *, one_state: Eigenstate | None = None, cutoff: float = 1e-10
) -> Mapping[str, _ProbabilityType]:
"""Extracts the probabilties of measuring each bitstring.
Args:
one_state: The eigenstate that measures to 1. Can be left undefined
if the state's eigenstates form a known eigenbasis with a
defined "one state".
cutoff: The value below which a probability is considered to be
zero.
Returns:
A mapping between bitstrings and their respective probabilities.
"""
one_state = one_state or self.infer_one_state()
zero_states = set(self.eigenstates) - {one_state}
probs = self.probabilities(cutoff=cutoff)
bitstring_probs: dict[str, _ProbabilityType] = {}
for state_str in probs:
bitstring = state_str.replace(one_state, "1")
for s_ in zero_states:
bitstring = bitstring.replace(s_, "0")
# Avoid defaultdict for typing reasons
curr_val = bitstring_probs.setdefault(bitstring, 0.0)
bitstring_probs[bitstring] = probs[state_str] + curr_val
return dict(bitstring_probs)

def sample(
self,
*,
Expand All @@ -162,38 +121,7 @@ def sample(
Returns:
The measured bitstrings, by count.
"""
bitstring_probs = self.bitstring_probabilities(
one_state=one_state, cutoff=1 / (1000 * num_shots)
)
bitstrings = np.array(list(bitstring_probs))
probs = np.array(list(map(float, bitstring_probs.values())))
dist = np.random.multinomial(num_shots, probs)
# Filter out bitstrings without counts
non_zero_counts = dist > 0
bitstrings = bitstrings[non_zero_counts]
dist = dist[non_zero_counts]
if p_false_pos == 0.0 and p_false_neg == 0.0:
return Counter(dict(zip(bitstrings, dist)))

# Convert bitstrings to a 2D array
bitstr_arr = np.array([list(bs) for bs in bitstrings], dtype=int)
# If 1 is measured, flip_prob=p_false_neg else flip_prob=p_false_pos
flip_probs = np.where(bitstr_arr == 1, p_false_neg, p_false_pos)
# Repeat flip_probs of a bitstring as many times as it was measured
flip_probs_repeated = np.repeat(flip_probs, dist, axis=0)
# Generate random matrix of same shape
random_matrix = np.random.uniform(size=flip_probs_repeated.shape)
# Compare random matrix with flip probabilities to get the flips
flips = random_matrix < flip_probs_repeated
# Apply the flips with an XOR between original array and flips
new_bitstrings = bitstr_arr.repeat(dist, axis=0) ^ flips

# Count all the new_bitstrings
# Not converting to str right away because tuple indexing is faster
new_counts: Counter = Counter(map(tuple, new_bitstrings))
return Counter(
{"".join(map(str, k)): v for k, v in new_counts.items()}
)
pass

@classmethod
@abstractmethod
Expand Down
93 changes: 90 additions & 3 deletions pulser-simulation/pulser_simulation/qutip_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import math
from collections import Counter, defaultdict
from collections.abc import Collection, Mapping, Sequence
from typing import SupportsComplex, Type, TypeVar

Expand Down Expand Up @@ -64,6 +65,10 @@ def n_qudits(self) -> int:
"""The number of qudits in the state."""
return round(math.log(self._state.shape[0], self.qudit_dim))

def to_qobj(self) -> qutip.Qobj:
"""Returns a copy of the state's Qobj representation."""
return self._state.copy()

def overlap(self, other: QutipState) -> float:
"""Compute the overlap between this state and another of the same type.
Expand All @@ -88,7 +93,7 @@ def overlap(self, other: QutipState) -> float:
raise ValueError(
"Can't calculate the overlap between a state with "
f"{self.n_qudits} {self.qudit_dim}-dimensional qudits and "
f"another with {self.n_qudits} {self.qudit_dim}-dimensional"
f"another with {other.n_qudits} {other.qudit_dim}-dimensional "
"qudits."
)
if self.eigenstates != other.eigenstates:
Expand Down Expand Up @@ -119,9 +124,91 @@ def probabilities(self, *, cutoff: float = 1e-10) -> dict[str, float]:
probs = np.abs(self._state.diag())
else:
probs = (np.abs(self._state.full()) ** 2).flatten()
non_zero = probs > cutoff
non_zero = np.argwhere(probs > cutoff).flatten()
return dict(
zip(np.array(self.basis_states)[non_zero], probs[non_zero])
zip(
map(self.get_basis_state_from_index, non_zero), probs[non_zero]
)
)

def bitstring_probabilities(
self, *, one_state: Eigenstate | None = None, cutoff: float = 1e-10
) -> Mapping[str, float]:
"""Extracts the probabilties of measuring each bitstring.
Args:
one_state: The eigenstate that measures to 1. Can be left undefined
if the state's eigenstates form a known eigenbasis with a
defined "one state".
cutoff: The value below which a probability is considered to be
zero.
Returns:
A mapping between bitstrings and their respective probabilities.
"""
one_state = one_state or self.infer_one_state()
zero_states = set(self.eigenstates) - {one_state}
probs = self.probabilities(cutoff=cutoff)
bitstring_probs: dict[str, float] = defaultdict(float)
for state_str in probs:
bitstring = state_str.replace(one_state, "1")
for s_ in zero_states:
bitstring = bitstring.replace(s_, "0")
bitstring_probs[bitstring] += probs[state_str]
return dict(bitstring_probs)

def sample(
self,
*,
num_shots: int,
one_state: Eigenstate | None = None,
p_false_pos: float = 0.0,
p_false_neg: float = 0.0,
) -> Counter[str]:
"""Sample bitstrings from the state, taking into account error rates.
Args:
num_shots: How many bitstrings to sample.
one_state: The eigenstate that measures to 1. Can be left undefined
if the state's eigenstates form a known eigenbasis with a
defined "one state".
p_false_pos: The rate at which a 0 is read as a 1.
p_false_neg: The rate at which a 1 is read as a 0.
Returns:
The measured bitstrings, by count.
"""
bitstring_probs = self.bitstring_probabilities(
one_state=one_state, cutoff=1 / (1000 * num_shots)
)
bitstrings = np.array(list(bitstring_probs))
probs = np.array(list(map(float, bitstring_probs.values())))
dist = np.random.multinomial(num_shots, probs)
# Filter out bitstrings without counts
non_zero_counts = dist > 0
bitstrings = bitstrings[non_zero_counts]
dist = dist[non_zero_counts]
if p_false_pos == 0.0 and p_false_neg == 0.0:
return Counter(dict(zip(bitstrings, dist)))

# Convert bitstrings to a 2D array
bitstr_arr = np.array([list(bs) for bs in bitstrings], dtype=int)
# If 1 is measured, flip_prob=p_false_neg else flip_prob=p_false_pos
flip_probs = np.where(bitstr_arr == 1, p_false_neg, p_false_pos)
# Repeat flip_probs of a bitstring as many times as it was measured
flip_probs_repeated = np.repeat(flip_probs, dist, axis=0)
# Generate random matrix of same shape
random_matrix = np.random.uniform(size=flip_probs_repeated.shape)
# Compare random matrix with flip probabilities to get the flips
flips = random_matrix < flip_probs_repeated
# Apply the flips with an XOR between original array and flips
new_bitstrings = bitstr_arr.repeat(dist, axis=0) ^ flips

# Count all the new_bitstrings
# Not converting to str right away because tuple indexing is faster
new_counts: Counter = Counter(map(tuple, new_bitstrings))
return Counter(
{"".join(map(str, k)): v for k, v in new_counts.items()}
)

@classmethod
Expand Down

0 comments on commit ab1dfab

Please sign in to comment.