Skip to content

Commit

Permalink
add xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 25, 2024
1 parent 7b543b7 commit 0d8a99c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 84 deletions.
52 changes: 26 additions & 26 deletions paddle/phi/kernels/xpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ void ClipGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
}

// template <typename T, typename Context>
// void ClipTensorGradKernel(const Context& dev_ctx,
// const DenseTensor& x,
// const DenseTensor& min,
// const DenseTensor& max,
// const DenseTensor& out_grad,
// DenseTensor* x_grad) {
// dev_ctx.template Alloc<T>(x_grad);
template <typename T, typename Context>
void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);

// DenseTensor min_tensor(phi::DataType::BOOL);
// DenseTensor max_tensor(phi::DataType::BOOL);
// LessThanKernel<T, Context>(dev_ctx, min, x, &min_tensor);
// LessThanKernel<T, Context>(dev_ctx, x, max, &max_tensor);
// DenseTensor out(phi::DataType::BOOL);
// EqualKernel<T, Context>(dev_ctx, min_tensor, max_tensor, &out);
// DenseTensor zero_tensor(x_grad->dtype());
// FullKernel<T, Context>(dev_ctx, common::vectorize(x_grad->dims()), 0.0f, zero_tensor.dtype(), &zero_tensor);
// WhereKernel<T, Context>(dev_ctx, out, out_grad, zero_tensor, x_grad);
// }
DenseTensor min_tensor(phi::DataType::BOOL);
DenseTensor max_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, min, x, &min_tensor);
LessThanKernel<T, Context>(dev_ctx, x, max, &max_tensor);
DenseTensor out(phi::DataType::BOOL);
EqualKernel<T, Context>(dev_ctx, min_tensor, max_tensor, &out);
DenseTensor zero_tensor(x_grad->dtype());
FullKernel<T, Context>(dev_ctx, common::vectorize(x_grad->dims()), 0.0f, zero_tensor.dtype(), &zero_tensor);
WhereKernel<T, Context>(dev_ctx, out, out_grad, zero_tensor, x_grad);
}
} // namespace phi

PD_REGISTER_KERNEL(clip_grad,
Expand All @@ -74,11 +74,11 @@ PD_REGISTER_KERNEL(clip_grad,
int64_t,
int) {}

// PD_REGISTER_KERNEL(clip_tensor_grad,
// XPU,
// ALL_LAYOUT,
// phi::ClipTensorGradKernel,
// float,
// phi::dtype::float16,
// int64_t,
// int) {}
PD_REGISTER_KERNEL(clip_tensor_grad,
XPU,
ALL_LAYOUT,
phi::ClipTensorGradKernel,
float,
phi::dtype::float16,
int64_t,
int) {}
116 changes: 58 additions & 58 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,55 +50,55 @@ void ClipKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r]));
}

// template <typename T, typename Context>
// void ClipTensorKernel(const Context& dev_ctx,
// const DenseTensor& x,
// const DenseTensor& min,
// const DenseTensor& max,
// DenseTensor* out) {
// using XPUDataType = typename XPUTypeTrait<T>::Type;
// const XPUDataType* x_data = reinterpret_cast<const XPUDataType*>(x.data<T>());
// const XPUDataType* min_data = reinterpret_cast<const XPUDataType*>(min.data<T>());
// const XPUDataType* max_data = reinterpret_cast<const XPUDataType*>(max.data<T>());
// XPUDataType* out_data = reinterpret_cast<XPUDataType*>(dev_ctx.template Alloc<T>(out));

// auto min_dims = common::vectorize<int>(min.dims());
// if (min_dims.size() == 0) {
// min_dims = std::vector<int>({1});
// }
// auto max_dims = common::vectorize<int>(max.dims());
// if (max_dims.size() == 0) {
// max_dims = std::vector<int>({1});
// }

// DenseTensor min_tensor(phi::DataType::BOOL);
// LessThanKernel<T, Context>(dev_ctx, x, min, &min_tensor);

// auto min_tensor_dims = common::vectorize<int>(min_tensor.dims());
// if (min_tensor_dims.size() == 0) {
// min_tensor_dims = std::vector<int>({1});
// }

// const bool* min_tensor_data = min_tensor.data<bool>();
// int ret = xpu::select(
// dev_ctx.x_context(), min_tensor_data, min_data, x_data, out_data, min_tensor_dims, min_dims);

// PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");

// DenseTensor max_tensor(phi::DataType::BOOL);
// LessThanKernel<T, Context>(dev_ctx, max, x, &max_tensor);

// auto max_tensor_dims = common::vectorize<int>(max_tensor.dims());
// if (max_tensor_dims.size() == 0) {
// max_tensor_dims = std::vector<int>({1});
// }

// const bool* max_tensor_data = max_tensor.data<bool>();
// int ret2 = xpu::select(
// dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims);
// PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select");

// }
template <typename T, typename Context>
void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
using XPUDataType = typename XPUTypeTrait<T>::Type;
const XPUDataType* x_data = reinterpret_cast<const XPUDataType*>(x.data<T>());
const XPUDataType* min_data = reinterpret_cast<const XPUDataType*>(min.data<T>());
const XPUDataType* max_data = reinterpret_cast<const XPUDataType*>(max.data<T>());
XPUDataType* out_data = reinterpret_cast<XPUDataType*>(dev_ctx.template Alloc<T>(out));

auto min_dims = common::vectorize<int>(min.dims());
if (min_dims.size() == 0) {
min_dims = std::vector<int>({1});
}
auto max_dims = common::vectorize<int>(max.dims());
if (max_dims.size() == 0) {
max_dims = std::vector<int>({1});
}

DenseTensor min_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, x, min, &min_tensor);

auto min_tensor_dims = common::vectorize<int>(min_tensor.dims());
if (min_tensor_dims.size() == 0) {
min_tensor_dims = std::vector<int>({1});
}

const bool* min_tensor_data = min_tensor.data<bool>();
int ret = xpu::select(
dev_ctx.x_context(), min_tensor_data, min_data, x_data, out_data, min_tensor_dims, min_dims);

PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");

DenseTensor max_tensor(phi::DataType::BOOL);
LessThanKernel<T, Context>(dev_ctx, max, x, &max_tensor);

auto max_tensor_dims = common::vectorize<int>(max_tensor.dims());
if (max_tensor_dims.size() == 0) {
max_tensor_dims = std::vector<int>({1});
}

const bool* max_tensor_data = max_tensor.data<bool>();
int ret2 = xpu::select(
dev_ctx.x_context(), max_tensor_data, max_data, x_data, out_data, max_tensor_dims, max_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret2, "xpu::select");

}

} // namespace phi

Expand All @@ -112,12 +112,12 @@ PD_REGISTER_KERNEL(clip,
int64_t,
int) {}

// PD_REGISTER_KERNEL(clip_tensor,
// XPU,
// ALL_LAYOUT,
// phi::ClipTensorKernel,
// float,
// phi::dtype::float16,
// phi::dtype::bfloat16,
// int64_t,
// int) {}
PD_REGISTER_KERNEL(clip_tensor,
XPU,
ALL_LAYOUT,
phi::ClipTensorKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {}

0 comments on commit 0d8a99c

Please sign in to comment.