Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] Add prod backward composite rule #51238

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a320174
first commit
rainyfly Mar 6, 2023
65642d1
add registry
rainyfly Mar 6, 2023
0495719
add unit test
rainyfly Mar 6, 2023
b88ca25
fix format
rainyfly Mar 6, 2023
d08874f
add unit test
rainyfly Mar 6, 2023
fee42de
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 9, 2023
bc6282e
fix bug
rainyfly Mar 9, 2023
e22b142
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 9, 2023
6d59c31
replace unsuqeeze to reshape
rainyfly Mar 9, 2023
d7a432a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 10, 2023
6978193
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 10, 2023
e5960aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 10, 2023
0d6e248
fix
rainyfly Mar 10, 2023
e9a952a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 10, 2023
4b859d3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 13, 2023
8605ea2
fix unit test
rainyfly Mar 13, 2023
d460ef8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
ef883ed
update test
rainyfly Mar 22, 2023
89f3c74
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
0c0ef21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
905eb55
update test
rainyfly Mar 22, 2023
af03ba4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
13812dd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
c2deb9d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 22, 2023
5f15bf0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 23, 2023
97a5074
fix unit test
rainyfly Mar 23, 2023
5ab58b9
fix
rainyfly Mar 23, 2023
8456de7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
rainyfly Mar 28, 2023
fefccc3
fix
rainyfly Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_prod_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
// limitations under the License.

#include "paddle/fluid/operators/reduce_ops/reduce_prod_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
Expand All @@ -27,6 +30,42 @@ class OpBase;
} // namespace imperative
} // namespace paddle

namespace paddle {
namespace operators {
class ReduceProdCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out = this->GetSingleForwardOutput("Out");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");

// get attr
std::vector<int> axis = this->Attr<std::vector<int>>("dim");
bool keep_dim = this->Attr<bool>("keep_dim");
bool reduce_all = this->Attr<bool>("reduce_all");

// get output
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");

// get output ptr
auto x_grad = this->GetOutputPtr(&x_grad_t);

// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
VLOG(6) << "Runing prod_grad composite func";
// call composite backward func
prim::prod_grad<prim::DescTensor>(
x, out, out_grad, axis, keep_dim, reduce_all, x_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

class ReduceProdOpMaker : public ops::ReduceOpMaker {
Expand All @@ -46,5 +85,6 @@ REGISTER_OPERATOR(
ReduceProdOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::ReduceProdCompositeGradOpMaker,
ReduceProdInferShapeFunctor);
REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp);
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,62 @@ void gather_nd_grad(const Tensor& x,
}
}

template <typename T>
void prod_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
Tensor* x_grad) {
if (x_grad) {
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
auto out_tmp = Tensor();
if (x_dim_size == 1) {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto reshape_dim = std::vector<int64_t>(x_dim);
for (auto i : axis_) {
reshape_dim[i] = 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use unsqueezeapi replace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto out_grad_ = reshape<T>(out_grad, reshape_dim);
auto out_ = reshape<T>(out, reshape_dim);
x_grad_tmp = out_grad_.expand(IntArray(x_dim));
out_tmp = out_.expand(IntArray(x_dim));
} else {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
}
}
auto x_grad_res = x_grad_tmp * out_tmp * (1 / x);
set_output<T>(x_grad_res, x_grad);
}
}

template <typename T>
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@
param : [x]
kernel :
func : prod_grad
composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad)

- backward_op : psroi_pool_grad
forward : psroi_pool (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale) -> Tensor(out)
Expand Down
31 changes: 20 additions & 11 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ class TestProdOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod
self.prim_op_type = "prim"

self.init_data_type()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)}
self.outputs = {'Out': self.inputs['X'].prod(axis=0)}
Expand All @@ -365,16 +367,20 @@ def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)


class TestProdOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.prod
self.op_type = "reduce_prod"
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}
# reduce doesn't support float64 in cinn.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it raise error when test fp64 with other test(not 0D tensor) in cinn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp64 added in test

# 0-D tensor doesn't support in cinn
self.enable_cinn = False

def test_check_output(self):
self.check_output(check_eager=True)
Expand All @@ -387,6 +393,7 @@ class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod
self.prim_op_type = "prim"
self.init_data_type()
self.inputs = {
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirm all dtypes has been tested (FP16, FP32,Fp64)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP16 is not supported, Fp64 has been added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

591602a47371af7ac4432f53a88fe2b2

Expand All @@ -405,13 +412,14 @@ def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)


class TestProd8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
self.python_api = raw_reduce_prod
self.prim_op_type = "prim"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if don't test prim delete this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.init_data_type()
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype(
Expand Down Expand Up @@ -1036,15 +1044,16 @@ def test_check_grad(self):

class TestReduceSumOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of reduce_sum_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, paddle.sum, x1)
# The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64.
x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="uint8")
self.assertRaises(TypeError, paddle.sum, x2)
with paddle.fluid.framework._static_guard():
with program_guard(Program(), Program()):
# The input type of reduce_sum_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, paddle.sum, x1)
# The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64.
x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="uint8")
self.assertRaises(TypeError, paddle.sum, x2)


class API_TestSumOp(unittest.TestCase):
Expand Down