-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug: Installation with GPU support #102
Comments
Hi @seanjkanderson. Thanks for raising this - it's a particularly tricky issue that I've been trying to get to the bottom of for a while now. We've just released v0.4.12 which resolves both of your above issues. To be explicit - you can now run GPJax on Colab GPUs without having to restart the kernel post installation. I have put the below notebook together that will let you get started with this. Please do reach out with any further issues - I'd be interested to hear about your experiences, be it positive or negative! https://colab.research.google.com/drive/1FWyhE8XLJbWaneDRUNzuZQvNueO-kLdc?usp=sharing If you are happy that this resolves your issue, then please feel free to close this issue. |
Thanks for the help @thomaspinder. That example works for me in Colab and on a VM running Ubuntu (and I was able to get it working in my project as well). It seems like you intentionally didn't use gpjax.fit() as tree_map and InferenceState don't seem compatible at the moment? Is that right or is there a new expected behavior/use of .fit()? Just curious as I can use the training approach as described in the notebook. Cheers! |
No problem @seanjkanderson!
inference_state = gpx.fit(mll, params, trainables, opt)
final_params_fit, history = inference_state.unpack()
final_params_fit = gpx.transform(final_params_fit, constrainer) I have also added a cell to the above linked notebook with this code in - the only reason I left it out originally was to stay true to the README code. |
Got it, my mistake was passing the ParameterState object to transform(). Thanks @thomaspinder! |
Just got back to using GPJax on Colab/a machine with GPU support. It seems that the install for GPJax is back to uninstalling jaxlib with CUDA support. What's the workaround here @thomaspinder? Thanks! Example section of trace from
|
The workaround for using a colab notebook seems to be to use:
where the second reinstalls the cuda version of jaxlib. |
Bug Report
Moving this from
Discussions
as not seeing much traffic there.GPJax version: 0.4.10
Current behavior:
I can install GPJax with CPU only support but not with GPU support. If I try to install gpjax with cuda support, either based on the current gpjax readthedocs or by first installing Jax according to the latest install with cuda directions:
I've tried to find a matching CUDA-compatible version of
jax==0.3.5
but haven't found a version that is also supported by my machine.I've also tried installing it on a Google Colab notebook with
!pip install gpjax
and this disables GPU support based on the package versions specified in setup.py.I'm thinking the setup.py file needs to be updated to a higher jax version and potentially make corresponding updates to changes in jax since 0.3.5, or I'm misunderstanding how to configure GPU support with GPJax.
Expected behavior:
I would expect that GPJax does have GPU support.
Steps to reproduce:
Ubuntu 20.04.3 and have tried python 3.8-3.10
run the regression example for instance:
Running this returns the following warning and error:
I can try to provide something more minimal, but I'm guessing I'm just missing something more fundamental on install.
In Colab, run
!pip install gpjax
Then reload as per the prompt
import gpjax as gpx
At that point it should indicate their is no GPU/TPU found
Thanks!
The text was updated successfully, but these errors were encountered: