Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should get_namespace support more than arrays? #799

Open
NeilGirdhar opened this issue Apr 29, 2024 · 6 comments
Open

Should get_namespace support more than arrays? #799

NeilGirdhar opened this issue Apr 29, 2024 · 6 comments

Comments

@NeilGirdhar
Copy link

NeilGirdhar commented Apr 29, 2024

Background

Consider a probability distribution library such as efax. To add Array API support, each probability distribution class will contain a number of parameters, and it makes sense that they will all be from the same namespace. Thus, the a standard pattern for methods that use the parameters will be to query the namespace of self by feeding in all of the parameters to get_namespace.

This pattern is not unique to efax. I imagine it will pop up in SciPy's future distribution classes (that are being developed and will support the Array API). It could be added to any object exposing the "Jax PyTrees" interface (see the registry) or generally any aggregate structure with a homogenous set of arrays.

Motivation

To simplify getting the namespace in functions that interact with aggregate structure containing homogeneous sets of arrays.

Example

Suppose that Distribution is an aggregate structure containing array-valued parameters. Instead of:

def f(x: Distribution, y: Distribution, z: Array):
  xp = x.get_namespace()  # Call method to get namespace.
  assert y.get_namespace() == xp  # Call method and check that it's the same namespace.
  assert get_namespace(z) == xp  # Check that it's the same namespace.

we would like to simply have:

def f(x: Distribution, y: Distribution, z: Array):
  xp = get_namespace(x, y, z)  # One simple line

Proposal

Extend get_namespace(o) to first read o.__namespace_arrays__, which returns an iterable of arrays that get_namespace can use as before.

Thus, instead of aggregate structures proving a method that queries the namespace like this function, we would instead have

class Distribution:
    def __namespace_arrays__(self) -> Iterable[Array]:
        return (getattr(self, field.name) for field in fields(self))

A simple recursive extension of get_namespace is illustrated here.

Alternative proposal

One alternative is to support __array_namespace__ on all inputs to get_namespace. Thus, we would have

class Distribution:
    def __array_namespace__(self, api_version: str, use_compat: bool) -> ArrayNameSpace:
        return get_namespace(*(getattr(self, field.name) for field in fields(self)),
                             api_version=api_version,
                             use_compat=use_compat)

The problem with this is that it complicates extending the parameter specification of get_namespace.

@betatim
Copy link
Member

betatim commented Apr 29, 2024

In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). https://github.com/scikit-learn/scikit-learn/blob/19c068f64249f95f745962b42a4dd581c7738218/sklearn/utils/_array_api.py#L473

Could you do something like that in efax? Or asked differently, aren't you going to end up having something like this sooner or later anyway, in which case it could also take care of dealing with "things that aren't arrays but contain them"?

@NeilGirdhar
Copy link
Author

How is the linked function related? It doesn't deal with aggregate structures, which is the motivation for this proposal.

@asmeurer
Copy link
Member

get_namespace isn't actually part of the array API, it's part of the compat library. The array API defines x.__array_namespace__. The compat library get_namespace() (which is also called array_namespace()) is just a wrapper around calling this method which manually returns the compat layer namespace when necessary. Maybe this should be made clearer in the documentation.

I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?

In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). scikit-learn/scikit-learn@19c068f/sklearn/utils/_array_api.py#L473

Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.

@NeilGirdhar
Copy link
Author

I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?

Yes, but it's just a question of convenience. Sometimes, you have a method that accepts an aggregate object (say, self) and some arrays (say, x). I guess you could expand out the aggregate object into its component arrays and pass them to get_namespace(self.a, self.b, self.c, x). I'm proposing the convenience of get_namespace(self, x). It's just simplicity.

@betatim
Copy link
Member

betatim commented May 2, 2024

Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.

If you want to, go for it. No strong feelings from my side. I have a slight preference/find it nicer to keep the get_namespace in the compat library simple. At least I can see a future happening where it accumulates "all the useful things" from the various array consuming libraries and then becomes quite unwieldy.

The reason I linked to the custom get_namespace in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace that implements things that are convenient for it. efax could define its own get_namespace that makes dealing with the types that occur in efax convenient.

@NeilGirdhar
Copy link
Author

The reason I linked to the custom get_namespace in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace that implements things that are convenient for it.

Ah, right, that makes sense!

efax could define its own get_namespace that makes dealing with the types that occur in efax convenient.

Right, which is what I'm doing. The reason I suggested upstreaming aggregate structure support is in case there are ever functions that accept aggregate structure types from different libraries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants