Skip to content

Commit

Permalink
add multi-output
Browse files Browse the repository at this point in the history
  • Loading branch information
Zjq9409 committed Dec 27, 2021
1 parent b1f58dc commit e07e54e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 62 deletions.
130 changes: 73 additions & 57 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Expand All @@ -31,13 +30,18 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
const T* dout,
int64_t size, T* dx,
T* dy) {
// for (int i = blockIdx.x * blockDim.x + threadIdx.x;
// i < size; i += blockDim.x * gridDim.x) {
// T o = dout[i];
// dx[i] = o / y[i];
// dy[i] = -o * out[i] / y[i];
// }

int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
T o = dout[col];
if (dx != nullptr) {
dx[col] = o / y[col];
}
dx[col] = o / y[col];
dy[col] = -o * out[col] / y[col];
col += blockDim.x * gridDim.x;
}
Expand All @@ -59,10 +63,8 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dy[col] = -o * out_div_y_conj;
dx[col] = o / y_conj;
dy[col] = -dout[col] * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}
Expand All @@ -83,14 +85,30 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}

template <typename T>
void reduce_functor(const framework::ExecutionContext& ctx,
const framework::Tensor* in, const framework::Tensor* out,
framework::Tensor* src, framework::Tensor* dst) {
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (dst->dims() == out->dims()) {
// dst->ShareDataWith(*src);
framework::TensorCopy(*src, ctx.GetPlace(), dev_ctx, dst);
return;
}
int axis = ctx.Attr<int>("axis");
std::vector<int> reduce_dims = GetReduceDim(in->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, stream);
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
Expand All @@ -103,63 +121,61 @@ default_elementwise_div_grad(const framework::ExecutionContext& ctx,
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
// dx
if (dx != nullptr) {
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
framework::Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace());
framework::Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());
if (dx != nullptr && dy != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
if (dx->dims() == dout->dims()) {
// dx = dout/y
ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
ctx, dout, y, axis, DivGradFunctor<T>(), dx);
// dout.dims==out.dims
std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs = {&tmp_dx, &tmp_dy};
auto functor = DivGradXYFunctor<T, T>();
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kTernary, T, T,
decltype(functor), 2>(
dev_ctx, ins, &outs, axis, functor);

if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
dx->ShareDataWith(tmp_dx);
dy->ShareDataWith(tmp_dy);
} else {
framework::Tensor tmp_dx;
tmp_dx.Resize(dout->dims());

ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
ctx, dout, y, axis, DivGradFunctor<T>(), &tmp_dx);

std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
tmp_dx, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
reduce_functor<T>(ctx, x, out, &tmp_dx, dx);
reduce_functor<T>(ctx, y, out, &tmp_dy, dy);
}
}
// dy
if (dy != nullptr) {
} else if (dx != nullptr && dy == nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<const framework::Tensor*> ins = {dout, y};
std::vector<framework::Tensor*> outs = {&tmp_dx};
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, axis, DivGradFunctor<T>());
if (dx->dims() != dout->dims()) {
reduce_functor<T>(ctx, x, out, &tmp_dx, dx);
} else {
dx->ShareDataWith(tmp_dx);
}
} else if (dy != nullptr && dx == nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
// dy = - dout * out / y
auto size = dy->numel();
dim3 grid_size = dim3(
(size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseDivGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
nullptr, dy->mutable_data<T>(ctx.GetPlace()));
}
std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs = {&tmp_dy};
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>());
if (dy->dims() != dout->dims()) {
reduce_functor<T>(ctx, y, out, &tmp_dy, dy);
} else {
framework::Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());

std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs = {&tmp_dy};

const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>());

std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
tmp_dy, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
dy->ShareDataWith(tmp_dy);
}
}
}
Expand Down
35 changes: 32 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
Expand Down Expand Up @@ -131,6 +132,34 @@ struct MulGradFunctor<paddle::platform::complex<T>> {
}
};

template <typename InT, typename OutT>
struct DivGradXYFunctor {
paddle::framework::Array<OutT, 2> outs;
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(InT a, InT b,
InT c) {
// dx = dout / y
// dy = - dout * out / y
outs[0] = a / c;
outs[1] = -a * b / c;
return outs;
}
};

template <typename InT, typename OutT>
struct DivGradXYFunctor<paddle::platform::complex<InT>,
paddle::platform::complex<OutT>> {
paddle::framework::Array<paddle::platform::complex<OutT>, 2> outs;
inline HOSTDEVICE paddle::framework::Array<paddle::platform::complex<OutT>, 2>
operator()(paddle::platform::complex<InT> a, paddle::platform::complex<InT> b,
paddle::platform::complex<InT> c) {
paddle::platform::complex<InT> c_conj(c.real, -c.imag);
paddle::platform::complex<InT> out_div_y_conj((b / c).real, -(b / c).imag);
outs[0] = a / c_conj;
outs[1] = -a * out_div_y_conj;
return outs;
}
};

// Float div grad
template <typename T>
struct DivGradFunctor {
Expand All @@ -152,7 +181,7 @@ struct DivGradFunctor<paddle::platform::complex<T>> {
template <typename T>
struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const {
return a * b / c;
return -a * b / c;
}
};

Expand All @@ -163,8 +192,8 @@ struct DivGradYFunctor<paddle::platform::complex<T>> {
const paddle::platform::complex<T>& a,
const paddle::platform::complex<T>& b,
const paddle::platform::complex<T>& c) const {
paddle::platform::complex<T> c_conj(c.real, -c.imag);
return a * b / c_conj;
paddle::platform::complex<T> out_div_y_conj((b / c).real, -(b / c).imag);
return -a * out_div_y_conj;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ struct DimensionsTransform {
}
};

template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
Expand Down Expand Up @@ -190,7 +191,7 @@ void LaunchBroadcastElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(
ctx, pt_inputs, &pt_outputs, axis, func);
}

Expand Down

0 comments on commit e07e54e

Please sign in to comment.