Skip to content

Commit

Permalink
fix gpu code
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 12, 2024
1 parent eb13906 commit 58480f5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
24 changes: 16 additions & 8 deletions paddle/phi/kernels/gpu/clip_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
namespace phi {

template <typename T>
class ClipWithTensorGradFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T min_, const T max_) const {
return (y > min_ && y < max_) ? x : static_cast<T>(0);
__global__ void ClipWithTensorGradFunctor(const int N, const T* out_grad, const T* x, const T* min, const T* max, T* x_grad) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
x_grad[idx] = (x[idx] > min[idx]) && (x[idx] < max[idx]) ? out_grad[idx] : static_cast<T>(0);
}
};

Expand All @@ -37,11 +38,18 @@ void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {

std::vector<const DenseTensor*> ins = {&out_grad, &x, &min, &max};
std::vector<DenseTensor*> outs = {x_grad};
ClipWithTensorGradFunctor<T> func;
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::ElementwiseKernel<T, ClipWithTensorGradFunctor<T>, 1>(dev_ctx, ins, &outs, func);
const T* x_data = x.data<T>();
auto numel = x.numel();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
const T* out_grad_data = out_grad.data<T>();

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);

auto stream = dev_ctx.stream();
auto config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
ClipWithTensorGradFunctor<T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, out_grad_data, x_data, min_data, max_data, x_grad_data);
}

}
Expand Down
39 changes: 35 additions & 4 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,47 @@ void ClipWithTensorKernel(const Context& dev_ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
using XPUDataType = typename XPUTypeTrait<T>::Type;

const XPUDataType* x_data = reinterpret_cast<const XPUDataType*>(x.data<T>());
const XPUDataType* min_data = reinterpret_cast<const XPUDataType*>(min.data<T>());
const XPUDataType* max_data = reinterpret_cast<const XPUDataType*>(max.data<T>());
XPUDataType* out_data = reinterpret_cast<XPUDataType*>(dev_ctx.template Alloc<T>(out));

auto min_dims = common::vectorize<int>(min.dims());
if (min_dims.size() == 0) {
min_dims = std::vector<int>({1});
}
auto max_dims = common::vectorize<int>(max.dims());
if (max_dims.size() == 0) {
max_dims = std::vector<int>({1});
}

DenseTensor min_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, x, min, &min_tensor);
WhereKernel<T, Context>(dev_ctx, min_tensor, min, x, out);

auto min_tensor_dims = common::vectorize<int>(min_tensor.dims());
if (min_tensor_dims.size() == 0) {
min_tensor_dims = std::vector<int>({1});
}

const bool* min_tensor_data = min_tensor.data<bool>();
int ret = xpu::select(
dev_ctx.x_context(), min_tensor_data, min_data, x_data, out_data, min_tensor_dims, min_dims);

PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");

DenseTensor max_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, max, x, &max_tensor);
WhereKernel<T, Context>(dev_ctx, max_tensor, max, x, out);

auto max_tensor_dims = common::vectorize<int>(max_tensor.dims());
if (max_tensor_dims.size() == 0) {
max_tensor_dims = std::vector<int>({1});
}

const bool* max_tensor_data = max_tensor.data<bool>();
int ret2 = xpu::select(
dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select");

}

Expand Down

0 comments on commit 58480f5

Please sign in to comment.