Skip to content

Commit

Permalink
improved rvs implementations (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski authored Feb 9, 2024
1 parent 3ac5b8e commit 89dc80f
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 11 deletions.
3 changes: 1 addition & 2 deletions src/numba_stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def _ppf(p, loc, scale):
@_rvs_jit(2)
def _rvs(loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, loc, scale)
return loc + np.random.exponential(scale, size)


_generate_wrappers(globals())
3 changes: 1 addition & 2 deletions src/numba_stats/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def _ppf(p, loc, scale):
@_rvs_jit(2)
def _rvs(loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, loc, scale)
return np.random.laplace(loc, scale, size)


_generate_wrappers(globals())
3 changes: 1 addition & 2 deletions src/numba_stats/lognorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _ppf(p, s, loc, scale):
@_rvs_jit(3, cache=False)
def _rvs(s, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, s, loc, scale)
return loc + scale * np.random.lognormal(0, s, size)


_generate_wrappers(globals())
5 changes: 2 additions & 3 deletions src/numba_stats/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def _ppf(p, df, loc, scale):
return scale * r + loc


@_rvs_jit(3, cache=False)
@_rvs_jit(3)
def _rvs(df, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, df, loc, scale)
return loc + scale * np.random.standard_t(df, size)


_generate_wrappers(globals())
3 changes: 1 addition & 2 deletions src/numba_stats/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def _ppf(p, a, w):
@_rvs_jit(2)
def _rvs(a, w, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, a, w)
return np.random.uniform(a, a + w, size)


_generate_wrappers(globals())

0 comments on commit 89dc80f

Please sign in to comment.