Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev: let CoLA decide how to solve #475

Closed
theorashid opened this issue Sep 6, 2024 · 4 comments
Closed

dev: let CoLA decide how to solve #475

theorashid opened this issue Sep 6, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@theorashid
Copy link
Contributor

Hey everyone,

I'm running an exact GP on a large dataset (16.5k observations) and I'm running into memory issues / issues with the stability of solve. These issues do not occur when I use a smaller subset of the data. Of course, I could bump the memory on my server/GPU and see if that helps, but I'm wondering if there's another option. This specific bit of code uses the .sample_approx() method.

The relevant lines where I get the error are

	canonical_weights = solve(
            Sigma,
            y + eps - jnp.inner(Phi, fourier_weights),
            Cholesky(),
        )  #  [N, B]

In this case, we force CoLA to use the Cholesky() alg. However, CoLA has two algorithms for PSD A.

def inv(A: LinearOperator, alg: Auto):
    """ Auto:
        - if A is PSD and small, use Cholesky
        - if A is PSD and large, use CG
        - if A is not PSD and small, use LU
        - if A is not PSD and large, use GMRES
    """

So what we could do instead is wrap Sigma in cola.PSD and let CoLA decide the best algo. I'm open to being shut down if there is a particular reason for Cholesky, or if you think there is another reason I am running out of memory.


GPJax version (although, despite the nnx update, I don't think it should matter in this case):

"gpjax==0.8.2"

@theorashid theorashid added the enhancement New feature or request label Sep 6, 2024
@thomaspinder
Copy link
Collaborator

Hey! This is a good point. We enforced the use of Cholesky as it's a bit disarming when things change under the hood without one's knowledge. However, we could make the solver an optional argument to the function, where the default value is Cholesky. Would that be a solution to your issue?

@theorashid
Copy link
Contributor Author

I think so, but from an GPJax API maintenance point of view, if you're committed to using CoLA (or a fork of it) to do the linalg stuff, then it's probably easier just to specify Sigma as PSD and let CoLA decide on the solve method. It might even be beneficial if they optimise the solve API further, although I doubt they will come up with new maths beyond Cholesky() or CG().

@thomaspinder
Copy link
Collaborator

thomaspinder commented Sep 13, 2024

Yes. I think in the spirit of flexibility, I would make the solver an optional argument. Given how important the solver is, I'd prefer users to never be surprised by how the solve operation is done.

I am OK with making this change to the codebase.

@theorashid would you be willing to open a PR for this? I can do it, but not in the next few weeks.

Related issue #381

@thomaspinder
Copy link
Collaborator

Resolved in #478

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants