Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【prim】Matmul double grad composite api #50452

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d636d72
modify name
xiaoguoguo626807 Jan 18, 2023
184fa04
fix_conflict
xiaoguoguo626807 Jan 18, 2023
127d266
merge develop
xiaoguoguo626807 Jan 18, 2023
3ebc0f7
fix_conflict
xiaoguoguo626807 Jan 30, 2023
dd5f33e
original code
xiaoguoguo626807 Feb 1, 2023
90d5165
Merge branch 'matmul_double_grad' of https://github.com/xiaoguoguo626…
xiaoguoguo626807 Feb 1, 2023
35ebc01
build modify
xiaoguoguo626807 Feb 2, 2023
da60ca4
success 2*2
xiaoguoguo626807 Feb 3, 2023
34f633d
fused dim=1 failed
xiaoguoguo626807 Feb 3, 2023
9ae2915
success
xiaoguoguo626807 Feb 8, 2023
277da45
modify static
xiaoguoguo626807 Feb 9, 2023
8e659ec
fix_conflict
xiaoguoguo626807 Feb 9, 2023
37f7664
success for static except dim=1
xiaoguoguo626807 Feb 13, 2023
6f41b26
fix_conflict
xiaoguoguo626807 Feb 13, 2023
3e0ae31
delete log
xiaoguoguo626807 Feb 13, 2023
04d0d3b
tmp modify
xiaoguoguo626807 Feb 14, 2023
201e865
success
xiaoguoguo626807 Feb 15, 2023
9745b06
success
xiaoguoguo626807 Feb 15, 2023
d8d3625
add fp1664
xiaoguoguo626807 Feb 17, 2023
651c486
fix_conflict
xiaoguoguo626807 Feb 17, 2023
8f86a00
delete fp16 cpu test
xiaoguoguo626807 Feb 20, 2023
c9e35a7
stop windows test
xiaoguoguo626807 Feb 21, 2023
2c9e691
review modify
xiaoguoguo626807 Feb 22, 2023
5faf2cb
fix_conflixt
xiaoguoguo626807 Feb 23, 2023
d93ad7f
Merge branch 'develop' into matmul_double_grad
xiaoguoguo626807 Feb 23, 2023
883ec48
modify tanh test
xiaoguoguo626807 Feb 24, 2023
cf766e1
modify tanh
xiaoguoguo626807 Feb 24, 2023
d7bd49a
Merge branch 'develop' into matmul_double_grad
xiaoguoguo626807 Feb 24, 2023
cb28a07
fix_conflict
xiaoguoguo626807 Feb 24, 2023
313c4bc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Feb 24, 2023
dba7cae
fix_conflixt
xiaoguoguo626807 Feb 24, 2023
16aec14
modift static prim
xiaoguoguo626807 Feb 24, 2023
17717c2
Merge branch 'develop' into matmul_double_grad
xiaoguoguo626807 Feb 27, 2023
dac2764
fix_conflict
xiaoguoguo626807 Feb 28, 2023
61ea077
fix_conflict
xiaoguoguo626807 Feb 28, 2023
8cd6842
Update test_static_prim.cc
xiaoguoguo626807 Feb 28, 2023
93a8143
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Feb 28, 2023
1a1270b
update
xiaoguoguo626807 Feb 28, 2023
4bf2af3
fix_conflict
xiaoguoguo626807 Feb 28, 2023
e97e738
bug fix
xiaoguoguo626807 Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions paddle/fluid/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"

Expand Down Expand Up @@ -246,6 +247,51 @@ class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
}
};

class MatMulCompositeDoubleGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor dout =
this->GetSingleForwardInput(framework::GradVarName("Out"));
paddle::optional<paddle::experimental::Tensor> ddx =
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
paddle::optional<paddle::experimental::Tensor> ddy =
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));

// get attr
bool trans_x = this->Attr<bool>("trans_x");
bool trans_y = this->Attr<bool>("trans_y");

// get output
paddle::experimental::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::experimental::Tensor y_grad_t = this->GetSingleInputGrad("Y");
paddle::experimental::Tensor grad_out_grad_t =
this->GetSingleInputGrad(framework::GradVarName("Out"));

// get output ptr
paddle::experimental::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
paddle::experimental::Tensor* y_grad = this->GetOutputPtr(&y_grad_t);
paddle::experimental::Tensor* grad_out_grad =
this->GetOutputPtr(&grad_out_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
std::string y_grad_name = this->GetOutputName(y_grad_t);
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
VLOG(3) << "Runing matmul_double_grad composite func";
// call composite backward func
prim::matmul_double_grad<prim::DescTensor>(
x, y, dout, ddx, ddy, trans_x, trans_y, x_grad, y_grad, grad_out_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
this->RecoverOutputName(y_grad_t, y_grad_name);
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
}
};

class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
Expand Down Expand Up @@ -335,6 +381,7 @@ REGISTER_OPERATOR(matmul_v2_grad,
ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
ops::MatMulCompositeDoubleGradOpMaker,
MatMulV2GradInferShapeFunctor);

REGISTER_OPERATOR(matmul_v2_grad_grad,
Expand Down
Loading