Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Installation with GPU support #102

Closed
seanjkanderson opened this issue Aug 22, 2022 · 6 comments · Fixed by #103
Closed

bug: Installation with GPU support #102

seanjkanderson opened this issue Aug 22, 2022 · 6 comments · Fixed by #103
Labels
bug Something isn't working

Comments

@seanjkanderson
Copy link

Bug Report

Moving this from Discussions as not seeing much traffic there.
GPJax version: 0.4.10

Current behavior:
I can install GPJax with CPU only support but not with GPU support. If I try to install gpjax with cuda support, either based on the current gpjax readthedocs or by first installing Jax according to the latest install with cuda directions:

I've tried to find a matching CUDA-compatible version of jax==0.3.5 but haven't found a version that is also supported by my machine.

I've also tried installing it on a Google Colab notebook with !pip install gpjax and this disables GPU support based on the package versions specified in setup.py.

I'm thinking the setup.py file needs to be updated to a higher jax version and potentially make corresponding updates to changes in jax since 0.3.5, or I'm misunderstanding how to configure GPU support with GPJax.

Expected behavior:
I would expect that GPJax does have GPU support.

Steps to reproduce:

Ubuntu 20.04.3 and have tried python 3.8-3.10
run the regression example for instance:

from pprint import PrettyPrinter

import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.example_libraries import optimizers
from jax import jit
import optax as ox

pp = PrettyPrinter(indent=4)
key = jr.PRNGKey(123)


def pred_dist(xtest, y, training, final_params):
    latent_distribution = posterior(training, final_params)(xtest)
    predictive_distribution = likelihood(latent_distribution, final_params)
    return jnp.sum(predictive_distribution.log_prob(y))

if __name__ == '__main__':
    import matplotlib.pyplot as plt

    n = 100
    noise = 0.3

    x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).sort().reshape(-1, 1)
    f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
    signal = f(x)
    y = signal + jr.normal(key, shape=signal.shape) * noise

    D = gpx.Dataset(X=x, y=y)

    xtest = jnp.linspace(-3.5, 3.5, 500).reshape(-1, 1)
    ytest = f(xtest)

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(xtest, ytest, label="Latent function")
    ax.plot(x, y, "o", label="Observations")
    ax.legend(loc="best")

    kernel = gpx.RBF()
    prior = gpx.Prior(kernel=kernel)
    likelihood = gpx.Gaussian(num_datapoints=D.n)
    posterior = prior * likelihood
    params, trainable, constrainer, unconstrainer = gpx.initialise(posterior)
    pp.pprint(params)
    params = gpx.transform(params, unconstrainer)
    mll = jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))
    mll(params)

    opt = ox.adam(learning_rate=0.01)
    final_params = gpx.fit(
        mll,
        params,
        trainable,
        opt,
        n_iters=500,
    )

    final_params = gpx.transform(final_params, constrainer)
    pp.pprint(final_params)

    latent_dist = posterior(D, final_params)(xtest)
    predictive_dist = likelihood(latent_dist, final_params)

    predictive_mean = predictive_dist.mean()
    predictive_std = predictive_dist.stddev()
    # test gradient behavior with respect to posterior
    gtest = jax.grad(pred_dist)
    grad = gtest(xtest, y, D, final_params)

Running this returns the following warning and error:

/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/chex-0.1.3-py3.10.egg/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))
{   'kernel': {   'lengthscale': DeviceArray([1.], dtype=float64),
                  'variance': DeviceArray([1.], dtype=float64)},
    'likelihood': {'obs_noise': DeviceArray([1.], dtype=float64)},
    'mean_function': {}}
Traceback (most recent call last):
  File "/home/paperspace/pycharm_projects/exp_design/gpjax_demo.py", line 47, in <module>
    params = gpx.transform(params, unconstrainer)
  File "/home/paperspace/GPJax/gpjax/parameters.py", line 147, in transform
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 205, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 205, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/paperspace/GPJax/gpjax/parameters.py", line 147, in <lambda>
    return jax.tree_util.tree_map(lambda param, trans: trans(param), params, transform_map)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/bijectors/lambda_bijector.py", line 112, in inverse
    return self._inverse(y)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/utils/transformations.py", line 119, in wrapped
    out = _interpret_inverse(jaxpr, consts, *args)
  File "/home/paperspace/anaconda3/envs/ml310/lib/python3.10/site-packages/distrax-0.1.2-py3.10.egg/distrax/_src/utils/transformations.py", line 250, in _interpret_inverse
    write(jax.core.unitvar, jax.core.unit)
AttributeError: module 'jax.core' has no attribute 'unitvar'

I can try to provide something more minimal, but I'm guessing I'm just missing something more fundamental on install.

In Colab, run
!pip install gpjax
Then reload as per the prompt
import gpjax as gpx
At that point it should indicate their is no GPU/TPU found

Thanks!

@seanjkanderson seanjkanderson added the bug Something isn't working label Aug 22, 2022
@seanjkanderson seanjkanderson changed the title bug: bug: Installation with GPU support Aug 23, 2022
@thomaspinder thomaspinder mentioned this issue Aug 23, 2022
6 tasks
@thomaspinder thomaspinder reopened this Aug 23, 2022
@thomaspinder
Copy link
Collaborator

thomaspinder commented Aug 23, 2022

Hi @seanjkanderson. Thanks for raising this - it's a particularly tricky issue that I've been trying to get to the bottom of for a while now. We've just released v0.4.12 which resolves both of your above issues. To be explicit - you can now run GPJax on Colab GPUs without having to restart the kernel post installation. I have put the below notebook together that will let you get started with this. Please do reach out with any further issues - I'd be interested to hear about your experiences, be it positive or negative!

https://colab.research.google.com/drive/1FWyhE8XLJbWaneDRUNzuZQvNueO-kLdc?usp=sharing

If you are happy that this resolves your issue, then please feel free to close this issue.

@seanjkanderson
Copy link
Author

Thanks for the help @thomaspinder. That example works for me in Colab and on a VM running Ubuntu (and I was able to get it working in my project as well).

It seems like you intentionally didn't use gpjax.fit() as tree_map and InferenceState don't seem compatible at the moment? Is that right or is there a new expected behavior/use of .fit()? Just curious as I can use the training approach as described in the notebook.

Cheers!

@thomaspinder
Copy link
Collaborator

No problem @seanjkanderson!

fit still works - you just have to be sure to pass it params not the ParameterState object. This is something we're looking to fix in the next few days though. To run fit, you'd simply run

inference_state = gpx.fit(mll, params, trainables, opt)
final_params_fit, history = inference_state.unpack()
final_params_fit = gpx.transform(final_params_fit, constrainer)

I have also added a cell to the above linked notebook with this code in - the only reason I left it out originally was to stay true to the README code.

@seanjkanderson
Copy link
Author

Got it, my mistake was passing the ParameterState object to transform(). Thanks @thomaspinder!

@seanjkanderson
Copy link
Author

Just got back to using GPJax on Colab/a machine with GPU support. It seems that the install for GPJax is back to uninstalling jaxlib with CUDA support. What's the workaround here @thomaspinder?

Thanks!

Example section of trace from pip install gpjax

Installing collected packages: typeguard, ml-collections, deprecation, jaxtyping, jaxlib, jax, chex, optax, distrax, jaxutils, jaxlinop, jaxkern, gpjax
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    Uninstalling typeguard-2.7.1:
      Successfully uninstalled typeguard-2.7.1
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.3.25+cuda11.cudnn805
    Uninstalling jaxlib-0.3.25+cuda11.cudnn805:
      Successfully uninstalled jaxlib-0.3.25+cuda11.cudnn805
  Attempting uninstall: jax
    Found existing installation: jax 0.3.25
    Uninstalling jax-0.3.25:
      Successfully uninstalled jax-0.3.25
Successfully installed chex-0.1.5 deprecation-2.1.0 distrax-0.1.2 gpjax-0.5.9 jax-0.4.2 jaxkern-0.0.5 jaxlib-0.4.2 jaxlinop-0.0.3 jaxtyping-0.2.11 jaxutils-0.0.8 ml-collections-0.1.0 optax-0.1.4 typeguard-2.13.3

@seanjkanderson seanjkanderson reopened this Feb 2, 2023
@seanjkanderson
Copy link
Author

seanjkanderson commented Feb 3, 2023

The workaround for using a colab notebook seems to be to use:

!pip install gpjax
!pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

where the second reinstalls the cuda version of jaxlib.
I would be interested to know if there's a convenient command for installing gpjax with cuda support. I believe the Installation guide says to set CUDA_VERSION=XX before installing gpjax but this didn't seem to work for me in Colab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants