You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running a TPU v3-8 VM on Google. On the VM I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.
Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU. when issuing jax.device_count(). The same holds for pip install jax==0.2.12. Only when I'm using pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.
How can I get it running with these versions?
The text was updated successfully, but these errors were encountered:
Hello
I'm running a TPU v3-8 VM on Google. On the VM I installed jax with
pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
.Unfortunately, I'm getting the message
No GPU/TPU found, falling back to CPU.
when issuingjax.device_count()
. The same holds forpip install jax==0.2.12
. Only when I'm usingpip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.How can I get it running with these versions?
The text was updated successfully, but these errors were encountered: