Skip to content

Commit

Permalink
[cherry-pick] add pad and concat double grad #29549 (#30432)
Browse files Browse the repository at this point in the history
* add pad and concat double grad

* resolve conflict
  • Loading branch information
ceci3 authored Jan 18, 2021
1 parent de003ce commit 5e4d54a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
16 changes: 16 additions & 0 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};

template <typename T>
class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("concat");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -209,6 +223,8 @@ REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::ConcatOpGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/operators/pad3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,22 @@ class Pad3dOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};

template <typename T>
class Pad3dOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> grad_op) const override {
if (this->HasInput("Paddings")) {
grad_op->SetInput("Paddings", this->Input("Paddings"));
}
grad_op->SetType("pad3d");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad3dOpGradNoNeedBufferVarsInferer, "X");

} // namespace operators
Expand All @@ -904,6 +920,8 @@ REGISTER_OPERATOR(pad3d, ops::Pad3dOp, ops::Pad3dOpMaker,
ops::Pad3dOpGradMaker<paddle::framework::OpDesc>,
ops::Pad3dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad3d_grad, ops::Pad3dOpGrad,
ops::Pad3dOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::Pad3dOpDoubleGradMaker<paddle::imperative::OpBase>,
ops::Pad3dOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(pad3d, ops::Pad3dCPUKernel<float>,
ops::Pad3dCPUKernel<double>, ops::Pad3dCPUKernel<int>,
Expand Down
17 changes: 16 additions & 1 deletion paddle/fluid/operators/pad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,19 @@ class PadOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};

template <typename T>
class PadOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("pad");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -150,7 +163,9 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker,
ops::PadOpGradMaker<paddle::framework::OpDesc>,
ops::PadOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad,
ops::PadOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::PadOpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,70 @@ def test_grad(self):
self.func(p)


class TestConstantPadDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
pad = [1, 1, 1, 1]
eps = 0.005
dtype = np.float64

x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.nn.functional.pad(x, pad)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)

gradient_checker.double_grad_check(
[x], out, x_init=x_arr, place=place, eps=eps)

def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)


class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck):
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
pad = [1, 0, 1, 0, 1, 0, 1, 0]
dtype = np.float64

x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.nn.functional.pad(x, pad)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)

gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)


class TestConcatDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [2, 3, 4, 5]
pad = [1, 1, 1, 1]
dtype = np.float64

x1 = layers.data('x', x_shape, False, dtype)
x2 = layers.data('x', x_shape, False, dtype)
x1.persistable = True
x2.persistable = True
out = paddle.concat([x1, x2], axis=0)
x2_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
x1_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)

gradient_checker.double_grad_check(
[x1, x2], out, x_init=[x1_arr, x2_arr], place=place)

def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5e4d54a

Please sign in to comment.