-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add repeat
to the specification
#690
Conversation
I've updated the proposed specification to include a note advising conforming array libraries to include a warning regarding device synchronization if |
@leofang Would you mind giving this PR a review? I believe this PR addresses the concerns you raised in #654 (comment), but I want to confirm before merging. |
@leofang Pinging in case you missed the above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing this, Athan. Sorry I missed the ping. Took a stab at it, no concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more question: Should we also add a note on "data-dependent output shapes" like what we do for unique*
/nonzero
?
@leofang Re: admonition. I am not certain. It's only when And for that case, we include a note regarding device synchronization. If we add a "data-dependent admonition" here, this would make this API optional, which I am not sure is desirable. As a point of reference, in |
Doesn't this API always have the output shape determined by the input data ( |
@leofang I added the data-dependent shape admonition. Given that JAX requires a Now that this has been added, I believe that this PR should be ready for another review. cc @rgommers |
Kinda sorta, but I think the "data-dependent shape" admonition is more aimed at the input values of the input array. E.g., the most common usage here will be with a literal int: I played with this a bit with JAX: >>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(3)
>>> jnp.repeat(x, 2)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> jnp.repeat(x, 2, total_repeat_length=x.size*2) # the documented way to allow JIT-ing
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> def func(x):
... return jnp.repeat(x, 2)
...
>>> func(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> # It's not actually needed to use `total_repeat_length` if `repeats` is a literal int:
>>> jax.jit(func)(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> # It is needed if we make `repeats` data-dependent:
>>> def func(x):
... return jnp.repeat(x, x[2])
...
>>> func(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> jax.jit(func)(x)
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)> A few conclusions:
# from scipy.signal tests
repeats[1::2] = x[1::2]
x = np.repeat(x, repeats)
# from scipy.integrate functionality
diff.data /= np.repeat(h, np.diff(diff.indptr)) My suggested resolution:
|
@rgommers We discussed making |
I'd say I agree with Leo's comments in that thread. There's just not much of a point of it being a sequence. It is not like you can do NumPy also documents it as an ndarray; the only reason a sequence works for NumPy is because it calls |
Q: Does it make sense to say "when the input is an array, this API has data-dependent output shape" and followed by the note? |
I think so. I'd generalize it slightly - maybe "unless the |
Okay. I've dropped support for sequences and updated the admonition. The admonition now only allows optional support for providing an array as the second argument; all conforming libraries must support providing an integer. As providing sequences is still controversial and can be added in a subsequent revision of the standard, I think it is fine to omit for the time being. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kgryte. The repeats
treatment LGTM. Two other minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now. I think this should be good to go; will aim to merge this at the end of today unless there are new comments.
Thanks @kgryte & all reviewers! |
This PR
repeat
to the array API specification.repeats
to be either anint
or anarray
. NumPy and other inspired libraries and TensorFlow support one-dimensional arrays. NumPy also supports lists and tuples. CuPy docs suggest support for only lists and tuples. PyTorch supports a one-dimensional array; however, there has been discussion (linked to in the linked RFC) preferring sequences over arrays due to synchronization issues. However, it's not clear that providing a sequence of integers is particularly common or useful. In this PR, I've chosen to explicitly typerepeats
to supportint
and array. Should sequences be considered acceptable, this can be revisited in a future revision of the Array API standard.repeats
argument.