Skip to content

Commit

Permalink
Add count_nonzero extension (#6295)
Browse files Browse the repository at this point in the history
Co-authored by: @MarShaikh <[email protected]>
  • Loading branch information
raghuveerbhat authored Nov 10, 2022
1 parent 7c9cb4d commit 5bb74e5
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 7 deletions.
55 changes: 54 additions & 1 deletion ivy/array/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union
from typing import Optional, Union, Tuple

# local
import ivy
Expand Down Expand Up @@ -274,6 +274,59 @@ def exp2(
"""
return ivy.exp2(self._data, out=out)

def count_nonzero(
self: ivy.Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.count_nonzero. This method simply
wraps the function, and so the docstring for ivy.count_nonzero also applies to
this method with minimal changes.
Parameters
----------
self
input array for which to count non-zeros.
axis
optional axis or tuple of axes along which to count non-zeros. Default is
None, meaning that non-zeros will be counted along a flattened
version of the input array.
keepdims
optional, if this is set to True, the axes that are counted are left in the
result as dimensions with size one. With this option, the result
will broadcast correctly against the input array.
dtype
optional output dtype. Default is of type integer.
out
optional output array, for writing the result to.
Returns
-------
ret
Number of non-zero values in the array along a given axis. Otherwise,
the total number of non-zero values in the array is returned.
Examples
--------
>>> x = ivy.array([1, 2, 3])
>>> x.count_nonzero()
ivy.array(3)
>>> x = ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]])
>>> x.count_nonzero(axis=0)
ivy.array([[1, 2],
[2, 2]])
>>> x = ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]])
>>> x.count_nonzero(axis=(0,1), keepdims=True)
ivy.array([[[3, 4]]])
"""
return ivy.count_nonzero(
self._data, axis=axis, keepdims=keepdims, dtype=dtype, out=out
)

def nansum(
self: ivy.Array,
/,
Expand Down
183 changes: 182 additions & 1 deletion ivy/container/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Tuple

# local
import ivy
Expand Down Expand Up @@ -604,6 +604,187 @@ def exp2(
"""
return self.static_exp2(self, out=out)

@staticmethod
def static_count_nonzero(
a: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.count_nonzero. This method simply
wraps the function, and so the docstring for ivy.count_nonzero also applies
to this method with minimal changes.
Parameters
----------
a
container with the base input arrays.
axis
optional axis or tuple of axes along which to count non-zeros. Default is
None, meaning that non-zeros will be counted along a flattened
version of the input array.
keepdims
optional, if this is set to True, the axes that are counted are left in the
result as dimensions with size one. With this option, the result
will broadcast correctly against the input array.
dtype
optional output dtype. Default is of type integer.
key_chains
The key-chains to apply or not apply the method to. Default is None.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is True.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is False.
map_sequences
Whether to also map method to sequences (lists, tuples). Default is False.
out
optional output container, for writing the result to.
Returns
-------
ret
Container including number of non-zero values in the array along a
given axis. Otherwise, container with the total number of non-zero
values in the array is returned.
Examples
--------
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> ivy.Container.static_count_nonzero(x)
{
a: ivy.array(7),
b: ivy.array(7)
}
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> ivy.Container.static_count_nonzero(x, axis=0)
{
a: ivy.array([1, 2, 2, 2]),
b: ivy.array([[1, 2],
[2, 2]])
}
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> ivy.Container.static_count_nonzero(x, axis=(0,1), keepdims=True)
{
a: ivy.array([[7]]),
b: ivy.array([[[3, 4]]])
}
"""
return ContainerBase.multi_map_in_static_method(
"count_nonzero",
a,
axis=axis,
keepdims=keepdims,
dtype=dtype,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def count_nonzero(
self: ivy.Container,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.count_nonzero. This method
simply wraps the function, and so the docstring for ivy.count_nonzero also
applies to this method with minimal changes.
Parameters
----------
self
container with the base input arrays.
axis
optional axis or tuple of axes along which to count non-zeros. Default is
None, meaning that non-zeros will be counted along a flattened
version of the input array.
keepdims
optional, if this is set to True, the axes that are counted are left in the
result as dimensions with size one. With this option, the result
will broadcast correctly against the input array.
dtype
optional output dtype. Default is of type integer.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``
out
optional output container, for writing the result to.
Returns
-------
ret
Container including number of non-zero values in the array along a
given axis. Otherwise, container with the total number of non-zero
values in the array is returned.
Examples
--------
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> x.count_nonzero()
{
a: ivy.array(7),
b: ivy.array(7)
}
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> x.count_nonzero(axis=0)
{
a: ivy.array([1, 2, 2, 2]),
b: ivy.array([[1, 2],
[2, 2]])
}
>>> x = ivy.Container(a=ivy.array([[0, 1, 2, 3],[4, 5, 6, 7]]),\
b=ivy.array([[[0,1],[2,3]],[[4,5],[6,7]]]))
>>> x.count_nonzero(axis=(0,1), keepdims=True)
{
a: ivy.array([[7]]),
b: ivy.array([[[3, 4]]])
}
"""
return self.static_count_nonzero(
self,
axis=axis,
keepdims=keepdims,
dtype=dtype,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

@staticmethod
def static_nansum(
x: Union[ivy.Container, ivy.Array, ivy.NativeArray],
Expand Down
18 changes: 17 additions & 1 deletion ivy/functional/backends/jax/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Tuple
from ivy.functional.backends.jax import JaxArray
import jax.numpy as jnp

Expand Down Expand Up @@ -62,6 +62,22 @@ def exp2(
return jnp.exp2(x)


def count_nonzero(
a: JaxArray,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if isinstance(axis, list):
axis = tuple(axis)
if dtype is None:
return jnp.count_nonzero(a, axis=axis, keepdims=keepdims)
return jnp.array(jnp.count_nonzero(a, axis=axis, keepdims=keepdims), dtype=dtype)


def nansum(
x: JaxArray,
/,
Expand Down
22 changes: 21 additions & 1 deletion ivy/functional/backends/numpy/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Tuple
import numpy as np
from ivy.functional.backends.numpy.helpers import _scalar_output_to_0d_array

Expand Down Expand Up @@ -109,6 +109,26 @@ def exp2(
exp2.support_native_out = True


@_scalar_output_to_0d_array
def count_nonzero(
x: np.ndarray,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[np.dtype] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if isinstance(axis, list):
axis = tuple(axis)
if dtype is None:
return np.count_nonzero(x, axis=axis, keepdims=keepdims)
return np.array(np.count_nonzero(x, axis=axis, keepdims=keepdims), dtype=dtype)


count_nonzero.support_native_out = False


def nansum(
x: np.ndarray,
/,
Expand Down
18 changes: 17 additions & 1 deletion ivy/functional/backends/tensorflow/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional
from typing import Union, Optional, Tuple
import tensorflow as tf
from .. import backend_version

Expand Down Expand Up @@ -95,6 +95,22 @@ def exp2(
return tf.math.pow(2, x, name=None)


def count_nonzero(
a: Union[tf.Tensor, tf.Variable],
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[tf.DType] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if dtype is None:
return tf.math.count_nonzero(a, axis=axis, keepdims=keepdims, name=None)
return tf.math.count_nonzero(
a, axis=axis, keepdims=keepdims, dtype=dtype, name=None
)


def nansum(
x: Union[tf.Tensor, tf.Variable],
/,
Expand Down
32 changes: 31 additions & 1 deletion ivy/functional/backends/torch/experimental/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union
from typing import Optional, Union, Tuple
import torch

# local
Expand Down Expand Up @@ -109,6 +109,36 @@ def exp2(
exp2.support_native_out = True


def count_nonzero(
a: torch.Tensor,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: Optional[bool] = False,
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(axis, list):
axis = tuple(axis)

def _dtype_count_nonzero(a, axis, dtype):
if dtype is None:
return torch.count_nonzero(a, dim=axis)
return torch.tensor(torch.count_nonzero(a, dim=axis), dtype=dtype)

x = _dtype_count_nonzero(a, axis, dtype)
if not keepdims:
return x
if isinstance(axis, tuple):
for d in sorted(axis, reverse=True):
x = x.unsqueeze(d)
return x
return x.unsqueeze(axis)


count_nonzero.support_native_out = False


def nansum(
x: torch.Tensor,
/,
Expand Down
Loading

0 comments on commit 5bb74e5

Please sign in to comment.