Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise type coercion #264

Merged
merged 2 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions cunumeric/_ufunc/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
#

import numpy as np
from cunumeric.array import (
broadcast_shapes,
convert_to_cunumeric_ndarray,
ndarray,
)
from numpy import can_cast as np_can_cast, dtype as np_dtype

_UNARY_DOCSTRING_TEMPLATE = """{}

Expand Down Expand Up @@ -132,7 +132,7 @@ def _maybe_cast_input(self, arr, to_dtype, casting):
if arr.dtype == to_dtype:
return arr

if not np_can_cast(arr.dtype, to_dtype, casting=casting):
if not np.can_cast(arr.dtype, to_dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' input from "
f"{arr.dtype} to {to_dtype} with casting rule '{casting}'"
Expand Down Expand Up @@ -168,17 +168,17 @@ def ntypes(self):

def _resolve_dtype(self, arr, casting, precision_fixed):
if arr.dtype.char in self._types:
return arr, np_dtype(self._types[arr.dtype.char])
return arr, np.dtype(self._types[arr.dtype.char])

if arr.dtype in self._resolution_cache:
to_dtype = self._resolution_cache[arr.dtype]
arr = arr.astype(to_dtype)
return arr, np_dtype(self._types[to_dtype.char])
return arr, np.dtype(self._types[to_dtype.char])

chosen = None
if not precision_fixed:
for in_ty in self._types.keys():
if np_can_cast(arr.dtype, in_ty):
if np.can_cast(arr.dtype, in_ty):
chosen = in_ty
break

Expand All @@ -188,10 +188,10 @@ def _resolve_dtype(self, arr, casting, precision_fixed):
"for the given casting"
)

to_dtype = np_dtype(chosen)
to_dtype = np.dtype(chosen)
self._resolution_cache[arr.dtype] = to_dtype

return arr.astype(to_dtype), np_dtype(self._types[to_dtype.char])
return arr.astype(to_dtype), np.dtype(self._types[to_dtype.char])

def __call__(
self,
Expand Down Expand Up @@ -244,7 +244,7 @@ def __call__(
out = result
else:
if out.dtype != res_dtype:
if not np_can_cast(res_dtype, out.dtype, casting=casting):
if not np.can_cast(res_dtype, out.dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' output from "
f"{res_dtype} to {out.dtype} with casting rule "
Expand Down Expand Up @@ -296,26 +296,56 @@ def types(self):
def ntypes(self):
return len(self._types)

def _resolve_dtype(self, arrs, casting, precision_fixed):
common_dtype = ndarray.find_common_type(*arrs)
@staticmethod
def _find_common_type(arrs, orig_args):
# FIXME: The following is a miserable attempt to implement type
# coercion rules that try to match NumPy's rules for a subset of cases;
# for the others, cuNumeric computes a type different from what
# NumPy produces for the same operands. Type coercion rules shouldn't
# be this difficult to imitate...

all_scalars = all(arr.ndim == 0 for arr in arrs)
all_arrays = all(arr.ndim > 0 for arr in arrs)
kinds = set(arr.dtype.kind for arr in arrs)
lossy_conversion = ("i" in kinds or "u" in kinds) and (
"f" in kinds or "c" in kinds
)
use_min_scalar = not (all_scalars or all_arrays or lossy_conversion)

scalar_types = []
array_types = []
for arr, orig_arg in zip(arrs, orig_args):
if arr.ndim == 0:
scalar_types.append(
np.dtype(np.min_scalar_type(orig_arg))
if use_min_scalar
else arr.dtype
)
else:
array_types.append(arr.dtype)

return np.find_common_type(array_types, scalar_types)

def _resolve_dtype(self, arrs, orig_args, casting, precision_fixed):
common_dtype = self._find_common_type(arrs, orig_args)

key = (common_dtype.char, common_dtype.char)
if key in self._types:
arrs = [arr.astype(common_dtype) for arr in arrs]
return arrs, np_dtype(self._types[key])
return arrs, np.dtype(self._types[key])

if key in self._resolution_cache:
to_dtypes = self._resolution_cache[key]
arrs = [
arr.astype(to_dtype) for arr, to_dtype in zip(arrs, to_dtypes)
]
return arrs, np_dtype(self._types[to_dtypes])
return arrs, np.dtype(self._types[to_dtypes])

chosen = None
if not precision_fixed:
for in_dtypes in self._types.keys():
if all(
np_can_cast(arr.dtype, to_dtype)
np.can_cast(arr.dtype, to_dtype)
for arr, to_dtype in zip(arrs, in_dtypes)
):
chosen = in_dtypes
Expand All @@ -330,7 +360,7 @@ def _resolve_dtype(self, arrs, casting, precision_fixed):
self._resolution_cache[key] = chosen
arrs = [arr.astype(to_dtype) for arr, to_dtype in zip(arrs, chosen)]

return arrs, np_dtype(self._types[chosen])
return arrs, np.dtype(self._types[chosen])

def __call__(
self,
Expand All @@ -343,7 +373,8 @@ def __call__(
dtype=None,
**kwargs,
):
arrs = [convert_to_cunumeric_ndarray(arr) for arr in (x1, x2)]
orig_args = (x1, x2)
arrs = [convert_to_cunumeric_ndarray(arr) for arr in orig_args]

if out is not None:
if isinstance(out, tuple):
Expand Down Expand Up @@ -384,7 +415,9 @@ def __call__(
# Resolve the dtype to use for the computation and cast the input
# if necessary. If the dtype is already fixed by the caller,
# the dtype must be one of the dtypes supported by this operation.
arrs, res_dtype = self._resolve_dtype(arrs, casting, precision_fixed)
arrs, res_dtype = self._resolve_dtype(
arrs, orig_args, casting, precision_fixed
)

if out is None:
result = ndarray(
Expand All @@ -393,7 +426,7 @@ def __call__(
out = result
else:
if out.dtype != res_dtype:
if not np_can_cast(res_dtype, out.dtype, casting=casting):
if not np.can_cast(res_dtype, out.dtype, casting=casting):
raise TypeError(
f"Cannot cast ufunc '{self._name}' output from "
f"{res_dtype} to {out.dtype} with casting rule "
Expand Down
134 changes: 134 additions & 0 deletions tests/binary_op_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2021-2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from itertools import product

import numpy as np

import cunumeric as num


def value_type(obj):
if np.isscalar(obj):
return "scalar"
elif obj.ndim == 0:
return f"{obj.dtype} 0d array"
else:
return f"{obj.dtype} array"


def test(lhs_np, rhs_np, lhs_num, rhs_num):
print(f"{value_type(lhs_np)} x {value_type(rhs_np)}")

out_np = np.add(lhs_np, rhs_np)
out_num = num.add(lhs_num, rhs_num)

if out_np.dtype != out_num.dtype:
print("LHS")
print(lhs_np)
print("RHS")
print(rhs_np)
print(f"NumPy type: {out_np.dtype}, cuNumeric type: {out_num.dtype}")
assert False


def run_all_tests():
types = [
"b",
"B",
"h",
"H",
"i",
"I",
"l",
"L",
"q",
"Q",
"e",
"f",
"d",
"F",
"D",
]
array_values = [[1]]
scalar_values = [1, -1, 1.0, 1e-50, 1j]

for idx, lhs_type in enumerate(types):
for rhs_type in types[idx:]:
for lhs_value, rhs_value in product(array_values, array_values):
lhs_np = np.array(lhs_value, dtype=lhs_type)
rhs_np = np.array(rhs_value, dtype=rhs_type)

lhs_num = num.array(lhs_np)
rhs_num = num.array(rhs_np)

test(lhs_np, rhs_np, lhs_num, rhs_num)

for lhs_value, rhs_value in product(scalar_values, scalar_values):
try:
lhs_np = np.array(lhs_value, dtype=lhs_type)
rhs_np = np.array(rhs_value, dtype=rhs_type)

lhs_num = num.array(lhs_np)
rhs_num = num.array(rhs_np)

test(lhs_np, rhs_np, lhs_num, rhs_num)
except TypeError:
pass

for ty in types:
for array, scalar in product(array_values, scalar_values):
array_np = np.array(array, dtype=ty)
array_num = num.array(array_np)

test(array_np, scalar, array_num, scalar)

# TODO: NumPy's type coercion rules are confusing at best and impossible
# for any human being to understand in my opinion. My attempt to make
# sense of it for the past two days failed miserably. I managed to make
# the code somewhat compatible with NumPy for cases where Python scalars
# are passed.
#
# If anyone can do a better job than me and finally make cuNumeric
# implement the same typing rules, please put these tests back.
#
# for idx, lhs_type in enumerate(types):
# for rhs_type in types[idx:]:
# for array, scalar in product(array_values, scalar_values):
# try:
# lhs_np = np.array(array, dtype=lhs_type)
# rhs_np = np.array(scalar, dtype=rhs_type)

# lhs_num = num.array(lhs_np)
# rhs_num = num.array(rhs_np)

# test(lhs_np, rhs_np, lhs_num, rhs_num)
# except TypeError:
# pass

# try:
# lhs_np = np.array(scalar, dtype=lhs_type)
# rhs_np = np.array(array, dtype=rhs_type)

# lhs_num = num.array(lhs_np)
# rhs_num = num.array(rhs_np)

# test(lhs_np, rhs_np, lhs_num, rhs_num)
# except TypeError:
# pass


if __name__ == "__main__":
run_all_tests()