-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] numpy
Tests Fail when jax
is Installed
#76
Comments
Hi Matt, That's right, the numpy functionality will not work properly if jax is installed. This is something we are aware of (and only cover the other cases in the test suite) but to some extent we are trying to decide whether this should be considered the correct behavior or not. And of course whatever decision we make, that should be detailed more prominently in the docs. A fair amount of the functionality will not work with numpy (e.g. autograd enabled transposes, vmapped linear operator for unary function application), so if jax is properly installed we were figuring that the user would want that functionality enabled even if they were just using numpy arrays to construct their operators. Do you have a sense for some important use cases where one would want to use numpy even when jax is installed? |
Hey @mfinzi. I certainly agree that for the majority of users, it makes sense to 'automatically switch' to the As a (somewhat contrived) example of where this 'automatic backend switching' might cause trouble, suppose another library chose to use Just a couple of other quick thoughts:
Thanks for your help on this. Cheers, |
🐛 Bug
Some of the
numpy
backend unit tests fail ifjax
is installed, but pass whenjax
is not installed (i.e. these tests are 'flaky').To reproduce
Test results when
jax
is not installed:Test results when
jax
is installed:Note that similar results are observed when
pytest -m "numpy" -k "test_get_lu_from_tridiagonal"
is run.Expected Behavior
Ideally, unit tests should run in a predictable and consistent manner, with the result of a given test not depending on which optional dependencies the user may or may not have installed on their machine.
System information
jax
version: 0.4.16Additional context
I encountered this issue when running the test suite for the first time before starting work on #75. It appears that the current CI workflow doesn't 'pick-up' on this problem because the
numpy
tests are only executed tests whenjax
is not installed.From my own experiments, it seems that the source of the flaky-ness in these
numpy
tests is thatcola.backends.get_library_fns
correctly infers the back-end of anumpy
array to benumpy_fns
whenjax
is not installed, but incorrectly infers the back-end to bejax_fns
whenjax
is installed. We can see why this occurs by considering the current implementation ofget_library_fns
:i.e.
get_library_fns
will infer the back-end to bejax
ifjax
can be imported and ifdtype
matches with ajax.numpy
type. Unfortunately, it turns out (much to my surprise) thatjax.numpy
types are basically just aliases fornumpy
types, which means that Python evaluatesjax.numpy
andnumpy
types as equal to one another:This means
get_library_fns
will always returnjax_fns
when provided with anumpy
array ifjax
is installed. Even more surprisingly, thedtype
property of ajax.numpy
array is not even guaranteed to be ajax.numpy
type:I think these observations illustrate that the 'premise' behind the
get_library_fns
function (i.e. that you can determine which back-end to use purely based on thedtype
property of an array) probably isn't sound.Proposed Solutions
Two potential fixes come to mind:
get_library_fns
function and replace it with a similar function that requires the user to explicitly name the back-end they wish to be returned.get_library_fns
to 'force' it to return thenumpy
backend, even whenjax
is installed; this flag can then be used during thenumpy
tests to ensure that they're consistent.I'm more than happy to work on this issue, but it would be great to hear what others think about all this first. Thanks in advance for any help.
Cheers,
Matt.
The text was updated successfully, but these errors were encountered: