Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] support batch_norm vjp #51283

Merged
merged 10 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/multiary.h"

Expand Down Expand Up @@ -534,6 +538,70 @@ phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType(
ctx.GetPlace());
}

class BatchNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
// inputs and outputs of batch_norm
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor scale = this->GetSingleForwardInput("Scale");
paddle::Tensor bias = this->GetSingleForwardInput("Bias");
paddle::Tensor mean = this->GetSingleForwardInput("Mean");
paddle::Tensor variance = this->GetSingleForwardInput("Variance");
paddle::Tensor y = this->GetSingleForwardOutput("Y");
paddle::Tensor mean_out = this->GetSingleForwardOutput("MeanOut");
paddle::Tensor variance_out = this->GetSingleForwardOutput("VarianceOut");
paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean");
paddle::Tensor saved_variance =
this->GetSingleForwardOutput("SavedVariance");
paddle::optional<paddle::Tensor> reserve_space;

paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");

auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
std::string dscale_name = this->GetOutputName(scale_grad);
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
std::string dbias_name = this->GetOutputName(bias_grad);

// attrs of batch_norm
auto momentum = this->Attr<float>("momentum");
auto epsilon = this->Attr<float>("epsilon");
auto data_layout = this->Attr<std::string>("data_layout");
auto is_test = this->Attr<bool>("is_test");
auto use_global_stats = this->Attr<bool>("use_global_stats");
auto trainable_statistics = this->Attr<bool>("trainable_statistics");

VLOG(3) << "Runing batch_norm composite func";
prim::batch_norm_grad<prim::DescTensor>(x,
scale,
bias,
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space,
y_grad,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
dx_ptr,
dscale_ptr,
dbias_ptr);
this->RecoverOutputName(x_grad, dx_name);
this->RecoverOutputName(scale_grad, dscale_name);
this->RecoverOutputName(bias_grad, dbias_name);
}
};

DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});

} // namespace operators
Expand All @@ -550,7 +618,8 @@ REGISTER_OPERATOR(batch_norm,
ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>);
ops::BatchNormGradMaker<paddle::imperative::OpBase>,
ops::BatchNormCompositeGradOpMaker);

REGISTER_OPERATOR(batch_norm_grad,
ops::BatchNormGradOp,
Expand Down
163 changes: 163 additions & 0 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,5 +982,168 @@ void dropout_grad(const Tensor& mask,
}
}
}

template <typename T>
void batch_norm_grad(const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const paddle::optional<Tensor>& mean_out,
const paddle::optional<Tensor>& variance_out,
const Tensor& saved_mean,
const Tensor& saved_variance,
const paddle::optional<Tensor>& reserve_space,
const Tensor& out_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
use_global_stats = is_test || use_global_stats;

DataLayout data_layout_ = phi::StringToDataLayout(data_layout);

Tensor x_data = x;
Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}
if (out_grad.dtype() == phi::DataType::FLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}
auto x_dims = x_data.dims();
const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
int nume = 1;
for (auto i = 0; i < x_dims.size(); i++) {
nume = nume * x_dims[i];
}

const int nhw = nume / C;

if (x_dims.size() == 2 && data_layout_ == DataLayout::kNCHW) {
data_layout_ = DataLayout::kNHWC;
}

auto run_var = variance_out.get();
auto run_mean = mean_out.get();

Tensor mean_data;
Tensor rsqrt_var;

if (use_global_stats) {
auto eps =
full<T>(phi::vectorize(run_var.dims()), epsilon, run_var.dtype());
mean_data = run_mean;
rsqrt_var = 1 / (run_var + eps).pow(0.5);
} else {
mean_data = saved_mean;
rsqrt_var = saved_variance;
}

// inv_var = 1 / sqrt(var + eps)
// reduce_axis = [0, 2, 3] (NCHW) [0, 1, 2] (NHWC)
//
// d_bias = np.sum(d_y, reduce_axis)
// d_scale = np.sum((X - mean) / inv_var * dy, reduce_axis)
//
// train mode
// d_x = (1. / nhw) * scale * inv_var
// *(nhw * d_y - np.sum(d_y, reduce_axis) - (X - mean) * inv_var * inv_var *
// np.sum(d_y * (X - mean), reduce_axis))
//
// test mode
// d_x = d_y * scale * inv_var

std::vector<int> nchw_to_nhwc_dim = {0, 2, 3, 1};
std::vector<int> nhwc_to_nchw_dim = {0, 3, 1, 2};
auto reduce_axis = IntArray(std::vector<int>{0, 1, 2});
auto dtype = x_data.dtype();

switch (data_layout_) {
case DataLayout::kNCHW: {
auto nhwc_x = transpose<T>(x_data, nchw_to_nhwc_dim);
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);

auto x_sub_mean = nhwc_x - mean_data;

if (x_grad) {
if (use_global_stats) {
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
set_output<T>(nchw_x_grad, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(nhwc_out_grad, reduce_axis, dtype, false) / nhw;

auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;

auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
}
set_output<T>(nchw_x_grad, x_grad);
}
}
if (scale_grad) {
auto scale_grad_data = sum<T>(
nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false);
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
}
break;
}
case DataLayout::kNHWC: {
if (x_grad) {
auto x_sub_mean = x_data - mean_data;
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
set_output<T>(x_grad_data, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(out_grad_data, reduce_axis, dtype, false) / nhw;

auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2;

auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad);
}
if (scale_grad) {
auto scale_grad_data = sum<T>(out_grad_data * x_sub_mean * rsqrt_var,
reduce_axis,
dtype,
false);
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data =
sum<T>(out_grad_data, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
}
break;
}
}
default:
PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %s",
data_layout));
}
}

} // namespace prim
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
func : batch_norm_grad
data_type : out_grad
optional : mean_out, variance_out, reserve_space
composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add grad outputs in

backward : batch_norm_double_grad

- backward_op : bce_loss_grad
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def test_resnet(self):

def test_resnet_composite(self):
core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("batch_norm")
static_loss = self.train(to_static=True)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False)
Expand Down
Loading