Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add classes for handling posteriors during the decision making loop #362

Conversation

Thomas-Christie
Copy link
Contributor

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Added a PosteriorHandler class which can be used to obtain and update posteriors during the decision making loop as the number of datapoints in the training dataset changes (which is necessary for instantiating a likelihood object).

Also added an AbstractPosteriorOptimizer abstract base class for optimizing posterior hyperparameters during the decision making loop. This allows for more customisability when optimizing posteriors. Added a concrete implementation in the form of an AdamPosteriorOptimizer which can be used to optimize posteriors with the Adam optimizer provided by optax.

Issue Number: N/A

Added a `PosteriorHandler` class which can be used to obtain and update
posteriors during the decision making loop as the number of datapoints
in the training dataset changes (which is necessary for instantiating a
likelihood object).

Also added an `AbstractPosteriorOptimizer` abstract base class for
optimizing posterior hyperparameters during the decision making loop.
This allows for more customisability when optimising posteriors. Added a
concrete implementation in the form of an `AdamPosteriorOptimizer` which
can be used to optimize posteriors with the Adam optimizer provided by
optax.
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've attempted a review. However, I'm having a pretty hard time at this point imagining how all the moving part of the decision making framework work together. A notebook or even examples in the docstrings would be really helpful. My concern is that things are becoming slightly too abstracted and taking the necessary level of control away from the user. That being said, I may be wrong, hence asking for some example(s).



@dataclass
class AdamPosteriorOptimizer(AbstractPosteriorOptimizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a level of abstraction too far. What if I want to use adamw or sgd? With this abstraction I'd have subclass AbstractPosteriorOptimizer just to change the optimiser. Could we instead structure it as OptaxPosteriorOptimizer or something similar where the user supplies the relevant optimiser? It feels a bit more general.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further, what's the benefit of this object, given that it just calls fit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @thomaspinder. I think you just need to pass in optim and num_iters into the PosteriorHandler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - I had originally kept the optimisation logic separate in anticipation of moving towards fitting models with L-BFGS further down the line. However, I think I will move this logic to the PosteriorHandler class, and accept that I may just have to change this class a bit once I switch to using L-BFGS.

Copy link
Contributor

@henrymoss henrymoss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I agree that the posterior optimiser seems overkill! Impressive battery of tests though

key: A JAX PRNG key which is used for optimizing the posterior
hyperparameters.
"""
posterior = previous_posterior.prior * self.likelihood_builder(dataset.n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, so this steals the optimized kernel params from the last go?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is included to cover for a scenario where e.g. a user wishes to run the decision making loop, but doesn't want to optimise the posterior on each iteration of the loop. With this logic they can still update the likelihood, to reflect the change in dataset size, but without changing the prior parameters.



@dataclass
class AdamPosteriorOptimizer(AbstractPosteriorOptimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @thomaspinder. I think you just need to pass in optim and num_iters into the PosteriorHandler

Removed the `PosteriorOptimizer` classes and included posterior
optimization logic in the `PosteriorHandler` class.
@thomaspinder
Copy link
Collaborator

The PR looks good. However, I think #364 will affect the structure of this PR - we should be careful with this before merging into main.

@Thomas-Christie Thomas-Christie merged commit 84da3cd into JaxGaussianProcesses:tchristie/bo Aug 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants