-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add classes for handling posteriors during the decision making loop #362
Add classes for handling posteriors during the decision making loop #362
Conversation
Added a `PosteriorHandler` class which can be used to obtain and update posteriors during the decision making loop as the number of datapoints in the training dataset changes (which is necessary for instantiating a likelihood object). Also added an `AbstractPosteriorOptimizer` abstract base class for optimizing posterior hyperparameters during the decision making loop. This allows for more customisability when optimising posteriors. Added a concrete implementation in the form of an `AdamPosteriorOptimizer` which can be used to optimize posteriors with the Adam optimizer provided by optax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've attempted a review. However, I'm having a pretty hard time at this point imagining how all the moving part of the decision making framework work together. A notebook or even examples in the docstrings would be really helpful. My concern is that things are becoming slightly too abstracted and taking the necessary level of control away from the user. That being said, I may be wrong, hence asking for some example(s).
|
||
|
||
@dataclass | ||
class AdamPosteriorOptimizer(AbstractPosteriorOptimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a level of abstraction too far. What if I want to use adamw
or sgd
? With this abstraction I'd have subclass AbstractPosteriorOptimizer
just to change the optimiser. Could we instead structure it as OptaxPosteriorOptimizer
or something similar where the user supplies the relevant optimiser? It feels a bit more general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further, what's the benefit of this object, given that it just calls fit
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @thomaspinder. I think you just need to pass in optim and num_iters into the PosteriorHandler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - I had originally kept the optimisation logic separate in anticipation of moving towards fitting models with L-BFGS further down the line. However, I think I will move this logic to the PosteriorHandler
class, and accept that I may just have to change this class a bit once I switch to using L-BFGS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I agree that the posterior optimiser seems overkill! Impressive battery of tests though
key: A JAX PRNG key which is used for optimizing the posterior | ||
hyperparameters. | ||
""" | ||
posterior = previous_posterior.prior * self.likelihood_builder(dataset.n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, so this steals the optimized kernel params from the last go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is included to cover for a scenario where e.g. a user wishes to run the decision making loop, but doesn't want to optimise the posterior on each iteration of the loop. With this logic they can still update the likelihood, to reflect the change in dataset size, but without changing the prior parameters.
|
||
|
||
@dataclass | ||
class AdamPosteriorOptimizer(AbstractPosteriorOptimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @thomaspinder. I think you just need to pass in optim and num_iters into the PosteriorHandler
Removed the `PosteriorOptimizer` classes and included posterior optimization logic in the `PosteriorHandler` class.
The PR looks good. However, I think #364 will affect the structure of this PR - we should be careful with this before merging into |
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Added a
PosteriorHandler
class which can be used to obtain and update posteriors during the decision making loop as the number of datapoints in the training dataset changes (which is necessary for instantiating a likelihood object).Also added an
AbstractPosteriorOptimizer
abstract base class for optimizing posterior hyperparameters during the decision making loop. This allows for more customisability when optimizing posteriors. Added a concrete implementation in the form of anAdamPosteriorOptimizer
which can be used to optimize posteriors with the Adam optimizer provided by optax.Issue Number: N/A