diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index d59a9eec751e..876b17dbe8b4 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -114,6 +114,7 @@ ceil as ceil, clip as clip, conj as conj, + copysign as copysign, cos as cos, cosh as cosh, divide as divide, @@ -139,6 +140,8 @@ logical_not as logical_not, logical_or as logical_or, logical_xor as logical_xor, + maximum as maximum, + minimum as minimum, multiply as multiply, negative as negative, not_equal as not_equal, @@ -148,6 +151,7 @@ remainder as remainder, round as round, sign as sign, + signbit as signbit, sin as sin, sinh as sinh, sqrt as sqrt, @@ -168,7 +172,9 @@ concat as concat, expand_dims as expand_dims, flip as flip, + moveaxis as moveaxis, permute_dims as permute_dims, + repeat as repeat, reshape as reshape, roll as roll, squeeze as squeeze, @@ -179,6 +185,7 @@ argmax as argmax, argmin as argmin, nonzero as nonzero, + searchsorted as searchsorted, where as where, ) diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index cafd0371e475..c34e9d93cfb0 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -17,7 +17,6 @@ result_type as _result_type, isdtype as _isdtype, ) -import numpy as np def _promote_dtypes(name, *args): @@ -148,6 +147,11 @@ def conj(x, /): return jax.numpy.conj(x) +def copysign(x1, x2, /): + """Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1.""" + return jax.numpy.copysign(x1, x2) + + def cos(x, /): """Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x.""" x, = _promote_dtypes("cos", x) @@ -300,6 +304,18 @@ def logical_xor(x1, x2, /): return jax.numpy.logical_xor(x1, x2) +def maximum(x1, x2, /): + """Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("maximum", x1, x2) + return jax.numpy.maximum(x1, x2) + + +def minimum(x1, x2, /): + """Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("minimum", x1, x2) + return jax.numpy.minimum(x1, x2) + + def multiply(x1, x2, /): """Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" x1, x2 = _promote_dtypes("multiply", x1, x2) @@ -356,6 +372,11 @@ def sign(x, /): return jax.numpy.sign(x) +def signbit(x, /): + """Determines whether the sign bit is set for each element x_i of the input array x.""" + return jax.numpy.signbit(x) + + def sin(x, /): """Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x.""" x, = _promote_dtypes("sin", x) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index bac3afd67472..fdc83fc83090 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -47,11 +47,21 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: return jax.numpy.flip(x, axis=axis) +def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array: + """Moves array axes (dimensions) to new positions, while leaving other axes in their original positions.""" + return jax.numpy.moveaxis(x, source, destination) + + def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: """Permutes the axes (dimensions) of an array x.""" return jax.numpy.permute_dims(x, axes=axes) +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + """Repeats each element of an array a specified number of times on a per-element basis.""" + return jax.numpy.repeat(x, repeats=repeats, axis=axis) + + def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused diff --git a/jax/experimental/array_api/_searching_functions.py b/jax/experimental/array_api/_searching_functions.py index 8357ae3eae86..f329e4add813 100644 --- a/jax/experimental/array_api/_searching_functions.py +++ b/jax/experimental/array_api/_searching_functions.py @@ -33,6 +33,15 @@ def nonzero(x, /): return jax.numpy.nonzero(x) +def searchsorted(x1, x2, /, *, side='left', sorter=None): + """ + Finds the indices into x1 such that, if the corresponding elements in x2 + were inserted before the indices, the order of x1, when sorted in ascending + order, would be preserved. + """ + return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter) + + def where(condition, x1, x2, /): """Returns elements chosen from x1 or x2 depending on condition.""" dtype = _result_type(x1, x2) diff --git a/tests/array_api_test.py b/tests/array_api_test.py index b11fc35845ff..9871c100b3ec 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -65,6 +65,7 @@ 'complex64', 'concat', 'conj', + 'copysign', 'cos', 'cosh', 'divide', @@ -115,9 +116,12 @@ 'matmul', 'matrix_transpose', 'max', + 'maximum', 'mean', 'meshgrid', 'min', + 'minimum', + 'moveaxis', 'multiply', 'nan', 'negative', @@ -133,11 +137,14 @@ 'prod', 'real', 'remainder', + 'repeat', 'reshape', 'result_type', 'roll', 'round', + 'searchsorted', 'sign', + 'signbit', 'sin', 'sinh', 'sort',