Skip to content

Commit

Permalink
feat(ops): support np.argpartition (#19588)
Browse files Browse the repository at this point in the history
* feat(ops): support np.argpartition

* updated documentation, type-casting, and tf implementation

* fixed tf implementation

* added torch cast to int32

* updated torch type and API generated files

* added torch output type cast
  • Loading branch information
lpizzinidev authored Apr 27, 2024
1 parent fe03ca5 commit 688daa5
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
from keras.src.ops.numpy import arctanh
from keras.src.ops.numpy import argmax
from keras.src.ops.numpy import argmin
from keras.src.ops.numpy import argpartition
from keras.src.ops.numpy import argsort
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras.src.ops.numpy import arctanh
from keras.src.ops.numpy import argmax
from keras.src.ops.numpy import argmin
from keras.src.ops.numpy import argpartition
from keras.src.ops.numpy import argsort
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
from keras.src.ops.numpy import arctanh
from keras.src.ops.numpy import argmax
from keras.src.ops.numpy import argmin
from keras.src.ops.numpy import argpartition
from keras.src.ops.numpy import argsort
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras.src.ops.numpy import arctanh
from keras.src.ops.numpy import argmax
from keras.src.ops.numpy import argmin
from keras.src.ops.numpy import argpartition
from keras.src.ops.numpy import argsort
from keras.src.ops.numpy import array
from keras.src.ops.numpy import average
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,7 @@ def select(condlist, choicelist, default=0):
def slogdet(x):
x = convert_to_tensor(x)
return tuple(jnp.linalg.slogdet(x))


def argpartition(x, kth, axis=-1):
return jnp.argpartition(x, kth, axis)
4 changes: 4 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,7 @@ def select(condlist, choicelist, default=0):

def slogdet(x):
return tuple(np.linalg.slogdet(x))


def argpartition(x, kth, axis=-1):
return np.argpartition(x, kth, axis).astype("int32")
21 changes: 21 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,3 +2435,24 @@ def select(condlist, choicelist, default=0):
def slogdet(x):
x = convert_to_tensor(x)
return tuple(tf.linalg.slogdet(x))


def argpartition(x, kth, axis=-1):
x = convert_to_tensor(x, tf.int32)

x = swapaxes(x, axis, -1)
bottom_ind = tf.math.top_k(-x, kth + 1).indices

n = tf.shape(x)[-1]

mask = tf.reduce_sum(tf.one_hot(bottom_ind, n, dtype=tf.int32), axis=0)

indices = tf.where(mask)
updates = tf.squeeze(tf.zeros(tf.shape(indices)[0], dtype=tf.int32))

final_mask = tf.tensor_scatter_nd_update(x, indices, updates)

top_ind = tf.math.top_k(final_mask, tf.shape(x)[-1] - kth - 1).indices

out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1)
return swapaxes(out, -1, axis)
20 changes: 20 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,3 +1613,23 @@ def select(condlist, choicelist, default=0):
def slogdet(x):
x = convert_to_tensor(x)
return tuple(torch.linalg.slogdet(x))


def argpartition(x, kth, axis=-1):
x = convert_to_tensor(x, "int32")

x = torch.transpose(x, axis, -1)
bottom_ind = torch.topk(-x, kth + 1)[1]

def set_to_zero(a, i):
a[i] = 0
return a

for _ in range(x.dim() - 1):
set_to_zero = torch.vmap(set_to_zero)
proxy = set_to_zero(torch.ones(x.shape, dtype=torch.int32), bottom_ind)

top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1]

out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1)
return cast(torch.transpose(out, -1, axis), "int32")
42 changes: 42 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6191,3 +6191,45 @@ def slogdet(x):
if any_symbolic_tensors((x,)):
return Slogdet().symbolic_call(x)
return backend.numpy.slogdet(x)


class Argpartition(Operation):
def __init__(self, kth, axis=-1):
super().__init__()
if not isinstance(kth, int):
raise ValueError("kth must be an integer. Received:" f"kth = {kth}")
self.kth = kth
self.axis = axis

def call(self, x):
return backend.numpy.argpartition(x, kth=self.kth, axis=self.axis)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype="int32")


@keras_export(["keras.ops.argpartition", "keras.ops.numpy.argpartition"])
def argpartition(x, kth, axis=-1):
"""Performs an indirect partition along the given axis.
It returns an array
of indices of the same shape as `x` that index data along the given axis
in partitioned order.
Args:
a: Array to sort.
kth: Element index to partition by.
The k-th element will be in its final sorted position and all
smaller elements will be moved before it and all larger elements
behind it. The order of all elements in the partitions is undefined.
If provided with a sequence of k-th it will partition all of them
into their sorted position at once.
axis: Axis along which to sort. The default is -1 (the last axis).
If `None`, the flattened array is used.
Returns:
Array of indices that partition `x` along the specified `axis`.
"""
if any_symbolic_tensors((x,)):
return Argpartition(kth, axis).symbolic_call(x)
return backend.numpy.argpartition(x, kth, axis)
48 changes: 48 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,14 @@ def test_vstack(self):
y = KerasTensor((None, None))
self.assertEqual(knp.vstack([x, y]).shape, (None, 3))

def test_argpartition(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.argpartition(x, 3).shape, (None, 3))
self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (None, 3))

with self.assertRaises(ValueError):
knp.argpartition(x, (1, 3))


class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
def test_mean(self):
Expand Down Expand Up @@ -1981,6 +1989,14 @@ def test_vstack(self):
y = KerasTensor((2, 3))
self.assertEqual(knp.vstack([x, y]).shape, (4, 3))

def test_argpartition(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.argpartition(x, 3).shape, (2, 3))
self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (2, 3))

with self.assertRaises(ValueError):
knp.argpartition(x, (1, 3))


class NumpyTwoInputOpsCorretnessTest(testing.TestCase, parameterized.TestCase):
def test_add(self):
Expand Down Expand Up @@ -4303,6 +4319,19 @@ def myfunc(a, b):
out, np.vectorize(np.diag, signature="(d,d)->(d)")(np.eye(4))
)

def test_argpartition(self):
x = np.array([3, 4, 2, 1])
self.assertAllClose(knp.argpartition(x, 2), np.argpartition(x, 2))
self.assertAllClose(knp.Argpartition(2)(x), np.argpartition(x, 2))

x = np.array([[3, 4, 2], [1, 3, 1]])
self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1))
self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1))

x = np.array([[[3, 4], [2, 3]], [[1, 2], [0, 1]]])
self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1))
self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1))


class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
def test_ones(self):
Expand Down Expand Up @@ -5402,6 +5431,25 @@ def test_argmin(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_argpartition(self, dtype):
import jax.numpy as jnp

if dtype == "bool":
self.skipTest("argpartition doesn't support bool dtype")

x = knp.array([1, 2, 3], dtype=dtype)
x_jax = jnp.array([1, 2, 3], dtype=dtype)
expected_dtype = standardize_dtype(jnp.argpartition(x_jax, 1).dtype)

self.assertEqual(
standardize_dtype(knp.argpartition(x, 1).dtype), expected_dtype
)
self.assertEqual(
standardize_dtype(knp.Argpartition(1).symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_argsort(self, dtype):
import jax.numpy as jnp
Expand Down

0 comments on commit 688daa5

Please sign in to comment.