Skip to content

Commit

Permalink
update grad
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 11, 2024
1 parent aae8c04 commit 94f71c6
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 10 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/clip_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void ClipGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& ctx,
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ void ClipWithTensorKernel(const Context& ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
const T* x_data = x.data<bool>();
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto x_numel = x.numel();

T* out_data = ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x;
out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i];
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/clip_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class ClipWithTensorGradFunctor {
template <typename T, typename Context>
void ClipWithTensorGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {

std::vector<const DenseTensor*> ins = {&out_grad, &x, &min, &max};
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ void ClipKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r]));
}

template <typename T, typename Context>
void ClipWithTensorKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
using XPUDataType = typename XPUTypeTrait<T>::Type;

int ret = xpu::select(
ctx.x_context(), cond_data, min_data, x_data, out_data, cond_dims, x_dims);

PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");
}

} // namespace phi

PD_REGISTER_KERNEL(clip,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@
inplace : (out_grad -> x_grad)

- backward_op : clipwithtensor_double_grad
forward : clipwithtensor_grad (Tensor x, Tensor grad_out, Tensor min, Tensor max) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, Tensor min, Tensor max)
forward : clipwithtensor_grad (Tensor x, Tensor min, Tensor max, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor min, Tensor max, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
Expand All @@ -415,7 +415,7 @@

- backward_op : clipwithtensor_grad
forward : clipwithtensor (Tensor x, Tensor min, Tensor max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Tensor min, Tensor)
args : (Tensor x, Tensor min, Tensor max, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,6 @@
max : Max
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

- op : clip_by_norm
inputs :
Expand Down

0 comments on commit 94f71c6

Please sign in to comment.