Skip to content

Commit

Permalink
first typed up version
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg committed Aug 13, 2022
1 parent 137fc26 commit b8a7eab
Showing 1 changed file with 96 additions and 54 deletions.
150 changes: 96 additions & 54 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,74 @@
from scipy import special
from scipy.stats import norm, poisson

from typing import TypeVar, Callable, Literal, Sequence, Generic, Tuple, Mapping
from numpy.typing import (
NDArray,
NBitBase,
ArrayLike,
)

Shape = Tuple[int, ...]

log = logging.getLogger(__name__)

T = TypeVar("T", bound=NBitBase)
Tensor = NDArray[np.floating[T]]


class _BasicPoisson:
def __init__(self, rate):
class _BasicPoisson(Generic[T]):
def __init__(self, rate: Tensor[T]):
self.rate = rate

def sample(self, sample_shape):
def sample(self, sample_shape: NDArray[np.integer[T]]) -> Tensor[T]:
return poisson(self.rate).rvs(size=sample_shape + self.rate.shape)

def log_prob(self, value):
def log_prob(self, value: NDArray[np.number[T]]) -> Tensor[T]:
tensorlib = numpy_backend()
return tensorlib.poisson_logpdf(value, self.rate)


class _BasicNormal:
def __init__(self, loc, scale):
class _BasicNormal(Generic[T]):
def __init__(self, loc: NDArray[np.number[T]], scale: NDArray[np.number[T]]):
self.loc = loc
self.scale = scale

def sample(self, sample_shape):
def sample(self, sample_shape: NDArray[np.integer[T]]) -> Tensor[T]:
return norm(self.loc, self.scale).rvs(size=sample_shape + self.loc.shape)

def log_prob(self, value):
def log_prob(self, value: NDArray[np.number[T]]) -> Tensor[T]:
tensorlib = numpy_backend()
return tensorlib.normal_logpdf(value, self.loc, self.scale)


class numpy_backend:
class numpy_backend(Generic[T]):
"""NumPy backend for pyhf"""

__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']

def __init__(self, **kwargs):
def __init__(self, **kwargs: dict[str, str]):
self.name = 'numpy'
self.precision = kwargs.get('precision', '64b')
self.dtypemap = {
self.dtypemap: Mapping[
Literal['float', 'int', 'bool'], np.floating[T] | np.integer[T] | np.bool_
] = {
'float': np.float64 if self.precision == '64b' else np.float32,
'int': np.int64 if self.precision == '64b' else np.int32,
'bool': np.bool_,
}
self.default_do_grad = False

def _setup(self):
def _setup(self) -> None:
"""
Run any global setups for the numpy lib.
"""

def clip(self, tensor_in, min_value, max_value):
def clip(
self,
tensor_in: Tensor[T],
min_value: np.integer[T] | np.floating[T],
max_value: np.integer[T] | np.floating[T],
) -> Tensor[T]:
"""
Clips (limits) the tensor values to be within a specified min and max.
Expand All @@ -76,7 +94,7 @@ def clip(self, tensor_in, min_value, max_value):
"""
return np.clip(tensor_in, min_value, max_value)

def erf(self, tensor_in):
def erf(self, tensor_in: Tensor[T]) -> Tensor[T]:
"""
The error function of complex argument.
Expand All @@ -96,7 +114,7 @@ def erf(self, tensor_in):
"""
return special.erf(tensor_in)

def erfinv(self, tensor_in):
def erfinv(self, tensor_in: Tensor[T]) -> Tensor[T]:
"""
The inverse of the error function of complex argument.
Expand All @@ -116,7 +134,7 @@ def erfinv(self, tensor_in):
"""
return special.erfinv(tensor_in)

def tile(self, tensor_in, repeats):
def tile(self, tensor_in: Tensor[T], repeats: int | Sequence[int]) -> Tensor[T]:
"""
Repeat tensor data along a specific dimension
Expand All @@ -138,7 +156,12 @@ def tile(self, tensor_in, repeats):
"""
return np.tile(tensor_in, repeats)

def conditional(self, predicate, true_callable, false_callable):
def conditional(
self,
predicate: NDArray[np.bool_],
true_callable: Callable[[], Tensor[T]],
false_callable: Callable[[], Tensor[T]],
) -> Tensor[T]:
"""
Runs a callable conditional on the boolean value of the evaluation of a predicate
Expand All @@ -162,27 +185,29 @@ def conditional(self, predicate, true_callable, false_callable):
"""
return true_callable() if predicate else false_callable()

def tolist(self, tensor_in):
def tolist(self, tensor_in: ArrayLike) -> list[T]:
try:
return tensor_in.tolist()
except AttributeError:
if isinstance(tensor_in, list):
return tensor_in
raise

def outer(self, tensor_in_1, tensor_in_2):
def outer(self, tensor_in_1: Tensor[T], tensor_in_2: Tensor[T]) -> Tensor[T]:
return np.outer(tensor_in_1, tensor_in_2)

def gather(self, tensor, indices):
def gather(self, tensor: Tensor[T], indices: NDArray[np.integer[T]]) -> Tensor[T]:
return tensor[indices]

def boolean_mask(self, tensor, mask):
def boolean_mask(self, tensor: Tensor[T], mask: NDArray[np.bool_]) -> Tensor[T]:
return tensor[mask]

def isfinite(self, tensor):
def isfinite(self, tensor: Tensor[T]) -> NDArray[np.bool_]:
return np.isfinite(tensor)

def astensor(self, tensor_in, dtype='float'):
def astensor(
self, tensor_in: ArrayLike, dtype: Literal['float'] = 'float'
) -> Tensor[T]:
"""
Convert to a NumPy array.
Expand Down Expand Up @@ -213,18 +238,20 @@ def astensor(self, tensor_in, dtype='float'):

return np.asarray(tensor_in, dtype=dtype)

def sum(self, tensor_in, axis=None):
def sum(self, tensor_in: Tensor[T], axis: int | None = None) -> Tensor[T]:
return np.sum(tensor_in, axis=axis)

def product(self, tensor_in, axis=None):
def product(self, tensor_in: Tensor[T], axis: int | None = None) -> Tensor[T]:
return np.product(tensor_in, axis=axis)

def abs(self, tensor):
def abs(self, tensor: Tensor[T]) -> Tensor[T]:
return np.abs(tensor)

def ones(self, shape, dtype="float"):
def ones(
self, shape: Shape, dtype_str: Literal["float", "int", "bool"] = "float"
) -> Tensor[T]:
try:
dtype = self.dtypemap[dtype]
dtype = self.dtypemap[dtype_str]
except KeyError:
log.error(
f"Invalid dtype: dtype must be one of {list(self.dtypemap.keys())}.",
Expand All @@ -234,9 +261,11 @@ def ones(self, shape, dtype="float"):

return np.ones(shape, dtype=dtype)

def zeros(self, shape, dtype="float"):
def zeros(
self, shape: Shape, dtype_str: Literal["float", "int", "bool"] = "float"
) -> Tensor[T]:
try:
dtype = self.dtypemap[dtype]
dtype = self.dtypemap[dtype_str]
except KeyError:
log.error(
f"Invalid dtype: dtype must be one of {list(self.dtypemap.keys())}.",
Expand All @@ -246,22 +275,30 @@ def zeros(self, shape, dtype="float"):

return np.zeros(shape, dtype=dtype)

def power(self, tensor_in_1, tensor_in_2):
def power(self, tensor_in_1: Tensor[T], tensor_in_2: Tensor[T]) -> Tensor[T]:
return np.power(tensor_in_1, tensor_in_2)

def sqrt(self, tensor_in):
def sqrt(self, tensor_in: Tensor[T]) -> Tensor[T]:
return np.sqrt(tensor_in)

def divide(self, tensor_in_1, tensor_in_2):
def divide(self, tensor_in_1: Tensor[T], tensor_in_2: Tensor[T]) -> Tensor[T]:
return np.divide(tensor_in_1, tensor_in_2)

def log(self, tensor_in):
def log(self, tensor_in: Tensor[T]) -> Tensor[T]:
return np.log(tensor_in)

def exp(self, tensor_in):
def exp(self, tensor_in: Tensor[T]) -> Tensor[T]:
return np.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
def percentile(
self,
tensor_in: Tensor[T],
q: Tensor[T],
axis: None | Shape = None,
interpolation: Literal[
"linear", "lower", "higher", "midpoint", "nearest"
] = "linear",
) -> Tensor[T]:
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Expand Down Expand Up @@ -297,15 +334,18 @@ def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
NumPy ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation)
# see numpy/numpy#22125
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation) # type: ignore[call-overload]

def stack(self, sequence, axis=0):
def stack(self, sequence: Sequence[Tensor[T]], axis: int = 0) -> Tensor[T]:
return np.stack(sequence, axis=axis)

def where(self, mask, tensor_in_1, tensor_in_2):
def where(
self, mask: NDArray[np.bool_], tensor_in_1: Tensor[T], tensor_in_2: Tensor[T]
) -> Tensor[T]:
return np.where(mask, tensor_in_1, tensor_in_2)

def concatenate(self, sequence, axis=0):
def concatenate(self, sequence: Tensor[T], axis: None | int = 0) -> Tensor[T]:
"""
Join a sequence of arrays along an existing axis.
Expand All @@ -319,7 +359,7 @@ def concatenate(self, sequence, axis=0):
"""
return np.concatenate(sequence, axis=axis)

def simple_broadcast(self, *args):
def simple_broadcast(self, *args: Sequence[Tensor[T]]) -> Sequence[Tensor[T]]:
"""
Broadcast a sequence of 1 dimensional arrays.
Expand All @@ -341,13 +381,13 @@ def simple_broadcast(self, *args):
"""
return np.broadcast_arrays(*args)

def shape(self, tensor):
def shape(self, tensor: Tensor[T]) -> Shape:
return tensor.shape

def reshape(self, tensor, newshape):
def reshape(self, tensor: Tensor[T], newshape: Shape) -> Tensor[T]:
return np.reshape(tensor, newshape)

def ravel(self, tensor):
def ravel(self, tensor: Tensor[T]) -> Tensor[T]:
"""
Return a flattened view of the tensor, not a copy.
Expand All @@ -367,7 +407,7 @@ def ravel(self, tensor):
"""
return np.ravel(tensor)

def einsum(self, subscripts, *operands):
def einsum(self, subscripts: str, *operands: Sequence[Tensor[T]]) -> Tensor[T]:
"""
Evaluates the Einstein summation convention on the operands.
Expand All @@ -386,10 +426,10 @@ def einsum(self, subscripts, *operands):
"""
return np.einsum(subscripts, *operands)

def poisson_logpdf(self, n, lam):
def poisson_logpdf(self, n: Tensor[T], lam: Tensor[T]) -> Tensor[T]:
return xlogy(n, lam) - lam - gammaln(n + 1.0)

def poisson(self, n, lam):
def poisson(self, n: Tensor[T], lam: Tensor[T]) -> Tensor[T]:
r"""
The continuous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
to the probability mass function of the Poisson distribution evaluated
Expand Down Expand Up @@ -433,7 +473,7 @@ def poisson(self, n, lam):
lam = np.asarray(lam)
return np.exp(xlogy(n, lam) - lam - gammaln(n + 1.0))

def normal_logpdf(self, x, mu, sigma):
def normal_logpdf(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> Tensor[T]:
# this is much faster than
# norm.logpdf(x, loc=mu, scale=sigma)
# https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
Expand All @@ -446,7 +486,7 @@ def normal_logpdf(self, x, mu, sigma):
# def normal_logpdf(self, x, mu, sigma):
# return norm.logpdf(x, loc=mu, scale=sigma)

def normal(self, x, mu, sigma):
def normal(self, x: Tensor[T], mu: Tensor[T], sigma: Tensor[T]) -> Tensor[T]:
r"""
The probability density function of the Normal distribution evaluated
at :code:`x` given parameters of mean of :code:`mu` and standard deviation
Expand Down Expand Up @@ -474,7 +514,9 @@ def normal(self, x, mu, sigma):
"""
return norm.pdf(x, loc=mu, scale=sigma)

def normal_cdf(self, x, mu=0, sigma=1):
def normal_cdf(
self, x: Tensor[T], mu: float | Tensor[T] = 0, sigma: float | Tensor[T] = 1
) -> Tensor[T]:
"""
The cumulative distribution function for the Normal distribution
Expand All @@ -498,7 +540,7 @@ def normal_cdf(self, x, mu=0, sigma=1):
"""
return norm.cdf(x, loc=mu, scale=sigma)

def poisson_dist(self, rate):
def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson[T]:
r"""
The Poisson distribution with rate parameter :code:`rate`.
Expand All @@ -519,7 +561,7 @@ def poisson_dist(self, rate):
"""
return _BasicPoisson(rate)

def normal_dist(self, mu, sigma):
def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal[T]:
r"""
The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
Expand All @@ -543,7 +585,7 @@ def normal_dist(self, mu, sigma):
"""
return _BasicNormal(mu, sigma)

def to_numpy(self, tensor_in):
def to_numpy(self, tensor_in: Tensor[T]) -> Tensor[T]:
"""
Return the input tensor as it already is a :class:`numpy.ndarray`.
This API exists only for ``pyhf.tensorlib`` compatibility.
Expand Down Expand Up @@ -571,7 +613,7 @@ def to_numpy(self, tensor_in):
"""
return tensor_in

def transpose(self, tensor_in):
def transpose(self, tensor_in: Tensor[T]) -> Tensor[T]:
"""
Transpose the tensor.
Expand Down

0 comments on commit b8a7eab

Please sign in to comment.