Skip to content

Commit

Permalink
Use Numpy axis normalizations where possible (#419)
Browse files Browse the repository at this point in the history
* use numpy normalizers in array.py

* use numpy normalizers in linalg.py

* use numpy normalizers in sort.py
  • Loading branch information
bryevdv authored Jun 23, 2022
1 parent 658c7d7 commit ecb76ee
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 45 deletions.
41 changes: 15 additions & 26 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from collections.abc import Iterable
from functools import reduce, wraps
from inspect import signature
from typing import Callable, Optional, Set, Tuple, TypeVar
from typing import Callable, Optional, Set, TypeVar

import numpy as np
import pyarrow
from numpy.core.multiarray import normalize_axis_index
from numpy.core.numeric import normalize_axis_tuple
from typing_extensions import ParamSpec

from legate.core import Array
Expand Down Expand Up @@ -1782,10 +1784,8 @@ def take(self, indices, axis=None, out=None, mode="raise"):
if axis is None:
self = self.ravel()
axis = 0
elif axis < 0:
axis = self.ndim + axis
if axis < 0 or axis >= self.ndim:
raise ValueError("axis argument is out of bounds")
else:
axis = normalize_axis_index(axis, self.ndim)

# TODO remove "raise" logic when bounds check for advanced
# indexing is implementd
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def choose(self, choices, out=None, mode="raise"):
def compress(self, condition, axis=None, out=None):
"""a.compress(self, condition, axis=None, out=None)
Return selected slices of an array along given axis..
Return selected slices of an array along given axis.
Refer to :func:`cunumeric.compress` for full documentation.
Expand All @@ -1959,9 +1959,12 @@ def compress(self, condition, axis=None, out=None):
category=RuntimeWarning,
)
condition = condition.astype(bool)

if axis is None:
axis = 0
a = self.ravel()
else:
axis = normalize_axis_index(axis, self.ndim)

if a.shape[axis] < condition.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -2530,7 +2533,7 @@ def getfield(self, dtype, offset=0):
"for ndarray.getfield"
)

def _convert_singleton_key(self, args: Tuple):
def _convert_singleton_key(self, args: tuple):
if len(args) == 0 and self.size == 1:
return (0,) * self.ndim
if len(args) == 1 and isinstance(args[0], int):
Expand Down Expand Up @@ -3070,9 +3073,7 @@ def squeeze(self, axis=None):
"all axis to squeeze must be less than ndim"
)
if self.shape[axis] != 1:
raise ValueError(
"axis to squeeze must have extent " "of one"
)
raise ValueError("axis to squeeze must have extent of one")
elif isinstance(axis, tuple):
for ax in axis:
if ax >= self.ndim:
Expand Down Expand Up @@ -3593,23 +3594,11 @@ def _perform_unary_reduction(
raise NotImplementedError(
"(arg)max/min not supported for complex-type arrays"
)
# Compute the output shape
axes = axis
if axes is None:
axes = tuple(range(src.ndim))
elif not isinstance(axes, tuple):
axes = (axes,)

if any(type(ax) != int for ax in axes):
raise TypeError(
"'axis' must be an integer or a tuple of integers, "
f"but got {axis}"
)

axes = tuple(ax + src.ndim if ax < 0 else ax for ax in axes)

if any(ax < 0 for ax in axes):
raise ValueError(f"Invalid 'axis' value {axis}")
if axis is None:
axes = tuple(range(src.ndim))
else:
axes = normalize_axis_tuple(axis, src.ndim)

out_shape = ()
for dim in range(src.ndim):
Expand Down
18 changes: 7 additions & 11 deletions cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from cunumeric._ufunc.math import add, sqrt as _sqrt
from cunumeric.array import add_boilerplate, convert_to_cunumeric_ndarray
from cunumeric.module import dot, empty_like, eye, matmul, ndarray
from numpy.core.multiarray import ( # type: ignore [attr-defined]
normalize_axis_index,
)
from numpy.core.multiarray import normalize_axis_index # type: ignore
from numpy.core.numeric import normalize_axis_tuple # type: ignore

if TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -424,14 +423,11 @@ def norm(
ret = ret.reshape(ndim * [1])
return ret

# Normalize the `axis` argument to a tuple.
nd = x.ndim
if axis is None:
computed_axis = tuple(range(nd))
elif not isinstance(axis, tuple):
computed_axis = (axis,)
computed_axis = tuple(range(x.ndim))
else:
computed_axis = axis
computed_axis = normalize_axis_tuple(axis, x.ndim)

for ax in computed_axis:
if not isinstance(ax, int):
raise TypeError(
Expand Down Expand Up @@ -469,8 +465,8 @@ def norm(
return ret
elif len(computed_axis) == 2:
row_axis, col_axis = computed_axis
row_axis = normalize_axis_index(row_axis, nd)
col_axis = normalize_axis_index(col_axis, nd)
row_axis = normalize_axis_index(row_axis, x.ndim)
col_axis = normalize_axis_index(col_axis, x.ndim)
if row_axis == col_axis:
raise ValueError("Duplicate axes given")
if ord == 2:
Expand Down
15 changes: 7 additions & 8 deletions cunumeric/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from cunumeric.config import CuNumericOpCode
from numpy.core.multiarray import normalize_axis_index

from legate.core import types as ty

Expand All @@ -32,7 +33,7 @@ def sort_flattened(output, input, argsort, stable):


def sort_swapped(output, input, argsort, sort_axis, stable):
assert sort_axis < input.ndim - 1 and sort_axis >= 0
sort_axis = normalize_axis_index(sort_axis, input.ndim)

# swap axes
swapped = input.swapaxes(sort_axis, input.ndim - 1)
Expand Down Expand Up @@ -97,12 +98,10 @@ def sort(output, input, argsort, axis=-1, stable=False):
else:
if axis is None:
axis = 0
elif axis < 0:
axis = input.ndim + axis

if axis is not input.ndim - 1:
sort_swapped(output, input, argsort, axis, stable)

else:
# run actual sort task
axis = normalize_axis_index(axis, input.ndim)

if axis == input.ndim - 1:
sort_task(output, input, argsort, stable)
else:
sort_swapped(output, input, argsort, axis, stable)

0 comments on commit ecb76ee

Please sign in to comment.