Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Prim】Add multiply,expand,div vjp rules #49831

Merged
merged 25 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
374b354
support elementwise base func
JiabinYang Jan 10, 2023
2d3cb54
fix compiling error and add test
JiabinYang Jan 10, 2023
488e587
support vjp for div using comp
JiabinYang Jan 11, 2023
18c5742
remove additional change
JiabinYang Jan 11, 2023
25b3262
fix dy2st error with magic num
JiabinYang Jan 11, 2023
8167a5d
fix dy magic num
JiabinYang Jan 11, 2023
18def5a
another magic
JiabinYang Jan 11, 2023
e77540f
another magic
JiabinYang Jan 11, 2023
e7587f3
another magic
JiabinYang Jan 11, 2023
820f56e
add skip rename strategy
JiabinYang Jan 12, 2023
8922413
support add vjp
JiabinYang Jan 12, 2023
30642bb
support add with new axis cal
JiabinYang Jan 12, 2023
b850d1e
support sub vjp
JiabinYang Jan 12, 2023
14113f3
[prim] add multiply vjp rules
cxxly Jan 11, 2023
6bd61aa
[prim] add multiply vjp rules
cxxly Jan 11, 2023
fbe8061
[prim] fix no infershape with composite in _append_backward_ops
cxxly Jan 12, 2023
2cf72cf
[prim] add expand vjp rule
cxxly Jan 13, 2023
3ac18b8
[prim] add exp vjp rule
cxxly Jan 13, 2023
baff012
uncomment infer shape for reshape/sum static prim api
cxxly Jan 13, 2023
4b91d6c
[prim] fix tanh nullptr error
cxxly Jan 13, 2023
e6d3d39
remove some print message
cxxly Jan 13, 2023
ad8545c
fix magic number in run_program relative tests @JiaBinYang
cxxly Jan 13, 2023
cd64be5
[prim] add expand,multiply,exp vjp rules
cxxly Jan 14, 2023
31ec399
fix only support single direction reduce error
cxxly Jan 14, 2023
bbe3480
infer reduce dims using out dims
cxxly Jan 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -63,6 +66,33 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class ElementwiseMulGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;

public:
void Apply() override {
auto x = this->GetSingleForwardInput("X");
auto y = this->GetSingleForwardInput("Y");
auto out_grad = this->GetSingleOutputGrad("Out");
auto x_grad = this->GetSingleInputGrad("X");
auto x_grad_p = this->GetOutputPtr(&x_grad);
auto x_grad_name = this->GetOutputName(x_grad);
auto y_grad = this->GetSingleInputGrad("Y");
auto y_grad_p = this->GetOutputPtr(&y_grad);
auto y_grad_name = this->GetOutputName(y_grad);
prim::multiply_grad<prim::DescTensor>(
x,
y,
out_grad,
static_cast<int>(this->Attr<int>("axis")),
x_grad_p,
y_grad_p);
this->RecoverOutputName(x_grad, x_grad_name);
this->RecoverOutputName(y_grad, y_grad_name);
}
};

template <typename T>
class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down Expand Up @@ -123,7 +153,8 @@ REGISTER_OPERATOR(elementwise_mul,
ops::ElementwiseMulOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseMulOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseMulOpGradMaker<paddle::imperative::OpBase>);
ops::ElementwiseMulOpGradMaker<paddle::imperative::OpBase>,
ops::ElementwiseMulGradCompositeOpMaker);
REGISTER_OPERATOR(
elementwise_mul_grad,
ops::ElementwiseOpGrad,
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/operators/expand_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

Expand Down Expand Up @@ -190,6 +193,23 @@ class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> {
}
};

class ExpandV2GradCompositeOpMaker : public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;

public:
void Apply() override {
auto x = this->GetSingleForwardInput("X");
auto out_grad = this->GetSingleOutputGrad("Out");
auto x_grad = this->GetSingleInputGrad("X");
auto x_grad_p = this->GetOutputPtr(&x_grad);
auto x_grad_name = this->GetOutputName(x_grad);
auto shape = this->Attr<std::vector<int>>("shape");
prim::expand_grad<prim::DescTensor>(
x, out_grad, paddle::experimental::IntArray(shape), x_grad_p);
this->RecoverOutputName(x_grad, x_grad_name);
}
};

template <typename T>
class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down Expand Up @@ -223,6 +243,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(expand_v2,
ops::ExpandV2Op,
ops::ExpandV2OpMaker,
ops::ExpandV2GradCompositeOpMaker,
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>,
ExpandInferShapeFunctor);
Expand Down
186 changes: 147 additions & 39 deletions paddle/fluid/prim/api/manual/backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray =
paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
// using IntArray = paddle::experimental::IntArray;
// This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
}

template <typename T>
void subtract_grad(const Tensor& x,
const Tensor& y,
Expand All @@ -42,25 +43,33 @@ void subtract_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
by_pass<T>(scale_out_grad, dy);
} else {
auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
}
} else {
by_pass<T>(scale_out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
}
} else {
by_pass<T>(out_grad, dx);
}
Expand All @@ -75,25 +84,34 @@ void add_grad(const Tensor& x,
Tensor* dx,
Tensor* dy) {
if (dy) {
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dy);
} else {
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
}

} else {
by_pass<T>(out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
}
} else {
by_pass<T>(out_grad, dx);
}
Expand Down Expand Up @@ -130,9 +148,9 @@ void sum_grad(const Tensor& x,
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, x_dim);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, x_dim);
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}

x_grad->set_impl(x_grad_tmp.impl());
Expand All @@ -152,13 +170,17 @@ void divide_grad(const Tensor& x,
auto tmp1 = divide<T>(x, tmp0);
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = multiply<T>(tmp2, out_grad);
if (phi::product(x.dims()) > phi::product(y.dims())) {
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
dy->set_impl(dy_res.impl());
} else {
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
}
} else {
dy->set_impl(dy_res.impl());
}
Expand All @@ -168,13 +190,18 @@ void divide_grad(const Tensor& x,
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (phi::product(y.dims()) > phi::product(x.dims())) {
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
dx->set_impl(dx_res.impl());
} else {
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
}

} else {
dx->set_impl(dx_res.impl());
}
Expand All @@ -190,5 +217,86 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
x_grad->set_impl(x_grad_tmp.impl());
}
}

template <typename T>
void multiply_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
auto x_grad_unreduce = multiply<T>(out_grad, y);
if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) {
x_grad->set_impl(x_grad_unreduce.impl());
} else {
auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes),
x_grad_unreduce.dtype(),
false);
if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
x_grad->set_impl(x_grad_reduced.impl());
}
} else {
x_grad->set_impl(x_grad_unreduce.impl());
}
}
if (y_grad) {
auto y_grad_unreduce = multiply<T>(out_grad, x);
if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) {
y_grad->set_impl(y_grad_unreduce.impl());
} else {
auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes),
y_grad_unreduce.dtype(),
false);
if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
y_grad->set_impl(y_grad_reduced.impl());
}
} else {
y_grad->set_impl(y_grad_unreduce.impl());
}
}
}

template <typename T>
void expand_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shape,
Tensor* x_grad) {
if (x_grad) {
auto out_dims = phi::make_ddim(shape.GetData());
if (out_dims != x.dims()) {
auto axes = get_reduce_dims(x.dims(), out_dims);
if (!axes.size()) {
by_pass<T>(out_grad, x_grad);
} else {
auto reduced = sum<T>(out_grad, phi::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
x_grad->set_impl(reduced.impl());
}
} else {
by_pass<T>(out_grad, x_grad);
}
}
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
x_grad->set_impl(multiply<T>(out_grad, out).impl());
}
}

} // namespace prim
} // namespace paddle
10 changes: 10 additions & 0 deletions paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,15 @@ template <>
Tensor reshape<Tensor>(Tensor x, IntArray shape) {
return ::reshape_ad_func(x, shape);
}

template <>
Tensor exp<Tensor>(const Tensor& x) {
return ::exp_ad_func(x);
}

template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape) {
return ::expand_ad_func(x, shape);
}
} // namespace prim
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/fluid/prim/api/manual/prim_api/prim_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,11 @@ Tensor sum(Tensor x,

template <typename T>
Tensor reshape(Tensor x, IntArray shape);

template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);

template <typename T>
Tensor exp(const Tensor& x);
} // namespace prim
} // namespace paddle
Loading