Skip to content

Commit

Permalink
Refactor isnan (replace Isnan with is_nan) and remove redundant file.
Browse files Browse the repository at this point in the history
  • Loading branch information
rkarim2 committed Aug 10, 2022
1 parent c4f4d8b commit 07e8340
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions src/cunumeric/unary/convert_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include "cunumeric/cunumeric.h"
#include "cunumeric/isnan.h"
#include "cunumeric/unary/isnan.h"

namespace cunumeric {

Expand Down Expand Up @@ -101,53 +101,50 @@ struct ConvertOp<ConvertCode::PROD, DST_TYPE, SRC_TYPE> {
using SRC = legate::legate_type_of<SRC_TYPE>;
using DST = legate::legate_type_of<DST_TYPE>;

cunumeric::Isnan<SRC_TYPE> isn;

template <
typename _SRC = SRC,
std::enable_if_t<!legate::is_complex<_SRC>::value or legate::is_complex<DST>::value>* = nullptr>
constexpr DST operator()(const _SRC& src) const
{
return isn(src) ? static_cast<DST>(1) : static_cast<DST>(src);
return cunumeric::is_nan(src) ? static_cast<DST>(1) : static_cast<DST>(src);
}

template <typename _SRC = SRC,
std::enable_if_t<legate::is_complex<_SRC>::value and !legate::is_complex<DST>::value>* =
nullptr>
constexpr DST operator()(const _SRC& src) const
{
return isn(src) ? static_cast<DST>(1) : static_cast<DST>(src.real());
return cunumeric::is_nan(src) ? static_cast<DST>(1) : static_cast<DST>(src.real());
}
};

template <legate::LegateTypeCode SRC_TYPE>
struct ConvertOp<ConvertCode::PROD, legate::LegateTypeCode::HALF_LT, SRC_TYPE> {
using SRC = legate::legate_type_of<SRC_TYPE>;

cunumeric::Isnan<SRC_TYPE> isn;

template <typename _SRC = SRC, std::enable_if_t<!legate::is_complex<_SRC>::value>* = nullptr>
__CUDA_HD__ __half operator()(const _SRC& src) const
{
return isn(src) ? static_cast<__half>(1) : static_cast<__half>(static_cast<double>(src));
return cunumeric::is_nan(src) ? static_cast<__half>(1)
: static_cast<__half>(static_cast<double>(src));
}

template <typename _SRC = SRC, std::enable_if_t<legate::is_complex<_SRC>::value>* = nullptr>
__CUDA_HD__ __half operator()(const _SRC& src) const
{
return isn(src) ? static_cast<__half>(1) : static_cast<__half>(static_cast<double>(src.real()));
return cunumeric::is_nan(src) ? static_cast<__half>(1)
: static_cast<__half>(static_cast<double>(src.real()));
}
};

template <legate::LegateTypeCode DST_TYPE>
struct ConvertOp<ConvertCode::PROD, DST_TYPE, legate::LegateTypeCode::HALF_LT> {
using DST = legate::legate_type_of<DST_TYPE>;

cunumeric::Isnan<HALF_LT> isn;

constexpr DST operator()(const __half& src) const
{
return isn(src) ? static_cast<DST>(1) : static_cast<DST>(static_cast<double>(src));
return cunumeric::is_nan(src) ? static_cast<DST>(1)
: static_cast<DST>(static_cast<double>(src));
}
};

Expand All @@ -156,53 +153,50 @@ struct ConvertOp<ConvertCode::SUM, DST_TYPE, SRC_TYPE> {
using SRC = legate::legate_type_of<SRC_TYPE>;
using DST = legate::legate_type_of<DST_TYPE>;

cunumeric::Isnan<SRC_TYPE> isn;

template <
typename _SRC = SRC,
std::enable_if_t<!legate::is_complex<_SRC>::value or legate::is_complex<DST>::value>* = nullptr>
constexpr DST operator()(const _SRC& src) const
{
return isn(src) ? static_cast<DST>(0) : static_cast<DST>(src);
return cunumeric::is_nan(src) ? static_cast<DST>(0) : static_cast<DST>(src);
}

template <typename _SRC = SRC,
std::enable_if_t<legate::is_complex<_SRC>::value and !legate::is_complex<DST>::value>* =
nullptr>
constexpr DST operator()(const _SRC& src) const
{
return isn(src) ? static_cast<DST>(0) : static_cast<DST>(src.real());
return cunumeric::is_nan(src) ? static_cast<DST>(0) : static_cast<DST>(src.real());
}
};

template <legate::LegateTypeCode SRC_TYPE>
struct ConvertOp<ConvertCode::SUM, legate::LegateTypeCode::HALF_LT, SRC_TYPE> {
using SRC = legate::legate_type_of<SRC_TYPE>;

cunumeric::Isnan<SRC_TYPE> isn;

template <typename _SRC = SRC, std::enable_if_t<!legate::is_complex<_SRC>::value>* = nullptr>
__CUDA_HD__ __half operator()(const _SRC& src) const
{
return isn(src) ? static_cast<__half>(0) : static_cast<__half>(static_cast<double>(src));
return cunumeric::is_nan(src) ? static_cast<__half>(0)
: static_cast<__half>(static_cast<double>(src));
}

template <typename _SRC = SRC, std::enable_if_t<legate::is_complex<_SRC>::value>* = nullptr>
__CUDA_HD__ __half operator()(const _SRC& src) const
{
return isn(src) ? static_cast<__half>(0) : static_cast<__half>(static_cast<double>(src.real()));
return cunumeric::is_nan(src) ? static_cast<__half>(0)
: static_cast<__half>(static_cast<double>(src.real()));
}
};

template <legate::LegateTypeCode DST_TYPE>
struct ConvertOp<ConvertCode::SUM, DST_TYPE, legate::LegateTypeCode::HALF_LT> {
using DST = legate::legate_type_of<DST_TYPE>;

cunumeric::Isnan<HALF_LT> isn;

constexpr DST operator()(const __half& src) const
{
return isn(src) ? static_cast<DST>(0) : static_cast<DST>(static_cast<double>(src));
return cunumeric::is_nan(src) ? static_cast<DST>(0)
: static_cast<DST>(static_cast<double>(src));
}
};

Expand Down

0 comments on commit 07e8340

Please sign in to comment.