diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 21a06e5257acd..bb69ef734ac7b 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -23,6 +23,10 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" + #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/infermeta/multiary.h" @@ -534,6 +538,70 @@ phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType( ctx.GetPlace()); } +class BatchNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // inputs and outputs of batch_norm + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor scale = this->GetSingleForwardInput("Scale"); + paddle::Tensor bias = this->GetSingleForwardInput("Bias"); + paddle::Tensor mean = this->GetSingleForwardInput("Mean"); + paddle::Tensor variance = this->GetSingleForwardInput("Variance"); + paddle::Tensor y = this->GetSingleForwardOutput("Y"); + paddle::Tensor mean_out = this->GetSingleForwardOutput("MeanOut"); + paddle::Tensor variance_out = this->GetSingleForwardOutput("VarianceOut"); + paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean"); + paddle::Tensor saved_variance = + this->GetSingleForwardOutput("SavedVariance"); + paddle::optional reserve_space; + + paddle::Tensor y_grad = this->GetSingleOutputGrad("Y"); + paddle::Tensor x_grad = this->GetSingleInputGrad("X"); + paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale"); + paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias"); + + auto dx_ptr = this->GetOutputPtr(&x_grad); + std::string dx_name = this->GetOutputName(x_grad); + auto dscale_ptr = this->GetOutputPtr(&scale_grad); + std::string dscale_name = this->GetOutputName(scale_grad); + auto dbias_ptr = this->GetOutputPtr(&bias_grad); + std::string dbias_name = this->GetOutputName(bias_grad); + + // attrs of batch_norm + auto momentum = this->Attr("momentum"); + auto epsilon = this->Attr("epsilon"); + auto data_layout = this->Attr("data_layout"); + auto is_test = this->Attr("is_test"); + auto use_global_stats = this->Attr("use_global_stats"); + auto trainable_statistics = this->Attr("trainable_statistics"); + + VLOG(3) << "Runing batch_norm composite func"; + prim::batch_norm_grad(x, + scale, + bias, + mean_out, + variance_out, + saved_mean, + saved_variance, + reserve_space, + y_grad, + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + dx_ptr, + dscale_ptr, + dbias_ptr); + this->RecoverOutputName(x_grad, dx_name); + this->RecoverOutputName(scale_grad, dscale_name); + this->RecoverOutputName(bias_grad, dbias_name); + } +}; + DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); } // namespace operators @@ -550,7 +618,8 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOpMaker, ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, - ops::BatchNormGradMaker); + ops::BatchNormGradMaker, + ops::BatchNormCompositeGradOpMaker); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index da1daac8b8873..63a677fd2b99a 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -982,5 +982,168 @@ void dropout_grad(const Tensor& mask, } } } + +template +void batch_norm_grad(const Tensor& x, + const Tensor& scale, + const Tensor& bias, + const paddle::optional& mean_out, + const paddle::optional& variance_out, + const Tensor& saved_mean, + const Tensor& saved_variance, + const paddle::optional& reserve_space, + const Tensor& out_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + use_global_stats = is_test || use_global_stats; + + DataLayout data_layout_ = phi::StringToDataLayout(data_layout); + + Tensor x_data = x; + Tensor out_grad_data = out_grad; + if (x.dtype() == phi::DataType::FLOAT16) { + x_data = cast(x, phi::DataType::FLOAT32); + } + if (out_grad.dtype() == phi::DataType::FLOAT16) { + out_grad_data = cast(out_grad, phi::DataType::FLOAT32); + } + auto x_dims = x_data.dims(); + const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + int nume = 1; + for (auto i = 0; i < x_dims.size(); i++) { + nume = nume * x_dims[i]; + } + + const int nhw = nume / C; + + if (x_dims.size() == 2 && data_layout_ == DataLayout::kNCHW) { + data_layout_ = DataLayout::kNHWC; + } + + auto run_var = variance_out.get(); + auto run_mean = mean_out.get(); + + Tensor mean_data; + Tensor rsqrt_var; + + if (use_global_stats) { + auto eps = + full(phi::vectorize(run_var.dims()), epsilon, run_var.dtype()); + mean_data = run_mean; + rsqrt_var = 1 / (run_var + eps).pow(0.5); + } else { + mean_data = saved_mean; + rsqrt_var = saved_variance; + } + + // inv_var = 1 / sqrt(var + eps) + // reduce_axis = [0, 2, 3] (NCHW) [0, 1, 2] (NHWC) + // + // d_bias = np.sum(d_y, reduce_axis) + // d_scale = np.sum((X - mean) / inv_var * dy, reduce_axis) + // + // train mode + // d_x = (1. / nhw) * scale * inv_var + // *(nhw * d_y - np.sum(d_y, reduce_axis) - (X - mean) * inv_var * inv_var * + // np.sum(d_y * (X - mean), reduce_axis)) + // + // test mode + // d_x = d_y * scale * inv_var + + std::vector nchw_to_nhwc_dim = {0, 2, 3, 1}; + std::vector nhwc_to_nchw_dim = {0, 3, 1, 2}; + auto reduce_axis = IntArray(std::vector{0, 1, 2}); + auto dtype = x_data.dtype(); + + switch (data_layout_) { + case DataLayout::kNCHW: { + auto nhwc_x = transpose(x_data, nchw_to_nhwc_dim); + auto nhwc_out_grad = transpose(out_grad_data, nchw_to_nhwc_dim); + + auto x_sub_mean = nhwc_x - mean_data; + + if (x_grad) { + if (use_global_stats) { + auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; + auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); + set_output(nchw_x_grad, x_grad); + } else { + auto part1 = scale * rsqrt_var; + auto mean_temp1 = + sum(nhwc_out_grad, reduce_axis, dtype, false) / nhw; + + auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw; + auto mean_temp2 = sum(tmp, reduce_axis, dtype, false); + auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; + + auto x_grad_data = part1 * part2; + auto nchw_x_grad = transpose(x_grad_data, nhwc_to_nchw_dim); + if (x.dtype() == phi::DataType::FLOAT16) { + nchw_x_grad = cast(nchw_x_grad, x.dtype()); + } + set_output(nchw_x_grad, x_grad); + } + } + if (scale_grad) { + auto scale_grad_data = sum( + nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false); + set_output(scale_grad_data, scale_grad); + } + if (bias_grad) { + auto bias_grad_data = sum(nhwc_out_grad, reduce_axis, dtype, false); + set_output(bias_grad_data, bias_grad); + } + break; + } + case DataLayout::kNHWC: { + if (x_grad) { + auto x_sub_mean = x_data - mean_data; + if (use_global_stats) { + auto x_grad_data = scale * rsqrt_var * out_grad_data; + set_output(x_grad_data, x_grad); + } else { + auto part1 = scale * rsqrt_var; + auto mean_temp1 = + sum(out_grad_data, reduce_axis, dtype, false) / nhw; + + auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw; + auto mean_temp2 = sum(tmp, reduce_axis, dtype, false); + auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2; + + auto x_grad_data = part1 * part2; + if (x.dtype() == phi::DataType::FLOAT16) { + x_grad_data = cast(x_grad_data, x.dtype()); + } + set_output(x_grad_data, x_grad); + } + if (scale_grad) { + auto scale_grad_data = sum(out_grad_data * x_sub_mean * rsqrt_var, + reduce_axis, + dtype, + false); + set_output(scale_grad_data, scale_grad); + } + if (bias_grad) { + auto bias_grad_data = + sum(out_grad_data, reduce_axis, dtype, false); + set_output(bias_grad_data, bias_grad); + } + break; + } + } + default: + PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %s", + data_layout)); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f1978a6c970dd..d93a3544d9273 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -130,6 +130,7 @@ func : batch_norm_grad data_type : out_grad optional : mean_out, variance_out, reserve_space + composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics) backward : batch_norm_double_grad - backward_op : bce_loss_grad diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py index 5bbeba860f590..5c79f8882619a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py @@ -427,6 +427,7 @@ def test_resnet(self): def test_resnet_composite(self): core._set_prim_backward_enabled(True) + core._add_skip_comp_ops("batch_norm") static_loss = self.train(to_static=True) core._set_prim_backward_enabled(False) dygraph_loss = self.train(to_static=False) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py index 2c5bc6f72e262..ed6dcd6a6823e 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py @@ -137,73 +137,97 @@ def expect_forward( ) +def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x1 = paddle.static.data( + 'x1', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x2 = paddle.static.data( + 'x2', shape=running_mean.shape, dtype=str(running_mean.dtype) + ) + x3 = paddle.static.data( + 'x3', + shape=running_variance.shape, + dtype=str(running_variance.dtype), + ) + x4 = paddle.static.data( + 'x4', shape=weight.shape, dtype=str(weight.dtype) + ) + x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype)) + y = fn( + x1, + x2, + x3, + x4, + x5, + attrs.training, + attrs.momentum, + attrs.epsilon, + attrs.data_format, + attrs.use_global_stats, + ) + blocks = main_program.blocks + + names = dict( + zip( + blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names + ) + ) + vars_list = [ + names[key] + for key in [ + "Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + ] + ] + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that batch_norm in original block + assert 'batch_norm' in fwd_ops + + if mode: + primapi.to_prim(blocks) + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that batch_norm is splitted into small ops + assert 'batch_norm' not in fwd_ops_new + + exe = paddle.static.Executor() + exe.run(startup_program) + + # indeed SavedVariance is 1/sqrt(batch_var+eps) + Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run( + main_program, + feed={ + 'x1': inputs, + 'x2': running_mean, + 'x3': running_variance, + 'x4': weight, + 'x5': bias, + }, + fetch_list=vars_list, + ) + paddle.disable_static() + core._set_prim_all_enabled(False) + + return Y, MeanOut, VarianceOut, SavedMean, SavedVariance + + class TestCompositeBatchNorm(unittest.TestCase): def setUp(self): self.dtypes = ["float32", "float64"] self.training = [False, True] - self.shapes = [[8, 8, 16, 16], [2, 1, 2, 3]] + self.shapes = [[8, 8, 16, 16], [2, 3, 4, 4]] self.momentum = [0.1, 0.9] self.data_formats = ["NCHW", "NHWC"] self.use_global_stats = [None, True, False] - def cal_composite( - self, inputs, running_mean, running_variance, weight, bias - ): - paddle.enable_static() - core._set_prim_all_enabled(True) - startup_program = paddle.static.Program() - main_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - x1 = paddle.static.data( - 'x1', shape=inputs.shape, dtype=str(inputs.dtype) - ) - x2 = paddle.static.data( - 'x2', shape=running_mean.shape, dtype=str(running_mean.dtype) - ) - x3 = paddle.static.data( - 'x3', - shape=running_variance.shape, - dtype=str(running_variance.dtype), - ) - x4 = paddle.static.data( - 'x4', shape=weight.shape, dtype=str(weight.dtype) - ) - x5 = paddle.static.data( - 'x5', shape=bias.shape, dtype=str(bias.dtype) - ) - y = fn( - x1, - x2, - x3, - x4, - x5, - attrs.training, - attrs.momentum, - attrs.epsilon, - attrs.data_format, - attrs.use_global_stats, - ) - blocks = main_program.blocks - primapi.to_prim(blocks) - - exe = paddle.static.Executor() - exe.run(startup_program) - res = exe.run( - main_program, - feed={ - 'x1': inputs, - 'x2': running_mean, - 'x3': running_variance, - 'x4': weight, - 'x5': bias, - }, - fetch_list=[y], - ) - paddle.disable_static() - core._set_prim_all_enabled(False) - - return res - def compare_forward(self): np_data = generate_data(attrs.shape, attrs.dtype) tensor_data = paddle.to_tensor(np_data) @@ -234,32 +258,87 @@ def compare_forward(self): np_running_variance = np.ones(C, dtype=attrs.dtype) np_weight = np.ones(C, dtype=attrs.dtype) * 2 np_bias = np.ones(C, dtype=attrs.dtype) - actual = self.cal_composite( + res_origin = cal_static( np_data, np_running_mean, np_running_variance, np_weight, np_bias - )[0] - assert expect.dtype == actual.dtype + ) + res_prim = cal_static( + np_data, + np_running_mean, + np_running_variance, + np_weight, + np_bias, + mode="prim", + ) + + # prim out vs dygraph mode out + assert expect.dtype == res_prim[0].dtype np.testing.assert_allclose( expect, - actual, + res_prim[0], rtol=attrs.get_rtol("forward"), atol=attrs.get_atol("forward"), ) + # prim all outs vs origin static all outs + use_global_stats = attrs.use_global_stats + if use_global_stats is None: + use_global_stats = not attrs.training + trainable_statistics = False + else: + trainable_statistics = not use_global_stats + test_mode = (not attrs.training) and (not trainable_statistics) + + global_stats = test_mode or use_global_stats + vars_name = [ + "Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + ] + + assert len(res_origin) == len(res_prim) + for idx in range(len(res_origin)): + if global_stats and idx >= 3: + # In this case saved_mean and saved_var are not expected. + continue + origin_item = res_origin[idx] + prim_item = res_prim[idx] + + assert origin_item.dtype == prim_item.dtype + rtol = attrs.get_rtol("forward") + atol = attrs.get_atol("forward") + if attrs.dtype == "float64" and idx in (1, 2, 3): + atol = 1e-7 + rtol = 1e-7 + if not isinstance( + framework._current_expected_place(), core.CPUPlace + ) and idx in (2, 3): + atol = 5e-3 + rtol = 5e-3 + np.testing.assert_allclose( + origin_item, + prim_item, + rtol=atol, + atol=rtol, + err_msg=f"Check diff failed of output: {vars_name[idx]}", + ) + def test_forward(self): for i in self.training: for j in self.dtypes: - for m in self.momentum: + for k in self.use_global_stats: attrs.set_training(i) attrs.set_dtype(j) - attrs.set_momentum(m) + attrs.set_use_global_stats(k) self.compare_forward() for n in self.shapes: - for s in self.data_formats: - for t in self.use_global_stats: + for m in self.momentum: + for s in self.data_formats: + attrs.set_momentum(m) attrs.set_shape(n) attrs.set_data_format(s) - attrs.set_use_global_stats(t) self.compare_forward() diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py index 379ee30fb840c..32e83c4b2abe7 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py @@ -142,6 +142,7 @@ def setUpClass(cls): def test_prim(self): # todo: to be removed after adjust of rtol core._set_prim_forward_blacklist("batch_norm") + core._add_skip_comp_ops("batch_norm") dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False) # NOTE: Now dy2st is equal to dy2st_prim. With the splitting of kernels, the threshold here may need to be adjusted np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-6) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_batch_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_batch_norm_grad.py new file mode 100644 index 0000000000000..5fa9599097398 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_batch_norm_grad.py @@ -0,0 +1,261 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + +np.random.seed(2023) + + +class Arg: + dout = None + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = [8, 8, 16, 16] + self.training = True + self.momentum = 0.9 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_training(self, training) -> None: + self.training = training + return + + def set_momentum(self, momentum) -> None: + self.momentum = momentum + return + + def set_epsilon(self, epsilon) -> None: + self.epsilon = epsilon + return + + def set_data_format(self, data_format) -> None: + self.data_format = data_format + return + + def set_use_global_stats(self, use_global_stats) -> None: + self.use_global_stats = use_global_stats + return + + +attrs = Attr() + + +def fn( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, +): + z = F.batch_norm( + x, + running_mean, + running_variance, + weight, + bias, + training=training, + momentum=momentum, + epsilon=epsilon, + data_format=data_format, + use_global_stats=use_global_stats, + ) + out = z * paddle.to_tensor(Arg.dout) + res = paddle.mean(out) + return res + + +def expect_grad( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, +): + x.stop_gradient = False + res = fn( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, + ) + gradients = paddle.grad(res, x) + return gradients + + +def cal_composite(inputs, running_mean, running_variance, weight, bias): + paddle.enable_static() + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x1 = paddle.static.data( + 'x1', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x1.stop_gradient = False + x2 = paddle.static.data( + 'x2', shape=running_mean.shape, dtype=str(running_mean.dtype) + ) + x3 = paddle.static.data( + 'x3', + shape=running_variance.shape, + dtype=str(running_variance.dtype), + ) + x4 = paddle.static.data( + 'x4', shape=weight.shape, dtype=str(weight.dtype) + ) + x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype)) + y = fn( + x1, + x2, + x3, + x4, + x5, + attrs.training, + attrs.momentum, + attrs.epsilon, + attrs.data_format, + attrs.use_global_stats, + ) + z = paddle.static.gradients([y], [x1]) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x1': inputs, + 'x2': running_mean, + 'x3': running_variance, + 'x4': weight, + 'x5': bias, + }, + fetch_list=[z], + ) + paddle.disable_static() + return res + + +class TestCompositeBatchNorm(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32"] + self.training = [False, True] + self.shapes = [[8, 8, 16, 16], [2, 1, 2, 3]] + self.momentum = [0.1, 0.9] + self.epsilon = [1e-05, 2e-05] + self.data_formats = ["NCHW"] + self.use_global_stats = [None, True, False] + + def compare_backward(self): + if attrs.training is True and attrs.use_global_stats is False: + # in this case, origin bn grad kernel is not the same as forward kernel. + return + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + Arg.dout = np.random.random(np_data.shape).astype(attrs.dtype) + C = np_data.shape[1] + + running_mean = paddle.zeros(C, dtype=attrs.dtype) + running_variance = paddle.ones(C, dtype=attrs.dtype) + weight = paddle.ones(C, dtype=attrs.dtype) * 2 + bias = paddle.ones(C, dtype=attrs.dtype) + + expect = expect_grad( + tensor_data, + running_mean, + running_variance, + weight, + bias, + attrs.training, + attrs.momentum, + attrs.epsilon, + attrs.data_format, + attrs.use_global_stats, + )[0].numpy() + np_running_mean = np.zeros(C, dtype=attrs.dtype) + np_running_variance = np.ones(C, dtype=attrs.dtype) + np_weight = np.ones(C, dtype=attrs.dtype) * 2 + np_bias = np.ones(C, dtype=attrs.dtype) + + actual = cal_composite( + np_data, np_running_mean, np_running_variance, np_weight, np_bias + )[0] + + assert expect.dtype == actual.dtype + + np.testing.assert_allclose( + expect, + actual, + rtol=1e-5, + atol=1e-5, + ) + + def test_backward_prim_dygraph_vjp(self): + core.set_prim_eager_enabled(True) + for i in self.training: + for j in self.dtypes: + for m in self.momentum: + attrs.set_training(i) + attrs.set_dtype(j) + attrs.set_momentum(m) + self.compare_backward() + + for n in self.shapes: + for t in self.use_global_stats: + attrs.set_shape(n) + attrs.set_use_global_stats(t) + self.compare_backward() + core.set_prim_eager_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_batch_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_batch_norm_grad.py new file mode 100644 index 0000000000000..5f084bab8617c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_batch_norm_grad.py @@ -0,0 +1,291 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core, framework + +np.random.seed(2023) + + +class Arg: + dout = None + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = [8, 8, 16, 16] + self.training = True + self.momentum = 0.9 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_training(self, training) -> None: + self.training = training + return + + def set_momentum(self, momentum) -> None: + self.momentum = momentum + return + + def set_epsilon(self, epsilon) -> None: + self.epsilon = epsilon + return + + def set_data_format(self, data_format) -> None: + self.data_format = data_format + return + + def set_use_global_stats(self, use_global_stats) -> None: + self.use_global_stats = use_global_stats + return + + +attrs = Attr() + + +def fn( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, +): + z = F.batch_norm( + x, + running_mean, + running_variance, + weight, + bias, + training=training, + momentum=momentum, + epsilon=epsilon, + data_format=data_format, + use_global_stats=use_global_stats, + ) + out = z * paddle.to_tensor(Arg.dout) + res = paddle.mean(out) + return res + + +def expect_grad( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, +): + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False + res = fn( + x, + running_mean, + running_variance, + weight, + bias, + training, + momentum, + epsilon, + data_format, + use_global_stats, + ) + gradients = paddle.grad(res, (x, weight, bias)) + return gradients + + +def cal_composite(inputs, running_mean, running_variance, weight, bias): + paddle.enable_static() + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x1 = paddle.static.data( + 'x1', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x1.stop_gradient = False + x2 = paddle.static.data( + 'x2', shape=running_mean.shape, dtype=str(running_mean.dtype) + ) + x3 = paddle.static.data( + 'x3', + shape=running_variance.shape, + dtype=str(running_variance.dtype), + ) + x4 = paddle.static.data( + 'x4', shape=weight.shape, dtype=str(weight.dtype) + ) + x4.stop_gradient = False + x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype)) + x5.stop_gradient = False + y = fn( + x1, + x2, + x3, + x4, + x5, + attrs.training, + attrs.momentum, + attrs.epsilon, + attrs.data_format, + attrs.use_global_stats, + ) + blocks = main_program.blocks + paddle.incubate.autograd.primapi.to_prim(blocks) + z = paddle.static.gradients([y], [x1, x4, x5]) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run( + main_program, + feed={ + 'x1': inputs, + 'x2': running_mean, + 'x3': running_variance, + 'x4': weight, + 'x5': bias, + }, + fetch_list=[z], + ) + paddle.disable_static() + return res + + +class TestCompositeBatchNorm(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.training = [False, True] + self.shapes = [[8, 8, 16, 16], [2, 4, 3, 3]] + self.momentum = [0.1, 0.9] + self.epsilon = [1e-05, 2e-05] + self.data_formats = ["NCHW", "NHWC"] + self.use_global_stats = [None, True, False] + + def compare_backward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + Arg.dout = np.random.random(np_data.shape).astype(attrs.dtype) + if attrs.data_format == 'NCHW': + C = np_data.shape[1] + elif attrs.data_format == 'NHWC': + C = np_data.shape[-1] + else: + raise TypeError + + running_mean = paddle.zeros(C, dtype=attrs.dtype) + running_variance = paddle.ones(C, dtype=attrs.dtype) + weight = paddle.ones(C, dtype=attrs.dtype) * 2 + bias = paddle.ones(C, dtype=attrs.dtype) + + res_origin = expect_grad( + tensor_data, + running_mean, + running_variance, + weight, + bias, + attrs.training, + attrs.momentum, + attrs.epsilon, + attrs.data_format, + attrs.use_global_stats, + ) + np_running_mean = np.zeros(C, dtype=attrs.dtype) + np_running_variance = np.ones(C, dtype=attrs.dtype) + np_weight = np.ones(C, dtype=attrs.dtype) * 2 + np_bias = np.ones(C, dtype=attrs.dtype) + + res_prim = cal_composite( + np_data, np_running_mean, np_running_variance, np_weight, np_bias + ) + + vars_name = ["x_grad", "weight_grad", "bias_grad"] + assert len(res_origin) == len(res_prim) + for idx in range(len(res_origin)): + origin_item = res_origin[idx].numpy() + prim_item = res_prim[idx] + assert origin_item.dtype == prim_item.dtype + rtol = 1e-5 + atol = 1e-5 + if ( + not isinstance( + framework._current_expected_place(), core.CPUPlace + ) + and attrs.data_format == "NHWC" + ): + rtol = 1e-4 + atol = 1e-4 + if idx in (1, 2): + continue + + np.testing.assert_allclose( + origin_item, + prim_item, + rtol=rtol, + atol=atol, + err_msg=f"Check diff failed of output: {vars_name[idx]} with data_format: {attrs.data_format}", + ) + + def test_backward_prim_static_vjp(self): + core._set_prim_backward_enabled(True) + for i in self.training: + for j in self.dtypes: + for k in self.data_formats: + for m in self.momentum: + attrs.set_training(i) + attrs.set_dtype(j) + attrs.set_data_format(k) + attrs.set_momentum(m) + self.compare_backward() + + for s in self.training: + for n in self.shapes: + for t in self.use_global_stats: + attrs.set_training(s) + attrs.set_shape(n) + attrs.set_use_global_stats(t) + self.compare_backward() + core._set_prim_backward_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py index fec2375f7649b..85517ffeba2d0 100644 --- a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py @@ -134,6 +134,7 @@ def test_check_resnet50_accuracy_with_composite(self): loop_num = 10 feed = self.generate_random_data(loop_num) core._set_prim_backward_enabled(True) + core._add_skip_comp_ops("batch_norm") loss_c = self.train(place, loop_num, feed, use_cinn=True) core._set_prim_backward_enabled(False) loss_p = self.train(place, loop_num, feed, use_cinn=True) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ce70f42df25a5..d31a0185f9270 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -109,16 +109,19 @@ def composite_batchnorm( if is_amp: y = cast(y, "float16") + # As the same with op kernel, indeed return inverse std + inv_std = 1.0 / sqrt(batch_var + epsilon) + # add op assign to detach tensor in void unsafe change outside the rule. batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) - batch_var_ = assign(reshape(batch_var, run_var.shape)) + inv_std_ = assign(reshape(inv_std, run_var.shape)) run_mean_ = assign(run_mean) run_var_ = assign(run_var) # reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition. reserve_space = None - return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space + return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space @REGISTER_COMPOSITE('layer_norm')