Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to optimize prior under constraints #5231

Merged
merged 54 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8ca3ded
Replace print statement by AttributeError
AlexAndorra Nov 30, 2021
9dc0096
pre-commit formatting
AlexAndorra Nov 30, 2021
9675e4f
Mention in release notes
AlexAndorra Nov 30, 2021
d132364
Handle 1-param and 3-param distributions
AlexAndorra Dec 1, 2021
6f9ccd4
Update tests
AlexAndorra Dec 1, 2021
fea6643
Fix some wording
AlexAndorra Dec 1, 2021
524a900
pre-commit formatting
AlexAndorra Dec 3, 2021
91174b9
Only raise UserWarning when mass_in_interval not optimal
AlexAndorra Dec 3, 2021
29741f1
Raise NotImplementedError for non-scalar params
AlexAndorra Dec 3, 2021
1ad4297
Remove pipe operator for old python versions
AlexAndorra Dec 3, 2021
a708e6d
Update tests
AlexAndorra Dec 3, 2021
e1c5125
Add test with discrete distrib & wrap in pytest.warns(None)
AlexAndorra Dec 7, 2021
bc9b543
Remove pipe operator for good
AlexAndorra Dec 7, 2021
18ad975
Fix TypeError in dist_params
AlexAndorra Dec 7, 2021
e92d6d8
Relax tolerance for tests
AlexAndorra Dec 7, 2021
94b406b
Force float64 config in find_optim_prior
AlexAndorra Dec 14, 2021
76dbb1f
Rename file name to func_utils.py
AlexAndorra Dec 14, 2021
53bfc00
Replace print statement by AttributeError
AlexAndorra Nov 30, 2021
77a0bb1
pre-commit formatting
AlexAndorra Nov 30, 2021
fd5f498
Mention in release notes
AlexAndorra Nov 30, 2021
171a4aa
Handle 1-param and 3-param distributions
AlexAndorra Dec 1, 2021
36b95cb
Update tests
AlexAndorra Dec 1, 2021
55138d9
Fix some wording
AlexAndorra Dec 1, 2021
4bed2cd
pre-commit formatting
AlexAndorra Dec 3, 2021
02d117b
Only raise UserWarning when mass_in_interval not optimal
AlexAndorra Dec 3, 2021
7742571
Raise NotImplementedError for non-scalar params
AlexAndorra Dec 3, 2021
8a6e0e7
Remove pipe operator for old python versions
AlexAndorra Dec 3, 2021
602391b
Update tests
AlexAndorra Dec 3, 2021
9bb14a3
Add test with discrete distrib & wrap in pytest.warns(None)
AlexAndorra Dec 7, 2021
ab0ef0f
Remove pipe operator for good
AlexAndorra Dec 7, 2021
58f5d56
Fix TypeError in dist_params
AlexAndorra Dec 7, 2021
a6c7f0d
Relax tolerance for tests
AlexAndorra Dec 7, 2021
c9c24d6
Force float64 config in find_optim_prior
AlexAndorra Dec 14, 2021
c75f8c9
Rename file name to func_utils.py
AlexAndorra Dec 14, 2021
3ffd7ff
Change optimization error function and refactor tests
ricardoV94 Dec 16, 2021
a1a6bdf
Use aesaraf.compile_pymc
ricardoV94 Dec 21, 2021
7cd0e55
Merge branch 'optim-prior' of https://github.com/pymc-devs/pymc into …
AlexAndorra Dec 22, 2021
1d868fa
Add and test AssertionError for mass value
AlexAndorra Dec 22, 2021
063bc96
Fix type error in warning message
AlexAndorra Dec 23, 2021
cb7908c
Split up Poisson test
AlexAndorra Dec 24, 2021
16ed438
Use scipy default for Exponential and reactivate tests
AlexAndorra Dec 24, 2021
1b84e18
Refactor Poisson tests
AlexAndorra Dec 24, 2021
6ea7861
Reduce Poisson test tol to 1% for float32
AlexAndorra Dec 25, 2021
d63b652
Remove Exponential logic
AlexAndorra Dec 27, 2021
37e6251
Rename function
AlexAndorra Dec 27, 2021
b912ac6
Refactor test functions names
AlexAndorra Dec 28, 2021
d4bce39
Use more precise exception for gradient
AlexAndorra Dec 30, 2021
9a51289
Don't catch TypeError
AlexAndorra Jan 3, 2022
90a88ff
Merge branch 'main' into optim-prior
AlexAndorra Jan 3, 2022
8b9ae6e
Remove specific Poisson test
AlexAndorra Jan 3, 2022
d53154a
Remove typo from old Poisson test
AlexAndorra Jan 3, 2022
1f42835
Put tests for constrained priors into their own file
AlexAndorra Jan 3, 2022
bad236c
Add code examples in docstrings
AlexAndorra Jan 4, 2022
d89e375
Merge branch 'main' into optim-prior
AlexAndorra Jan 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
- Improve sampling, increase default number of particles [5229](https://github.com/pymc-devs/pymc3/pull/5229).
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
- ...
- The new `pm.find_constrained_prior` function can be used to find optimized prior parameters of a distribution under some
twiecki marked this conversation as resolved.
Show resolved Hide resolved
constraints (e.g lower and upper bound). See [#5231](https://github.com/pymc-devs/pymc/pull/5231).


### Internal changes
Expand Down
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __set_compiler_flags():
from pymc.distributions import *
from pymc.distributions import transforms
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.math import (
expand_packed_triangular,
invlogit,
Expand Down
138 changes: 138 additions & 0 deletions pymc/func_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2021 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import Dict, Optional

import aesara.tensor as aet
import numpy as np

from scipy import optimize

import pymc as pm

__all__ = ["find_constrained_prior"]


def find_constrained_prior(
distribution: pm.Distribution,
lower: float,
upper: float,
init_guess: Dict[str, float],
mass: float = 0.95,
fixed_params: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""
Find optimal parameters to get `mass` % of probability
of `pm_dist` between `lower` and `upper`.
Note: only works for one- and two-parameter distributions, as there
are exactly two constraints. Fix some combination of parameters
if you want to use it on >=3-parameter distributions.

Parameters
----------
distribution : pm.Distribution
PyMC distribution you want to set a prior on.
Needs to have a ``logcdf`` method implemented in PyMC.
lower : float
Lower bound to get `mass` % of probability of `pm_dist`.
upper : float
Upper bound to get `mass` % of probability of `pm_dist`.
init_guess: Dict[str, float]
Initial guess for ``scipy.optimize.least_squares`` to find the
optimal parameters of `pm_dist` fitting the interval constraint.
Must be a dictionary with the name of the PyMC distribution's
parameter as keys and the initial guess as values.
mass: float, default to 0.95
Share of the probability mass we want between ``lower`` and ``upper``.
Defaults to 95%.
fixed_params: Dict[str, float], Optional, default None
Only used when `pm_dist` has at least three parameters.
Dictionary of fixed parameters, so that there are only 2 to optimize.
For instance, for a StudenT, you fix nu to a constant and get the optimized
mu and sigma.

Returns
-------
The optimized distribution parameters as a dictionary with the parameters'
name as key and the optimized value as value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow up PR (not to be sadistic with this one): we should add a code example in the docstrings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not preaching for my choir here, but I actually should add that here. Don't merge in the meantime. Will ping when done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done @ricardoV94 -- we can merge once tests pass

"""
assert 0.01 <= mass <= 0.99, (
"This function optimizes the mass of the given distribution +/- "
f"1%, so `mass` has to be between 0.01 and 0.99. You provided {mass}."
)

# exit when any parameter is not scalar:
if np.any(np.asarray(distribution.rv_op.ndims_params) != 0):
raise NotImplementedError(
"`pm.find_constrained_prior` does not work with non-scalar parameters yet.\n"
"Feel free to open a pull request on PyMC repo if you really need this feature."
)

dist_params = aet.vector("dist_params")
params_to_optim = {
arg_name: dist_params[i] for arg_name, i in zip(init_guess.keys(), range(len(init_guess)))
}

if fixed_params is not None:
params_to_optim.update(fixed_params)

dist = distribution.dist(**params_to_optim)

try:
logcdf_lower = pm.logcdf(dist, pm.floatX(lower))
logcdf_upper = pm.logcdf(dist, pm.floatX(upper))
except AttributeError:
raise AttributeError(
f"You cannot use `find_constrained_prior` with {distribution} -- it doesn't have a logcdf "
"method yet.\nOpen an issue or, even better, a pull request on PyMC repo if you really "
"need it."
)

cdf_error = (pm.math.exp(logcdf_upper) - pm.math.exp(logcdf_lower)) - mass
cdf_error_fn = pm.aesaraf.compile_pymc([dist_params], cdf_error, allow_input_downcast=True)

try:
aesara_jac = pm.gradient(cdf_error, [dist_params])
jac = pm.aesaraf.compile_pymc([dist_params], aesara_jac, allow_input_downcast=True)
# when PyMC cannot compute the gradient
# TODO: use specific gradient, not implemented exception
except Exception:
Copy link
Member

@ricardoV94 ricardoV94 Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the exception we want here is the one found in aesara.gradient.NullTypeGradError as well as NotImplementedError

Suggested change
except Exception:
except (NotImplementedError, NullTyppeGradError):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception that's thrown is a TypeError by the Aesara grad method (line 501) in aesara/gradient.py:

if cost is not None and cost.ndim != 0:
>           raise TypeError("cost must be a scalar.")

I added that error in the Except clause and tests pass locally. aesara.gradient.NullTypeGradError and NotImplementedError don't seem to be raised but I kept them in case they are by other cases we may have forgotten

Copy link
Member

@ricardoV94 ricardoV94 Dec 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't catch that TypeError. That means we produced a wrong input to aesara grad.

The other two exceptions are the ones that (should) appear when a grad is not implemented for an Op.

Copy link
Contributor Author

@AlexAndorra AlexAndorra Jan 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means we should make these two exceptions appear then, shouldn't we? Because they are not raised right now -- only the TypeError is raised
(here is to my first GH comment of the year 🥂 )

Copy link
Member

@ricardoV94 ricardoV94 Jan 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means we should make these two exceptions appear then, shouldn't we? Because they are not raised right now -- only the TypeError is raised
(here is to my first GH comment of the year 🥂 )

Those two exceptions mean there is no grad implemented for some Op in the cdf which can very well happen and its a good reason to silently default to the scipy approximation. The TypeError, on the other hand, should not be catched, as that means we did something wrong.

In fact there was a point during this PR when it was always silently defaulting to the scipy approximation because we were passing two values to grad and suppressing the TypeError.

Copy link
Member

@ricardoV94 ricardoV94 Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be the case. Locally it is passing in float64/float32 and 1e-5 precision, so we don't need to have the separate test just for the Poisson anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha ha damn, I just pushed without that change. Tests are indeed passing locally. Gonna refactor the tests and push again

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hot damn, tests are passing locally 🔥 Pushed!
Why does the symbolic gradien help so much with numerical errors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the logcdf uses gammaincc whose gradient is notoriously tricky. We somewhat recently added a numerically stable(r) implementation to Aesara

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jac = "2-point"

opt = optimize.least_squares(cdf_error_fn, x0=list(init_guess.values()), jac=jac)
if not opt.success:
raise ValueError("Optimization of parameters failed.")

# save optimal parameters
opt_params = {
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
}
if fixed_params is not None:
opt_params.update(fixed_params)

# check mass in interval is not too far from `mass`
opt_dist = distribution.dist(**opt_params)
mass_in_interval = (
pm.math.exp(pm.logcdf(opt_dist, upper)) - pm.math.exp(pm.logcdf(opt_dist, lower))
).eval()
if (np.abs(mass_in_interval - mass)) > 0.01:
warnings.warn(
f"Final optimization has {(mass_in_interval if mass_in_interval.ndim < 1 else mass_in_interval[0])* 100:.0f}% of probability mass between "
f"{lower} and {upper} instead of the requested {mass * 100:.0f}%.\n"
"You may need to use a more flexible distribution, change the fixed parameters in the "
"`fixed_params` dictionary, or provide better initial guesses."
)

return opt_params
127 changes: 127 additions & 0 deletions pymc/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,130 @@ def fn(a=UNSET):
help(fn)
captured = capsys.readouterr()
assert "a=UNSET" in captured.out


@pytest.mark.parametrize(
"distribution, lower, upper, init_guess, fixed_params",
[
(pm.Gamma, 0.1, 0.4, {"alpha": 1, "beta": 10}, {}),
(pm.Normal, 155, 180, {"mu": 170, "sigma": 3}, {}),
(pm.StudentT, 0.1, 0.4, {"mu": 10, "sigma": 3}, {"nu": 7}),
(pm.StudentT, 0, 1, {"mu": 5, "sigma": 2, "nu": 7}, {}),
# (pm.Exponential, 0, 1, {"lam": 1}, {}), PyMC Exponential gradient is failing miserably, need to figure out why
(pm.HalfNormal, 0, 1, {"sigma": 1}, {}),
(pm.Binomial, 0, 8, {"p": 0.5}, {"n": 10}),
],
)
@pytest.mark.parametrize("mass", [0.5, 0.75, 0.95])
def test_find_constrained_prior(distribution, lower, upper, init_guess, fixed_params, mass):
with pytest.warns(None) as record:
opt_params = pm.find_constrained_prior(
distribution,
lower=lower,
upper=upper,
mass=mass,
init_guess=init_guess,
fixed_params=fixed_params,
)
assert len(record) == 0
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

opt_distribution = distribution.dist(**opt_params)
mass_in_interval = (
pm.math.exp(pm.logcdf(opt_distribution, upper))
- pm.math.exp(pm.logcdf(opt_distribution, lower))
).eval()
assert np.abs(mass_in_interval - mass) <= 1e-5


# test Poisson separately -- hard to optimize precisely when float32
@pytest.mark.parametrize(
"lower, upper, init_guess",
[
(1, 15, {"mu": 10}),
(19, 41, {"mu": 30}),
],
)
def test_constrained_prior_poisson(lower, upper, init_guess):
distribution = pm.Poisson
mass = 0.95
with pytest.warns(None) as record:
opt_params = pm.find_constrained_prior(
distribution,
lower=lower,
upper=upper,
init_guess=init_guess,
)
assert len(record) == 0

opt_distribution = distribution.dist(**opt_params)
mass_in_interval = (
pm.math.exp(pm.logcdf(opt_distribution, upper))
- pm.math.exp(pm.logcdf(opt_distribution, lower))
).eval()
assert np.abs(mass_in_interval - mass) <= 1e-2 # reduce to 1% tolerance for float32


@pytest.mark.parametrize(
"distribution, lower, upper, init_guess, fixed_params",
[
(pm.Gamma, 0.1, 0.4, {"alpha": 1}, {"beta": 10}),
(pm.Exponential, 0.1, 1, {"lam": 1}, {}),
(pm.Binomial, 0, 2, {"p": 0.8}, {"n": 10}),
],
)
def test_find_constrained_prior_error_too_large(
distribution, lower, upper, init_guess, fixed_params
):
with pytest.warns(UserWarning, match="instead of the requested 95%"):
pm.find_constrained_prior(
distribution,
lower=lower,
upper=upper,
mass=0.95,
init_guess=init_guess,
fixed_params=fixed_params,
)


def test_find_constrained_prior_input_errors():
# missing param
with pytest.raises(TypeError, match="required positional argument"):
pm.find_constrained_prior(
pm.StudentT,
lower=0.1,
upper=0.4,
mass=0.95,
init_guess={"mu": 170, "sigma": 3},
)

# mass too high
with pytest.raises(AssertionError, match="has to be between 0.01 and 0.99"):
pm.find_constrained_prior(
pm.StudentT,
lower=0.1,
upper=0.4,
mass=0.995,
init_guess={"mu": 170, "sigma": 3},
fixed_params={"nu": 7},
)

# mass too low
with pytest.raises(AssertionError, match="has to be between 0.01 and 0.99"):
pm.find_constrained_prior(
pm.StudentT,
lower=0.1,
upper=0.4,
mass=0.005,
init_guess={"mu": 170, "sigma": 3},
fixed_params={"nu": 7},
)

# non-scalar params
with pytest.raises(NotImplementedError, match="does not work with non-scalar parameters yet"):
pm.find_constrained_prior(
pm.MvNormal,
lower=0,
upper=1,
mass=0.95,
init_guess={"mu": 5, "cov": np.asarray([[1, 0.2], [0.2, 1]])},
)