You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Originally posted by patel-zeel November 1, 2022
Hi,
I was recently comparing my implementation of a method with GPJAX and noticed that things completely change after I import gpjax as gpx. It took me a while to figure out that it happened because of a change in the precision (from 32bit to 64 bit). Since, jax.random behaves differently for different precisions, I was getting completely different results (please refer to the code below). Would it be useful not to enable 64bit by default to avoid such problems? Another way to handle this might be to use jax.config.update("jax_enable_x64", True) in all documentation examples to make it a habit for gpjax users.
Discussed in #126
Originally posted by patel-zeel November 1, 2022
Hi,
I was recently comparing my implementation of a method with GPJAX and noticed that things completely change after I
import gpjax as gpx
. It took me a while to figure out that it happened because of a change in the precision (from 32bit to 64 bit). Since,jax.random
behaves differently for different precisions, I was getting completely different results (please refer to the code below). Would it be useful not to enable 64bit by default to avoid such problems? Another way to handle this might be to usejax.config.update("jax_enable_x64", True)
in all documentation examples to make it a habit for gpjax users.The text was updated successfully, but these errors were encountered: