-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of broadcast div backward by reduce #38044
Changes from 1 commit
d3173f8
c6cef2e
9265a8d
080bf95
f0f1cf3
8c43581
b1f58dc
e07e54e
3594f6b
7adf371
560ed45
2920824
476c797
8259c34
d2f3776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
|
@@ -23,6 +25,34 @@ namespace plat = paddle::platform; | |
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
struct MulDxDyFunctor { | ||
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } | ||
}; | ||
template <typename T> | ||
struct MulDxDyFunctor<paddle::platform::complex<T>> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接叫MulFunctor和DivFunctor不行吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 名字修改为MulGradFunctor和DivGradFunctor |
||
inline HOSTDEVICE paddle::platform::complex<T> operator()( | ||
const paddle::platform::complex<T>& x, | ||
const paddle::platform::complex<T>& y) const { | ||
paddle::platform::complex<T> y_conj(y.real, -y.imag); | ||
return x * y_conj; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct DivDxDyFunctor { | ||
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } | ||
}; | ||
template <typename T> | ||
struct DivDxDyFunctor<paddle::platform::complex<T>> { | ||
inline HOSTDEVICE paddle::platform::complex<T> operator()( | ||
const paddle::platform::complex<T>& x, | ||
const paddle::platform::complex<T>& y) const { | ||
paddle::platform::complex<T> y_conj(y.real, -y.imag); | ||
return x / y_conj; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, | ||
const T* out, | ||
|
@@ -33,7 +63,9 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, | |
|
||
while (col < size) { | ||
T o = dout[col]; | ||
dx[col] = o / y[col]; | ||
if (dx != nullptr) { | ||
dx[col] = o / y[col]; | ||
} | ||
dy[col] = -o * out[col] / y[col]; | ||
col += blockDim.x * gridDim.x; | ||
} | ||
|
@@ -55,7 +87,9 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>( | |
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
if (dx != nullptr) { | ||
dx[col] = o / y_conj; | ||
} | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种写法可以修改成为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
|
@@ -77,12 +111,85 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>( | |
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
if (dx != nullptr) { | ||
dx[col] = o / y_conj; | ||
} | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
} | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
auto* dout_data = dout->data<T>(); | ||
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. block_size 定义了但没有被使用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经删掉 |
||
// dx | ||
if (dx != nullptr) { | ||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mutable_data的结果不必传给指针(下文没用到指针),下同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (dx->dims() == dout->dims()) { | ||
// dx = dout/y | ||
ElementwiseComputeEx<DivDxDyFunctor<T>, DeviceContext, T>( | ||
ctx, dout, y, axis, DivDxDyFunctor<T>(), dx); | ||
} else { | ||
// For inplace strategy, dx will be stored in addr of dout, which makes | ||
// the result of dy wrong. | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
} | ||
framework::Tensor div_dx; | ||
div_dx.Resize(dout->dims()); | ||
|
||
ElementwiseComputeEx<DivDxDyFunctor<T>, DeviceContext, T>( | ||
ctx, dout, y, axis, DivDxDyFunctor<T>(), &div_dx); | ||
|
||
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
TensorReduceFunctorImpl<T, T, CustomSum>(div_dx, dx, reduce_dims, stream); | ||
} | ||
} | ||
// dy | ||
if (dy != nullptr) { | ||
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); | ||
if (dy->dims() == dout->dims()) { | ||
if (dy_data != dout_data) { | ||
// dy = - dout * out / y | ||
auto size = dy->numel(); | ||
dim3 grid_size = dim3( | ||
(size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); | ||
SimpleElemwiseDivGradCUDAKernel<T><<< | ||
grid_size, block_size, 0, | ||
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>( | ||
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size, | ||
nullptr, dy->mutable_data<T>(ctx.GetPlace())); | ||
} | ||
} else { | ||
// dy = - dout * out / y | ||
framework::Tensor mul_dy; | ||
mul_dy.Resize(dout->dims()); | ||
ElementwiseComputeEx<MulDxDyFunctor<T>, DeviceContext, T>( | ||
ctx, dout, out, axis, MulDxDyFunctor<T>(), &mul_dy); | ||
|
||
framework::Tensor div_dy; | ||
div_dy.Resize(dout->dims()); | ||
ElementwiseComputeEx<DivDxDyFunctor<T>, DeviceContext, T>( | ||
ctx, &mul_dy, y, axis, DivDxDyFunctor<T>(), &div_dy); | ||
|
||
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
TensorReduceFunctorImpl<T, T, CustomSub>(div_dy, dy, reduce_dims, stream); | ||
} | ||
} | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,21 @@ struct DivDoubleDY { | |
} | ||
}; | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
|
@@ -117,13 +132,21 @@ elementwise_div_grad(const framework::ExecutionContext& ctx, | |
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); | ||
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
template <typename DeviceContext, typename T> | ||
// cuda definition | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy); | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
|
@@ -147,14 +170,12 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { | |
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DefaultElementwiseDivGrad已经包括这个分支了,可以删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} else { | ||
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), | ||
DivGradDY<T>()); | ||
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default也改个名字吧,比如改成Common,或者其他更好的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续会统一修改 |
||
dy); | ||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经删除