diff --git a/cunumeric/_ufunc/ufunc.py b/cunumeric/_ufunc/ufunc.py index 57fc146b5..709edffef 100644 --- a/cunumeric/_ufunc/ufunc.py +++ b/cunumeric/_ufunc/ufunc.py @@ -13,12 +13,12 @@ # limitations under the License. # +import numpy as np from cunumeric.array import ( broadcast_shapes, convert_to_cunumeric_ndarray, ndarray, ) -from numpy import can_cast as np_can_cast, dtype as np_dtype _UNARY_DOCSTRING_TEMPLATE = """{} @@ -132,7 +132,7 @@ def _maybe_cast_input(self, arr, to_dtype, casting): if arr.dtype == to_dtype: return arr - if not np_can_cast(arr.dtype, to_dtype, casting=casting): + if not np.can_cast(arr.dtype, to_dtype, casting=casting): raise TypeError( f"Cannot cast ufunc '{self._name}' input from " f"{arr.dtype} to {to_dtype} with casting rule '{casting}'" @@ -168,17 +168,17 @@ def ntypes(self): def _resolve_dtype(self, arr, casting, precision_fixed): if arr.dtype.char in self._types: - return arr, np_dtype(self._types[arr.dtype.char]) + return arr, np.dtype(self._types[arr.dtype.char]) if arr.dtype in self._resolution_cache: to_dtype = self._resolution_cache[arr.dtype] arr = arr.astype(to_dtype) - return arr, np_dtype(self._types[to_dtype.char]) + return arr, np.dtype(self._types[to_dtype.char]) chosen = None if not precision_fixed: for in_ty in self._types.keys(): - if np_can_cast(arr.dtype, in_ty): + if np.can_cast(arr.dtype, in_ty): chosen = in_ty break @@ -188,10 +188,10 @@ def _resolve_dtype(self, arr, casting, precision_fixed): "for the given casting" ) - to_dtype = np_dtype(chosen) + to_dtype = np.dtype(chosen) self._resolution_cache[arr.dtype] = to_dtype - return arr.astype(to_dtype), np_dtype(self._types[to_dtype.char]) + return arr.astype(to_dtype), np.dtype(self._types[to_dtype.char]) def __call__( self, @@ -244,7 +244,7 @@ def __call__( out = result else: if out.dtype != res_dtype: - if not np_can_cast(res_dtype, out.dtype, casting=casting): + if not np.can_cast(res_dtype, out.dtype, casting=casting): raise TypeError( f"Cannot cast ufunc '{self._name}' output from " f"{res_dtype} to {out.dtype} with casting rule " @@ -296,26 +296,56 @@ def types(self): def ntypes(self): return len(self._types) - def _resolve_dtype(self, arrs, casting, precision_fixed): - common_dtype = ndarray.find_common_type(*arrs) + @staticmethod + def _find_common_type(arrs, orig_args): + # FIXME: The following is a miserable attempt to implement type + # coercion rules that try to match NumPy's rules for a subset of cases; + # for the others, cuNumeric computes a type different from what + # NumPy produces for the same operands. Type coercion rules shouldn't + # be this difficult to imitate... + + all_scalars = all(arr.ndim == 0 for arr in arrs) + all_arrays = all(arr.ndim > 0 for arr in arrs) + kinds = set(arr.dtype.kind for arr in arrs) + lossy_conversion = ("i" in kinds or "u" in kinds) and ( + "f" in kinds or "c" in kinds + ) + use_min_scalar = not (all_scalars or all_arrays or lossy_conversion) + + scalar_types = [] + array_types = [] + for arr, orig_arg in zip(arrs, orig_args): + if arr.ndim == 0: + scalar_types.append( + np.dtype(np.min_scalar_type(orig_arg)) + if use_min_scalar + else arr.dtype + ) + else: + array_types.append(arr.dtype) + + return np.find_common_type(array_types, scalar_types) + + def _resolve_dtype(self, arrs, orig_args, casting, precision_fixed): + common_dtype = self._find_common_type(arrs, orig_args) key = (common_dtype.char, common_dtype.char) if key in self._types: arrs = [arr.astype(common_dtype) for arr in arrs] - return arrs, np_dtype(self._types[key]) + return arrs, np.dtype(self._types[key]) if key in self._resolution_cache: to_dtypes = self._resolution_cache[key] arrs = [ arr.astype(to_dtype) for arr, to_dtype in zip(arrs, to_dtypes) ] - return arrs, np_dtype(self._types[to_dtypes]) + return arrs, np.dtype(self._types[to_dtypes]) chosen = None if not precision_fixed: for in_dtypes in self._types.keys(): if all( - np_can_cast(arr.dtype, to_dtype) + np.can_cast(arr.dtype, to_dtype) for arr, to_dtype in zip(arrs, in_dtypes) ): chosen = in_dtypes @@ -330,7 +360,7 @@ def _resolve_dtype(self, arrs, casting, precision_fixed): self._resolution_cache[key] = chosen arrs = [arr.astype(to_dtype) for arr, to_dtype in zip(arrs, chosen)] - return arrs, np_dtype(self._types[chosen]) + return arrs, np.dtype(self._types[chosen]) def __call__( self, @@ -343,7 +373,8 @@ def __call__( dtype=None, **kwargs, ): - arrs = [convert_to_cunumeric_ndarray(arr) for arr in (x1, x2)] + orig_args = (x1, x2) + arrs = [convert_to_cunumeric_ndarray(arr) for arr in orig_args] if out is not None: if isinstance(out, tuple): @@ -384,7 +415,9 @@ def __call__( # Resolve the dtype to use for the computation and cast the input # if necessary. If the dtype is already fixed by the caller, # the dtype must be one of the dtypes supported by this operation. - arrs, res_dtype = self._resolve_dtype(arrs, casting, precision_fixed) + arrs, res_dtype = self._resolve_dtype( + arrs, orig_args, casting, precision_fixed + ) if out is None: result = ndarray( @@ -393,7 +426,7 @@ def __call__( out = result else: if out.dtype != res_dtype: - if not np_can_cast(res_dtype, out.dtype, casting=casting): + if not np.can_cast(res_dtype, out.dtype, casting=casting): raise TypeError( f"Cannot cast ufunc '{self._name}' output from " f"{res_dtype} to {out.dtype} with casting rule " diff --git a/tests/binary_op_typing.py b/tests/binary_op_typing.py new file mode 100644 index 000000000..0877ca9cc --- /dev/null +++ b/tests/binary_op_typing.py @@ -0,0 +1,134 @@ +# Copyright 2021-2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import product + +import numpy as np + +import cunumeric as num + + +def value_type(obj): + if np.isscalar(obj): + return "scalar" + elif obj.ndim == 0: + return f"{obj.dtype} 0d array" + else: + return f"{obj.dtype} array" + + +def test(lhs_np, rhs_np, lhs_num, rhs_num): + print(f"{value_type(lhs_np)} x {value_type(rhs_np)}") + + out_np = np.add(lhs_np, rhs_np) + out_num = num.add(lhs_num, rhs_num) + + if out_np.dtype != out_num.dtype: + print("LHS") + print(lhs_np) + print("RHS") + print(rhs_np) + print(f"NumPy type: {out_np.dtype}, cuNumeric type: {out_num.dtype}") + assert False + + +def run_all_tests(): + types = [ + "b", + "B", + "h", + "H", + "i", + "I", + "l", + "L", + "q", + "Q", + "e", + "f", + "d", + "F", + "D", + ] + array_values = [[1]] + scalar_values = [1, -1, 1.0, 1e-50, 1j] + + for idx, lhs_type in enumerate(types): + for rhs_type in types[idx:]: + for lhs_value, rhs_value in product(array_values, array_values): + lhs_np = np.array(lhs_value, dtype=lhs_type) + rhs_np = np.array(rhs_value, dtype=rhs_type) + + lhs_num = num.array(lhs_np) + rhs_num = num.array(rhs_np) + + test(lhs_np, rhs_np, lhs_num, rhs_num) + + for lhs_value, rhs_value in product(scalar_values, scalar_values): + try: + lhs_np = np.array(lhs_value, dtype=lhs_type) + rhs_np = np.array(rhs_value, dtype=rhs_type) + + lhs_num = num.array(lhs_np) + rhs_num = num.array(rhs_np) + + test(lhs_np, rhs_np, lhs_num, rhs_num) + except TypeError: + pass + + for ty in types: + for array, scalar in product(array_values, scalar_values): + array_np = np.array(array, dtype=ty) + array_num = num.array(array_np) + + test(array_np, scalar, array_num, scalar) + + # TODO: NumPy's type coercion rules are confusing at best and impossible + # for any human being to understand in my opinion. My attempt to make + # sense of it for the past two days failed miserably. I managed to make + # the code somewhat compatible with NumPy for cases where Python scalars + # are passed. + # + # If anyone can do a better job than me and finally make cuNumeric + # implement the same typing rules, please put these tests back. + # + # for idx, lhs_type in enumerate(types): + # for rhs_type in types[idx:]: + # for array, scalar in product(array_values, scalar_values): + # try: + # lhs_np = np.array(array, dtype=lhs_type) + # rhs_np = np.array(scalar, dtype=rhs_type) + + # lhs_num = num.array(lhs_np) + # rhs_num = num.array(rhs_np) + + # test(lhs_np, rhs_np, lhs_num, rhs_num) + # except TypeError: + # pass + + # try: + # lhs_np = np.array(scalar, dtype=lhs_type) + # rhs_np = np.array(array, dtype=rhs_type) + + # lhs_num = num.array(lhs_np) + # rhs_num = num.array(rhs_np) + + # test(lhs_np, rhs_np, lhs_num, rhs_num) + # except TypeError: + # pass + + +if __name__ == "__main__": + run_all_tests()