-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of the Python Array API standard #16099
Conversation
2932adb
to
87ff9c7
Compare
596164b
to
b35c628
Compare
It looks like there's a bug in the 2022.12 release of array-api; I filed an issue here: data-apis/array-api#631 |
Nice to see this work. Feel free to reach out if you have any questions on the spec or the test suite. |
I notice that you're using a separate namespace here. I'm curious what your rationale is for that. Is it mostly just so that you can experiment without having to worry about breaking things? I would recommend aiming to make the main I should also point out that we have the array-api-compat library, https://github.com/data-apis/array-api-compat, which can be used to provide a compatibility layer if there are places where JAX deviates from the standard and cannot easily change because of backwards compatibility concerns. It already supports NumPy, CuPy, and PyTorch, and is being used by scikit-learn and SciPy (and hopefully others soon). |
I'm using a separate namespace for ease of experimentation. Making the main namespace compliant will be a much larger project because of the number of existing behaviors that will have to be deprecated. (I suspect this is the same reason e.g. |
In any case, it's not clear that JAX can be made compatible with the array-api at all, because nearly every test in Is mutation a necessary feature of the array API standard? If so then JAX is basically disqualified entirely. If not, then perhaps this is a bug in |
We are running into other issues as well; e.g. In the array API standard, |
Mutation is not a requirement of the array API. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html. This was done specifically so that libraries like JAX can be made compliant. However, it is true that the test suite currently expects mutation to work. This is the first time someone has tried to implement the array API on a library that doesn't allow mutation. So we're going to have to do some work to update it to only do mutation in places where it is actually required.
That's a harder thing for sure (the good news is that at least this shouldn't be an issue for the test suite, as device support isn't really tested right now). PyTorch has a similar issue for The approach we're using for that is to define a |
CCing @honno on the test suite stuff. Fixing things to not require mutation could be a hard problem, given that it's part of hypothesis itself (maybe hypothesis should instead generate a list of lists and pass it to Does your separate namespace here somehow enable mutation? As far as I can tell, you're just reusing the same JAX array object. If not, the test suite will still be unusable for you in its current state. What you're doing here is somewhat akin to what might go in |
No, there's no meaningful way to make JAX programs support in-place array mutation.
Sounds good! I'd propose working off this branch for now, and long-term we can discuss what makes the most sense. Let me know if there's anything I can do to help get the test suite working for non-mutable objects. |
@shoyer pointed out that we could fix the At the very least it gives a viable deprecation path if we want to make that change to the core API. |
Assuming your |
Note there's no specified way in the array API to check what dtypes a given namespace supports. Currently in Hypothesis and A dirty "fix"—is there a world where Also commented on this in data-apis/array-api#499 (comment) |
This is what I would recommend. Or at least make them error when they are used. JAX creation functions returning Making things work for array API consumers should be the top priority. We can always adjust the test suite. |
For what it's worth - my solution to this has been to only run the tests with |
I guess libraries like scikit-learn and scipy would need to discuss whether this sort of solution works for them (or indeed whether this is an actual problem at all for them; I guess it depends if there's anywhere where they require float64). The problem with an environment variable is that it can only be set by an end-user. |
I opened an upstream issue to track test suite support for non-mutation data-apis/array-api-tests#188 |
Has the Array API team considered adding try:
rm = result.mutable()
rm[i] = val
except Exception as e:
... and adding a |
To be clear, my statement "I think we could make it work in JAX" in the linked comment was my own opinion, not one vetted with the JAX team. In principle, the object returned by @dataclasses.dataclass
class Mutable:
value: jax.Array
def __setitem__(self, key, value):
self.value = self.value.at[key].set(value)
def __jax_array__(self):
return self.value
# TODO(shoyer): implement array methods That said, I'm not sure it would actually be worth the cognitive overhead of adding this into JAX. It might be just as sane to imagine adding
|
I think it might be worth elaborating on your idea a bit more completely to really evaluate it. Yes, the mutable method is additional cognitive overhead, but it saves the cognitive overhead of using the |
One wart here is that the array API specifies that We work around this because the array API defines the expected device type as the type of the object returned by A bit messy, but it works, and avoids us having to deprecate the |
The compat library already has a
The object returned by |
Oh well, there goes my clever hack... how bad is it in practice if |
Well right now everyone using the array API is using the device() helper function from the compat library, because numpy doesn't even have |
What about adding a dummy |
Yeah, I think that would be a good way forward. A deeper problem, though, is that in general jax arrays can be sharded across multiple devices, which seems to conflict with the core assumption of the Array API that each array lives on a single device. |
I think it's DLPack assuming a single-device setting. For Array API, |
Complicating this, it looks like there is a TODO to remove the existing |
ef5bed7
to
c35eac2
Compare
f11e740
to
421df9f
Compare
Part of #18353
Usage:
And then
xp
is the Array API namespace backed by JAX.This initial implementation still has some missing features (see
array-api-skips.txt
for examples of known failures) and so I'm not yet making it available viajax.Array.__array_namespace__
unlessjax.experimental.array_api
is first explicitly imported.