Skip to content

Commit

Permalink
[FEATURE] Use RTC for reduction ops (apache#19426)
Browse files Browse the repository at this point in the history
* Initial rebase

* Fixes after merge

* Fixes

* Fix lint

* Fix lint for real

* Cleaning and code reuse

* Fix lint

* Try to WAR the maybe-uninitialized warning

* Second try

* Fix Windows compilation

* More fixes for Windows compilation

* Breaking the strings to please Windows compiler

* Do not use the default stream in kron

* Fix argmin/argmax

* Fix layernorm
  • Loading branch information
ptrendx authored May 25, 2021
1 parent ec119d3 commit 57d0ace
Show file tree
Hide file tree
Showing 53 changed files with 1,944 additions and 1,907 deletions.
5 changes: 4 additions & 1 deletion src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,16 @@ CUfunction get_function(const std::string &parameters,
std::string(fp16_support_string) + "\n" +
type_support_string + "\n" +
util_string + "\n" +
limits + "\n" +
special_functions_definitions + '\n' +
vectorization_support_string + "\n" +
function_definitions_util + "\n" +
function_definitions_binary + "\n" +
function_definitions_unary + "\n" +
backward_function_definitions + "\n" +
reducer + "\n";
grad_function_definitions + "\n" +
reducer + "\n" +
logic_reducer + "\n";
std::string code_with_header = common_header + parameters + code;
// If verbose mode, output kernel source, though not including the common header
if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
Expand Down
240 changes: 159 additions & 81 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,98 @@ backward_square(const DTypeGrad grad, const DType val) {
return 2 * val * grad;
}
template <typename DType, typename DType2>
__device__ inline DType div_rgrad(const DType val,
const DType2 val2) {
return -val / (val2 * val2);
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_clip(const DTypeGrad grad, const DType val,
const float a_min, const float a_max) {
if (val > a_max || val < a_min) {
return 0;
} else {
return grad;
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_reciprocal(const DTypeGrad grad, const DType val) {
return -grad / (val * val);
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erf(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
constexpr type my_pi = pi;
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erfinv(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
constexpr type my_pi = pi;
const type g = grad;
const type v = val;
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gamma(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gammaln(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::special_functions::cephes::psi<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_digamma(const DTypeGrad grad, const DType val) {
using type = mixed_type<DTypeGrad, DType>;
const type v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::trigamma<double>(v);
} else {
return grad * op::special_functions::trigamma<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gelu(const DTypeGrad grad, const DType val) {
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
}
} // namespace op
)code";

const char grad_function_definitions[] = R"code(
namespace op {
template <typename DType, typename DType2>
__device__ inline mixed_type<DType, DType2>
rdiv_grad(const DType val,
Expand All @@ -252,12 +344,6 @@ div_grad(const DType val,
return op::reciprocal(temp);
}
template <typename DType, typename DType2>
__device__ inline DType div_rgrad(const DType val,
const DType2 val2) {
return -val / (val2 * val2);
}
template <typename DType, typename DType2>
__device__ inline DType mod_grad(const DType val,
const DType2 val2) {
Expand Down Expand Up @@ -368,80 +454,6 @@ rldexp_grad(const DType val,
return val2 * op::power(static_cast<type>(2), val) * op::log(static_cast<type>(2));
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_clip(const DTypeGrad grad, const DType val,
const float a_min, const float a_max) {
if (val > a_max || val < a_min) {
return 0;
} else {
return grad;
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_reciprocal(const DTypeGrad grad, const DType val) {
return -grad / (val * val);
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erf(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_erfinv(const DTypeGrad grad, const DType val) {
constexpr mixed_type<DTypeGrad, DType> my_pi = pi;
const mixed_type<DTypeGrad, DType> g = grad;
const mixed_type<DTypeGrad, DType> v = val;
return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gamma(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gammaln(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::cephes::psi<double>(v);
} else {
return grad * op::special_functions::cephes::psi<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_digamma(const DTypeGrad grad, const DType val) {
const mixed_type<DTypeGrad, DType> v = val;
if (type_util::is_same<DTypeGrad, double>::value) {
return grad * op::special_functions::trigamma<double>(v);
} else {
return grad * op::special_functions::trigamma<float>(v);
}
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_gelu(const DTypeGrad grad, const DType val) {
return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
}
template <typename DType, typename DType2>
__device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
auto bsq = scalar * scalar;
Expand All @@ -467,8 +479,74 @@ __device__ inline DType prelu_grad(const DType val,
return (val > 0) ? 0 : val;
}
} // namespace op
template <typename DType, typename DType2>
__device__ inline mixed_type<DType2, DType>
gamma_implicit_grad(const DType a_in, const DType2 x_in) {
using OType = mixed_type<DType2, DType>;
const OType a = a_in;
const OType x = x_in;
if (x < 0.8f) {
OType numer = 1;
OType denom = a;
OType series1 = numer / denom;
OType series2 = numer / (denom * denom);
#pragma unroll
for (int i = 1; i <= 5; i++) {
numer *= -x / static_cast<DType>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
OType pow_x_alpha = op::power(x, a);
OType gamma_pdf = op::power(x, a - 1) * op::exp(-x);
OType gamma_cdf = pow_x_alpha * series1;
OType gamma_cdf_alpha =
(op::log(x) - OType(special_functions::cephes::psi<float>(a))) *
gamma_cdf -
pow_x_alpha * series2;
OType result = -gamma_cdf_alpha / gamma_pdf;
return op::isnan(result) ? 0.f : result;
}
if (a > 8.0f) {
if (0.9f * a <= x && x <= 1.1f * a) {
OType numer_1 = 1 + 24 * a * (1 + 12 * a);
OType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
65 * x * x / a + a * (107 + 3600 * x);
OType denom = 1244160 * (a * a) * (a * a);
return numer_1 * numer_2 / denom;
}
OType denom = op::sqrt(8 * a);
OType term2 = denom / (a - x);
OType term3 =
op::power(x - a - a * op::log(x / a), static_cast<OType>(-1.5));
OType term23 = (x < a) ? term2 - term3 : term2 + term3;
OType term1 = op::log(x / a) * term23 -
op::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
OType stirling = 1.f + 1.f / (12.f * a) * (1.f + 1.f / (24.f * a));
OType numer = x * term1;
return -stirling * numer / denom;
}
OType u = op::log(x / a);
OType v = op::log(a);
OType coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
-0.0021309326, 0.00085092367, -1.5247877e-07},
};
OType coef_v[8];
#pragma unroll
for (int i = 0; i < 8; i++) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
OType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
OType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return op::exp(p / q);
}
} // namespace op
)code";

} // namespace rtc
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ __device__ inline DType log_sigmoid(const DType val) {
template <typename DType>
__device__ inline DType softrelu(const DType val) {
// Avoid overflow of exp for large inputs.
// The threshold 20 is chosen such that softrelu(a) = a
// for a > 20 using floating precision.
if (val > 20) return val;
if (type_util::has_double_or_integral<DType>::value) {
return ::log(1 + ::exp(val));
} else {
Expand Down Expand Up @@ -936,6 +940,11 @@ __device__ inline bool_t np_logical_not(const DType val) {
return !static_cast<bool>(val);
}
template <typename DType>
__device__ inline bool_t NonZero(const DType val) {
return val != 0;
}
#undef DEFINE_UNARY_MATH_FUNC
template <typename DType>
Expand Down
Loading

0 comments on commit 57d0ace

Please sign in to comment.