Skip to content

Commit

Permalink
Merge pull request #264 from magnatelee/type-coercion-fix
Browse files Browse the repository at this point in the history
Revise type coercion
  • Loading branch information
magnatelee authored Apr 6, 2022
2 parents 5708ea8 + 1d0a6bd commit aed0a54
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 17 deletions.
67 changes: 50 additions & 17 deletions cunumeric/_ufunc/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
#

import numpy as np
from cunumeric.array import (
broadcast_shapes,
convert_to_cunumeric_ndarray,
ndarray,
)
from numpy import can_cast as np_can_cast, dtype as np_dtype

_UNARY_DOCSTRING_TEMPLATE = """{}
Expand Down Expand Up @@ -132,7 +132,7 @@ def _maybe_cast_input(self, arr, to_dtype, casting):
if arr.dtype == to_dtype:
return arr

if not np_can_cast(arr.dtype, to_dtype, casting=casting):
if not np.can_cast(arr.dtype, to_dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' input from "
f"{arr.dtype} to {to_dtype} with casting rule '{casting}'"
Expand Down Expand Up @@ -168,17 +168,17 @@ def ntypes(self):

def _resolve_dtype(self, arr, casting, precision_fixed):
if arr.dtype.char in self._types:
return arr, np_dtype(self._types[arr.dtype.char])
return arr, np.dtype(self._types[arr.dtype.char])

if arr.dtype in self._resolution_cache:
to_dtype = self._resolution_cache[arr.dtype]
arr = arr.astype(to_dtype)
return arr, np_dtype(self._types[to_dtype.char])
return arr, np.dtype(self._types[to_dtype.char])

chosen = None
if not precision_fixed:
for in_ty in self._types.keys():
if np_can_cast(arr.dtype, in_ty):
if np.can_cast(arr.dtype, in_ty):
chosen = in_ty
break

Expand All @@ -188,10 +188,10 @@ def _resolve_dtype(self, arr, casting, precision_fixed):
"for the given casting"
)

to_dtype = np_dtype(chosen)
to_dtype = np.dtype(chosen)
self._resolution_cache[arr.dtype] = to_dtype

return arr.astype(to_dtype), np_dtype(self._types[to_dtype.char])
return arr.astype(to_dtype), np.dtype(self._types[to_dtype.char])

def __call__(
self,
Expand Down Expand Up @@ -244,7 +244,7 @@ def __call__(
out = result
else:
if out.dtype != res_dtype:
if not np_can_cast(res_dtype, out.dtype, casting=casting):
if not np.can_cast(res_dtype, out.dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' output from "
f"{res_dtype} to {out.dtype} with casting rule "
Expand Down Expand Up @@ -296,26 +296,56 @@ def types(self):
def ntypes(self):
return len(self._types)

def _resolve_dtype(self, arrs, casting, precision_fixed):
common_dtype = ndarray.find_common_type(*arrs)
@staticmethod
def _find_common_type(arrs, orig_args):
# FIXME: The following is a miserable attempt to implement type
# coercion rules that try to match NumPy's rules for a subset of cases;
# for the others, cuNumeric computes a type different from what
# NumPy produces for the same operands. Type coercion rules shouldn't
# be this difficult to imitate...

all_scalars = all(arr.ndim == 0 for arr in arrs)
all_arrays = all(arr.ndim > 0 for arr in arrs)
kinds = set(arr.dtype.kind for arr in arrs)
lossy_conversion = ("i" in kinds or "u" in kinds) and (
"f" in kinds or "c" in kinds
)
use_min_scalar = not (all_scalars or all_arrays or lossy_conversion)

scalar_types = []
array_types = []
for arr, orig_arg in zip(arrs, orig_args):
if arr.ndim == 0:
scalar_types.append(
np.dtype(np.min_scalar_type(orig_arg))
if use_min_scalar
else arr.dtype
)
else:
array_types.append(arr.dtype)

return np.find_common_type(array_types, scalar_types)

def _resolve_dtype(self, arrs, orig_args, casting, precision_fixed):
common_dtype = self._find_common_type(arrs, orig_args)

key = (common_dtype.char, common_dtype.char)
if key in self._types:
arrs = [arr.astype(common_dtype) for arr in arrs]
return arrs, np_dtype(self._types[key])
return arrs, np.dtype(self._types[key])

if key in self._resolution_cache:
to_dtypes = self._resolution_cache[key]
arrs = [
arr.astype(to_dtype) for arr, to_dtype in zip(arrs, to_dtypes)
]
return arrs, np_dtype(self._types[to_dtypes])
return arrs, np.dtype(self._types[to_dtypes])

chosen = None
if not precision_fixed:
for in_dtypes in self._types.keys():
if all(
np_can_cast(arr.dtype, to_dtype)
np.can_cast(arr.dtype, to_dtype)
for arr, to_dtype in zip(arrs, in_dtypes)
):
chosen = in_dtypes
Expand All @@ -330,7 +360,7 @@ def _resolve_dtype(self, arrs, casting, precision_fixed):
self._resolution_cache[key] = chosen
arrs = [arr.astype(to_dtype) for arr, to_dtype in zip(arrs, chosen)]

return arrs, np_dtype(self._types[chosen])
return arrs, np.dtype(self._types[chosen])

def __call__(
self,
Expand All @@ -343,7 +373,8 @@ def __call__(
dtype=None,
**kwargs,
):
arrs = [convert_to_cunumeric_ndarray(arr) for arr in (x1, x2)]
orig_args = (x1, x2)
arrs = [convert_to_cunumeric_ndarray(arr) for arr in orig_args]

if out is not None:
if isinstance(out, tuple):
Expand Down Expand Up @@ -384,7 +415,9 @@ def __call__(
# Resolve the dtype to use for the computation and cast the input
# if necessary. If the dtype is already fixed by the caller,
# the dtype must be one of the dtypes supported by this operation.
arrs, res_dtype = self._resolve_dtype(arrs, casting, precision_fixed)
arrs, res_dtype = self._resolve_dtype(
arrs, orig_args, casting, precision_fixed
)

if out is None:
result = ndarray(
Expand All @@ -393,7 +426,7 @@ def __call__(
out = result
else:
if out.dtype != res_dtype:
if not np_can_cast(res_dtype, out.dtype, casting=casting):
if not np.can_cast(res_dtype, out.dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' output from "
f"{res_dtype} to {out.dtype} with casting rule "
Expand Down
134 changes: 134 additions & 0 deletions tests/binary_op_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2021-2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from itertools import product

import numpy as np

import cunumeric as num


def value_type(obj):
if np.isscalar(obj):
return "scalar"
elif obj.ndim == 0:
return f"{obj.dtype} 0d array"
else:
return f"{obj.dtype} array"


def test(lhs_np, rhs_np, lhs_num, rhs_num):
print(f"{value_type(lhs_np)} x {value_type(rhs_np)}")

out_np = np.add(lhs_np, rhs_np)
out_num = num.add(lhs_num, rhs_num)

if out_np.dtype != out_num.dtype:
print("LHS")
print(lhs_np)
print("RHS")
print(rhs_np)
print(f"NumPy type: {out_np.dtype}, cuNumeric type: {out_num.dtype}")
assert False


def run_all_tests():
types = [
"b",
"B",
"h",
"H",
"i",
"I",
"l",
"L",
"q",
"Q",
"e",
"f",
"d",
"F",
"D",
]
array_values = [[1]]
scalar_values = [1, -1, 1.0, 1e-50, 1j]

for idx, lhs_type in enumerate(types):
for rhs_type in types[idx:]:
for lhs_value, rhs_value in product(array_values, array_values):
lhs_np = np.array(lhs_value, dtype=lhs_type)
rhs_np = np.array(rhs_value, dtype=rhs_type)

lhs_num = num.array(lhs_np)
rhs_num = num.array(rhs_np)

test(lhs_np, rhs_np, lhs_num, rhs_num)

for lhs_value, rhs_value in product(scalar_values, scalar_values):
try:
lhs_np = np.array(lhs_value, dtype=lhs_type)
rhs_np = np.array(rhs_value, dtype=rhs_type)

lhs_num = num.array(lhs_np)
rhs_num = num.array(rhs_np)

test(lhs_np, rhs_np, lhs_num, rhs_num)
except TypeError:
pass

for ty in types:
for array, scalar in product(array_values, scalar_values):
array_np = np.array(array, dtype=ty)
array_num = num.array(array_np)

test(array_np, scalar, array_num, scalar)

# TODO: NumPy's type coercion rules are confusing at best and impossible
# for any human being to understand in my opinion. My attempt to make
# sense of it for the past two days failed miserably. I managed to make
# the code somewhat compatible with NumPy for cases where Python scalars
# are passed.
#
# If anyone can do a better job than me and finally make cuNumeric
# implement the same typing rules, please put these tests back.
#
# for idx, lhs_type in enumerate(types):
# for rhs_type in types[idx:]:
# for array, scalar in product(array_values, scalar_values):
# try:
# lhs_np = np.array(array, dtype=lhs_type)
# rhs_np = np.array(scalar, dtype=rhs_type)

# lhs_num = num.array(lhs_np)
# rhs_num = num.array(rhs_np)

# test(lhs_np, rhs_np, lhs_num, rhs_num)
# except TypeError:
# pass

# try:
# lhs_np = np.array(scalar, dtype=lhs_type)
# rhs_np = np.array(array, dtype=rhs_type)

# lhs_num = num.array(lhs_np)
# rhs_num = num.array(rhs_np)

# test(lhs_np, rhs_np, lhs_num, rhs_num)
# except TypeError:
# pass


if __name__ == "__main__":
run_all_tests()

0 comments on commit aed0a54

Please sign in to comment.