Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test model logp before starting any MCMC chains #4211

Merged
merged 10 commits into from
Nov 27, 2020

Conversation

StephenHogg
Copy link
Contributor

This PR addresses #4116 - making find_MAP and sample check their starting conditions before running any chains. I probably need to work out what the linting settings this repo uses are because it seems like a fair bit of formatting has changed.

@michaelosthege
Copy link
Member

I'm sorry, but can you do the PR without applying black first? I appreciate the intent & it's fine to do in the end, but at this point it makes pinpointing the relevant changes close to impossible..

@MarcoGorelli
Copy link
Contributor

MarcoGorelli commented Nov 9, 2020

I'm sorry, but can you do the PR without applying black first? I appreciate the intent & it's fine to do in the end, but at this point it makes pinpointing the relevant changes close to impossible..

I think the opposite has happened - black has already been applied to the entire codebase, it looks like here they've reverted some of its changes - e.g., black wouldn't do this:

+ from .util import (chains_and_samples, dataset_to_point_dict, get_default_varnames, get_untransformed_name,
+                    is_transformed_name, update_start_vals)

@StephenHogg please see the Python Style guide for this repo

@twiecki twiecki changed the title PR to fix #4116 Test model logp before starting any MCMC chains Nov 9, 2020
@StephenHogg
Copy link
Contributor Author

Sorry about this - had auto-linting on in my GUI and didn't realise. Have a look now, hopefully it's clearer.

for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)

start_points = [start] if isinstance(start, dict) else start
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the downstream code treats a list of start as "start points for each chain", which could explain your index error.

@StephenHogg
Copy link
Contributor Author

StephenHogg commented Nov 10, 2020 via email

@michaelosthege
Copy link
Member

The error you posted in #4116 could also be a cause of invalid model test points. Could be that not all distributions have tests points.
Try to copy the model from the failing test case into a notebook to inspect it.

@StephenHogg
Copy link
Contributor Author

Just for clarity - what's the path forward here? Sorry for the bother

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think running this check is generally a good idea, but I think it needs to be put in at the right place. If you look at sampling.py, this change will get rid of a bunch of downstream code (there are some if start is None, but I don't think in a great way.

What if we had a _check_start_point function, and it got called at the end of init_nuts? I think it would contain model.check_test_point, and the nice error messages, but it would not do anything to the start argument passed to it.

pymc3/sampling.py Outdated Show resolved Hide resolved
pymc3/tuning/starting.py Outdated Show resolved Hide resolved
@ColCarroll
Copy link
Member

Sorry, I just read the attached issue, and it seems like that was steering @StephenHogg to put the changes where they are. Interested to hear if what you and @michaelosthege think!

@StephenHogg
Copy link
Contributor Author

As a first time contributor I defer to Michael! :)

@michaelosthege
Copy link
Member

As a first time contributor I defer to Michael! :)

I think Colin is right: The block could easily become its own function. That also makes it easier to test, or improve.

@StephenHogg
Copy link
Contributor Author

Ok - are you also saying the new function should be called at the end of init_nuts instead of where it is now?

@michaelosthege
Copy link
Member

Ok - are you also saying the new function should be called at the end of init_nuts instead of where it is now?

No, the NUTS initialization often suffers from inf/NaN and I think it's more useful to check that before initialization.

@ColCarroll
Copy link
Member

The "main path" logic right now is:

  • start goes into the function
  • init_nuts is called, and comes up with its own start_
    • usually this is by taking the test point, and adding Uniform(-1, 1) noise
    • start_ in init_nuts is used to initialize a mean for a running variance estimate
  • start_ is overwritten in the main loop whenever start is not None

I think Michael's right that start_ also needs to be checked, since it is secretly being used internally (which is bad, but shouldn't be fixed here). But looking at the code, I think it can be done right after returning from init_nuts, otherwise there would be a ton of checks.

Concretely, I'm suggesting

  • a function that checks a user-supplied start value very early (around line 460, where there's already a check on start)
  • again checking start_ after it returns from init_nuts (around line 490)
  • perhaps in a followup PR, having the function check other initialization schemes (there's an else branch for discrete models I've been ignoring...)

@StephenHogg
Copy link
Contributor Author

@michaelosthege any more thoughts on the above? Would like to make sure I'm clear about what I'm coding up before starting again

@michaelosthege
Copy link
Member

@StephenHogg listen to Colin on this one. He's much more literate in what the NUTS code is actually doing. With those checks in their own function, you can run them before & after NUTS initialization.
It costs a few model evaluations, but hey, this is MCMC sampling and it gives us more interpretable errors.

@StephenHogg
Copy link
Contributor Author

I've shifted this into a function called pm.util.check_start_vals. In the process, I've spotted (and fixed) a bug in model.check_test_point. The function's test_point argument was being ignored at all times.

@StephenHogg
Copy link
Contributor Author

StephenHogg commented Nov 14, 2020

Looking at the test output, it seems like a few other tests (e.g. test_nuts_error_reporting in test_hmc.py) are actually broken by the changes in this branch. Any guidance as to how to handle this appreciated.

@michaelosthege michaelosthege added this to the 3.10 milestone Nov 14, 2020
@StephenHogg
Copy link
Contributor Author

Here's the output I get from pytest at this point, if that helps. Some of these are a bit mystifying, as I'm not sure why I'd be getting a max recursion depth error on a test that I've not touched, for instance. Will push one more change to format the error string a bit more nicely, but after that I think I'm probably stuck for now.

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks nice! I took a look at most of the the test failures, and they're surprisingly helpful. Feel free to ping again if you need more help, but I think this is close:

  • Delete test_hmc.py:test_nuts_error_reporting. your check is a better one for the same behavior.
  • test_sampling.py:test_deterministic_of_observed looks like a flake. let's ignore that and hope it goes away. if it doesn't, make the rtol bigger.
  • test_examples.py::TestLatentOccupancy::test_run is interesting, and looks like a legit failure you found! In this case, the likelihood is passing parameters in the wrong order. It should be
    pm.ZeroInflatedPoisson("y", psi, theta, observed=y) (note that psi and theta are switched). I imagine it was passing because the multipart sampling got everything to a reasonable place.
  • Two failures in pymc3/tests/test_step.py can also be either deleted, or ported to the new exception you throw -- it looks like we have a SamplingError defined, which may be a good, specific error to raise instead of a ValueError.

pymc3/util.py Outdated Show resolved Hide resolved
@StephenHogg
Copy link
Contributor Author

StephenHogg commented Nov 15, 2020

The only thing still failing at this point is one test in test_step.py, which is because it is expecting a ParallelSamplingError but the new check is just returning a SamplingError. I'm happy to shift this over, but wanted to check I'm not violating the spirit of this test?

Edit: the flaky test is also not passing, but that definitely passes locally

@StephenHogg
Copy link
Contributor Author

StephenHogg commented Nov 17, 2020

Hi @ColCarroll - the only thing that still fails now is test_sampling.py:test_deterministic_of_observed when FLOATX is set to float32 (doesn't appear to be a problem with 64 bit floats?). I can increase the rtol here but it's already 1e-3, got any preference?

@ColCarroll
Copy link
Member

This looks great! What if you loosen the tolerances on the test, but also open a bug and mention that it got worse when this PR was merged? That's very strange...

I think the last two things are:

  • Remove draft status of this pr
  • Add a line to RELEASE_NOTES.md

@codecov
Copy link

codecov bot commented Nov 19, 2020

Codecov Report

Merging #4211 (ddf9fc3) into master (2723b6c) will decrease coverage by 0.18%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4211      +/-   ##
==========================================
- Coverage   88.14%   87.95%   -0.19%     
==========================================
  Files          87       87              
  Lines       14243    14248       +5     
==========================================
- Hits        12554    12532      -22     
- Misses       1689     1716      +27     
Impacted Files Coverage Δ
pymc3/model.py 89.05% <ø> (-0.12%) ⬇️
pymc3/sampling.py 85.39% <100.00%> (-1.10%) ⬇️
pymc3/tuning/starting.py 83.72% <100.00%> (+1.90%) ⬆️
pymc3/step_methods/hmc/base_hmc.py 90.83% <0.00%> (-7.50%) ⬇️
pymc3/backends/report.py 90.90% <0.00%> (-2.10%) ⬇️
pymc3/parallel_sampling.py 86.79% <0.00%> (-1.58%) ⬇️
pymc3/step_methods/hmc/quadpotential.py 79.03% <0.00%> (-0.54%) ⬇️

@StephenHogg StephenHogg marked this pull request as ready for review November 19, 2020 22:19
@twiecki
Copy link
Member

twiecki commented Nov 25, 2020

Seems like there are still conflicts though:
Uploading image.png…

@StephenHogg
Copy link
Contributor Author

Seems like there are still conflicts though:

Yes, that's what I'm saying - I can either leave the conflict in, in which case I can't merge, or I can resolve the conflict in which case linting fails because there's an unneeded import. It's a Catch-22.

pymc3/util.py Show resolved Hide resolved
@MarcoGorelli
Copy link
Contributor

MarcoGorelli commented Nov 25, 2020

Seems like there are still conflicts though:

Yes, that's what I'm saying - I can either leave the conflict in, in which case I can't merge, or I can resolve the conflict in which case linting fails because there's an unneeded import. It's a Catch-22.

Shouldn't be a catch-22 😄

Can you try

git fetch --all --prune
git merge upstream/master

Then, in pymc3/util.py, you'll see something like

<<<<<<< HEAD
=======
from numpy import ndarray
>>>>>>> upstream/master

Change it to

(i.e., choose the current changes, ignore the incoming ones)

Then,

git add -u
git commit
git push -u origin HEAD

for more on git, I heartily recommend the pro git book

@StephenHogg
Copy link
Contributor Author

As before - there's a mysterious new test failure

@ColCarroll
Copy link
Member

Wow, CI got changed under you!

test_examples.py::TestLatentOccupancy::test_run got reverted during your merge, I think: pm.ZeroInflatedPoisson("y", psi, theta, observed=y) should be pm.ZeroInflatedPoisson("y", psi, theta, observed=y)

@StephenHogg
Copy link
Contributor Author

This new error doesn't seem to have much to do with the code I wrote? Not sure, though

@MarcoGorelli
Copy link
Contributor

This new error doesn't seem to have much to do with the code I wrote? Not sure, though

can you check if that test passes when you run it locally?

pytest pymc3/tests/test_step.py::TestMLDA::test_acceptance_rate_against_coarseness

@StephenHogg
Copy link
Contributor Author

@MarcoGorelli passes locally, had to update theano-pymc to version 1.0.11 to do it. Here's the output:

(pymc3) shogg@192:~/git/pymc3$ pytest pymc3/tests/test_step.py::TestMLDA::test_acceptance_rate_against_coarseness
================================================================================================================================== test session starts ==================================================================================================================================
platform darwin -- Python 3.7.5, pytest-5.0.1, py-1.9.0, pluggy-0.13.1
rootdir: /Users/shogg/git/pymc3, inifile: setup.cfg
collected 1 item                                                                                                                                                                                                                                                                        

pymc3/tests/test_step.py .                                                                                                                                                                                                                                                        [100%]

=================================================================================================================================== warnings summary ====================================================================================================================================
pymc3/tests/test_step.py::TestMLDA::test_acceptance_rate_against_coarseness
pymc3/tests/test_step.py::TestMLDA::test_acceptance_rate_against_coarseness
pymc3/tests/test_step.py::TestMLDA::test_acceptance_rate_against_coarseness
  /Users/shogg/git/pymc3/pymc3/step_methods/mlda.py:383: UserWarning: The MLDA implementation in PyMC3 is still immature. You should be particularly critical of its results.
    "The MLDA implementation in PyMC3 is still immature. You should be particularly critical of its results."

-- Docs: https://docs.pytest.org/en/latest/warnings.html
========================================================================================================================= 1 passed, 3 warnings in 13.26 seconds =========================================================================================================================

@StephenHogg
Copy link
Contributor Author

Wait, all checks have passed now? Maybe the test was flaky?

@twiecki twiecki merged commit 22c079c into pymc-devs:master Nov 27, 2020
@twiecki
Copy link
Member

twiecki commented Nov 27, 2020

Yes, I think so. Thanks @StephenHogg!

@StephenHogg
Copy link
Contributor Author

Whew, thanks

@ColCarroll
Copy link
Member

+1 Thanks for sticking with us, @StephenHogg -- this was trickier than expected, but I think it will really improve lots of people's experiences.

ricardoV94 added a commit to ricardoV94/pymc that referenced this pull request Dec 2, 2020
Spaak pushed a commit that referenced this pull request Dec 5, 2020
* - Fix regression caused by #4211

* - Add test to make sure jitter is being applied to chains starting points by default

* - Import appropriate empty context for python < 3.7

* - Apply black formatting

* - Change the second check_start_vals to explicitly run on the newly assigned start variable.

* - Improve test documentation and add a new condition

* Use monkeypatch for more robust test

* - Black formatting, once again...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants