Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] numpy Tests Fail when jax is Installed #76

Open
mabilton opened this issue Oct 4, 2023 · 2 comments
Open

[Bug] numpy Tests Fail when jax is Installed #76

mabilton opened this issue Oct 4, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@mabilton
Copy link

mabilton commented Oct 4, 2023

🐛 Bug

Some of the numpy backend unit tests fail if jax is installed, but pass when jax is not installed (i.e. these tests are 'flaky').

To reproduce

# Set-up
git clone https://github.com/wilson-labs/cola.git
cd cola
python -m venv venv
source venv/bin/activate
pip install -e ".[dev]"
pip install -r docs/requirements.txt
# Run tests without `jax` installed:
pytest -m "numpy"  -k "test_unary"
# Re-run tests with `jax` installed:
pip install jax jaxlib
pytest -m "numpy" -k "test_unary"

Test results when jax is not installed:

============================= test session starts ==============================
platform linux -- Python 3.10.6, pytest-7.4.2, pluggy-1.3.0
rootdir: /home/mabilton/cola
configfile: setup.cfg
plugins: anyio-4.0.0, cov-4.1.0
collected 460 items / 428 deselected / 32 selected

tests/linalg/test_unary.py ......F.........................              [100%]

=================================== FAILURES ===================================
...
=========================== short test summary info ============================
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'exp', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
=========== 1 failed, 31 passed, 428 deselected, 9 warnings in 1.95s ===========

Test results when jax is installed:

============================= test session starts ==============================
platform linux -- Python 3.10.6, pytest-7.4.2, pluggy-1.3.0
rootdir: /home/mabilton/cola
configfile: setup.cfg
plugins: anyio-4.0.0, cov-4.1.0
collected 460 items / 428 deselected / 32 selected

tests/linalg/test_unary.py ..FFF...........FF...FFF....FF..              [100%]

=================================== FAILURES ===================================
...
=========================== short test summary info ============================
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_lowertriangular', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_lowertriangular', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_kronsum', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_dense', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'square_dense', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_kron', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'exp', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_big', 'sqrt', marks=[MarkDecoratormark=Markname='big', args=, kwargs={}, MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_blockdiag', 'exp', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
FAILED tests/linalg/test_unary.py::test_unary[ParameterSetvalues='numpy', 'float64', 'psd_blockdiag', 'sqrt', marks=[MarkDecoratormark=Markname='numpy', args=, kwargs={}], id=None]
========= 10 failed, 22 passed, 428 deselected, 70 warnings in 15.03s ==========

Note that similar results are observed when pytest -m "numpy" -k "test_get_lu_from_tridiagonal" is run.

Expected Behavior

Ideally, unit tests should run in a predictable and consistent manner, with the result of a given test not depending on which optional dependencies the user may or may not have installed on their machine.

System information

  • Cola version: 0.0.4
  • jax version: 0.4.16
  • OS: Pop!_OS 22.04 LTS

Additional context

I encountered this issue when running the test suite for the first time before starting work on #75. It appears that the current CI workflow doesn't 'pick-up' on this problem because the numpy tests are only executed tests when jax is not installed.

From my own experiments, it seems that the source of the flaky-ness in these numpy tests is that cola.backends.get_library_fns correctly infers the back-end of a numpy array to be numpy_fns when jax is not installed, but incorrectly infers the back-end to be jax_fns when jax is installed. We can see why this occurs by considering the current implementation of get_library_fns:

def get_library_fns(dtype):
    try:
        from jax import numpy as jnp
        if dtype in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128, jnp.int32, jnp.int64]:
            from cola.backends import jax_fns as fns
            return fns
    except ImportError:
        pass
    ...
    if dtype in [np.float32, np.float64, np.complex64, np.complex128, np.int32, np.int64]:
        from cola.backends import np_fns as fns
        return fns
    raise ImportError("No supported array library found")

i.e. get_library_fns will infer the back-end to be jax if jax can be imported and if dtype matches with a jax.numpy type. Unfortunately, it turns out (much to my surprise) that jax.numpy types are basically just aliases for numpy types, which means that Python evaluates jax.numpy and numpy types as equal to one another:

import numpy as np
import jax.numpy as jnp
print(jnp.float32 == np.float32)
# Prints: True

This means get_library_fns will always return jax_fns when provided with a numpy array if jax is installed. Even more surprisingly, the dtype property of a jax.numpy array is not even guaranteed to be a jax.numpy type:

import jax.numpy as jnp
x = jnp.array([1.,2.,3.], dtype=jnp.float32)
print(type(x.dtype))
# Prints: <class 'numpy.dtype[float32]'>

I think these observations illustrate that the 'premise' behind the get_library_fns function (i.e. that you can determine which back-end to use purely based on the dtype property of an array) probably isn't sound.

Proposed Solutions

Two potential fixes come to mind:

  1. Deprecate the get_library_fns function and replace it with a similar function that requires the user to explicitly name the back-end they wish to be returned.
  2. Add an additional flag to get_library_fns to 'force' it to return the numpy backend, even when jax is installed; this flag can then be used during the numpy tests to ensure that they're consistent.

I'm more than happy to work on this issue, but it would be great to hear what others think about all this first. Thanks in advance for any help.

Cheers,
Matt.

@mabilton mabilton added the bug Something isn't working label Oct 4, 2023
@mfinzi
Copy link
Collaborator

mfinzi commented Oct 4, 2023

Hi Matt,

That's right, the numpy functionality will not work properly if jax is installed. This is something we are aware of (and only cover the other cases in the test suite) but to some extent we are trying to decide whether this should be considered the correct behavior or not. And of course whatever decision we make, that should be detailed more prominently in the docs.

A fair amount of the functionality will not work with numpy (e.g. autograd enabled transposes, vmapped linear operator for unary function application), so if jax is properly installed we were figuring that the user would want that functionality enabled even if they were just using numpy arrays to construct their operators.
In some ways we are still trying to think about how we want to position the limited numpy support.

Do you have a sense for some important use cases where one would want to use numpy even when jax is installed?

@mabilton
Copy link
Author

mabilton commented Oct 7, 2023

Hey @mfinzi.

I certainly agree that for the majority of users, it makes sense to 'automatically switch' to the jax backend from the numpy backend when possible. In saying that, would there be any harm in allowing users to explicitly disable this 'automatic switching', particularly if doing so would allow us to write tests that run a bit more consistently?

As a (somewhat contrived) example of where this 'automatic backend switching' might cause trouble, suppose another library chose to use cola to perform some linear algebra operations, but also decided to implement 'jax-unfriendly' computations (e.g. computations involving dynamic array shapes, or performing an in-place array update) later on in their code. If the maintainers of this library don't have jax installed, everything will appear fine when they use numpy arrays. If a user of this library, however, also happens to have jax installed on their system, the later 'jax-unfriendly' steps would likely throw an error (unless the maintainers of the library explicitly converted everything back to np.arrays after the steps involving cola, which might be a bit tedious). As I say, a bit of a silly example, but I think the general point about 'silently' converting numpy arrays to jax arrays still stands.

Just a couple of other quick thoughts:

  1. If we stick with the current behavior of get_library_fns, I think it would be useful to add a brief comment explaining that the jax backend will be returned if a numpy dtype is passed but jax is installed, since I don't think this is obvious from the function definition.
  2. If we choose to stick with the current testing behavior (i.e. some numpy tests failing if jax is installed), it might be worth checking out whether something like pytest.mark.skipif can be used to automatically 'turn off' those tests that are expected to fail.

Thanks for your help on this.

Cheers,
Matt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants