Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pm.Simulator (2nd attempt) #4877

Closed
wants to merge 2 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 23, 2021

The previous attempt at this in #4802 revealed big issues between pickling and dynamically created classes (which subsist even after #4858). The alternative presented in this PR is a bit more cumbersome from the user-side but at least avoids these issues.

Here is a minimal example demonstrating the new API:

import numpy as np
import pymc3 as pm

data = np.random.normal(0, 1, size=10)

def my_simulator_fn(rng, loc, scale, size):
    return rng.normal(loc, scale, size=size)

class MySimulatorRV(pm.SimulatorRV):
    ndim_supp = 0
    ndims_params = [0, 0, 0]
    fn = my_simulator_fn
    distance = "gaussian"
    sum_stat = "sort"

my_simulator = MySimulatorRV()

with pm.Model() as m:
    simulator = pm.Simulator("sim", my_simulator, 0, 1, epsilon=1.0, observed=data)

If anyone has better suggestions about how to avoid pickling issues when creating dynamic classes please let me know! For example the classes used in these tests had to be defined outside of TestSimulator.setup_class.

@codecov
Copy link

codecov bot commented Jul 23, 2021

Codecov Report

Merging #4877 (572cab6) into main (819f045) will increase coverage by 0.37%.
The diff coverage is 96.38%.

❗ Current head 572cab6 differs from pull request most recent head 17d9f2e. Consider uploading reports for the commit 17d9f2e to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main    #4877      +/-   ##
==========================================
+ Coverage   73.16%   73.53%   +0.37%     
==========================================
  Files          86       86              
  Lines       13838    13809      -29     
==========================================
+ Hits        10125    10155      +30     
+ Misses       3713     3654      -59     
Impacted Files Coverage Δ
pymc3/distributions/simulator.py 88.60% <95.62%> (+62.74%) ⬆️
pymc3/aesaraf.py 91.34% <100.00%> (+0.07%) ⬆️
pymc3/distributions/__init__.py 100.00% <100.00%> (ø)
pymc3/distributions/distribution.py 84.12% <100.00%> (+3.69%) ⬆️
pymc3/smc/sample_smc.py 96.87% <100.00%> (+4.33%) ⬆️
pymc3/smc/smc.py 99.31% <100.00%> (+26.71%) ⬆️
pymc3/distributions/multivariate.py 63.84% <0.00%> (-7.62%) ⬇️
pymc3/tests/conftest.py 88.23% <0.00%> (-2.25%) ⬇️
pymc3/step_methods/hmc/base_hmc.py 90.24% <0.00%> (-0.82%) ⬇️

@ricardoV94 ricardoV94 force-pushed the restore_abc_alt branch 2 times, most recently from 8d72685 to 293aac1 Compare July 23, 2021 14:05
pymc3/aesaraf.py Outdated Show resolved Hide resolved
pymc3/aesaraf.py Show resolved Hide resolved
pymc3/distributions/simulator.py Show resolved Hide resolved
pymc3/distributions/simulator.py Outdated Show resolved Hide resolved
pymc3/tests/test_smc.py Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

I decided to leave epsilon as an explicit argument to make it easier to manipulate (e.g., if we want our samplers to tune it), but can be a class attribute like distance and sum_stat if we don't think that will be ever used.

I can see it as being confusing like this...

@junpenglao
Copy link
Member

Instead of subclassing and initializing the subclass, would something like:

my_simulator = pm.SimulatorRV(
    ndim_supp = 0
    ndims_params = [0, 0, 0]
    fn = my_simulator_fn
    distance = "gaussian"
    sum_stat = "sort"
)

works?

@michaelosthege
Copy link
Member

Instead of subclassing and initializing the subclass, would something like:

my_simulator = pm.SimulatorRV(
    ndim_supp = 0
    ndims_params = [0, 0, 0]
    fn = my_simulator_fn
    distance = "gaussian"
    sum_stat = "sort"
)

works?

The thing needs to be a class, so pm.make_simulator() would have to create a class inside and return that. But these kinds of class definitions often cause problems with pickling.
But I agree that the API looks much nicer. @ricardoV94 what do you think?

@junpenglao
Copy link
Member

The thing needs to be a class, so pm.make_simulator() would have to create a class inside and return that. But these kinds of class definitions often cause problems with pickling.

I am assuming that as long as pm.make_simulator() is called OUTSIDE of the with context we should be able to pickle it - @ricardoV94 do we have some minimal reproducible example that is easy to play with?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 26, 2021

The thing needs to be a class, so pm.make_simulator() would have to create a class inside and return that. But these kinds of class definitions often cause problems with pickling.

I am assuming that as long as pm.make_simulator() is called OUTSIDE of the with context we should be able to pickle it - @ricardoV94 do we have some minimal reproducible example that is easy to play with?

Here is a minimal gist that implements the current API and has some checks: https://gist.github.com/ricardoV94/b632085b20be716b87fd146609168090

And here is another gist with a previous iteration on these ideas (using pickle instead of the more flexible cloudpickle that we are now using in PyMC3): https://gist.github.com/ricardoV94/2bb59a2ac18a29f501f5511c9671ebbc

Copy link
Member

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few nitpicks

@@ -157,6 +147,29 @@ def sample_smc(
%282007%29133:7%28816%29>`__
"""

if kernel is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be deprecated, we still want to have kernels.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I know. We should just deprecate the "keywords" for the time being.

DeprecationWarning,
stacklevel=2,
)
if save_sim_data is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want this, as doing pm.sample_posterior_predictive could be potentially too expensive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have access to the simulated data in the logp graph

return posterior, {modelcontext(model).observed_RVs[0].name: np.array(sim_data)}
else:
return posterior
return idata if return_inferencedata else trace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is the trace going to be deprecated? maybe we could only return inferencedata

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would only deprecate it when it is deprecated in pm.sample

def _sum_stat(cls, value):
return cls.sum_stat(value)


class Simulator(NoDistribution):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, if we are going to have a Simulator and a SimulatorRV, maybe the first one should be renamed to PseudoLikelihood, SimulatedLikelihood AbcLikelihood or something similar. One counterargument to this proposal, is that if we are going to distinguish between simulator and pseudolikelihood, the distance and summary statistics should be part of the later not the former.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was one of the ideas, the SimulatorRV would simply be concerned with the random draws and the Pseudolikelihood would take care of the logp factor.

The downside is that we don't yet know how to create a "dynamic" logp using the optional user defined sum_stat and distance functions. It means we might need to have users subclass not only the SimulatorRV but also the Pseudolikelihood if they want to use non-default functions. If someone figures out #4831, then we could simply copy their strategy.

Defining the functions in the SimulatorRV was just an ugly hack to avoid forcing users to create two new classes...

Comment on lines +368 to +371
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What conclusion should a user make from from this warning?
Is it serious? If so we should raise. Otherwise maybe just _log.warn()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably raise

@lucianopaz
Copy link
Contributor

Instead of subclassing and initializing the subclass, would something like:

my_simulator = pm.SimulatorRV(
    ndim_supp = 0
    ndims_params = [0, 0, 0]
    fn = my_simulator_fn
    distance = "gaussian"
    sum_stat = "sort"
)

works?

The thing needs to be a class, so pm.make_simulator() would have to create a class inside and return that. But these kinds of class definitions often cause problems with pickling.
But I agree that the API looks much nicer. @ricardoV94 what do you think?

Could this be made to work with pickle if we defined SimulatorRV as a metaclass or provide a __reduce__ method to it? Something like what's done here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted SMC Sequential Monte Carlo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants