Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Ternary ops in elmentwise and broadcast #33976

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

Expand Down
27 changes: 13 additions & 14 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct DimensionsTransform {

struct StridesCalculation {
std::vector<std::vector<uint32_t>> strides;
std::vector<FastDivMod> divmoders;
std::vector<platform::FastDivMod> divmoders;

private:
// To calculate the strides of each input_tensor.
Expand All @@ -190,7 +190,7 @@ struct StridesCalculation {
strides.resize(N, std::vector<uint32_t>(dim_size, 1));

for (int i = 0; i < dim_size; ++i) {
divmoders[i] = FastDivMod(out_dims[i]);
divmoders[i] = platform::FastDivMod(out_dims[i]);
}
CalculateStrides(N, dim_size, in_dims);
}
Expand All @@ -199,15 +199,15 @@ struct StridesCalculation {
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
int VecSize, int kDims>
struct BroadcastArgsWarpper {
using InVecType = CudaAlignedVector<InT, VecSize>;
using OutVecType = CudaAlignedVector<OutT, VecSize>;
using InVecType = platform::CudaAlignedVector<InT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;

OutT *out_data;
OutVecType *vec_out_data;
const InT *__restrict__ in_data[ET];
const InVecType *__restrict__ vec_in_data[ET];
bool no_broadcast[ET];
FastDivMod divmoders[kDims];
platform::FastDivMod divmoders[kDims];
uint32_t strides[ET][framework::DDim::kMaxRank];
uint32_t scalar_cal_offset;
Functor func;
Expand All @@ -227,7 +227,7 @@ struct BroadcastArgsWarpper {
out_data = out->data<OutT>();
vec_out_data = reinterpret_cast<OutVecType *>(out_data);
memcpy(divmoders, offset_calculator.divmoders.data(),
kDims * sizeof(FastDivMod));
kDims * sizeof(platform::FastDivMod));
}

__device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
Expand Down Expand Up @@ -310,18 +310,17 @@ __device__ inline void ScalarizedBroadcastKernelImpl(
OutT args_out;
broadcast_warpper.LoadScalarizedData(args, tid);

#pragma unroll(ET)
for (int j = 1; j < ET; ++j) {
args_out = broadcast_warpper.func(args);
}
// Calcualtion of the in_tensor data.
args_out = broadcast_warpper.func(args);

broadcast_warpper.StoreScalarizedData(args_out, tid);
}

template <typename InT, typename OutT, typename BroadcastArgsWarpper,
ElementwiseType ET, int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
BroadcastArgsWarpper broadcast_warpper, int tid) {
using OutVecType = CudaAlignedVector<OutT, VecSize>;
using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
OutVecType args_out;
InT ins[ET];
InT args[ET][VecSize];
Expand Down Expand Up @@ -367,7 +366,7 @@ void LaunchBroadcastKernelForDifferentDimSize(
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
int numel = out->numel();
const int threads = 256;
int threads = GetThreadsConfig(ctx, numel, VecSize);
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / VecSize;
int tail_tid = numel % VecSize;
Expand Down Expand Up @@ -473,11 +472,11 @@ void LaunchBroadcastElementwiseCudaKernel(
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);

switch (vec_size) {
Expand Down
Loading