Skip to content

Commit

Permalink
Resolve numpy 1.25 issues (#973)
Browse files Browse the repository at this point in the history
* fix bare ._cunumeric usage

* fix deprecated .product

* Fix (most) find_common_type instances

* pass scalar values, not types

* fix ufunc usage

* re-enable scalar binop type tests
  • Loading branch information
bryevdv authored Jun 23, 2023
1 parent fcba038 commit 0d46e7c
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 33 deletions.
22 changes: 2 additions & 20 deletions cunumeric/_ufunc/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,33 +557,15 @@ def _find_common_type(
if len(unique_dtypes) == 1 and all_ndarray:
return arrs[0].dtype

# FIXME: The following is a miserable attempt to implement type
# coercion rules that try to match NumPy's rules for a subset of cases;
# for the others, cuNumeric computes a type different from what
# NumPy produces for the same operands. Type coercion rules shouldn't
# be this difficult to imitate...

all_scalars = all(arr.ndim == 0 for arr in arrs)
all_arrays = all(arr.ndim > 0 for arr in arrs)
kinds = set(arr.dtype.kind for arr in arrs)
lossy_conversion = ("i" in kinds or "u" in kinds) and (
"f" in kinds or "c" in kinds
)
use_min_scalar = not (all_scalars or all_arrays or lossy_conversion)

scalar_types = []
array_types = []
for arr, orig_arg in zip(arrs, orig_args):
if arr.ndim == 0:
scalar_types.append(
np.dtype(np.min_scalar_type(orig_arg))
if use_min_scalar
else arr.dtype
)
scalar_types.append(orig_arg)
else:
array_types.append(arr.dtype)

return np.find_common_type(array_types, scalar_types)
return np.result_type(*array_types, *scalar_types)

def _resolve_dtype(
self,
Expand Down
15 changes: 8 additions & 7 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
UnaryOpCode,
UnaryRedCode,
)
from .coverage import FALLBACK_WARNING, clone_class
from .coverage import FALLBACK_WARNING, clone_class, is_implemented
from .runtime import runtime
from .types import NdShape
from .utils import deep_apply, dot_modes, to_core_dtype
Expand Down Expand Up @@ -331,7 +331,7 @@ def __array_function__(
# arguments. Conversely, if the user calls `cn.foo(x, bar=True)`
# directly, that means they requested the cuNumeric implementation
# specifically, and the `NotImplementedError` should not be hidden.
if cn_func._cunumeric.implemented:
if is_implemented(cn_func):
try:
return cn_func(*args, **kwargs)
except NotImplementedError:
Expand Down Expand Up @@ -2000,9 +2000,10 @@ def choose(
choices = tuple(choices)
is_tuple = isinstance(choices, tuple)
if is_tuple:
n = len(choices)
if (n := len(choices)) == 0:
raise ValueError("invalid entry in choice array")
dtypes = [ch.dtype for ch in choices]
ch_dtype = np.find_common_type(dtypes, [])
ch_dtype = np.result_type(*dtypes)
choices = tuple(
convert_to_cunumeric_ndarray(choices[i]).astype(ch_dtype)
for i in range(n)
Expand Down Expand Up @@ -2776,7 +2777,7 @@ def fft(
else:
norm_shape = out.shape
norm_shape_along_axes = [norm_shape[ax] for ax in fft_axes]
factor = np.product(norm_shape_along_axes)
factor = np.prod(norm_shape_along_axes)
if fft_norm == FFTNormalization.ORTHOGONAL:
factor = np.sqrt(factor)
return out / factor
Expand Down Expand Up @@ -3403,7 +3404,7 @@ def searchsorted(
a = self
# in case we have different dtypes we ned to find a common type
if a.dtype != v_ndarray.dtype:
ch_dtype = np.find_common_type([a.dtype, v_ndarray.dtype], [])
ch_dtype = np.result_type(a.dtype, v_ndarray.dtype)

if v_ndarray.dtype != ch_dtype:
v_ndarray = v_ndarray.astype(ch_dtype)
Expand Down Expand Up @@ -3899,7 +3900,7 @@ def find_common_type(*args: Any) -> np.dtype[Any]:
scalar_types.append(array.dtype)
else:
array_types.append(array.dtype)
return np.find_common_type(array_types, scalar_types)
return np.find_common_type(array_types, scalar_types) # type: ignore

def _maybe_convert(self, dtype: np.dtype[Any], hints: Any) -> ndarray:
if self.dtype == dtype:
Expand Down
2 changes: 1 addition & 1 deletion cunumeric/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def _solve(
if b.dtype.kind not in ("f", "c"):
b = b.astype("float64")
if a.dtype != b.dtype:
dtype = np.find_common_type([a.dtype, b.dtype], [])
dtype = np.result_type(a.dtype, b.dtype)
a = a.astype(dtype)
b = b.astype(dtype)

Expand Down
6 changes: 3 additions & 3 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def arange(
step = 1

if dtype is None:
dtype = np.find_common_type([], [type(start), type(stop), type(step)])
dtype = np.result_type(start, stop, step)
else:
dtype = np.dtype(dtype)

Expand Down Expand Up @@ -740,7 +740,7 @@ def linspace(
raise ValueError("Number of samples, %s, must be non-negative." % num)
div = (num - 1) if endpoint else num

common_kind = np.find_common_type((start.dtype, stop.dtype), ()).kind
common_kind = np.result_type(start.dtype, stop.dtype).kind
dt = np.complex128 if common_kind == "c" else np.float64
if dtype is None:
dtype = dt
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def check_shape_dtype_without_axis(

# Cast arrays with the passed arguments (dtype, casting)
if dtype is None:
dtype = np.find_common_type((inp.dtype for inp in inputs), [])
dtype = np.result_type(*[inp.dtype for inp in inputs])
else:
dtype = np.dtype(dtype)

Expand Down
2 changes: 0 additions & 2 deletions tests/integration/test_binary_op_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ def test_array_array(lhs_np, rhs_np, lhs_num, rhs_num):
print(f"NumPy type: {out_np.dtype}, cuNumeric type: {out_num.dtype}")


# not all of these currently pass (see not above)
@pytest.mark.xfail
@pytest.mark.parametrize(
"lhs_np, rhs_np, lhs_num, rhs_num", generate_array_scalar_cases(), ids=str
)
Expand Down

0 comments on commit 0d46e7c

Please sign in to comment.