From ae1bd9a9fc92ccfe1dc34be7ad5e4c659907e44e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 17 May 2024 10:17:10 +0800 Subject: [PATCH] [Prim]Support mean_grad decompose in vjp (#64346) * [Prim]Support mean_grad decompose in vjp * fix compile * refine code --- paddle/fluid/primitive/codegen/gen.py | 1 + paddle/fluid/primitive/rule/vjp/details.h | 47 ++++++++++ paddle/fluid/primitive/utils/utils.h | 25 ++++- test/legacy_test/test_mean_op.py | 108 +++++++++++++++++++++- 4 files changed, 173 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 60131dda70b10..ee1a2e63cb5e7 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -119,6 +119,7 @@ 'instance_norm_grad', 'layer_norm_grad', 'leaky_relu_grad', + 'mean_grad', 'minimum_grad', 'pow_grad', 'relu_grad', diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 59c031952ee7f..bc460be6f832a 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -154,6 +154,53 @@ void sum_grad(const Tensor& x, set_output(x_grad_tmp, x_grad); } +template +void mean_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + Tensor x_grad_tmp; + sum_grad(x, out_grad, axis, keepdim, reduce_all, &x_grad_tmp); + + Tensor div_factor = [&] { + Tensor factor_tensor; + auto axis_data = axis.GetData(); + const std::vector x_dim = x.shape(); + if (axis.size() == 0) { + for (size_t i = 0; i < x_dim.size(); ++i) { + axis_data.push_back(i); + } + } + if (has_dynamic_shape(x_dim, axis_data)) { + auto x_shape = shape(x); + factor_tensor = + slice(x_shape, {0}, {axis_data[0]}, {axis_data[0] + 1}, {1}, {0}); + for (size_t i = 1; i < axis_data.size(); ++i) { + factor_tensor = + factor_tensor * + slice( + x_shape, {0}, {axis_data[i]}, {axis_data[i] + 1}, {1}, {0}); + } + factor_tensor = cast(factor_tensor, x.dtype()); + } else { + int64_t factor = 1; + for (int64_t idx : axis_data) { + if (idx < 0) idx += x_dim.size(); + factor *= x_dim[idx]; + } + factor_tensor = full(std::vector{}, factor, x.dtype()); + } + return factor_tensor; + }(); + + set_output(x_grad_tmp / div_factor, x_grad); +} + template void gelu_grad(const Tensor& x, const Tensor& out_grad, diff --git a/paddle/fluid/primitive/utils/utils.h b/paddle/fluid/primitive/utils/utils.h index 42f1533db723e..a836e27b585b3 100644 --- a/paddle/fluid/primitive/utils/utils.h +++ b/paddle/fluid/primitive/utils/utils.h @@ -180,12 +180,27 @@ static bool find_value(const std::vector& vec, int64_t value) { } } -static bool has_dynamic_shape(const std::vector& vec) { - if (std::find(vec.begin(), vec.end(), -1) != vec.end()) { - return true; - } else { - return false; +static bool has_dynamic_shape(const std::vector& shape) { + return std::find(shape.begin(), shape.end(), -1) != shape.end(); +} + +static bool has_dynamic_shape(const std::vector& shape, + const std::vector& axis) { + bool flag = false; + const int64_t rank = shape.size(); + for (int64_t idx : axis) { + if (idx < 0) idx += rank; + PADDLE_ENFORCE_LT( + idx, + rank, + ::common::errors::PreconditionNotMet( + "Required idx < shape.size(), but received %d.", idx)); + if (shape[idx] == -1) { + flag = true; + break; + } } + return flag; } } // namespace primitive diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 2de8cd082b8fa..a82173c90eb2c 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -44,11 +44,16 @@ class TestMeanOp(OpTest): def setUp(self): self.op_type = "mean" self.python_api = paddle.mean + self.public_python_api = paddle.mean self.dtype = np.float64 self.init_dtype_type() + self.init_prim_type() self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)} self.outputs = {'Out': np.mean(self.inputs["X"])} + def init_prim_type(self): + self.prim_op_type = "comp" + def init_dtype_type(self): pass @@ -56,7 +61,12 @@ def test_check_output(self): self.check_output(check_pir=True) def test_checkout_grad(self): - self.check_grad(['X'], 'Out', check_pir=True) + self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True) + + +class TestMeanOpPrim(TestMeanOp): + def init_prim_type(self): + self.prim_op_type = "prim" class TestMeanOp_ZeroDim(OpTest): @@ -64,14 +74,24 @@ def setUp(self): self.op_type = "mean" self.python_api = paddle.mean self.dtype = np.float64 + self.public_python_api = paddle.mean + self.init_prim_type() self.inputs = {'X': np.random.random([]).astype(self.dtype)} self.outputs = {'Out': np.mean(self.inputs["X"])} + def init_prim_type(self): + self.prim_op_type = "comp" + def test_check_output(self): self.check_output(check_pir=True) def test_checkout_grad(self): - self.check_grad(['X'], 'Out', check_pir=True) + self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True) + + +class TestMeanOp_ZeroDim_Prim(TestMeanOp_ZeroDim): + def init_prim_type(self): + self.prim_op_type = "prim" class TestMeanOpError(unittest.TestCase): @@ -161,7 +181,7 @@ def setUp(self): self.op_type = 'reduce_mean' self.python_api = reduce_mean_wrapper self.public_python_api = reduce_mean_wrapper - self.prim_op_type = "comp" + self.init_prim_type() self.dtype = 'float64' self.init_shapes() self.axis = [0] @@ -186,6 +206,9 @@ def setUp(self): 'reduce_all': self.reduce_all, } + def init_prim_type(self): + self.prim_op_type = "comp" + def init_shapes(self): self.shape = [2, 3, 4, 5] @@ -231,6 +254,43 @@ def test_check_grad(self): ) +class TestReduceMeanOpPrim(TestReduceMeanOp): + def init_prim_type(self): + self.prim_op_type = "prim" + + @test_with_pir_api + def test_check_output(self): + if self.dtype != 'float16': + self.check_output(check_prim_pir=True, check_pir=True) + else: + place = paddle.CUDAPlace(0) + self.check_output_with_place( + place=place, + check_prim_pir=True, + check_pir=True, + ) + + @test_with_pir_api + def test_check_grad(self): + if self.dtype != 'float16': + self.check_grad( + ['X'], + ['Out'], + check_prim_pir=True, + check_pir=True, + ) + else: + place = paddle.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + ['Out'], + numeric_grad_delta=0.5, + check_prim_pir=True, + check_pir=True, + ) + + class TestReduceMeanOp_ZeroDim(TestReduceMeanOp): def init_shapes(self): self.shape = [] @@ -306,16 +366,41 @@ def setUp(self): self.outputs = {'Out': out_np} +class TestReduceMeanOpDefaultAttrsForPrim(TestReduceMeanOpPrim): + def setUp(self): + self.op_type = 'reduce_mean' + self.python_api = reduce_mean_wrapper + self.public_python_api = reduce_mean_wrapper + self.init_prim_type() + self.dtype = 'float64' + self.shape = [2, 3, 4, 5] + + x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + out_np = np.mean(x_np, axis=0) + self.inputs = {'X': x_np} + self.outputs = {'Out': out_np} + + class TestReduceMeanOpFloat32(TestReduceMeanOp): def set_attrs(self): self.dtype = 'float32' +class TestReduceMeanOpFloat32Prim(TestReduceMeanOpPrim): + def set_attrs(self): + self.dtype = 'float32' + + class TestReduceMeanOpFloat16(TestReduceMeanOp): def set_attrs(self): self.dtype = 'float16' +class TestReduceMeanOpFloat16Prim(TestReduceMeanOpPrim): + def set_attrs(self): + self.dtype = 'float16' + + class TestReduceMeanOpShape1D(TestReduceMeanOp): def set_attrs(self): self.shape = [100] @@ -348,12 +433,23 @@ def set_attrs(self): self.axis = [0, 1, 2, 3] +class TestReduceMeanOpAxisAllPrim(TestReduceMeanOpPrim): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + + class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp): def set_attrs(self): self.axis = [0, 1, 2, 3] self.dtype = 'float16' +class TestReduceMeanOpAxisAllFP16Prim(TestReduceMeanOpPrim): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + self.dtype = 'float16' + + class TestReduceMeanOpAxisAllBF16(TestReduceMeanBF16Op): def set_attrs(self): self.axis = [0, 1, 2, 3] @@ -386,6 +482,12 @@ def set_attrs(self): self.dtype = 'float16' +class TestReduceMeanOpAxisNegativeFP16Prim(TestReduceMeanOpPrim): + def set_attrs(self): + self.axis = [-2, -1] + self.dtype = 'float16' + + class TestReduceMeanOpAxisNegativeBF16(TestReduceMeanBF16Op): def set_attrs(self): self.axis = [-2, -1]