Skip to content

Commit

Permalink
Support Ternary ops in elmentwise and broadcast (#33976)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy authored Aug 5, 2021
1 parent a68709d commit 1d7b75d
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 195 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

Expand Down
125 changes: 62 additions & 63 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct DimensionsTransform {

struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides;
std::vector<FastDivMod> divmoders;
std::vector<platform::FastDivMod> divmoders;

private:
// To calculate the strides of each input_tensor.
Expand All @@ -190,29 +190,29 @@ struct StridesCalculation {
strides.resize(N, std::vector<uint32_t>(dim_size, 1));

for (int i = 0; i < dim_size; ++i) {
divmoders[i] = FastDivMod(out_dims[i]);
divmoders[i] = platform::FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
};

template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWarpper {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
struct BroadcastArgsWrapper {
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;

OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
FastDivMod divmoders[kDims];
platform::FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;

HOSTDEVICE BroadcastArgsWarpper(
HOSTDEVICE BroadcastArgsWrapper(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int scalar_cal_offset, Functor func,
const StridesCalculation &offset_calculator)
Expand All @@ -227,7 +227,7 @@ struct BroadcastArgsWarpper {
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod));
kDims * sizeof(platform::FastDivMod));
}

__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
Expand Down Expand Up @@ -302,61 +302,60 @@ struct BroadcastArgsWarpper {
}
};

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
BroadcastArgsWrapper broadcast_wrapper, int tid) {
InT args[ET];
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid);
broadcast_wrapper.LoadScalarizedData(args, tid);

#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
args_out = broadcast_warpper.func(args);
}
broadcast_warpper.StoreScalarizedData(args_out, tid);
// Calcualtion of the in_tensor data.
args_out = broadcast_wrapper.func(args);

broadcast_wrapper.StoreScalarizedData(args_out, tid);
}

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
using OutVecType = CudaAlignedVector<OutT, VecSize>;
BroadcastArgsWrapper broadcast_wrapper, int tid) {
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
broadcast_warpper.LoadVectorizedData(args, tid);
broadcast_wrapper.LoadVectorizedData(args, tid);

#pragma unroll(VecSize)
for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
for (int j = 0; j < ET; ++j) {
ins[j] = args[j][i];
}
args_out.val[i] = broadcast_warpper.func(ins);
args_out.val[i] = broadcast_wrapper.func(ins);
}
broadcast_warpper.StoreVectorizedData(args_out, tid);
broadcast_wrapper.StoreVectorizedData(args_out, tid);
}

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
ElementwiseType ET, int VecSize>
__global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if (tid < main_tid) {
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
broadcast_warpper, tid);
VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET, VecSize>(
broadcast_wrapper, tid);
}
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if (tid < tail_tid) {
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
broadcast_warpper, tid);
ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET>(
broadcast_wrapper, tid);
}
}

Expand All @@ -367,7 +366,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
int numel = out->numel();
const int threads = 256;
int threads = GetThreadsConfig(ctx, numel, VecSize);
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
Expand All @@ -380,75 +379,75 @@ void LaunchBroadcastKernelForDifferentDimSize(

switch (merge_dims.dim_size) {
case 1: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 1>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 2: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 2>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 3: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 3>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 4: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 4>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 5: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 5>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 6: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 6>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 7: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 7>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
case 8: {
auto broadcast_warpper =
BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
auto broadcast_wrapper =
BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 8>(
ins, out, vec_len, func, offset_calculator);
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
VecSize><<<blocks, threads, 0, stream>>>(
broadcast_warpper, main_tid, tail_tid);
broadcast_wrapper, main_tid, tail_tid);
break;
}
default: {
Expand All @@ -473,11 +472,11 @@ void LaunchBroadcastElementwiseCudaKernel(
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);

switch (vec_size) {
Expand Down
Loading

0 comments on commit 1d7b75d

Please sign in to comment.