Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher level interface for predictions and arviz integration #349

Closed
neerajprad opened this issue Sep 24, 2019 · 10 comments · Fixed by #369
Closed

Higher level interface for predictions and arviz integration #349

neerajprad opened this issue Sep 24, 2019 · 10 comments · Fixed by #369
Assignees
Milestone

Comments

@neerajprad
Copy link
Member

neerajprad commented Sep 24, 2019

ArviZ integration

In arviz-devs/arviz#811, I already made an attempt for this integration. As demonstrated in InferenceData introduction, there are several groups of an inference data (prior/posterior/predictive/log_lik/inferece_stats/observed_data/...). Currently, users can do something like NumPyro section in ArviZ cookbook. The az.from_numpyro requires 3 args posterior (which is mcmc), prior (i.e. prior_predictive), posterior_predictive. Currently, we lack a handy utility for prior predictive.

Internally, we lack observed_data and log_lik groups. The reason is given a mcmc instance, we do not store args, kwargs, so even that we can access to model, we still unable to run it from arviz. observed_data is required for ppcplot or plot_loo_pit. log_lik is useful for many stuffs, including some bayesian criterions.

Solutions:

  • store args, kwargs of run method to MCMC state.
  • make prior_predictive, log_like_predictive utilities.

SVI

I think that current predictive utility is enough to get samples from guide (similar to the sample_posterior method in autoguide).

What predictive likes in other frameworks

Stan has generated quantities block, which is used to

  • forward sampling to generate simulated data for model testing,
  • generating predictions for new data,
  • calculating posterior event probabilities, including multiple comparisons, sign tests, etc.,
  • calculating posterior expectations,
  • transforming parameters for reporting,
  • applying full Bayesian decision theory,
  • calculating log likelihoods, deviances, etc. for model comparison.

PyMC3 has deterministic nodes to keep track of transformed variables and posterior_predictive to get values at observed nodes.

Typically, for prediction/forecasting, I imagine we will do something like

samples = mcmc.get_samples()
def forecasting_model(*args, **kwargs):
    vars_from_model = model(*args, **kwargs)
    forecast_values = do_something_else(vars_from_model)
    return forecast_values
forecast_values = predictive(rng, forecasting_model, samples)

I think we can do all the above tasks by either

  • having _RETURN node in trace as in Pyro and using predictive util to get values of this node
  • having a flag include_returned_values in predictive (or having another predictive utility for this purpose) which will run condition(model, preditive_sites + posterior_sites)(*args, **kwargs) to get those returned values. I prefer this. The minor drawback of this way is we will run the model 2 times: one to get predictive samples, another one to get returned values.
@neerajprad
Copy link
Member Author

The reason is given a mcmc instance, we do not store args, kwargs, so even that we can access to model, we still unable to run it from arviz.

Could you point to the arviz code that needs access to args / kwargs to run it? I wonder if there is a way to include it in the wrapper for arviz rather than storing it in MCMC.

having _RETURN node in trace as in Pyro and using predictive util to get values of this node

Could you clarify why we need this? This would have been observed by some observed site like pyro.sample('obs', .., obs=y). Is it not possible to extract this from the trace instead? For instance, your forecasting model above is doing something on top of the model over which we run inference. A concrete example will help.

make prior_predictive, log_like_predictive utilities.

These should be straightforward to add, once we clarify the issues above. I think predictive can simply be extended because when posterior_samples = {}, we will not be conditioning any latent sites and be effectively sampling from the prior predictive.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 25, 2019 via email

@neerajprad
Copy link
Member Author

Otherwise, numpyro users will need to compute obs_data,
log_likehood themselves and provide them to arviz through from_numpyro
method. That’s fine but redundant to me.

I think we should provide those utility functions, but I'm trying to understand the constraints here. Based on what you say, we could provide a log_likelihood function that depends on args, kwargs, but I'm guessing that arviz requires this function to only depend on the model?

The generated quantities from Stan are best to illustrate what I need: get some inferred
values (either stochastic or not) from posterior samples.

Can we use something like numpyro.sample('log_y', dist.Delta(v=trans_x), obs=trans_x) to record these values? We can probably have a helper function that adds the obs kwarg directly.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 25, 2019

arviz requires this function to only depend on the model

I guess I have given a bad explanation. arxiv only needs to access log_likelihood value at the observed node. So there are two ways to do it:

  • compute the value in numpyro: this requires users to provide log_likelihood value to az.from_numpyro method.
  • compute the value inside the az.from_numpyro: this requires access to args, kwargs so that we can trace the model to get log_likelihood.

For both ways, having log_lik_predictive utility is helpful:

  • for the first case, it is helpful for users to compute log_likelihood value
  • for the second case, it is helpful to not populate handler codes to arviz (so inside az.from_numpyro, we just need to call log_lik_predictive(rng, mcmc.sampler.model, *args,**kwargs).

My proposal is to store args, kwargs in MCMC for the second way.

Can we use something like numpyro.sample('log_y', dist.Delta(v=trans_x), obs=trans_x) to record these values?

Thanks, I didn't know that we can use Delta distribution for this purpose! :D That would be more than enough! (obs=trans_x might not be necessary unless we want to run some inference with forecast_model).

@neerajprad
Copy link
Member Author

for the second case, it is helpful to not populate handler codes to arviz (so inside az.from_numpyro, we just need to call log_lik_predictive(rng, mcmc.sampler.model, *args,**kwargs).

Sure, that sounds reasonable. I suppose we can store the args, kwargs in MCMC and provide a log_likelihood(model, posterior_samples, *args, **kwargs) function.

Hmm..obs= will be necessary I think, otherwise we'll run HMC on those delta distributions and that will forever be rejected.

@fehiepsi
Copy link
Member

obs= will be necessary I think

Sorry, my bad. I should have distinguished model and forecast_model. The former is the one which we run MCMC, while the latter is just useful for predicting/forecasting. But as you said, having obs=... will be helpful if we want to use model for both inference and forecasting (but doing so will slow down MCMC I guess).

In addition, now I think that having return node which stores a pytree is a bad idea because we can't run summary for this node. Your Delta solution can record return values in separated keys, which is much better!

@neerajprad
Copy link
Member Author

But as you said, having obs=... will be helpful if we want to use model for both inference and forecasting (but doing so will slow down MCMC I guess).

I don't think there should be any significant slowdown since the delta won't add anything to the PE term, but adding the same obs keyword isn't a very elegant solution. Do you expect the users to write this forecast function?

@fehiepsi
Copy link
Member

fehiepsi commented Sep 25, 2019

I don't mean the slowdown is caused by using Delta distribution. I mean the slowdown is caused by calculating the extra terms (e.g. log(y)) which does not contribute to joint density computation. But if those extra calculations go inside if forecast block then there won't be any slowdown at all. For complicated models, I still prefer to create a separate model for prediction, even though there will be duplicated code for model and forecast_model. :)

Do you expect the users to write this forecast function?

I am not sure for other users, but it is enough for my purpose. For example, in the ts forecasting tutorial, I did use the low-level handlers substitute, seed, trace for forecasting. I think that with your Delta solution, all I need is to create a forecast_model and use predictive utility. :D

@neerajprad
Copy link
Member Author

Are there any remaining todos on this one?

@fehiepsi
Copy link
Member

Yeah, I need to verify that we fully support arviz. I will make a notebook for verification tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants