Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: API reorganization #30

Closed
tpapp opened this issue Dec 25, 2018 · 9 comments
Closed

RFC: API reorganization #30

tpapp opened this issue Dec 25, 2018 · 9 comments

Comments

@tpapp
Copy link
Owner

tpapp commented Dec 25, 2018

Motivation

DynamicHMC.jl was started in 2017 June, and initially released in 2018 February. Since then, the architecture and the API only underwent minor changes. However, various use cases are stretching the API a bit, and it is time for a redesign and a rewrite of some intervals. I am opening this issue to discuss these — feel free to use it as a wishlist, or to share your recommendations or use cases.

I intend to keep the focus of the package the same, ie as a building block for Bayesian inference using variants of the NUTS sampler. The user is still expected to provide a log density (with gradient).

I briefly discuss the changes I am proposing below.

Low-level implementation changes

These would be mostly invisible to the user.

Non-allocating leapfrog calculations

The most important one is probably reducing allocations, by reusing the vectors for position and momentum. This has a tiny impact (as for nontrivial models, calculating the logdensity is the most costly), but is a low-hanging fruit and has some significance for high-dimensional models. I am undecided about this though, as I would have to trust the log density calculations not to change the position vector.

Ideas: perhaps make it optional, and allow SVector transparently?

EDIT: I abandoned this, because keeping the functional design (which is now generalized) makes it much easier for me to use multithreading in 1.3

Mid-level API

Flexible NUTS step implementation

The NUTS sampling step, currently implemented by DynamicHMC.NUTS_transition, currently reports the reason for termination, the new drawn position, and the average acceptance rate. However, obtaining the whole trajectory with probabilities could be useful, eg for debugging issues like SciML/DiffEqBayes.jl#60 and also for pedagogical purposes (eg visualizing HMC trajectories).

The interface should allow users to experiment with different step sizes (also jitter), momentum, kinetic energy, and max_depth specifications, and debug these. Eg if the user learns that most steps terminate because of divergence, he should be able to

  1. plot example trajectories from a given point,
  2. experiment with various step sizes and kinetic energies.

Allow jittered stepsize ϵ

A core stepsize should be adapted, while at the same time using a random jitter factor to adjust.

Interface for iterative application

Sometimes a more granular interface would be useful for tuning and adaptation. In #28 we arrived at the interface

sample, new_state = mcmc_step(stage, nuts::NUTS, state)

for performing NUTS steps, with the idea that state could be something that is tuned (eg stepsize ϵ) in stage. Currently the API only exposes doing this for a pre-determined number of steps (see below).

High-level API

Logging

A new de-facto standard seems to be emerging for progress meters via @logmsg, eg see Atom.jl and timholy/ProgressMeter.jl#102 . Progress reports should be using this. Cf #10.

Interface for initialization and adaptation

I envision each adaptation step as a transformation from previous parameters of the algorithm (stepsize, kinetic energy) to new ones, using random realizations of MCMC draws, ie

sampler parameters, Hamiltonian ====================> new sampler parameters
                                  NUTS realizations

The user could be interested in

  1. the whole history of adaptation (currently possible by invoking steps manually),

  2. the posterior and adapted sampler (what is now returned by NUTS_init_tune_mcmc),

  3. just the posterior.

Targetting (2) as the default interface may have been a mistake, as mostly I am interested in (3) when things go well, and (1) when they don't (cf #24, #9). Also, when samplers have to be parametrized manually, it would be useful to experiment with various initialization and adaptation strategies, eg

  1. picking the initial position by a crude or sophisticated maximization algorithm (addressing discuss starting from the mode #8, optimize before adapting stepsize #25),

  2. less or more aggressive adaptation of stepsize.

The proposed interface is the following: the user provides

  1. a chain of adaptation steps, eg as a Tuple,

  2. a parameter that specifies how much history should be kept.

Each is applied to the previous state (initialized using nothing), with the target log density and the current parameters as given, and returns a new set of parameters and an adaptation history (when required). The high-level interface can then pick what to keep and return.

@tpapp
Copy link
Owner Author

tpapp commented Jan 10, 2019

#1 and #35 should be addressed by this API rewrite.

@zenna
Copy link

zenna commented Jan 15, 2019

I wrote a long reply to this but didn't finish it and it got lost into the ether.

So i'll just briefly add that (i) thanks for the hard work, this package is very useful and (ii):

  • non-allocating operations are important
  • It would be nice to be able to turn off the printing
  • I had a difficult time getting around all the different parts of this package at first. From the stated goals I thought I would be able to simply supply a negative log density function / gradient, but instead I have to (i) use the transformation package to construct a transformation, (ii) construct a logdensityproblem, (iii) do some magic with ADGradient. Unless there are strong reason otherwise, I think these parts should be decoupled as much as possible and the API should accept standard objects.
  • I think you should use a callback system to allow the user to capture "extra" things. I can elaborate more if that's not clear.

Again, great work!

@tpapp
Copy link
Owner Author

tpapp commented Jan 15, 2019

Thanks for your comment.

Regarding the log density: you can just define Foo that supports

logdensity(::Type{LogDensityProblems.ValueGradient}, ::Foo, ::AbstractVector)

and use that. Or

logdensity(::Type{LogDensityProblems.Value}, ::Foo, ::AbstractVector)

and use one of the AD methods from that package. TransformVariables is just a convenience feature that is not required to be used. But perhaps I should make this more clear.

I would appreciate if you could elaborate about the callbacks.

@ghost
Copy link

ghost commented Jul 5, 2019

logdensity(::Type{LogDensityProblems.ValueGradient}, ::Foo, ::AbstractVector)

Is there an example that uses that? Would be very helpful. I'm struggling too...

@zenna

This comment has been minimized.

@ghost
Copy link

ghost commented Jul 5, 2019

Ha! @zenna, turns out I also need a callback when the sample is kept for the final tally, is this your use case?

I'm using MCMC to integrate over an internal variable, so I need to save internal state, but only for the samples that are kept. What's the best way to achieve that now?

@zenna
Copy link

zenna commented Jul 5, 2019

@GAIKA
The simplest way is to use a callback system, such as the one Flux uses. This would requires @tpapp to update the interfaces to accept a callback function. In Flux it's something like

donothing(x...)  = nothing
function train(..., cb = donothing)
  for i = 1:n
     dostuff
     cb(loss, otherdata)
   end
end

This works very well but has two major disavtanges
(i) You have to thread this cb function through all your functions. It becomes cumbersome
(ii) There may be internal variables you may wish to capture.

So I've been developing a combination of systems called Lens and Callbacks.

https://github.com/zenna/Lens.jl/pulls
https://github.com/zenna/Callbacks.jl

The high level idea is that you anotate your code with lens, and then you can execute a program in a lens capturing context.

It works fine. I use it in Omega. https://github.com/zenna/Omega.jl It avoids problem (i) but it still suffers from problem (ii).
Also there is some hacky global variable state that is undesirable.

The ultimate right solution is to use Cassette. This will allow you to solve (i) and (ii)
I have a version of Lens with runs using Cassette , but Cassette is still not yet fast enough.

So to solve your problem right now: If it were me I would make a fork of DynamicHMC and then either by passing in a callback or using a Lens, capture it.

@tpapp
Copy link
Owner Author

tpapp commented Jul 6, 2019

@GAIKA: do you need an example of how to code a posterior for which you have the gradient available? Please open an issue at https://github.com/tpapp/DynamicHMCExamples.jl/.

@zenna: I don't think callbacks are the right approach. I would do this with an extra payload, which could be a thunk. But I am open to suggestions.

It would be great to have a concrete example of a problem to focus the discussion. Please open another issue so that we can discuss it.

@tpapp tpapp mentioned this issue Jul 28, 2019
@tpapp tpapp mentioned this issue Aug 5, 2019
3 tasks
@tpapp
Copy link
Owner Author

tpapp commented Aug 19, 2019

This was mostly done by #44 and subsequent PRs, except for writing the docs (#62), and jittered stepsize (for which opened #61 as a reminder).

@tpapp tpapp closed this as completed Aug 19, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants