diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 8bde3b3deb8db..2a4af31744640 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -12,13 +12,11 @@ register_operators( fused_feedforward_op fused_multi_transformer_op fused_multi_transformer_int8_op - resnet_unit_op fused_gemm_epilogue_op resnet_basic_block_op) if(WITH_XPU) op_library(resnet_basic_block_op) - op_library(resnet_unit_op) op_library(fused_gemm_epilogue_op) op_library(fused_attention_op) op_library(fused_feedforward_op) @@ -46,10 +44,6 @@ if(WITH_GPU OR WITH_ROCM) op_library(fused_multi_transformer_op) op_library(fused_multi_transformer_int8_op) endif() - # resnet_unit needs cudnn 8.0 above - if((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) - op_library(resnet_unit_op) - endif() if(CUDA_VERSION GREATER_EQUAL 11.6) op_library(fused_gemm_epilogue_op) diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc deleted file mode 100644 index 9f8b8d0744ffe..0000000000000 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/common/float16.h" - -namespace paddle::operators { - -// Shape of bitmask -static phi::DDim GetBitmaskDims(std::vector out_shape) { - int c = out_shape.back(); - int64_t nhw = std::accumulate(out_shape.begin(), - out_shape.end(), - 1, - std::multiplies()) / // NOLINT - c; - int32_t c_int32_elems = ((c + 63) & ~63) / 32; - int32_t nhw_int32_elems = static_cast(((nhw + 31) & ~31)); - std::vector bitmask_shape = {nhw_int32_elems, c_int32_elems, 1}; - return common::make_ddim(bitmask_shape); -} - -class ResNetUnitOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - // Check input - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasInput("FilterX"), "Input", "FilterX", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp"); - - bool fuse_add = ctx->Attrs().Get("fuse_add"); - bool has_shortcut = ctx->Attrs().Get("has_shortcut"); - if (fuse_add || has_shortcut) { - OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp"); - } - if (has_shortcut) { - OP_INOUT_CHECK( - ctx->HasInput("FilterZ"), "Input", "FilterZ", "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasInput("ScaleZ"), "Input", "ScaleZ", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("MeanZ"), "Input", "MeanZ", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasInput("VarZ"), "Input", "VarZ", "ResNetUnitOp"); - } - - // Check output - OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasOutput("BitMask"), "Output", "BitMask", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("ConvX"), "Output", "ConvX", "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasOutput("SavedMeanX"), "Output", "SavedMeanX", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdX"), - "Output", - "SavedInvstdX", - "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("RunningMeanX"), - "Output", - "RunningMeanX", - "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasOutput("RunningVarX"), "Output", "RunningVarX", "ResNetUnitOp"); - if (has_shortcut) { - OP_INOUT_CHECK( - ctx->HasOutput("ConvZ"), "Output", "ConvZ", "ResNetUnitOp"); - OP_INOUT_CHECK( - ctx->HasOutput("SavedMeanZ"), "Output", "SavedMeanZ", "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdZ"), - "Output", - "SavedInvstdZ", - "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("RunningMeanZ"), - "Output", - "RunningMeanZ", - "ResNetUnitOp"); - OP_INOUT_CHECK(ctx->HasOutput("RunningVarZ"), - "Output", - "RunningVarZ", - "ResNetUnitOp"); - } - - // make sure Mean/RunningMean and Var/RunningVar share memory - PADDLE_ENFORCE_EQ( - ctx->Inputs("MeanX")[0], - ctx->Outputs("RunningMeanX")[0], - phi::errors::InvalidArgument( - "MeanX and RunningMeanX should share the same memory")); - PADDLE_ENFORCE_EQ(ctx->Inputs("VarX")[0], - ctx->Outputs("RunningVarX")[0], - phi::errors::InvalidArgument( - "VarX and RunningVarX should share the same memory")); - if (has_shortcut) { - PADDLE_ENFORCE_EQ( - ctx->Inputs("MeanZ")[0], - ctx->Outputs("RunningMeanZ")[0], - phi::errors::InvalidArgument( - "MeanZ and RunningMeanZ should share the same memory")); - PADDLE_ENFORCE_EQ( - ctx->Inputs("VarZ")[0], - ctx->Outputs("RunningVarZ")[0], - phi::errors::InvalidArgument( - "VarZ and RunningVarZ should share the same memory")); - } - - // Check dims of inputs - const auto x_dims = ctx->GetInputDim("X"); - const auto w_dims = ctx->GetInputDim("FilterX"); - std::vector bn_param_shape = - common::vectorize(ctx->GetInputDim("ScaleX")); - if (1 == bn_param_shape.size()) { - bn_param_shape = {1, 1, 1, bn_param_shape[0]}; - } - phi::DDim bn_param_dims = common::make_ddim(bn_param_shape); - PADDLE_ENFORCE_EQ( - x_dims.size(), - 4, - phi::errors::InvalidArgument("The dimensions of input " - "must equal to 4." - "But received: the shape of input " - "= [%s], the dimension of input = " - "[%d]", - x_dims, - x_dims.size())); - PADDLE_ENFORCE_EQ( - w_dims.size(), - 4, - phi::errors::InvalidArgument("The dimensions of filter " - "must equal to 4." - "But received: the shape of filter " - "= [%s], the dimension of filter = [%d] ", - w_dims, - w_dims.size())); - PADDLE_ENFORCE_EQ(bn_param_dims.size(), - 4, - phi::errors::InvalidArgument( - "The dimensions of bn param " - "must equal to 4." - "But received: the shape of bn param " - "= [%s], the dimension of bn param = [%d] ", - bn_param_dims, - bn_param_dims.size())); - auto data_format = ctx->Attrs().Get("data_format"); - bool is_nchw = (data_format == "NCHW"); - // Calculate the dims of outputs - int batch = x_dims[0]; - int output_channel = w_dims[0]; - int filter_size = w_dims[2]; - int stride = ctx->Attrs().Get("stride"); - int padding = ctx->Attrs().Get("padding"); - std::vector out_shape; - out_shape.push_back(batch); - if (is_nchw) { - int out_h = (x_dims[2] + padding * 2 - filter_size) / stride + 1; - int out_w = (x_dims[3] + padding * 2 - filter_size) / stride + 1; - out_shape.push_back(output_channel); - out_shape.push_back(out_h); - out_shape.push_back(out_w); - } else { - int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; - int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; - out_shape.push_back(out_h); - out_shape.push_back(out_w); - out_shape.push_back(output_channel); - } - - auto y_dims = common::make_ddim(out_shape); - auto bitmask_dims = GetBitmaskDims(out_shape); - // Set dims of outputs - ctx->SetOutputDim("Y", y_dims); - ctx->SetOutputDim("BitMask", bitmask_dims); - ctx->SetOutputDim("ConvX", y_dims); - ctx->SetOutputDim("SavedMeanX", bn_param_dims); - ctx->SetOutputDim("SavedInvstdX", bn_param_dims); - ctx->SetOutputDim("RunningMeanX", bn_param_dims); - ctx->SetOutputDim("RunningVarX", bn_param_dims); - if (has_shortcut) { - ctx->SetOutputDim("ConvZ", y_dims); - ctx->SetOutputDim("SavedMeanZ", bn_param_dims); - ctx->SetOutputDim("SavedInvstdZ", bn_param_dims); - ctx->SetOutputDim("RunningMeanZ", bn_param_dims); - ctx->SetOutputDim("RunningVarZ", bn_param_dims); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - // By default, the type of the scale, bias, mean, - // and var tensors should be float when input tensor's dtype is float16. - auto bn_param_type = phi::DataType::FLOAT32; - - PADDLE_ENFORCE_EQ( - bn_param_type, - ctx.Input("ScaleX")->dtype(), - phi::errors::InvalidArgument("Scale input should be of float type")); - PADDLE_ENFORCE_EQ( - bn_param_type, - ctx.Input("BiasX")->dtype(), - phi::errors::InvalidArgument("Bias input should be of float type")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input 1 tensor"); - AddInput("FilterX", "Filter tensor of input 1"); - AddInput("ScaleX", "Scale tensor of input 1 used in batchnorm"); - AddInput("BiasX", "Bias tensor of input 1 used in batchnorm"); - AddInput("MeanX", "Mean tensor of input 1 used in batchnorm"); - AddInput("VarX", "Variance tensor of input 1 used in batchnorm"); - AddInput("Z", "The input 2 tensor").AsDispensable(); - AddInput("FilterZ", "Filter tensor of input 2").AsDispensable(); - AddInput("ScaleZ", "Scale tensor of input 2").AsDispensable(); - AddInput("BiasZ", "Bias tensor of input 2").AsDispensable(); - AddInput("MeanZ", "Mean tensor of input 2").AsDispensable(); - AddInput("VarZ", "Variance tensor of input 2").AsDispensable(); - AddOutput("Y", "The result of the resnet unit"); - AddOutput("BitMask", "The bitmask generated after relu"); - AddOutput("ConvX", "The output of input 1 after conv"); - AddOutput("SavedMeanX", "Mean of input 1 in the current batch"); - AddOutput("SavedInvstdX", "Invstd of input 1 in the current batch"); - AddOutput("RunningMeanX", "Shared memory with MeanX"); - AddOutput("RunningVarX", "Shared memory with VarX"); - AddOutput("ConvZ", "The output of input 2 after conv").AsDispensable(); - AddOutput("SavedMeanZ", "Mean of input 1 in the current batch") - .AsDispensable(); - AddOutput("SavedInvstdZ", "Invstd of input 1 in the current batch") - .AsDispensable(); - AddOutput("RunningMeanZ", "Shared memory with MeanZ").AsDispensable(); - AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable(); - AddAttr("stride", "").SetDefault(1); - AddAttr("stride_z", "").SetDefault(1); - AddAttr("padding", "").SetDefault(0); - AddAttr("dilation", "").SetDefault(1); - AddAttr("group", "").SetDefault(1); - AddAttr("momentum", "").SetDefault(0.9); - AddAttr("epsilon", "").SetDefault(1e-5); - AddAttr("data_format", "").SetDefault("NHWC"); - AddAttr("fuse_add", "").SetDefault(false); - AddAttr("has_shortcut", "").SetDefault(false); - AddAttr("use_global_stats", "").SetDefault(false); - AddAttr("is_test", - "(bool, default false) Set to true for inference only, false " - "for training. Some layers may run faster when this is true.") - .SetDefault(false); - AddAttr("use_addto", "").SetDefault(false); - AddAttr("act_type", "The activation type to be fused.") - .SetDefault("relu"); - AddComment(R"DOC( -Fusion op of the basic unit of resnet block. - -The implementation is based on the latest fusion op interface in cuDNN v8.0. -For more details: -https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t - -)DOC"); - } -}; - -class ResNetUnitGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - // check input - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("FilterX"), "Input", "FilterX", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("ConvX"), "Input", "ConvX", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("SavedMeanX"), "Input", "SavedMeanX", "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), - "Input", - "SavedInvstdX", - "ResNetUnitGradOp"); - - bool fuse_add = ctx->Attrs().Get("fuse_add"); - bool has_shortcut = ctx->Attrs().Get("has_shortcut"); - if (fuse_add || has_shortcut) { - OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp"); - } - if (has_shortcut) { - OP_INOUT_CHECK( - ctx->HasInput("FilterZ"), "Input", "FilterZ", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("ConvZ"), "Input", "ConvZ", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("ScaleZ"), "Input", "ScaleZ", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasInput("SavedMeanZ"), - "Input", - "SavedMeanZ", - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasInput("SavedInvstdZ"), - "Input", - "SavedInvstdZ", - "ResNetUnitGradOp"); - } - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ResNetUnitGradOp"); - OP_INOUT_CHECK( - ctx->HasInput("BitMask"), "Input", "BitMask", "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), - "Input", - framework::GradVarName("Y"), - "ResNetUnitGradOp"); - - // check output - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - framework::GradVarName("X"), - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterX")), - "Output", - framework::GradVarName("FilterX"), - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleX")), - "Output", - framework::GradVarName("ScaleX"), - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), - "Output", - framework::GradVarName("BiasX"), - "ResNetUnitGradOp"); - if (fuse_add) { - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), - "Output", - framework::GradVarName("Z"), - "ResNetUnitGradOp"); - } - if (has_shortcut) { - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterZ")), - "Output", - framework::GradVarName("FilterZ"), - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleZ")), - "Output", - framework::GradVarName("ScaleZ"), - "ResNetUnitGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasZ")), - "Output", - framework::GradVarName("BiasZ"), - "ResNetUnitGradOp"); - } - const auto x_dims = ctx->GetInputDim("X"); - const auto filter_x_dims = ctx->GetInputDim("FilterX"); - const auto param_dims = ctx->GetInputDim("ScaleX"); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims); - ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims); - ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims); - if (fuse_add || has_shortcut) { - const auto z_dims = ctx->GetInputDim("Z"); - ctx->SetOutputDim(framework::GradVarName("Z"), z_dims); - } - if (has_shortcut) { - const auto filter_z_dims = ctx->GetInputDim("FilterZ"); - ctx->SetOutputDim(framework::GradVarName("FilterZ"), filter_z_dims); - ctx->SetOutputDim(framework::GradVarName("ScaleZ"), param_dims); - ctx->SetOutputDim(framework::GradVarName("BiasZ"), param_dims); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar(framework::GradVarName("Y")), - phi::errors::NotFound("Can not find Y@GRAD in the execution context.")); - - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -template -class ResNetUnitGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("resnet_unit_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("FilterX", this->Input("FilterX")); - op->SetInput("ConvX", this->Output("ConvX")); - op->SetInput("ScaleX", this->Input("ScaleX")); - op->SetInput("BiasX", this->Input("BiasX")); - op->SetInput("SavedMeanX", this->Output("SavedMeanX")); - op->SetInput("SavedInvstdX", this->Output("SavedInvstdX")); - op->SetInput("Z", this->Input("Z")); - op->SetInput("FilterZ", this->Input("FilterZ")); - op->SetInput("ConvZ", this->Output("ConvZ")); - op->SetInput("ScaleZ", this->Input("ScaleZ")); - op->SetInput("BiasZ", this->Input("BiasZ")); - op->SetInput("SavedMeanZ", this->Output("SavedMeanZ")); - op->SetInput("SavedInvstdZ", this->Output("SavedInvstdZ")); - op->SetInput("Y", this->Output("Y")); - op->SetInput("BitMask", this->Output("BitMask")); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - - op->SetAttrMap(this->Attrs()); - - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("FilterX"), - this->InputGrad("FilterX")); - op->SetOutput(framework::GradVarName("ScaleX"), this->InputGrad("ScaleX")); - op->SetOutput(framework::GradVarName("BiasX"), this->InputGrad("BiasX")); - op->SetOutput(framework::GradVarName("Z"), this->InputGrad("Z")); - op->SetOutput(framework::GradVarName("FilterZ"), - this->InputGrad("FilterZ")); - op->SetOutput(framework::GradVarName("ScaleZ"), this->InputGrad("ScaleZ")); - op->SetOutput(framework::GradVarName("BiasZ"), this->InputGrad("BiasZ")); - } -}; - -class ResNetUnitOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Y"}}; - return m; - } -}; - -} // namespace paddle::operators - -namespace ops = paddle::operators; -REGISTER_OPERATOR(resnet_unit, - ops::ResNetUnitOp, - ops::ResNetUnitOpMaker, - ops::ResNetUnitOpInferVarType, - ops::ResNetUnitGradOpMaker, - ops::ResNetUnitGradOpMaker); -REGISTER_OPERATOR(resnet_unit_grad, ops::ResNetUnitGradOp); diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu deleted file mode 100644 index cbf9b1ce2e517..0000000000000 --- a/paddle/fluid/operators/fused/resnet_unit_op.cu +++ /dev/null @@ -1,429 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/kernels/fusion/gpu/cudnn_bn_stats_finalize.cu.h" -#include "paddle/phi/kernels/fusion/gpu/cudnn_norm_conv.cu.h" -#include "paddle/phi/kernels/fusion/gpu/cudnn_scale_bias_add_relu.cu.h" - -namespace paddle { -namespace operators { - -template -class ResNetUnitKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, - CUDNN_DATA_HALF, - phi::errors::Unavailable( - "ResNetUnitOp only supports float16 for now.")); - - // input x - const phi::DenseTensor *input_x = ctx.Input("X"); - const phi::DenseTensor *filter_x = ctx.Input("FilterX"); - const phi::DenseTensor *scale_x = ctx.Input("ScaleX"); - const phi::DenseTensor *bias_x = ctx.Input("BiasX"); - // norm conv - phi::DenseTensor *conv_out_x = ctx.Output("ConvX"); - // bn finalize - phi::DenseTensor *saved_mean_x = ctx.Output("SavedMeanX"); - phi::DenseTensor *saved_invstd_x = - ctx.Output("SavedInvstdX"); - phi::DenseTensor *running_mean_x = - ctx.Output("RunningMeanX"); - phi::DenseTensor *running_var_x = - ctx.Output("RunningVarX"); - // sbar - phi::DenseTensor *output = ctx.Output("Y"); - phi::DenseTensor *bitmask = ctx.Output("BitMask"); - // attrs - int padding = ctx.Attr("padding"); - int stride = ctx.Attr("stride"); - int stride_z = ctx.Attr("stride_z"); - int dilation = ctx.Attr("dilation"); - int group = ctx.Attr("group"); - double eps = static_cast(ctx.Attr("epsilon")); - double momentum = static_cast(ctx.Attr("momentum")); - bool has_shortcut = ctx.Attr("has_shortcut"); - bool fuse_add = ctx.Attr("fuse_add"); - bool use_global_stats = ctx.Attr("use_global_stats"); - bool is_test = ctx.Attr("is_test"); - bool is_train = !is_test && !use_global_stats; - std::string act_type = ctx.Attr("act_type"); - - auto input_x_shape = common::vectorize(input_x->dims()); - auto filter_x_shape = common::vectorize(filter_x->dims()); - // std::swap used to convert shape of filter from conv2d when kernel size is - // 1. - if (filter_x_shape[1] != filter_x_shape[2] && 1 == filter_x_shape[2]) { - std::swap(filter_x_shape[1], filter_x_shape[3]); - } - auto param_dims = scale_x->dims(); - auto param_shape = common::vectorize(scale_x->dims()); - if (1 == param_shape.size()) { - param_shape = {1, 1, 1, param_shape[0]}; - } - auto output_shape = common::vectorize(output->dims()); - auto bitmask_shape = common::vectorize(bitmask->dims()); - int output_channel = filter_x_shape[0]; - int64_t ele_count = std::accumulate(output_shape.begin(), - output_shape.end(), - 1, - std::multiplies()) / - output_channel; - - auto place = ctx.GetPlace(); - auto &dev_ctx = ctx.template device_context(); - - // 1. Conv - phi::DenseTensor sum_x; - phi::DenseTensor sum_of_squares_x; - sum_x.Resize(param_dims); - sum_of_squares_x.Resize(param_dims); - phi::fusion::CudnnNormConvolution conv_x_op(dev_ctx, - input_x_shape, - filter_x_shape, - output_shape, - padding, - stride, - dilation, - group); - conv_x_op.Forward( - dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x, &sum_of_squares_x); - - // 2. BN - phi::DenseTensor equiv_scale_x; - phi::DenseTensor equiv_bias_x; - equiv_scale_x.Resize(param_dims); - equiv_bias_x.Resize(param_dims); - phi::fusion::CudnnBNStatsFinalize bn_x_op(dev_ctx, param_shape); - bn_x_op.Forward(dev_ctx, - sum_x, - sum_of_squares_x, - *scale_x, - *bias_x, - saved_mean_x, - saved_invstd_x, - running_mean_x, - running_var_x, - &equiv_scale_x, - &equiv_bias_x, - eps, - momentum, - ele_count, - is_train); - - // 3. scale + bias + add + relu - phi::fusion::CudnnScaleBiasAddRelu sbar_op(dev_ctx, - act_type, - fuse_add, - has_shortcut, - output_shape, - param_shape, - bitmask_shape); - if (has_shortcut) { - // input z - const phi::DenseTensor *input_z = ctx.Input("Z"); - const phi::DenseTensor *filter_z = ctx.Input("FilterZ"); - const phi::DenseTensor *scale_z = ctx.Input("ScaleZ"); - const phi::DenseTensor *bias_z = ctx.Input("BiasZ"); - // norm conv - phi::DenseTensor *conv_out_z = ctx.Output("ConvZ"); - // bn finalize - phi::DenseTensor *saved_mean_z = - ctx.Output("SavedMeanZ"); - phi::DenseTensor *saved_invstd_z = - ctx.Output("SavedInvstdZ"); - phi::DenseTensor *running_mean_z = - ctx.Output("RunningMeanZ"); - phi::DenseTensor *running_var_z = - ctx.Output("RunningVarZ"); - - auto input_z_shape = common::vectorize(input_z->dims()); - auto filter_z_shape = common::vectorize(filter_z->dims()); - - // 3.1 Conv for second input - phi::DenseTensor sum_z; - phi::DenseTensor sum_of_squares_z; - sum_z.Resize(param_dims); - sum_of_squares_z.Resize(param_dims); - phi::fusion::CudnnNormConvolution conv_z_op(dev_ctx, - input_z_shape, - filter_z_shape, - output_shape, - padding, - stride_z, - dilation, - group); - conv_z_op.Forward( - dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z, &sum_of_squares_z); - - // 3.2 BN for second input - phi::DenseTensor equiv_scale_z; - phi::DenseTensor equiv_bias_z; - equiv_scale_z.Resize(param_dims); - equiv_bias_z.Resize(param_dims); - phi::fusion::CudnnBNStatsFinalize bn_z_op(dev_ctx, param_shape); - bn_z_op.Forward(dev_ctx, - sum_z, - sum_of_squares_z, - *scale_z, - *bias_z, - saved_mean_z, - saved_invstd_z, - running_mean_z, - running_var_z, - &equiv_scale_z, - &equiv_bias_z, - eps, - momentum, - ele_count, - is_train); - // 3.3 sbar - sbar_op.Forward(dev_ctx, - *conv_out_x, - equiv_scale_x, - equiv_bias_x, - conv_out_z, - &equiv_scale_z, - &equiv_bias_z, - output, - bitmask); - } else { - const phi::DenseTensor *input_z = - fuse_add ? ctx.Input("Z") : nullptr; - sbar_op.Forward(dev_ctx, - *conv_out_x, - equiv_scale_x, - equiv_bias_x, - input_z, - nullptr, - nullptr, - output, - bitmask); - } - } -}; - -template -class ResNetUnitGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - ctx.GetPlace().GetType() == phi::AllocationType::GPU, - true, - phi::errors::PreconditionNotMet("It must use CUDAPlace.")); - PADDLE_ENFORCE_EQ(phi::backends::gpu::CudnnDataType::type, - CUDNN_DATA_HALF, - phi::errors::Unavailable( - "ResNetUnitOp only supports float16 for now.")); - - const phi::DenseTensor *y_grad = - ctx.Input(framework::GradVarName("Y")); - - const phi::DenseTensor *x = ctx.Input("X"); - const phi::DenseTensor *filter_x = ctx.Input("FilterX"); - const phi::DenseTensor *scale_x = ctx.Input("ScaleX"); - const phi::DenseTensor *bias_x = ctx.Input("BiasX"); - const phi::DenseTensor *saved_mean_x = - ctx.Input("SavedMeanX"); - const phi::DenseTensor *saved_invstd_x = - ctx.Input("SavedInvstdX"); - - const phi::DenseTensor *conv_out_x = ctx.Input("ConvX"); - const phi::DenseTensor *output = ctx.Input("Y"); - const phi::DenseTensor *bitmask = ctx.Input("BitMask"); - - phi::DenseTensor *x_grad = - ctx.Output(framework::GradVarName("X")); - phi::DenseTensor *filter_x_grad = - ctx.Output(framework::GradVarName("FilterX")); - phi::DenseTensor *scale_x_grad = - ctx.Output(framework::GradVarName("ScaleX")); - phi::DenseTensor *bias_x_grad = - ctx.Output(framework::GradVarName("BiasX")); - - int padding = ctx.Attr("padding"); - int stride = ctx.Attr("stride"); - int stride_z = ctx.Attr("stride_z"); - int dilation = ctx.Attr("dilation"); - int group = ctx.Attr("group"); - double eps = static_cast(ctx.Attr("epsilon")); - double momentum = static_cast(ctx.Attr("momentum")); - bool has_shortcut = ctx.Attr("has_shortcut"); - bool fuse_add = ctx.Attr("fuse_add"); - bool use_global_stats = ctx.Attr("use_global_stats"); - std::string act_type = ctx.Attr("act_type"); - - auto x_shape = common::vectorize(x->dims()); - auto filter_x_shape = common::vectorize(filter_x->dims()); - auto param_shape = common::vectorize(scale_x->dims()); - auto output_shape = common::vectorize(output->dims()); - auto bitmask_shape = common::vectorize(bitmask->dims()); - - auto place = ctx.GetPlace(); - auto &dev_ctx = ctx.template device_context(); - - // 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, - // scale_x_grad, bias_x_grad - phi::DenseTensor conv_out_x_grad; - conv_out_x_grad.Resize(conv_out_x->dims()); - phi::fusion::CudnnScaleBiasAddRelu sbar_x_op(dev_ctx, - act_type, - fuse_add, - has_shortcut, - output_shape, - param_shape, - bitmask_shape); - if (has_shortcut) { - // X Z - // | | - // NormConv NormConv - // | | - // BNStatsFinalize BNStatsFinalize - // \ / - // ScaleBiasAddRelu - // | - // Y - const phi::DenseTensor *z = ctx.Input("Z"); - const phi::DenseTensor *filter_z = ctx.Input("FilterZ"); - const phi::DenseTensor *scale_z = ctx.Input("ScaleZ"); - const phi::DenseTensor *bias_z = ctx.Input("BiasZ"); - const phi::DenseTensor *saved_mean_z = - ctx.Input("SavedMeanZ"); - const phi::DenseTensor *saved_invstd_z = - ctx.Input("SavedInvstdZ"); - const phi::DenseTensor *conv_out_z = ctx.Input("ConvZ"); - - phi::DenseTensor *z_grad = - ctx.Output(framework::GradVarName("Z")); - phi::DenseTensor *filter_z_grad = - ctx.Output(framework::GradVarName("FilterZ")); - phi::DenseTensor *scale_z_grad = - ctx.Output(framework::GradVarName("ScaleZ")); - phi::DenseTensor *bias_z_grad = - ctx.Output(framework::GradVarName("BiasZ")); - - // 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad, - // scale_x_grad, bias_x_grad and z_grad_temp - phi::DenseTensor z_grad_temp; - z_grad_temp.Resize(conv_out_z->dims()); - sbar_x_op.Backward(dev_ctx, - *y_grad, - *conv_out_x, - *scale_x, - *bias_x, - *saved_mean_x, - *saved_invstd_x, - bitmask, - &conv_out_x_grad, - &z_grad_temp, - scale_x_grad, - bias_x_grad, - eps); - - // 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z - phi::DenseTensor conv_out_z_grad; - conv_out_z_grad.Resize(conv_out_z->dims()); - phi::fusion::CudnnScaleBiasAddRelu sbar_z_op( - dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape); - sbar_z_op.Backward(dev_ctx, - z_grad_temp, - *conv_out_z, - *scale_z, - *bias_z, - *saved_mean_z, - *saved_invstd_z, - nullptr, - &conv_out_z_grad, - nullptr, - scale_z_grad, - bias_z_grad, - eps); - - // 1.3 Backward of Conv for z, get z_grad and filter_z_grad - auto z_shape = common::vectorize(z->dims()); - auto filter_z_shape = common::vectorize(filter_z->dims()); - phi::fusion::CudnnNormConvolutionGrad conv_z_op(dev_ctx, - z_shape, - filter_z_shape, - output_shape, - padding, - stride_z, - dilation, - group); - conv_z_op.Backward( - dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad, filter_z_grad); - } else { - // 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, - // scale_x_grad, bias_x_grad (and z_grad) - phi::DenseTensor *z_grad = - fuse_add ? ctx.Output(framework::GradVarName("Z")) - : nullptr; - sbar_x_op.Backward(dev_ctx, - *y_grad, - *conv_out_x, - *scale_x, - *bias_x, - *saved_mean_x, - *saved_invstd_x, - bitmask, - &conv_out_x_grad, - z_grad, - scale_x_grad, - bias_x_grad, - eps); - } - - // 2. Backward of Conv for x, get x_grad and filter_x_grad - bool use_addto = ctx.Attr("use_addto"); - phi::fusion::CudnnNormConvolutionGrad conv_x_op(dev_ctx, - x_shape, - filter_x_shape, - output_shape, - padding, - stride, - dilation, - group); - conv_x_op.Backward(dev_ctx, - *x, - *filter_x, - conv_out_x_grad, - x_grad, - filter_x_grad, - use_addto); - } -}; - -} // namespace operators -} // namespace paddle - -#if CUDNN_VERSION >= 8000 -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL( - resnet_unit, GPU, ALL_LAYOUT, ops::ResNetUnitKernel, phi::dtype::float16) {} -PD_REGISTER_STRUCT_KERNEL(resnet_unit_grad, - GPU, - ALL_LAYOUT, - ops::ResNetUnitGradKernel, - phi::dtype::float16) {} -#endif diff --git a/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc b/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc deleted file mode 100644 index 729dc0de0303b..0000000000000 --- a/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc +++ /dev/null @@ -1,373 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/phi/common/float16.h" - -namespace paddle { -namespace operators { - -template -class ResNetUnitXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::XPU, - true, - phi::errors::PreconditionNotMet("It must use XPUPlace.")); - - bool is_nchw = (ctx.Attr("data_format") == "NCHW"); - // input x - const phi::DenseTensor *input_x = ctx.Input("X"); - const phi::DenseTensor *filter_x = ctx.Input("FilterX"); - const phi::DenseTensor *scale_x = ctx.Input("ScaleX"); - const phi::DenseTensor *bias_x = ctx.Input("BiasX"); - - // output x - phi::DenseTensor *conv_out_x = ctx.Output("ConvX"); - phi::DenseTensor *saved_mean_x = ctx.Output("SavedMeanX"); - phi::DenseTensor *saved_invstd_x = - ctx.Output("SavedInvstdX"); - phi::DenseTensor *running_mean_x = - ctx.Output("RunningMeanX"); - phi::DenseTensor *running_var_x = - ctx.Output("RunningVarX"); - - phi::DenseTensor *output = ctx.Output("Y"); - - // attrs - int padding = ctx.Attr("padding"); - int stride = ctx.Attr("stride"); - int stride_z = ctx.Attr("stride_z"); - int dilation = ctx.Attr("dilation"); - int group = ctx.Attr("group"); - float eps = ctx.Attr("epsilon"); - float momentum = ctx.Attr("momentum"); - bool has_shortcut = ctx.Attr("has_shortcut"); - bool fuse_add = ctx.Attr("fuse_add"); - bool use_global_stats = ctx.Attr("use_global_stats"); - bool is_test = ctx.Attr("is_test"); - bool is_train = !is_test && !use_global_stats; - std::string act_type = ctx.Attr("act_type"); - auto &dev_ctx = ctx.template device_context(); - - std::vector x_list = { - reinterpret_cast(input_x->data())}; - std::vector w_list = { - reinterpret_cast(filter_x->data())}; - std::vector conv_y_list = { - reinterpret_cast(conv_out_x->mutable_data(place))}; - - std::vector> x_shape_list = { - common::vectorize(input_x->dims())}; - - auto filter_x_shape = common::vectorize(filter_x->dims()); - std::vector ksize = {filter_x_shape[2], filter_x_shape[3]}; - if (!is_nchw) { - ksize[0] = filter_x_shape[1]; - ksize[1] = filter_x_shape[2]; - } - std::vector strides = {stride, stride}; - std::vector> ksize_list = {ksize}; - std::vector> stride_list = {strides}; - std::vector paddings = {padding, padding}; - std::vector dilations = {dilation, dilation}; - std::vector scale_list = {scale_x->data()}; - std::vector bias_list = {bias_x->data()}; - std::vector batch_mean_list = { - saved_mean_x->mutable_data(place)}; - std::vector batch_invstd_list = { - saved_invstd_x->mutable_data(place)}; - std::vector global_mean_list = { - running_mean_x->mutable_data(place)}; - std::vector global_var_list = { - running_var_x->mutable_data(place)}; - - std::vector x_maxlist = {nullptr}; - std::vector w_maxlist = {nullptr}; - if (has_shortcut) { - // input z - const phi::DenseTensor *input_z = ctx.Input("Z"); - const phi::DenseTensor *filter_z = ctx.Input("FilterZ"); - const phi::DenseTensor *scale_z = ctx.Input("ScaleZ"); - const phi::DenseTensor *bias_z = ctx.Input("BiasZ"); - - phi::DenseTensor *conv_out_z = ctx.Output("ConvZ"); - phi::DenseTensor *saved_mean_z = - ctx.Output("SavedMeanZ"); - phi::DenseTensor *saved_invstd_z = - ctx.Output("SavedInvstdZ"); - phi::DenseTensor *running_mean_z = - ctx.Output("RunningMeanZ"); - phi::DenseTensor *running_var_z = - ctx.Output("RunningVarZ"); - - x_list.push_back(reinterpret_cast(input_z->data())); - w_list.push_back(reinterpret_cast(filter_z->data())); - conv_y_list.push_back( - reinterpret_cast(conv_out_z->mutable_data(place))); - - x_shape_list.push_back(common::vectorize(input_z->dims())); - - auto filter_z_shape = common::vectorize(filter_z->dims()); - std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; - if (!is_nchw) { - ksize_z[0] = filter_z_shape[1]; - ksize_z[1] = filter_z_shape[2]; - } - ksize_list.push_back(ksize_z); - stride_list.push_back({stride_z, stride_z}); - scale_list.push_back(scale_z->data()); - bias_list.push_back(bias_z->data()); - batch_mean_list.push_back(saved_mean_z->mutable_data(place)); - batch_invstd_list.push_back(saved_invstd_z->mutable_data(place)); - global_mean_list.push_back(running_mean_z->mutable_data(place)); - global_var_list.push_back(running_var_z->mutable_data(place)); - x_maxlist.push_back(nullptr); - w_maxlist.push_back(nullptr); - } else { - if (fuse_add) { - const phi::DenseTensor *input_z = ctx.Input("Z"); - auto input_z_shape = common::vectorize(input_z->dims()); - x_list.push_back(reinterpret_cast(input_z->data())); - x_shape_list.push_back(input_z_shape); - x_maxlist.push_back(nullptr); - } - } - int r = xpu::resnet_unit_fusion( - dev_ctx.x_context(), - x_list, - w_list, - conv_y_list, - reinterpret_cast(output->mutable_data(place)), - x_shape_list, - filter_x_shape[0], - ksize_list, - stride_list, - paddings, - dilations, - group, - eps, - momentum, - x_maxlist, - w_maxlist, - scale_list, - bias_list, - batch_mean_list, - batch_invstd_list, - global_mean_list, - global_var_list, - xpu::Activation_t::RELU, - is_nchw, - has_shortcut, - fuse_add, - is_train); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_fusion"); - } -}; - -template -class ResNetUnitGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(place.GetType() == phi::AllocationType::XPU, - true, - phi::errors::PreconditionNotMet("It must use XPUPlace.")); - - bool is_nchw = (ctx.Attr("data_format") == "NCHW"); - const phi::DenseTensor *y_grad = - ctx.Input(framework::GradVarName("Y")); - const phi::DenseTensor *x = ctx.Input("X"); - const phi::DenseTensor *filter_x = ctx.Input("FilterX"); - const phi::DenseTensor *scale_x = ctx.Input("ScaleX"); - const phi::DenseTensor *saved_mean_x = - ctx.Input("SavedMeanX"); - const phi::DenseTensor *saved_invstd_x = - ctx.Input("SavedInvstdX"); - const phi::DenseTensor *conv_out_x = ctx.Input("ConvX"); - const phi::DenseTensor *output = ctx.Input("Y"); - - phi::DenseTensor *x_grad = - ctx.Output(framework::GradVarName("X")); - phi::DenseTensor *filter_x_grad = - ctx.Output(framework::GradVarName("FilterX")); - phi::DenseTensor *scale_x_grad = - ctx.Output(framework::GradVarName("ScaleX")); - phi::DenseTensor *bias_x_grad = - ctx.Output(framework::GradVarName("BiasX")); - - int padding = ctx.Attr("padding"); - int stride = ctx.Attr("stride"); - int stride_z = ctx.Attr("stride_z"); - int dilation = ctx.Attr("dilation"); - int group = ctx.Attr("group"); - float eps = ctx.Attr("epsilon"); - bool has_shortcut = ctx.Attr("has_shortcut"); - bool fuse_add = ctx.Attr("fuse_add"); - std::string act_type = ctx.Attr("act_type"); - - auto &dev_ctx = ctx.template device_context(); - - std::vector x_list = { - reinterpret_cast(x->data())}; - std::vector w_list = { - reinterpret_cast(filter_x->data())}; - std::vector conv_y_list = { - reinterpret_cast(conv_out_x->data())}; - std::vector dx_list = { - reinterpret_cast(x_grad->mutable_data(place))}; - std::vector dw_list = { - reinterpret_cast(filter_x_grad->mutable_data(place))}; - - std::vector> x_shape_list = { - common::vectorize(x->dims())}; - - auto filter_x_shape = common::vectorize(filter_x->dims()); - std::vector x_ksize = {filter_x_shape[2], filter_x_shape[3]}; - if (!is_nchw) { - x_ksize[0] = filter_x_shape[1]; - x_ksize[1] = filter_x_shape[2]; - } - std::vector> ksize_list = {x_ksize}; - std::vector> stride_list = {{stride, stride}}; - std::vector paddings = {padding, padding}; - std::vector dilations = {dilation, dilation}; - - std::vector x_maxlist = {nullptr}; - std::vector w_maxlist = {nullptr}; - - std::vector scale_list = {scale_x->data()}; - std::vector batch_mean_list = {saved_mean_x->data()}; - std::vector batch_invstd_list = { - saved_invstd_x->data()}; - std::vector dscale_list = { - scale_x_grad->mutable_data(place)}; - std::vector dbias_list = {bias_x_grad->mutable_data(place)}; - - if (has_shortcut) { - // X Z - // | | - // NormConv NormConv - // | | - // BNStatsFinalize BNStatsFinalize - // \ / - // ScaleBiasAddRelu - // | - // Y - const phi::DenseTensor *z = ctx.Input("Z"); - const phi::DenseTensor *filter_z = ctx.Input("FilterZ"); - const phi::DenseTensor *scale_z = ctx.Input("ScaleZ"); - const phi::DenseTensor *saved_mean_z = - ctx.Input("SavedMeanZ"); - const phi::DenseTensor *saved_invstd_z = - ctx.Input("SavedInvstdZ"); - const phi::DenseTensor *conv_out_z = ctx.Input("ConvZ"); - - phi::DenseTensor *z_grad = - ctx.Output(framework::GradVarName("Z")); - phi::DenseTensor *filter_z_grad = - ctx.Output(framework::GradVarName("FilterZ")); - phi::DenseTensor *scale_z_grad = - ctx.Output(framework::GradVarName("ScaleZ")); - phi::DenseTensor *bias_z_grad = - ctx.Output(framework::GradVarName("BiasZ")); - x_list.push_back(reinterpret_cast(z->data())); - w_list.push_back(reinterpret_cast(filter_z->data())); - conv_y_list.push_back( - reinterpret_cast(conv_out_z->data())); - dx_list.push_back( - reinterpret_cast(z_grad->mutable_data(place))); - dw_list.push_back( - reinterpret_cast(filter_z_grad->mutable_data(place))); - x_shape_list.push_back(common::vectorize(z->dims())); - - auto filter_z_shape = common::vectorize(filter_z->dims()); - std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; - if (!is_nchw) { - ksize_z[0] = filter_z_shape[1]; - ksize_z[1] = filter_z_shape[2]; - } - ksize_list.push_back(ksize_z); - stride_list.push_back({stride_z, stride_z}); - x_maxlist.push_back(nullptr); - w_maxlist.push_back(nullptr); - - scale_list.push_back(scale_z->data()); - batch_mean_list.push_back(saved_mean_z->data()); - batch_invstd_list.push_back(saved_invstd_z->data()); - dscale_list.push_back(scale_z_grad->mutable_data(place)); - dbias_list.push_back(bias_z_grad->mutable_data(place)); - } else { - if (fuse_add) { - auto z_grad = ctx.Output(framework::GradVarName("Z")); - dx_list.push_back( - reinterpret_cast(z_grad->mutable_data(place))); - } - } - - int r = xpu::resnet_unit_grad_fusion( - dev_ctx.x_context(), - x_list, - w_list, - reinterpret_cast(y_grad->data()), - reinterpret_cast(output->data()), - conv_y_list, - dx_list, - dw_list, - x_shape_list, - filter_x_shape[0], - ksize_list, - stride_list, - paddings, - dilations, - group, - x_maxlist, - w_maxlist, - scale_list, - batch_mean_list, - batch_invstd_list, - dscale_list, - dbias_list, - xpu::Activation_t::RELU, - eps, - is_nchw, - has_shortcut, - fuse_add); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(resnet_unit, - XPU, - ALL_LAYOUT, - ops::ResNetUnitXPUKernel, - phi::dtype::float16, - float) {} -PD_REGISTER_STRUCT_KERNEL(resnet_unit_grad, - XPU, - ALL_LAYOUT, - ops::ResNetUnitGradXPUKernel, - phi::dtype::float16, - float) {} diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 0332b04b2c0b9..49286e8883125 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -5119,4 +5119,214 @@ void FusionSeqpoolConcatInferMeta(const std::vector& x, out->set_dtype(x[0]->dtype()); } +// Shape of bitmask +static phi::DDim GetBitmaskDims(std::vector out_shape) { + int c = out_shape.back(); + int64_t nhw = std::accumulate(out_shape.begin(), + out_shape.end(), + 1, + std::multiplies()) / // NOLINT + c; + int32_t c_int32_elems = ((c + 63) & ~63) / 32; + int32_t nhw_int32_elems = static_cast(((nhw + 31) & ~31)); + std::vector bitmask_shape = {nhw_int32_elems, c_int32_elems, 1}; + return common::make_ddim(bitmask_shape); +} + +void ResnetUnitInferMeta(const MetaTensor& x, + const MetaTensor& filter_x, + const MetaTensor& scale_x, + const MetaTensor& bias_x, + const MetaTensor& mean_x, + const MetaTensor& var_x, + const MetaTensor& z, + const MetaTensor& filter_z, + const MetaTensor& scale_z, + const MetaTensor& bias_z, + const MetaTensor& mean_z, + const MetaTensor& var_z, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum, + float epsilon, + const std::string& data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string& act_type, + MetaTensor* out, + MetaTensor* bit_mask, + MetaTensor* conv_x, + MetaTensor* saved_mean_x, + MetaTensor* saved_invstd_x, + MetaTensor* running_mean_x, + MetaTensor* running_var_x, + MetaTensor* conv_z, + MetaTensor* saved_mean_z, + MetaTensor* saved_invstd_z, + MetaTensor* running_mean_z, + MetaTensor* running_var_z) { + // Check dims of inputs + const auto& x_dims = x.dims(); + const auto& w_dims = filter_x.dims(); + std::vector bn_param_shape = common::vectorize(scale_x.dims()); + if (1 == bn_param_shape.size()) { + bn_param_shape = {1, 1, 1, bn_param_shape[0]}; + } + phi::DDim bn_param_dims = common::make_ddim(bn_param_shape); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 4, + phi::errors::InvalidArgument("The dimensions of input " + "must equal to 4." + "But received: the shape of input " + "= [%s], the dimension of input = " + "[%d]", + x_dims, + x_dims.size())); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 4, + phi::errors::InvalidArgument("The dimensions of filter " + "must equal to 4." + "But received: the shape of filter " + "= [%s], the dimension of filter = [%d] ", + w_dims, + w_dims.size())); + PADDLE_ENFORCE_EQ( + bn_param_dims.size(), + 4, + phi::errors::InvalidArgument("The dimensions of bn param " + "must equal to 4." + "But received: the shape of bn param " + "= [%s], the dimension of bn param = [%d] ", + bn_param_dims, + bn_param_dims.size())); + bool is_nchw = (data_format == "NCHW"); + // Calculate the dims of outputs + int batch = x_dims[0]; + int output_channel = w_dims[0]; + int filter_size = w_dims[2]; + std::vector out_shape; + out_shape.push_back(batch); + if (is_nchw) { + int out_h = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[3] + padding * 2 - filter_size) / stride + 1; + out_shape.push_back(output_channel); + out_shape.push_back(out_h); + out_shape.push_back(out_w); + } else { + int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + out_shape.push_back(out_h); + out_shape.push_back(out_w); + out_shape.push_back(output_channel); + } + + auto y_dims = common::make_ddim(out_shape); + auto bitmask_dims = GetBitmaskDims(out_shape); + // Set dims of outputs + out->set_dims(y_dims); + bit_mask->set_dims(bitmask_dims); + conv_x->set_dims(y_dims); + saved_mean_x->set_dims(bn_param_dims); + saved_invstd_x->set_dims(bn_param_dims); + running_mean_x->set_dims(bn_param_dims); + running_var_x->set_dims(bn_param_dims); + + out->set_dtype(x.dtype()); + bit_mask->set_dtype(filter_x.dtype()); + conv_x->set_dtype(x.dtype()); + saved_mean_x->set_dtype(mean_x.dtype()); + saved_invstd_x->set_dtype(var_x.dtype()); + running_mean_x->set_dtype(mean_x.dtype()); + running_var_x->set_dtype(var_x.dtype()); + if (has_shortcut) { + conv_z->set_dims(y_dims); + saved_mean_z->set_dims(bn_param_dims); + saved_invstd_z->set_dims(bn_param_dims); + running_mean_z->set_dims(bn_param_dims); + running_var_z->set_dims(bn_param_dims); + + conv_z->set_dtype(z.dtype()); + saved_mean_z->set_dtype(mean_z.dtype()); + saved_invstd_z->set_dtype(var_z.dtype()); + running_mean_z->set_dtype(mean_z.dtype()); + running_var_z->set_dtype(var_z.dtype()); + } +} + +void ResnetUnitGradInferMeta(const MetaTensor& x, + const MetaTensor& filter_x, + const MetaTensor& conv_x, + const MetaTensor& scale_x, + const MetaTensor& bias_x, + const MetaTensor& saved_mean_x, + const MetaTensor& saved_invstd_x, + const MetaTensor& z, + const MetaTensor& filter_z, + const MetaTensor& conv_z, + const MetaTensor& scale_z, + const MetaTensor& bias_z, + const MetaTensor& saved_mean_z, + const MetaTensor& saved_invstd_z, + const MetaTensor& out, + const MetaTensor& bit_mask, + const MetaTensor& out_grad, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum, + float epsilon, + const std::string& data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string& act_type, + MetaTensor* x_grad, + MetaTensor* filter_x_grad, + MetaTensor* scale_x_grad, + MetaTensor* bias_x_grad, + MetaTensor* z_grad, + MetaTensor* filter_z_grad, + MetaTensor* scale_z_grad, + MetaTensor* bias_z_grad) { + const auto& x_dims = x.dims(); + const auto& filter_x_dims = filter_x.dims(); + const auto& param_dims = scale_x.dims(); + x_grad->set_dims(x_dims); + filter_x_grad->set_dims(filter_x_dims); + scale_x_grad->set_dims(param_dims); + bias_x_grad->set_dims(param_dims); + x_grad->set_dtype(x.dtype()); + filter_x_grad->set_dtype(filter_x.dtype()); + scale_x_grad->set_dtype(scale_x.dtype()); + bias_x_grad->set_dtype(bias_x.dtype()); + + if (fuse_add || has_shortcut) { + const auto& z_dims = z.dims(); + z_grad->set_dims(z_dims); + z_grad->set_dtype(z.dtype()); + } + if (has_shortcut) { + const auto filter_z_dims = filter_z.dims(); + filter_z_grad->set_dims(filter_z_dims); + scale_z_grad->set_dims(param_dims); + bias_z_grad->set_dims(param_dims); + + filter_z_grad->set_dtype(filter_z.dtype()); + scale_z_grad->set_dtype(scale_z.dtype()); + bias_z_grad->set_dtype(bias_z.dtype()); + } +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 10c376aefa651..36b41b98a928a 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -1136,4 +1136,83 @@ void FusionSeqpoolConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ResnetUnitInferMeta(const MetaTensor& x, + const MetaTensor& filter_x, + const MetaTensor& scale_x, + const MetaTensor& bias_x, + const MetaTensor& mean_x, + const MetaTensor& var_x, + const MetaTensor& z, + const MetaTensor& filter_z, + const MetaTensor& scale_z, + const MetaTensor& bias_z, + const MetaTensor& mean_z, + const MetaTensor& var_z, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum, + float epsilon, + const std::string& data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string& act_type, + MetaTensor* out, + MetaTensor* bit_mask, + MetaTensor* conv_x, + MetaTensor* saved_mean_x, + MetaTensor* saved_invstd_x, + MetaTensor* running_mean_x, + MetaTensor* running_var_x, + MetaTensor* conv_z, + MetaTensor* saved_mean_z, + MetaTensor* saved_invstd_z, + MetaTensor* running_mean_z, + MetaTensor* running_var_z); + +void ResnetUnitGradInferMeta(const MetaTensor& x, + const MetaTensor& filter_x, + const MetaTensor& conv_x, + const MetaTensor& scale_x, + const MetaTensor& bias_x, + const MetaTensor& saved_mean_x, + const MetaTensor& saved_invstd_x, + const MetaTensor& z, + const MetaTensor& filter_z, + const MetaTensor& conv_z, + const MetaTensor& scale_z, + const MetaTensor& bias_z, + const MetaTensor& saved_mean_z, + const MetaTensor& saved_invstd_z, + const MetaTensor& out, + const MetaTensor& bit_mask, + const MetaTensor& out_grad, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum, + float epsilon, + const std::string& data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string& act_type, + MetaTensor* x_grad, + MetaTensor* filter_x_grad, + MetaTensor* scale_x_grad, + MetaTensor* bias_x_grad, + MetaTensor* z_grad, + MetaTensor* filter_z_grad, + MetaTensor* scale_z_grad, + MetaTensor* bias_z_grad); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/resnet_unit_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/resnet_unit_grad_kernel.cu new file mode 100644 index 0000000000000..58d3250d33713 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/resnet_unit_grad_kernel.cu @@ -0,0 +1,271 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_bn_stats_finalize.cu.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_norm_conv.cu.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/utils/optional.h" + +#if CUDNN_VERSION >= 8000 +namespace phi { + +template +void ResNetUnitGradKernel( + const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &conv_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &saved_mean_x_in, + const DenseTensor &saved_invstd_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &conv_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &saved_mean_z_in, + const paddle::optional &saved_invstd_z_in, + const DenseTensor &out, + const DenseTensor &bit_mask, + const DenseTensor &out_grad, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *filter_x_grad, + DenseTensor *scale_x_grad, + DenseTensor *bias_x_grad, + DenseTensor *z_grad, + DenseTensor *filter_z_grad, + DenseTensor *scale_z_grad, + DenseTensor *bias_z_grad) { + PADDLE_ENFORCE_EQ( + phi::backends::gpu::CudnnDataType::type, + CUDNN_DATA_HALF, + phi::errors::Unavailable("ResNetUnitOp only supports float16 for now.")); + + const phi::DenseTensor *y_grad = &out_grad; + + const phi::DenseTensor *x = &x_in; + const phi::DenseTensor *filter_x = &filter_x_in; + const phi::DenseTensor *scale_x = &scale_x_in; + const phi::DenseTensor *bias_x = &bias_x_in; + const phi::DenseTensor *saved_mean_x = &saved_mean_x_in; + const phi::DenseTensor *saved_invstd_x = &saved_invstd_x_in; + + const phi::DenseTensor *conv_out_x = &conv_x_in; + const phi::DenseTensor *output = &out; + const phi::DenseTensor *bitmask = &bit_mask; + + double eps = static_cast(epsilon); + double momentum = static_cast(momentum_in); + + auto x_shape = common::vectorize(x->dims()); + auto filter_x_shape = common::vectorize(filter_x->dims()); + auto param_shape = common::vectorize(scale_x->dims()); + auto output_shape = common::vectorize(output->dims()); + auto bitmask_shape = common::vectorize(bitmask->dims()); + + // 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad + phi::DenseTensor conv_out_x_grad; + conv_out_x_grad.Resize(conv_out_x->dims()); + phi::fusion::CudnnScaleBiasAddRelu sbar_x_op(dev_ctx, + act_type, + fuse_add, + has_shortcut, + output_shape, + param_shape, + bitmask_shape); + if (has_shortcut) { + // X Z + // | | + // NormConv NormConv + // | | + // BNStatsFinalize BNStatsFinalize + // \ / + // ScaleBiasAddRelu + // | + // Y + const phi::DenseTensor *z = z_in.get_ptr(); + const phi::DenseTensor *filter_z = filter_z_in.get_ptr(); + const phi::DenseTensor *scale_z = scale_z_in.get_ptr(); + const phi::DenseTensor *bias_z = bias_z_in.get_ptr(); + const phi::DenseTensor *saved_mean_z = saved_mean_z_in.get_ptr(); + const phi::DenseTensor *saved_invstd_z = saved_invstd_z_in.get_ptr(); + const phi::DenseTensor *conv_out_z = conv_z_in.get_ptr(); + + // 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad and z_grad_temp + phi::DenseTensor z_grad_temp; + z_grad_temp.Resize(conv_out_z->dims()); + sbar_x_op.Backward(dev_ctx, + *y_grad, + *conv_out_x, + *scale_x, + *bias_x, + *saved_mean_x, + *saved_invstd_x, + bitmask, + &conv_out_x_grad, + &z_grad_temp, + scale_x_grad, + bias_x_grad, + eps); + + // 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z + phi::DenseTensor conv_out_z_grad; + conv_out_z_grad.Resize(conv_out_z->dims()); + phi::fusion::CudnnScaleBiasAddRelu sbar_z_op( + dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape); + sbar_z_op.Backward(dev_ctx, + z_grad_temp, + *conv_out_z, + *scale_z, + *bias_z, + *saved_mean_z, + *saved_invstd_z, + nullptr, + &conv_out_z_grad, + nullptr, + scale_z_grad, + bias_z_grad, + eps); + + // 1.3 Backward of Conv for z, get z_grad and filter_z_grad + auto z_shape = common::vectorize(z->dims()); + auto filter_z_shape = common::vectorize(filter_z->dims()); + phi::fusion::CudnnNormConvolutionGrad conv_z_op(dev_ctx, + z_shape, + filter_z_shape, + output_shape, + padding, + stride_z, + dilation, + group); + conv_z_op.Backward( + dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad, filter_z_grad); + } else { + // 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad (and z_grad) + phi::DenseTensor *z_grad_tmp = fuse_add ? z_grad : nullptr; + sbar_x_op.Backward(dev_ctx, + *y_grad, + *conv_out_x, + *scale_x, + *bias_x, + *saved_mean_x, + *saved_invstd_x, + bitmask, + &conv_out_x_grad, + z_grad_tmp, + scale_x_grad, + bias_x_grad, + eps); + } + + // 2. Backward of Conv for x, get x_grad and filter_x_grad + phi::fusion::CudnnNormConvolutionGrad conv_x_op(dev_ctx, + x_shape, + filter_x_shape, + output_shape, + padding, + stride, + dilation, + group); + conv_x_op.Backward(dev_ctx, + *x, + *filter_x, + conv_out_x_grad, + x_grad, + filter_x_grad, + use_addto); +} + +} // namespace phi + +PD_REGISTER_KERNEL(resnet_unit_grad, + GPU, + ALL_LAYOUT, + phi::ResNetUnitGradKernel, + phi::dtype::float16) {} +#else +namespace phi { + +template +void ResNetUnitGradEmptyKernel( + const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &conv_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &saved_mean_x_in, + const DenseTensor &saved_invstd_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &conv_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &saved_mean_z_in, + const paddle::optional &saved_invstd_z_in, + const DenseTensor &out, + const DenseTensor &bit_mask, + const DenseTensor &out_grad, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *filter_x_grad, + DenseTensor *scale_x_grad, + DenseTensor *bias_x_grad, + DenseTensor *z_grad, + DenseTensor *filter_z_grad, + DenseTensor *scale_z_grad, + DenseTensor *bias_z_grad) {} +} // namespace phi + +PD_REGISTER_KERNEL(resnet_unit_grad, + GPU, + ALL_LAYOUT, + phi::ResNetUnitGradEmptyKernel, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/fusion/gpu/resnet_unit_kernel.cu b/paddle/phi/kernels/fusion/gpu/resnet_unit_kernel.cu new file mode 100644 index 0000000000000..491061ff9fe6f --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/resnet_unit_kernel.cu @@ -0,0 +1,281 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_bn_stats_finalize.cu.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_norm_conv.cu.h" +#include "paddle/phi/kernels/fusion/gpu/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/utils/optional.h" + +#if CUDNN_VERSION >= 8000 +namespace phi { + +template +void ResNetUnitKernel(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &mean_x_in, + const DenseTensor &var_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &mean_z_in, + const paddle::optional &var_z_in, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *out, + DenseTensor *bit_mask, + DenseTensor *conv_x, + DenseTensor *saved_mean_x, + DenseTensor *saved_invstd_x, + DenseTensor *running_mean_x, + DenseTensor *running_var_x, + DenseTensor *conv_z, + DenseTensor *saved_mean_z, + DenseTensor *saved_invstd_z, + DenseTensor *running_mean_z, + DenseTensor *running_var_z) { + PADDLE_ENFORCE_EQ( + phi::backends::gpu::CudnnDataType::type, + CUDNN_DATA_HALF, + phi::errors::Unavailable("ResNetUnitOp only supports float16 for now.")); + + // input x + const phi::DenseTensor *input_x = &x_in; + const phi::DenseTensor *filter_x = &filter_x_in; + const phi::DenseTensor *scale_x = &scale_x_in; + const phi::DenseTensor *bias_x = &bias_x_in; + // norm conv + phi::DenseTensor *conv_out_x = conv_x; + // sbar + phi::DenseTensor *output = out; + phi::DenseTensor *bitmask = bit_mask; + // attrs + double eps = static_cast(epsilon); + double momentum = static_cast(momentum_in); + + bool is_train = !is_test && !use_global_stats; + + auto input_x_shape = common::vectorize(input_x->dims()); + auto filter_x_shape = common::vectorize(filter_x->dims()); + // std::swap used to convert shape of filter from conv2d when kernel size is + // 1. + if (filter_x_shape[1] != filter_x_shape[2] && 1 == filter_x_shape[2]) { + std::swap(filter_x_shape[1], filter_x_shape[3]); + } + auto param_dims = scale_x->dims(); + auto param_shape = common::vectorize(scale_x->dims()); + if (1 == param_shape.size()) { + param_shape = {1, 1, 1, param_shape[0]}; + } + auto output_shape = common::vectorize(output->dims()); + auto bitmask_shape = common::vectorize(bitmask->dims()); + int output_channel = filter_x_shape[0]; + int64_t ele_count = + std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()) / + output_channel; + + // 1. Conv + phi::DenseTensor sum_x; + phi::DenseTensor sum_of_squares_x; + sum_x.Resize(param_dims); + sum_of_squares_x.Resize(param_dims); + phi::fusion::CudnnNormConvolution conv_x_op(dev_ctx, + input_x_shape, + filter_x_shape, + output_shape, + padding, + stride, + dilation, + group); + conv_x_op.Forward( + dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x, &sum_of_squares_x); + + // 2. BN + phi::DenseTensor equiv_scale_x; + phi::DenseTensor equiv_bias_x; + equiv_scale_x.Resize(param_dims); + equiv_bias_x.Resize(param_dims); + phi::fusion::CudnnBNStatsFinalize bn_x_op(dev_ctx, param_shape); + bn_x_op.Forward(dev_ctx, + sum_x, + sum_of_squares_x, + *scale_x, + *bias_x, + saved_mean_x, + saved_invstd_x, + running_mean_x, + running_var_x, + &equiv_scale_x, + &equiv_bias_x, + eps, + momentum, + ele_count, + is_train); + + // 3. scale + bias + add + relu + phi::fusion::CudnnScaleBiasAddRelu sbar_op(dev_ctx, + act_type, + fuse_add, + has_shortcut, + output_shape, + param_shape, + bitmask_shape); + if (has_shortcut) { + // input z + const phi::DenseTensor *input_z = z_in.get_ptr(); + const phi::DenseTensor *filter_z = filter_z_in.get_ptr(); + const phi::DenseTensor *scale_z = scale_z_in.get_ptr(); + const phi::DenseTensor *bias_z = bias_z_in.get_ptr(); + // norm conv + phi::DenseTensor *conv_out_z = conv_z; + + auto input_z_shape = common::vectorize(input_z->dims()); + auto filter_z_shape = common::vectorize(filter_z->dims()); + + // 3.1 Conv for second input + phi::DenseTensor sum_z; + phi::DenseTensor sum_of_squares_z; + sum_z.Resize(param_dims); + sum_of_squares_z.Resize(param_dims); + phi::fusion::CudnnNormConvolution conv_z_op(dev_ctx, + input_z_shape, + filter_z_shape, + output_shape, + padding, + stride_z, + dilation, + group); + conv_z_op.Forward( + dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z, &sum_of_squares_z); + + // 3.2 BN for second input + phi::DenseTensor equiv_scale_z; + phi::DenseTensor equiv_bias_z; + equiv_scale_z.Resize(param_dims); + equiv_bias_z.Resize(param_dims); + phi::fusion::CudnnBNStatsFinalize bn_z_op(dev_ctx, param_shape); + bn_z_op.Forward(dev_ctx, + sum_z, + sum_of_squares_z, + *scale_z, + *bias_z, + saved_mean_z, + saved_invstd_z, + running_mean_z, + running_var_z, + &equiv_scale_z, + &equiv_bias_z, + eps, + momentum, + ele_count, + is_train); + // 3.3 sbar + sbar_op.Forward(dev_ctx, + *conv_out_x, + equiv_scale_x, + equiv_bias_x, + conv_out_z, + &equiv_scale_z, + &equiv_bias_z, + output, + bitmask); + } else { + const phi::DenseTensor *input_z = fuse_add ? z_in.get_ptr() : nullptr; + sbar_op.Forward(dev_ctx, + *conv_out_x, + equiv_scale_x, + equiv_bias_x, + input_z, + nullptr, + nullptr, + output, + bitmask); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + resnet_unit, GPU, ALL_LAYOUT, phi::ResNetUnitKernel, phi::dtype::float16) {} +#else +namespace phi { +template +void ResNetUnitEmptyKernel(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &mean_x_in, + const DenseTensor &var_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &mean_z_in, + const paddle::optional &var_z_in, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *out, + DenseTensor *bit_mask, + DenseTensor *conv_x, + DenseTensor *saved_mean_x, + DenseTensor *saved_invstd_x, + DenseTensor *running_mean_x, + DenseTensor *running_var_x, + DenseTensor *conv_z, + DenseTensor *saved_mean_z, + DenseTensor *saved_invstd_z, + DenseTensor *running_mean_z, + DenseTensor *running_var_z) { + PADDLE_THROW(phi::errors::Unavailable( + "ResNetUnitOp only supports CUDNN_VERSION >= 8000 for now.")); +} +} // namespace phi +PD_REGISTER_KERNEL(resnet_unit, + GPU, + ALL_LAYOUT, + phi::ResNetUnitEmptyKernel, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/fusion/xpu/resnet_unit_grad_kernel.cc b/paddle/phi/kernels/fusion/xpu/resnet_unit_grad_kernel.cc new file mode 100644 index 0000000000000..14a81e3d05ba2 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/resnet_unit_grad_kernel.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void ResNetUnitGradXPUKernel( + const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &conv_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &saved_mean_x_in, + const DenseTensor &saved_invstd_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &conv_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &saved_mean_z_in, + const paddle::optional &saved_invstd_z_in, + const DenseTensor &out, + const DenseTensor &bit_mask, + const DenseTensor &out_grad, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *x_grad, + DenseTensor *filter_x_grad, + DenseTensor *scale_x_grad, + DenseTensor *bias_x_grad, + DenseTensor *z_grad, + DenseTensor *filter_z_grad, + DenseTensor *scale_z_grad, + DenseTensor *bias_z_grad) { + using XPUType = typename XPUTypeTrait::Type; + + bool is_nchw = (data_format == "NCHW"); + const phi::DenseTensor *y_grad = &out_grad; + const phi::DenseTensor *x = &x_in; + const phi::DenseTensor *filter_x = &filter_x_in; + const phi::DenseTensor *scale_x = &scale_x_in; + const phi::DenseTensor *saved_mean_x = &saved_mean_x_in; + const phi::DenseTensor *saved_invstd_x = &saved_invstd_x_in; + const phi::DenseTensor *conv_out_x = &conv_x_in; + const phi::DenseTensor *output = &out; + + float eps = epsilon; + + std::vector x_list = { + reinterpret_cast(x->data())}; + std::vector w_list = { + reinterpret_cast(filter_x->data())}; + std::vector conv_y_list = { + reinterpret_cast(conv_out_x->data())}; + std::vector dx_list = { + reinterpret_cast(dev_ctx.template Alloc(x_grad))}; + std::vector dw_list = { + reinterpret_cast(dev_ctx.template Alloc(filter_x_grad))}; + + std::vector> x_shape_list = { + common::vectorize(x->dims())}; + + auto filter_x_shape = common::vectorize(filter_x->dims()); + std::vector x_ksize = {filter_x_shape[2], filter_x_shape[3]}; + if (!is_nchw) { + x_ksize[0] = filter_x_shape[1]; + x_ksize[1] = filter_x_shape[2]; + } + std::vector> ksize_list = {x_ksize}; + std::vector> stride_list = {{stride, stride}}; + std::vector paddings = {padding, padding}; + std::vector dilations = {dilation, dilation}; + + std::vector x_maxlist = {nullptr}; + std::vector w_maxlist = {nullptr}; + + std::vector scale_list = {scale_x->data()}; + std::vector batch_mean_list = {saved_mean_x->data()}; + std::vector batch_invstd_list = { + saved_invstd_x->data()}; + std::vector dscale_list = { + dev_ctx.template Alloc(scale_x_grad)}; + std::vector dbias_list = { + dev_ctx.template Alloc(bias_x_grad)}; + + if (has_shortcut) { + // X Z + // | | + // NormConv NormConv + // | | + // BNStatsFinalize BNStatsFinalize + // \ / + // ScaleBiasAddRelu + // | + // Y + const phi::DenseTensor *z = z_in.get_ptr(); + const phi::DenseTensor *filter_z = filter_z_in.get_ptr(); + const phi::DenseTensor *scale_z = scale_z_in.get_ptr(); + const phi::DenseTensor *saved_mean_z = saved_mean_z_in.get_ptr(); + const phi::DenseTensor *saved_invstd_z = saved_invstd_z_in.get_ptr(); + const phi::DenseTensor *conv_out_z = conv_z_in.get_ptr(); + + x_list.push_back(reinterpret_cast(z->data())); + w_list.push_back(reinterpret_cast(filter_z->data())); + conv_y_list.push_back( + reinterpret_cast(conv_out_z->data())); + dx_list.push_back( + reinterpret_cast(dev_ctx.template Alloc(z_grad))); + dw_list.push_back( + reinterpret_cast(dev_ctx.template Alloc(filter_z_grad))); + x_shape_list.push_back(common::vectorize(z->dims())); + + auto filter_z_shape = common::vectorize(filter_z->dims()); + std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; + if (!is_nchw) { + ksize_z[0] = filter_z_shape[1]; + ksize_z[1] = filter_z_shape[2]; + } + ksize_list.push_back(ksize_z); + stride_list.push_back({stride_z, stride_z}); + x_maxlist.push_back(nullptr); + w_maxlist.push_back(nullptr); + + scale_list.push_back(scale_z->data()); + batch_mean_list.push_back(saved_mean_z->data()); + batch_invstd_list.push_back(saved_invstd_z->data()); + dscale_list.push_back(dev_ctx.template Alloc(scale_z_grad)); + dbias_list.push_back(dev_ctx.template Alloc(bias_z_grad)); + } else { + if (fuse_add) { + auto z_grad_tmp = z_grad; + dx_list.push_back( + reinterpret_cast(dev_ctx.template Alloc(z_grad_tmp))); + } + } + + int r = xpu::resnet_unit_grad_fusion( + dev_ctx.x_context(), + x_list, + w_list, + reinterpret_cast(y_grad->data()), + reinterpret_cast(output->data()), + conv_y_list, + dx_list, + dw_list, + x_shape_list, + filter_x_shape[0], + ksize_list, + stride_list, + paddings, + dilations, + group, + x_maxlist, + w_maxlist, + scale_list, + batch_mean_list, + batch_invstd_list, + dscale_list, + dbias_list, + xpu::Activation_t::RELU, + eps, + is_nchw, + has_shortcut, + fuse_add); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion"); +} + +} // namespace phi +PD_REGISTER_KERNEL(resnet_unit_grad, + XPU, + ALL_LAYOUT, + phi::ResNetUnitGradXPUKernel, + phi::dtype::float16, + float) {} diff --git a/paddle/phi/kernels/fusion/xpu/resnet_unit_kernel.cc b/paddle/phi/kernels/fusion/xpu/resnet_unit_kernel.cc new file mode 100644 index 0000000000000..9efd4a51b6bb3 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/resnet_unit_kernel.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void ResNetUnitXPUKernel(const Context &dev_ctx, + const DenseTensor &x_in, + const DenseTensor &filter_x_in, + const DenseTensor &scale_x_in, + const DenseTensor &bias_x_in, + const DenseTensor &mean_x_in, + const DenseTensor &var_x_in, + const paddle::optional &z_in, + const paddle::optional &filter_z_in, + const paddle::optional &scale_z_in, + const paddle::optional &bias_z_in, + const paddle::optional &mean_z_in, + const paddle::optional &var_z_in, + int stride, + int stride_z, + int padding, + int dilation, + int group, + float momentum_in, + float epsilon, + const std::string &data_format, + bool fuse_add, + bool has_shortcut, + bool use_global_stats, + bool is_test, + bool use_addto, + const std::string &act_type, + DenseTensor *out, + DenseTensor *bit_mask, + DenseTensor *conv_x, + DenseTensor *saved_mean_x, + DenseTensor *saved_invstd_x, + DenseTensor *running_mean_x, + DenseTensor *running_var_x, + DenseTensor *conv_z, + DenseTensor *saved_mean_z, + DenseTensor *saved_invstd_z, + DenseTensor *running_mean_z, + DenseTensor *running_var_z) { + using XPUType = typename XPUTypeTrait::Type; + + bool is_nchw = (data_format == "NCHW"); + // input x + const phi::DenseTensor *input_x = &x_in; + const phi::DenseTensor *filter_x = &filter_x_in; + const phi::DenseTensor *scale_x = &scale_x_in; + const phi::DenseTensor *bias_x = &bias_x_in; + + // output x + phi::DenseTensor *conv_out_x = conv_x; + + phi::DenseTensor *output = out; + + // attrs + float eps = epsilon; + float momentum = momentum_in; + bool is_train = !is_test && !use_global_stats; + + std::vector x_list = { + reinterpret_cast(input_x->data())}; + std::vector w_list = { + reinterpret_cast(filter_x->data())}; + std::vector conv_y_list = { + reinterpret_cast(dev_ctx.template Alloc(conv_out_x))}; + + std::vector> x_shape_list = { + common::vectorize(input_x->dims())}; + + auto filter_x_shape = common::vectorize(filter_x->dims()); + std::vector ksize = {filter_x_shape[2], filter_x_shape[3]}; + if (!is_nchw) { + ksize[0] = filter_x_shape[1]; + ksize[1] = filter_x_shape[2]; + } + std::vector strides = {stride, stride}; + std::vector> ksize_list = {ksize}; + std::vector> stride_list = {strides}; + std::vector paddings = {padding, padding}; + std::vector dilations = {dilation, dilation}; + std::vector scale_list = {scale_x->data()}; + std::vector bias_list = {bias_x->data()}; + std::vector batch_mean_list = { + dev_ctx.template Alloc(saved_mean_x)}; + std::vector batch_invstd_list = { + dev_ctx.template Alloc(saved_invstd_x)}; + std::vector global_mean_list = { + dev_ctx.template Alloc(running_mean_x)}; + std::vector global_var_list = { + dev_ctx.template Alloc(running_var_x)}; + + std::vector x_maxlist = {nullptr}; + std::vector w_maxlist = {nullptr}; + if (has_shortcut) { + // input z + const phi::DenseTensor *input_z = z_in.get_ptr(); + const phi::DenseTensor *filter_z = filter_z_in.get_ptr(); + const phi::DenseTensor *scale_z = scale_z_in.get_ptr(); + const phi::DenseTensor *bias_z = bias_z_in.get_ptr(); + + phi::DenseTensor *conv_out_z = conv_z; + + x_list.push_back(reinterpret_cast(input_z->data())); + w_list.push_back(reinterpret_cast(filter_z->data())); + conv_y_list.push_back( + reinterpret_cast(dev_ctx.template Alloc(conv_out_z))); + + x_shape_list.push_back(common::vectorize(input_z->dims())); + + auto filter_z_shape = common::vectorize(filter_z->dims()); + std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; + if (!is_nchw) { + ksize_z[0] = filter_z_shape[1]; + ksize_z[1] = filter_z_shape[2]; + } + ksize_list.push_back(ksize_z); + stride_list.push_back({stride_z, stride_z}); + scale_list.push_back(scale_z->data()); + bias_list.push_back(bias_z->data()); + batch_mean_list.push_back(dev_ctx.template Alloc(saved_mean_z)); + batch_invstd_list.push_back(dev_ctx.template Alloc(saved_invstd_z)); + global_mean_list.push_back(dev_ctx.template Alloc(running_mean_z)); + global_var_list.push_back(dev_ctx.template Alloc(running_var_z)); + x_maxlist.push_back(nullptr); + w_maxlist.push_back(nullptr); + } else { + if (fuse_add) { + const phi::DenseTensor *input_z = z_in.get_ptr(); + auto input_z_shape = common::vectorize(input_z->dims()); + x_list.push_back(reinterpret_cast(input_z->data())); + x_shape_list.push_back(input_z_shape); + x_maxlist.push_back(nullptr); + } + } + int r = xpu::resnet_unit_fusion( + dev_ctx.x_context(), + x_list, + w_list, + conv_y_list, + reinterpret_cast(dev_ctx.template Alloc(output)), + x_shape_list, + filter_x_shape[0], + ksize_list, + stride_list, + paddings, + dilations, + group, + eps, + momentum, + x_maxlist, + w_maxlist, + scale_list, + bias_list, + batch_mean_list, + batch_invstd_list, + global_mean_list, + global_var_list, + xpu::Activation_t::RELU, + is_nchw, + has_shortcut, + fuse_add, + is_train); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_fusion"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(resnet_unit, + XPU, + ALL_LAYOUT, + phi::ResNetUnitXPUKernel, + phi::dtype::float16, + float) {} diff --git a/paddle/phi/ops/yaml/fused_backward.yaml b/paddle/phi/ops/yaml/fused_backward.yaml index 78438aa0295c2..85040bf5ed022 100644 --- a/paddle/phi/ops/yaml/fused_backward.yaml +++ b/paddle/phi/ops/yaml/fused_backward.yaml @@ -99,3 +99,31 @@ kernel : func : max_pool2d_v2_grad param: [x, out, saved_idx, out_grad, kernel_size, strides, paddings, data_format, global_pooling, adaptive] + +- backward_op : resnet_unit_grad + forward: resnet_unit (Tensor x, Tensor filter_x, Tensor scale_x, Tensor bias_x, Tensor mean_x, + Tensor var_x, Tensor z, Tensor filter_z, Tensor scale_z, Tensor bias_z, Tensor + mean_z, Tensor var_z, int stride = 1, int stride_z = 1, int padding = 0, int dilation + = 1, int group = 1, float momentum = 0.9, float epsilon = 1e-5, str data_format + = "NHWC", bool fuse_add = false, bool has_shortcut = false, bool use_global_stats + = false, bool is_test = false, bool use_addto = false, str act_type = "relu") -> + Tensor (out), Tensor (bit_mask), Tensor (conv_x), Tensor (saved_mean_x), + Tensor (saved_invstd_x), Tensor (running_mean_x), Tensor (running_var_x), Tensor + (conv_z), Tensor (saved_mean_z), Tensor (saved_invstd_z), Tensor (running_mean_z), + Tensor (running_var_z) + args: (Tensor x, Tensor filter_x, Tensor conv_x, Tensor scale_x, Tensor bias_x, Tensor saved_mean_x, + Tensor saved_invstd_x, Tensor z, Tensor filter_z, Tensor conv_z, Tensor scale_z, Tensor bias_z, Tensor + saved_mean_z, Tensor saved_invstd_z, Tensor out, Tensor bit_mask, Tensor out_grad, + int stride = 1, int stride_z = 1, int padding = 0, int dilation + = 1, int group = 1, float momentum = 0.9, float epsilon = 1e-5, str data_format + = "NHWC", bool fuse_add = false, bool has_shortcut = false, bool use_global_stats + = false, bool is_test = false, bool use_addto = false, str act_type = "relu") + output: Tensor (x_grad), Tensor (filter_x_grad), Tensor (scale_x_grad), Tensor (bias_x_grad), + Tensor (z_grad), Tensor (filter_z_grad), Tensor (scale_z_grad), Tensor (bias_z_grad) + infer_meta: + func: ResnetUnitGradInferMeta + kernel: + func: resnet_unit_grad + data_type: x + optional: z, filter_z, conv_z, scale_z, bias_z, saved_mean_z, saved_invstd_z + support_dygraph_mode : true diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index 514e31032029d..eee85e454ecd4 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -681,6 +681,27 @@ func : quantize_xpu data_type : x +- op : resnet_unit + args: (Tensor x, Tensor filter_x, Tensor scale_x, Tensor bias_x, Tensor mean_x, + Tensor var_x, Tensor z, Tensor filter_z, Tensor scale_z, Tensor bias_z, Tensor + mean_z, Tensor var_z, int stride = 1, int stride_z = 1, int padding = 0, int dilation + = 1, int group = 1, float momentum = 0.9, float epsilon = 1e-5, str data_format + = "NHWC", bool fuse_add = false, bool has_shortcut = false, bool use_global_stats + = false, bool is_test = false, bool use_addto = false, str act_type = "relu") + output: Tensor (out), Tensor (bit_mask), Tensor (conv_x), Tensor (saved_mean_x), + Tensor (saved_invstd_x), Tensor (running_mean_x), Tensor (running_var_x), Tensor + (conv_z), Tensor (saved_mean_z), Tensor (saved_invstd_z), Tensor (running_mean_z), + Tensor (running_var_z) + infer_meta: + func: ResnetUnitInferMeta + kernel: + func: resnet_unit + data_type: x + optional: z, filter_z, scale_z, bias_z, mean_z, var_z, conv_z, saved_mean_z, saved_invstd_z, + running_mean_z, running_var_z + backward: resnet_unit_grad + support_dygraph_mode : true + - op : roformer_relative_embedding_xpu args : (Tensor x, Tensor sin_emb, Tensor cos_emb, int max_pos_len) output : Tensor(out) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index bd634d596d4c9..3625929da2b3e 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -3093,6 +3093,13 @@ extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool use_quantizer = false] +- op : resnet_unit + backward: resnet_unit_grad + inputs: + {x : X, filter_x : FilterX, scale_x : ScaleX, bias_x : BiasX, mean_x : MeanX, var_x : VarX, z : Z, filter_z : FilterZ, scale_z : ScaleZ, bias_z : BiasZ, mean_z : MeanZ, var_z : VarZ} + outputs: + {out : Y, bit_mask : BitMask, conv_x : ConvX, saved_mean_x : SavedMeanX, saved_invstd_x : SavedInvstdX, running_mean_x : RunningMeanX, running_var_x : RunningVarX, conv_z : ConvZ, saved_mean_z : SavedMeanZ, saved_invstd_z : SavedInvstdZ, running_mean_z : RunningMeanZ, running_var_z : RunningVarZ} + - op : reverse inputs: x : X