diff --git a/cunumeric/array.py b/cunumeric/array.py index 4ca0f5e54..bf8d859dd 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -41,6 +41,7 @@ from .config import ( BinaryOpCode, + ConvertCode, FFTDirection, FFTNormalization, FFTType, @@ -4075,6 +4076,8 @@ def _perform_scan( out: Union[ndarray, None] = None, nan_to_identity: bool = False, ) -> ndarray: + if src.dtype.kind != "c" and src.dtype.kind != "f": + nan_to_identity = False if dtype is None: if out is None: if src.dtype.kind == "i": @@ -4084,12 +4087,6 @@ def _perform_scan( dtype = src.dtype else: dtype = out.dtype - if (src.dtype.kind in ("f", "c")) and np.issubdtype(dtype, np.integer): - # Needs changes to convert() - raise NotImplementedError( - "Integer output types currently not supported for " - "floating/complex inputs" - ) # flatten input when axis is None if axis is None: axis = 0 @@ -4110,9 +4107,18 @@ def _perform_scan( out = ndarray(shape=src_arr.shape, dtype=dtype) if dtype != src_arr.dtype: + if nan_to_identity: + if op is ScanCode.SUM: + nan_op = ConvertCode.SUM + else: + nan_op = ConvertCode.PROD + # If convert is called, it will handle NAN conversion + nan_to_identity = False + else: + nan_op = ConvertCode.NOOP # convert input to temporary for type conversion temp = ndarray(shape=src_arr.shape, dtype=dtype) - temp._thunk.convert(src_arr._thunk) + temp._thunk.convert(src_arr._thunk, nan_op=nan_op) src_arr = temp out._thunk.scan( diff --git a/cunumeric/config.py b/cunumeric/config.py index 543880538..752509f85 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -139,6 +139,9 @@ class _CunumericSharedLib: CUNUMERIC_CHOOSE: int CUNUMERIC_CONTRACT: int CUNUMERIC_CONVERT: int + CUNUMERIC_CONVERT_NAN_NOOP: int + CUNUMERIC_CONVERT_NAN_PROD: int + CUNUMERIC_CONVERT_NAN_SUM: int CUNUMERIC_CONVOLVE: int CUNUMERIC_DIAG: int CUNUMERIC_DOT: int @@ -526,6 +529,14 @@ class ScanCode(IntEnum): SUM = _cunumeric.CUNUMERIC_SCAN_SUM +# Match these to CuNumericConvertCode in cunumeric_c.h +@unique +class ConvertCode(IntEnum): + NOOP = _cunumeric.CUNUMERIC_CONVERT_NAN_NOOP + PROD = _cunumeric.CUNUMERIC_CONVERT_NAN_PROD + SUM = _cunumeric.CUNUMERIC_CONVERT_NAN_SUM + + # Match these to BitGeneratorOperation in cunumeric_c.h @unique class BitGeneratorOperation(IntEnum): diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index ebbef808d..ce0783821 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -44,6 +44,7 @@ BitGeneratorDistribution, BitGeneratorOperation, Bitorder, + ConvertCode, CuNumericOpCode, CuNumericRedopCode, RandGenCode, @@ -1145,7 +1146,12 @@ def swapaxes(self, axis1: int, axis2: int) -> DeferredArray: # Convert the source array to the destination array @auto_convert([1]) - def convert(self, rhs: Any, warn: bool = True) -> None: + def convert( + self, + rhs: Any, + warn: bool = True, + nan_op: ConvertCode = ConvertCode.NOOP, + ) -> None: lhs_array = self rhs_array = rhs assert lhs_array.dtype != rhs_array.dtype @@ -1165,7 +1171,7 @@ def convert(self, rhs: Any, warn: bool = True) -> None: task = self.context.create_auto_task(CuNumericOpCode.CONVERT) task.add_output(lhs) task.add_input(rhs) - task.add_dtype_arg(lhs_array.dtype) + task.add_scalar_arg(nan_op, ty.int32) task.add_alignment(lhs, rhs) diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 5f086f288..c6f19fb6b 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -33,6 +33,7 @@ FFT_R2C, FFT_Z2D, BinaryOpCode, + ConvertCode, FFTDirection, ScanCode, UnaryOpCode, @@ -485,15 +486,30 @@ def swapaxes(self, axis1: int, axis2: int) -> NumPyThunk: self.children.append(result) return result - def convert(self, rhs: Any, warn: bool = True) -> None: + def convert( + self, + rhs: Any, + warn: bool = True, + nan_op: ConvertCode = ConvertCode.NOOP, + ) -> None: self.check_eager_args(rhs) if self.deferred is not None: return self.deferred.convert(rhs, warn=warn) else: if self.array.size == 1: - self.array.fill(rhs.array.item()) + if nan_op is ConvertCode.SUM and np.isnan(rhs.array.item()): + self.array.fill(0) + elif nan_op is ConvertCode.PROD and np.isnan(rhs.array.item()): + self.array.fill(1) + else: + self.array.fill(rhs.array.item()) else: - self.array[:] = rhs.array + if nan_op is ConvertCode.SUM: + self.array[:] = np.where(np.isnan(rhs.array), 0, rhs.array) + elif nan_op is ConvertCode.PROD: + self.array[:] = np.where(np.isnan(rhs.array), 1, rhs.array) + else: + self.array[:] = rhs.array def fill(self, value: Any) -> None: if self.deferred is not None: diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index f3b1c93eb..6230905af 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -17,6 +17,8 @@ from abc import ABC, abstractmethod, abstractproperty from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from .config import ConvertCode + if TYPE_CHECKING: import numpy as np import numpy.typing as npt @@ -151,7 +153,12 @@ def swapaxes(self, axis1: int, axis2: int) -> NumPyThunk: ... @abstractmethod - def convert(self, rhs: Any, warn: bool = True) -> None: + def convert( + self, + rhs: Any, + warn: bool = True, + nan_op: ConvertCode = ConvertCode.NOOP, + ) -> None: ... @abstractmethod diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 7b2f8faed..cdf382a0a 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -221,6 +221,14 @@ enum CuNumericScanCode { CUNUMERIC_SCAN_SUM, }; +// Match these to ConvertCode in config.py +// Also, sort these alphabetically for easy lookup later +enum CuNumericConvertCode { + CUNUMERIC_CONVERT_NAN_NOOP = 1, + CUNUMERIC_CONVERT_NAN_PROD, + CUNUMERIC_CONVERT_NAN_SUM, +}; + // Match these to BitGeneratorOperation in config.py enum CuNumericBitGeneratorOperation { CUNUMERIC_BITGENOP_CREATE = 1, diff --git a/src/cunumeric/scan/scan_global_util.h b/src/cunumeric/scan/scan_global_util.h index 58321d308..b53ada288 100644 --- a/src/cunumeric/scan/scan_global_util.h +++ b/src/cunumeric/scan/scan_global_util.h @@ -40,8 +40,6 @@ constexpr decltype(auto) op_dispatch(ScanCode op_code, Functor f, Fnargs&&... ar return f.template operator()(std::forward(args)...); } -// RRRR not sure I fully understand these? - template struct ScanOp { }; diff --git a/src/cunumeric/unary/convert.cc b/src/cunumeric/unary/convert.cc index e333b0809..d96afc749 100644 --- a/src/cunumeric/unary/convert.cc +++ b/src/cunumeric/unary/convert.cc @@ -22,9 +22,9 @@ namespace cunumeric { using namespace Legion; using namespace legate; -template -struct ConvertImplBody { - using OP = ConvertOp; +template +struct ConvertImplBody { + using OP = ConvertOp; using SRC = legate_type_of; using DST = legate_type_of; diff --git a/src/cunumeric/unary/convert.cu b/src/cunumeric/unary/convert.cu index 8b778042e..db853b96e 100644 --- a/src/cunumeric/unary/convert.cu +++ b/src/cunumeric/unary/convert.cu @@ -42,9 +42,9 @@ static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) out[point] = func(in[point]); } -template -struct ConvertImplBody { - using OP = ConvertOp; +template +struct ConvertImplBody { + using OP = ConvertOp; using SRC = legate_type_of; using DST = legate_type_of; diff --git a/src/cunumeric/unary/convert.h b/src/cunumeric/unary/convert.h index 655399aaf..05bbfe112 100644 --- a/src/cunumeric/unary/convert.h +++ b/src/cunumeric/unary/convert.h @@ -16,6 +16,7 @@ #pragma once +#include "cunumeric/unary/convert_util.h" #include "cunumeric/cunumeric.h" namespace cunumeric { @@ -23,6 +24,7 @@ namespace cunumeric { struct ConvertArgs { const Array& out; const Array& in; + ConvertCode nan_op; }; class ConvertTask : public CuNumericTask { diff --git a/src/cunumeric/unary/convert_omp.cc b/src/cunumeric/unary/convert_omp.cc index 353f02cd4..090734a11 100644 --- a/src/cunumeric/unary/convert_omp.cc +++ b/src/cunumeric/unary/convert_omp.cc @@ -22,9 +22,9 @@ namespace cunumeric { using namespace Legion; using namespace legate; -template -struct ConvertImplBody { - using OP = ConvertOp; +template +struct ConvertImplBody { + using OP = ConvertOp; using SRC = legate_type_of; using DST = legate_type_of; diff --git a/src/cunumeric/unary/convert_template.inl b/src/cunumeric/unary/convert_template.inl index b6c715d57..41265892a 100644 --- a/src/cunumeric/unary/convert_template.inl +++ b/src/cunumeric/unary/convert_template.inl @@ -26,15 +26,19 @@ namespace cunumeric { using namespace Legion; using namespace legate; -template +template struct ConvertImplBody; -template +template struct ConvertImpl { template * = nullptr> void operator()(ConvertArgs& args) const { - using OP = ConvertOp; + using OP = ConvertOp; using SRC = legate_type_of; using DST = legate_type_of; @@ -57,7 +61,7 @@ struct ConvertImpl { #endif OP func{}; - ConvertImplBody()(func, out, in, pitches, rect, dense); + ConvertImplBody()(func, out, in, pitches, rect, dense); } template * = nullptr> @@ -67,20 +71,42 @@ struct ConvertImpl { } }; +template +struct ConvertDispatch { + template ::value || + legate::is_complex>::value) || + NAN_OP == ConvertCode::NOOP>* = nullptr> + void operator()(ConvertArgs& args) const + { + auto dim = std::max(1, args.out.dim()); + double_dispatch(dim, args.out.code(), ConvertImpl{}, args); + } + + template ::value || + legate::is_complex>::value) || + (NAN_OP == ConvertCode::NOOP))>* = nullptr> + void operator()(ConvertArgs& args) const + { + assert(false); + } +}; + template struct SourceTypeDispatch { template void operator()(ConvertArgs& args) const { - auto dim = std::max(1, args.out.dim()); - double_dispatch(dim, args.out.code(), ConvertImpl{}, args); + op_dispatch(args.nan_op, ConvertDispatch{}, args); } }; template static void convert_template(TaskContext& context) { - ConvertArgs args{context.outputs()[0], context.inputs()[0]}; + ConvertArgs args{ + context.outputs()[0], context.inputs()[0], context.scalars()[0].value()}; type_dispatch(args.in.code(), SourceTypeDispatch{}, args); } diff --git a/src/cunumeric/unary/convert_util.h b/src/cunumeric/unary/convert_util.h index 5f688d47e..3d4a10d48 100644 --- a/src/cunumeric/unary/convert_util.h +++ b/src/cunumeric/unary/convert_util.h @@ -17,11 +17,38 @@ #pragma once #include "cunumeric/cunumeric.h" +#include "cunumeric/unary/isnan.h" namespace cunumeric { -template +enum class ConvertCode : int { + NOOP = CUNUMERIC_CONVERT_NAN_NOOP, + PROD = CUNUMERIC_CONVERT_NAN_PROD, + SUM = CUNUMERIC_CONVERT_NAN_SUM, +}; + +template +constexpr decltype(auto) op_dispatch(ConvertCode nan_op, Functor f, Fnargs&&... args) +{ + switch (nan_op) { + case ConvertCode::NOOP: + return f.template operator()(std::forward(args)...); + case ConvertCode::PROD: + return f.template operator()(std::forward(args)...); + case ConvertCode::SUM: + return f.template operator()(std::forward(args)...); + default: break; + } + assert(false); + return f.template operator()(std::forward(args)...); +} + +template struct ConvertOp { +}; + +template +struct ConvertOp { using SRC = legate::legate_type_of; using DST = legate::legate_type_of; @@ -49,7 +76,7 @@ struct ConvertOp { }; template -struct ConvertOp { +struct ConvertOp { using SRC = legate::legate_type_of; template ::value>* = nullptr> @@ -66,7 +93,7 @@ struct ConvertOp { }; template -struct ConvertOp { +struct ConvertOp { using DST = legate::legate_type_of; constexpr DST operator()(const __half& src) const @@ -75,4 +102,108 @@ struct ConvertOp { } }; +template +struct ConvertOp { + using SRC = legate::legate_type_of; + using DST = legate::legate_type_of; + + template < + typename _SRC = SRC, + std::enable_if_t::value or legate::is_complex::value>* = nullptr> + constexpr DST operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast(1) : static_cast(src); + } + + template ::value and !legate::is_complex::value>* = + nullptr> + constexpr DST operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast(1) : static_cast(src.real()); + } +}; + +template +struct ConvertOp { + using SRC = legate::legate_type_of; + + template ::value>* = nullptr> + __CUDA_HD__ __half operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast<__half>(1) + : static_cast<__half>(static_cast(src)); + } + + template ::value>* = nullptr> + __CUDA_HD__ __half operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast<__half>(1) + : static_cast<__half>(static_cast(src.real())); + } +}; + +template +struct ConvertOp { + using DST = legate::legate_type_of; + + constexpr DST operator()(const __half& src) const + { + return cunumeric::is_nan(src) ? static_cast(1) + : static_cast(static_cast(src)); + } +}; + +template +struct ConvertOp { + using SRC = legate::legate_type_of; + using DST = legate::legate_type_of; + + template < + typename _SRC = SRC, + std::enable_if_t::value or legate::is_complex::value>* = nullptr> + constexpr DST operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast(0) : static_cast(src); + } + + template ::value and !legate::is_complex::value>* = + nullptr> + constexpr DST operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast(0) : static_cast(src.real()); + } +}; + +template +struct ConvertOp { + using SRC = legate::legate_type_of; + + template ::value>* = nullptr> + __CUDA_HD__ __half operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast<__half>(0) + : static_cast<__half>(static_cast(src)); + } + + template ::value>* = nullptr> + __CUDA_HD__ __half operator()(const _SRC& src) const + { + return cunumeric::is_nan(src) ? static_cast<__half>(0) + : static_cast<__half>(static_cast(src.real())); + } +}; + +template +struct ConvertOp { + using DST = legate::legate_type_of; + + constexpr DST operator()(const __half& src) const + { + return cunumeric::is_nan(src) ? static_cast(0) + : static_cast(static_cast(src)); + } +}; + } // namespace cunumeric diff --git a/tests/integration/test_scan.py b/tests/integration/test_scan.py index b745e03e3..009713eba 100644 --- a/tests/integration/test_scan.py +++ b/tests/integration/test_scan.py @@ -23,19 +23,22 @@ def _gen_array(n0, shape, dt, axis, outtype): - # range 1-10, avoiding zeros to ensure correct testing for int prod case + range_lower = 0 + # range 1-3 for ints, avoid zeros for correct testing in prod case + if np.issubdtype(dt, np.integer): + range_lower = 1 if dt == np.complex64: A = ( - (99 * np.random.random(shape) + 1) - + (99 * np.random.random(shape) + 1) * 1j + (3 * np.random.random(shape) + range_lower) + + (3 * np.random.random(shape) + range_lower) * 1j ).astype(np.complex64) elif dt == np.complex128: A = ( - (99 * np.random.random(shape) + 1) - + (99 * np.random.random(shape) + 1) * 1j + (3 * np.random.random(shape) + range_lower) + + (3 * np.random.random(shape) + range_lower) * 1j ).astype(np.complex128) else: - A = (99 * np.random.random(shape) + 1).astype(dt) + A = (3 * np.random.random(shape) + range_lower).astype(dt) if n0 == "first_half": # second element along all axes is a NAN A[(1,) * len(shape)] = np.nan @@ -58,6 +61,12 @@ def _gen_array(n0, shape, dt, axis, outtype): def _run_tests(op, n0, shape, dt, axis, out0, outtype): + if (np.issubdtype(dt, np.integer) and n0 is not None) or ( + np.issubdtype(outtype, np.integer) + and (op == "cumsum" or op == "cumprod") + and n0 is not None + ): + return print( f"Running test: {op}, shape: {shape}, nan location: {n0}" f", axis: {axis}, in type: {dt}, out type: {outtype}" @@ -89,74 +98,70 @@ def _run_tests(op, n0, shape, dt, axis, out0, outtype): "nancumsum", "nancumprod", ] -# keeping array sizes small to avoid accumulation variance -# between cunumeric and numpy +ops_nan = [ + "nancumsum", + "nancumprod", +] shapes = [ - [200], - [4, 50], + [100], + [4, 25], ] axes = [ None, 0, ] -out0s = [ - True, - False, +dtypes = [ + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + np.complex64, + np.complex128, +] +dtypes_simplified = [ + np.int32, + np.float32, + np.complex64, +] +n0s = [ + None, + "first_half", + "second_half", ] -@pytest.mark.parametrize("op", ops) @pytest.mark.parametrize("shape", shapes) @pytest.mark.parametrize("axis", axes) -@pytest.mark.parametrize("out0", out0s) -def test_scan(op, shape, axis, out0): - n0s = [ - None, - "first_half", - "second_half", - ] - int_types = [ - np.int16, - np.int32, - np.int64, - ] - float_types = [ - np.float32, - np.float64, - ] - complex_types = [ - np.complex64, - np.complex128, - ] - for outtype in int_types: - for dt in int_types: - _run_tests(op, None, shape, dt, axis, out0, outtype) - for dt in float_types: - for n0 in n0s: - print("Float to int NAN conversion currently not supported!") - for dt in complex_types: - for n0 in n0s: - print("Complex to int NAN conversion currently not supported!") - - for outtype in float_types: - for dt in int_types: - _run_tests(op, None, shape, dt, axis, out0, outtype) - for dt in float_types: - for n0 in n0s: - _run_tests(op, n0, shape, dt, axis, out0, outtype) - for dt in complex_types: - for n0 in n0s: - _run_tests(op, n0, shape, dt, axis, out0, outtype) - - for outtype in complex_types: - for dt in int_types: - _run_tests(op, None, shape, dt, axis, out0, outtype) - for dt in float_types: - for n0 in n0s: - _run_tests(op, n0, shape, dt, axis, out0, outtype) - for dt in complex_types: - for n0 in n0s: - _run_tests(op, n0, shape, dt, axis, out0, outtype) +@pytest.mark.parametrize("outtype", dtypes_simplified) +@pytest.mark.parametrize("dt", dtypes_simplified) +def test_scan_out0_shape(shape, axis, outtype, dt): + op = "cumsum" + out0 = True + n0 = None + _run_tests(op, n0, shape, dt, axis, out0, outtype) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("outtype", dtypes_simplified) +@pytest.mark.parametrize("dt", dtypes) +@pytest.mark.parametrize("n0", n0s) +def test_scan_nan(op, outtype, dt, n0): + shape = [100] + axis = None + out0 = False + _run_tests(op, n0, shape, dt, axis, out0, outtype) + + +@pytest.mark.parametrize("op", ops) +@pytest.mark.parametrize("outtype", dtypes_simplified) +@pytest.mark.parametrize("dt", dtypes_simplified) +def test_scan_op(op, outtype, dt): + shape = [100] + axis = None + out0 = False + n0 = None + _run_tests(op, n0, shape, dt, axis, out0, outtype) def test_empty_inputs():