Skip to content

Commit

Permalink
got the logcdf working without refactoring so the example code works
Browse files Browse the repository at this point in the history
Update CODE_OF_CONDUCT.md

Changed contact email

ChiSquared now returns a Gamma random variable (pymc-devs#7007)

Update devcontainer (pymc-devs#7017)

* Update dev container on release and schedule

* Use latest version of upstream dev container image

* Remove pre-commit cache hack

* Add Jupyter extension to container

Rename RVTransform to Transform

Remove duplicate ChainTransform

Move ZeroSumTransform methods inside respective class

Remove deprecated function rvs_to_value_vars

Merge functionality of pytensorf and logprob/utils

Also fixes circular imports

Deprecate unused function walk_model

update theme (pymc-devs#7018)

* update theme version

* update conf.py

Fix docs formatting in `shape_utils` (pymc-devs#7025)

* Small docs update shape_utils.py

Fix docs markup

* Update to_tuple docs shape_utils.py

* Types

---------

Co-authored-by: Denis Kataev <[email protected]>

Use PyTensor StudentT RV

Update GOVERNANCE.md (pymc-devs#7031)

* Update GOVERNANCE.md

* Update GOVERNANCE.md

Co-authored-by: Purna Chandra Mansingh <[email protected]>

---------

Co-authored-by: Purna Chandra Mansingh <[email protected]>

Remove deprecated model methods

Deprecate Model.model property

Deprecate pytensor_config

Bump Pytensor dependency to >=2.18.1,<2.19

Better float32 sampling support for TruncatedNormal (pymc-devs#7026)

* manually inv transform; force rng same type

* always upcast f64 and downcast to dtype of param

* add comment

* use class attr dtype

* need else stmt for dtype

* actually no need to downcast in this method

* rm unused import

Another small fixes in docs and imports (pymc-devs#7030)

* Remove old import from autosummary doc pages

* Fix small error

* Correct place for model_to_networkx & model_to_graphviz

* Correct currentmodule source for pytensorf.rst

* Update docs/source/api/model/core.rst

Co-authored-by: Thomas Wiecki <[email protected]>

* Fix sampling & jax docs imports

---------

Co-authored-by: Denis Kataev <[email protected]>
Co-authored-by: Thomas Wiecki <[email protected]>

Update installation instructions to remove incompatibility of JAX on Windows
  • Loading branch information
Luke LB committed Nov 29, 2023
1 parent 44e1c5e commit bf619e0
Show file tree
Hide file tree
Showing 49 changed files with 619 additions and 1,073 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/devcontainer-docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: devcontainer-docker-image

on:
workflow_dispatch:
schedule:
- cron: "48 19 * * 5" # Fridays at 19:48 UTC
release:
types: [published]

env:
REGISTRY: ghcr.io
Expand Down
1 change: 1 addition & 0 deletions .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ vscode:
- eamodio.gitlens
- ms-python.python
- ms-pyright.pyright
- ms-toolsai.jupyter
- donjayamanne.githistory

github:
Expand Down
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ further defined and clarified by project maintainers.

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting PyMC developer Christopher Fonnesbeck via email
(chris.fonnesbeck@vanderbilt.edu) or phone (615-955-0380). Alternatively, you
(fonnesbeck@gmail.com) or phone (615-955-0380). Alternatively, you
may also contact NumFOCUS Executive Director Leah Silen (512-222-5449), as PyMC
is a member of NumFOCUS and subscribes to their code of conduct as a
precondition for continued membership. All complaints will be reviewed and
Expand Down
1 change: 1 addition & 0 deletions GOVERNANCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ Contributors don't need to be part of any dedicated team.
* Michael Osthege (dev)
* Oriol Abril-Pla (docs, community)
* Osvaldo Martin (dev, docs)
* Purna Chandra Mansingh (community)
* Ravin Kumar (dev, community, docs)
* Reshama Shaikh (community - PyMC Labs)
* Ricardo Vieira (dev, community)
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- scipy>=1.4.1
- typing-extensions>=3.7.4
Expand All @@ -24,7 +24,7 @@ dependencies:
- numpydoc
- polyagamma
- pre-commit>=2.8.0
- pymc-sphinx-theme==0.13
- pymc-sphinx-theme==0.14
- sphinx-copybutton
- sphinx-design
- sphinx-notfound-page
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- pytensor>=2.18.1,<2.19
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
4 changes: 0 additions & 4 deletions docs/source/api/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ Functions exposed in pymc.math
sgn
ceil
floor
det
matrix_inverse
extract_diag
matrix_dot
trace
sigmoid
logsumexp
invlogit
Expand Down
13 changes: 11 additions & 2 deletions docs/source/api/model/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ Model creation and inspection
:toctree: generated/

Model
model_to_graphviz
model_to_networkx
modelcontext

Others
Expand All @@ -22,3 +20,14 @@ Others
set_data
Point
compile_fn


Graph visualization
-------------------

.. currentmodule:: pymc.model_graph
.. autosummary::
:toctree: generated/

model_to_networkx
model_to_graphviz
2 changes: 1 addition & 1 deletion docs/source/api/pytensorf.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PyTensor utils
**************

.. currentmodule:: pymc
.. currentmodule:: pymc.pytensorf

.. autosummary::
:toctree: generated/
Expand Down
30 changes: 22 additions & 8 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,39 @@ Samplers
This submodule contains functions for MCMC and forward sampling.


.. currentmodule:: pymc
.. currentmodule:: pymc.sampling.forward

.. autosummary::
:toctree: generated/

sample
sample_prior_predictive
sample_posterior_predictive
sample_posterior_predictive_w
sampling.jax.sample_blackjax_nuts
sampling.jax.sample_numpyro_nuts
init_nuts
draw


.. currentmodule:: pymc.sampling.mcmc

.. autosummary::
:toctree: generated/

sample
init_nuts

.. currentmodule:: pymc.sampling.jax

.. autosummary::
:toctree: generated/

sample_blackjax_nuts
sample_numpyro_nuts


Step methods
************

.. currentmodule:: pymc

HMC family
----------
.. currentmodule:: pymc.step_methods.hmc

.. autosummary::
:toctree: generated/
Expand All @@ -34,6 +46,7 @@ HMC family

Metropolis family
-----------------
.. currentmodule:: pymc.step_methods

.. autosummary::
:toctree: generated/
Expand All @@ -53,6 +66,7 @@ Metropolis family

Other step methods
------------------
.. currentmodule:: pymc.step_methods

.. autosummary::
:toctree: generated/
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
myst_substitutions = {
"version_slug": rtd_version,
}
myst_heading_anchors = None
myst_heading_anchors = 0

v3_example_tutorials = [
"case_studies/BEST",
Expand Down
4 changes: 0 additions & 4 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ Similarly, to use BlackJAX sampler instead:
conda install blackjax
```

Note that JAX is not directly supported on Windows systems at the moment.

## Nutpie sampling

You can also enable sampling with [nutpie](https://github.com/pymc-devs/nutpie).
Expand All @@ -41,5 +39,3 @@ Nutpie uses numba as the compiler and a sampler written in Rust for faster perfo
```console
conda install -c conda-forge nutpie
```

Unlike JAX, nutpie is directly supported on Windows.
93 changes: 55 additions & 38 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
BetaRV,
_gamma,
cauchy,
chisquare,
exponential,
gumbel,
halfcauchy,
Expand All @@ -49,6 +48,7 @@
lognormal,
normal,
pareto,
t,
triangular,
uniform,
vonmises,
Expand Down Expand Up @@ -571,11 +571,13 @@ def rng_fn(
upper: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
# Upcast to float64. (Caller will downcast to desired dtype if needed)
# (Work-around for https://github.com/scipy/scipy/issues/15928)
return stats.truncnorm.rvs(
a=(lower - mu) / sigma,
b=(upper - mu) / sigma,
loc=mu,
scale=sigma,
a=((lower - mu) / sigma).astype("float64"),
b=((upper - mu) / sigma).astype("float64"),
loc=(mu).astype("float64"),
scale=(sigma).astype("float64"),
size=size,
random_state=rng,
)
Expand Down Expand Up @@ -744,6 +746,42 @@ def logp(value, mu, sigma, lower, upper):

return logp

def logcdf(value, mu, sigma, lower, upper):
logcdf = normal_lccdf(mu, sigma, value)
lower_logcdf = normal_lccdf(mu, sigma, lower)
upper_logcdf = normal_lccdf(mu, sigma, upper)

is_lower_bounded = not (
isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))
)
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))

lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf

logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
logcdf_trunc = logcdf_numerator - lognorm

if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)

if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)

if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_trunc,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logcdf_trunc


@_default_transform.register(TruncatedNormal)
def truncated_normal_default_transform(op, rv):
Expand Down Expand Up @@ -1744,21 +1782,6 @@ def icdf(value, mu, sigma):
Lognormal = LogNormal


class StudentTRV(RandomVariable):
name = "studentt"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
_print_name = ("StudentT", "\\operatorname{StudentT}")

@classmethod
def rng_fn(cls, rng, nu, mu, sigma, size=None) -> np.ndarray:
return np.asarray(stats.t.rvs(nu, mu, sigma, size=size, random_state=rng))


studentt = StudentTRV()


class StudentT(Continuous):
r"""
Student's T log-likelihood.
Expand Down Expand Up @@ -1823,7 +1846,7 @@ class StudentT(Continuous):
with pm.Model():
x = pm.StudentT('x', nu=15, mu=0, lam=1/23)
"""
rv_op = studentt
rv_op = t

@classmethod
def dist(cls, nu, mu=0, *, sigma=None, lam=None, **kwargs):
Expand Down Expand Up @@ -2384,16 +2407,21 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(PositiveContinuous):
class ChiSquared:
r"""
:math:`\chi^2` log-likelihood.
This is the distribution from the sum of the squares of :math:`\nu` independent standard normal random variables or a special
case of the gamma distribution with :math:`\alpha = \nu/2` and :math:`\beta = 1/2`.
The pdf of this distribution is
.. math::
f(x \mid \nu) = \frac{x^{(\nu-2)/2}e^{-x/2}}{2^{\nu/2}\Gamma(\nu/2)}
Read more about the :math:`\chi^2` distribution at https://en.wikipedia.org/wiki/Chi-squared_distribution
.. plot::
:context: close-figs
Expand Down Expand Up @@ -2423,24 +2451,13 @@ class ChiSquared(PositiveContinuous):
nu : tensor_like of float
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

@classmethod
def dist(cls, nu, *args, **kwargs):
nu = pt.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def moment(rv, size, nu):
moment = nu
if not rv_size_is_none(size):
moment = pt.full(size, moment)
return moment
def __new__(cls, name, nu, **kwargs):
return Gamma(name, alpha=nu / 2, beta=1 / 2, **kwargs)

def logp(value, nu):
return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)

def logcdf(value, nu):
return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
@classmethod
def dist(cls, nu, **kwargs):
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)


# TODO: Remove this once logp for multiplication is working!
Expand Down
Loading

0 comments on commit bf619e0

Please sign in to comment.