Skip to content

Commit

Permalink
[Prim]Support mean_grad decompose in vjp (#64346)
Browse files Browse the repository at this point in the history
* [Prim]Support mean_grad decompose in vjp

* fix compile

* refine code
  • Loading branch information
Aurelius84 authored May 17, 2024
1 parent 8885339 commit ae1bd9a
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 8 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
'instance_norm_grad',
'layer_norm_grad',
'leaky_relu_grad',
'mean_grad',
'minimum_grad',
'pow_grad',
'relu_grad',
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,53 @@ void sum_grad(const Tensor& x,
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void mean_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
Tensor x_grad_tmp;
sum_grad<T>(x, out_grad, axis, keepdim, reduce_all, &x_grad_tmp);

Tensor div_factor = [&] {
Tensor factor_tensor;
auto axis_data = axis.GetData();
const std::vector<int64_t> x_dim = x.shape();
if (axis.size() == 0) {
for (size_t i = 0; i < x_dim.size(); ++i) {
axis_data.push_back(i);
}
}
if (has_dynamic_shape(x_dim, axis_data)) {
auto x_shape = shape<T>(x);
factor_tensor =
slice<T>(x_shape, {0}, {axis_data[0]}, {axis_data[0] + 1}, {1}, {0});
for (size_t i = 1; i < axis_data.size(); ++i) {
factor_tensor =
factor_tensor *
slice<T>(
x_shape, {0}, {axis_data[i]}, {axis_data[i] + 1}, {1}, {0});
}
factor_tensor = cast<T>(factor_tensor, x.dtype());
} else {
int64_t factor = 1;
for (int64_t idx : axis_data) {
if (idx < 0) idx += x_dim.size();
factor *= x_dim[idx];
}
factor_tensor = full<T>(std::vector<int64_t>{}, factor, x.dtype());
}
return factor_tensor;
}();

set_output<T>(x_grad_tmp / div_factor, x_grad);
}

template <typename T>
void gelu_grad(const Tensor& x,
const Tensor& out_grad,
Expand Down
25 changes: 20 additions & 5 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,27 @@ static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
}
}

static bool has_dynamic_shape(const std::vector<int64_t>& vec) {
if (std::find(vec.begin(), vec.end(), -1) != vec.end()) {
return true;
} else {
return false;
static bool has_dynamic_shape(const std::vector<int64_t>& shape) {
return std::find(shape.begin(), shape.end(), -1) != shape.end();
}

static bool has_dynamic_shape(const std::vector<int64_t>& shape,
const std::vector<int64_t>& axis) {
bool flag = false;
const int64_t rank = shape.size();
for (int64_t idx : axis) {
if (idx < 0) idx += rank;
PADDLE_ENFORCE_LT(
idx,
rank,
::common::errors::PreconditionNotMet(
"Required idx < shape.size(), but received %d.", idx));
if (shape[idx] == -1) {
flag = true;
break;
}
}
return flag;
}

} // namespace primitive
Expand Down
108 changes: 105 additions & 3 deletions test/legacy_test/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,54 @@ class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.public_python_api = paddle.mean
self.dtype = np.float64
self.init_dtype_type()
self.init_prim_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}

def init_prim_type(self):
self.prim_op_type = "comp"

def init_dtype_type(self):
pass

def test_check_output(self):
self.check_output(check_pir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestMeanOpPrim(TestMeanOp):
def init_prim_type(self):
self.prim_op_type = "prim"


class TestMeanOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.dtype = np.float64
self.public_python_api = paddle.mean
self.init_prim_type()
self.inputs = {'X': np.random.random([]).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}

def init_prim_type(self):
self.prim_op_type = "comp"

def test_check_output(self):
self.check_output(check_pir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestMeanOp_ZeroDim_Prim(TestMeanOp_ZeroDim):
def init_prim_type(self):
self.prim_op_type = "prim"


class TestMeanOpError(unittest.TestCase):
Expand Down Expand Up @@ -161,7 +181,7 @@ def setUp(self):
self.op_type = 'reduce_mean'
self.python_api = reduce_mean_wrapper
self.public_python_api = reduce_mean_wrapper
self.prim_op_type = "comp"
self.init_prim_type()
self.dtype = 'float64'
self.init_shapes()
self.axis = [0]
Expand All @@ -186,6 +206,9 @@ def setUp(self):
'reduce_all': self.reduce_all,
}

def init_prim_type(self):
self.prim_op_type = "comp"

def init_shapes(self):
self.shape = [2, 3, 4, 5]

Expand Down Expand Up @@ -231,6 +254,43 @@ def test_check_grad(self):
)


class TestReduceMeanOpPrim(TestReduceMeanOp):
def init_prim_type(self):
self.prim_op_type = "prim"

@test_with_pir_api
def test_check_output(self):
if self.dtype != 'float16':
self.check_output(check_prim_pir=True, check_pir=True)
else:
place = paddle.CUDAPlace(0)
self.check_output_with_place(
place=place,
check_prim_pir=True,
check_pir=True,
)

@test_with_pir_api
def test_check_grad(self):
if self.dtype != 'float16':
self.check_grad(
['X'],
['Out'],
check_prim_pir=True,
check_pir=True,
)
else:
place = paddle.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
['Out'],
numeric_grad_delta=0.5,
check_prim_pir=True,
check_pir=True,
)


class TestReduceMeanOp_ZeroDim(TestReduceMeanOp):
def init_shapes(self):
self.shape = []
Expand Down Expand Up @@ -306,16 +366,41 @@ def setUp(self):
self.outputs = {'Out': out_np}


class TestReduceMeanOpDefaultAttrsForPrim(TestReduceMeanOpPrim):
def setUp(self):
self.op_type = 'reduce_mean'
self.python_api = reduce_mean_wrapper
self.public_python_api = reduce_mean_wrapper
self.init_prim_type()
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]

x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = np.mean(x_np, axis=0)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}


class TestReduceMeanOpFloat32(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpFloat32Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpFloat16(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpFloat16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpShape1D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [100]
Expand Down Expand Up @@ -348,12 +433,23 @@ def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisAllPrim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.dtype = 'float16'


class TestReduceMeanOpAxisAllFP16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.dtype = 'float16'


class TestReduceMeanOpAxisAllBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
Expand Down Expand Up @@ -386,6 +482,12 @@ def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpAxisNegativeFP16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [-2, -1]
self.dtype = 'float16'


class TestReduceMeanOpAxisNegativeBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [-2, -1]
Expand Down

0 comments on commit ae1bd9a

Please sign in to comment.