From 046553c71389bf715edcc6836792627dd1443caa Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 10 Jan 2022 11:00:48 +0800 Subject: [PATCH] Support setting infershape function for custom grad op (#38776) * unify infer_shape func calling * support set grad infer shape fn for custom op * unify infershape in new executor and eager * remove todo comment * revert infershape in operator --- .../fluid/eager/legacy/prepared_operator.cc | 3 +- paddle/fluid/framework/custom_operator.cc | 303 ++++++++++-------- .../framework/new_executor/data_transfer.cc | 3 +- .../new_executor/interpretercore_util.cc | 2 +- paddle/fluid/framework/operator.cc | 4 +- paddle/fluid/imperative/prepared_operator.cc | 6 +- paddle/pten/api/lib/op_meta_info.cc | 7 - .../fluid/tests/custom_op/custom_relu_op.cc | 46 +++ .../fluid/tests/custom_op/custom_relu_op.cu | 19 ++ .../custom_op/test_custom_relu_op_jit.py | 3 +- 10 files changed, 236 insertions(+), 160 deletions(-) diff --git a/paddle/fluid/eager/legacy/prepared_operator.cc b/paddle/fluid/eager/legacy/prepared_operator.cc index 1c3429207f8b5..4e892b14a9c9c 100644 --- a/paddle/fluid/eager/legacy/prepared_operator.cc +++ b/paddle/fluid/eager/legacy/prepared_operator.cc @@ -174,8 +174,7 @@ static void PreparedOpRunImpl( EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, op.Type()); - static_cast(op).InferShape( - &infer_shape_ctx); + op.Info().infer_shape_(&infer_shape_ctx); func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs)); diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 785973e041a0d..fd2522b0336ff 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -94,7 +94,7 @@ std::vector ParseAttrStr(const std::string& attr) { // 2. type rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1))); - VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; + VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; return rlt; } @@ -109,11 +109,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, const std::vector& inputs, const std::vector& outputs, const std::vector& attrs) { - VLOG(1) << "Custom Operator: Start run KernelFunc."; + VLOG(3) << "Custom Operator: Start run KernelFunc."; std::vector custom_ins; std::vector> custom_vec_ins; for (auto& in_name : inputs) { - VLOG(1) << "Custom Operator: input name - " << in_name; + VLOG(3) << "Custom Operator: input name - " << in_name; if (detail::IsDuplicableVar(in_name)) { // return const std::vector auto vec_x = ctx.MultiInput(in_name); @@ -185,11 +185,11 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, } } - VLOG(1) << "Custom Operator: Run ComputeFunc."; + VLOG(3) << "Custom Operator: Run ComputeFunc."; try { auto outs = func(custom_ins, custom_vec_ins, custom_attrs); - VLOG(1) << "Custom Operator: Share outputs into ExecutionContext."; + VLOG(3) << "Custom Operator: Share outputs into ExecutionContext."; for (size_t i = 0; i < outputs.size(); ++i) { auto out_name = outputs[i]; if (detail::IsDuplicableVar(out_name)) { @@ -230,6 +230,95 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, } } +static void RunInferShapeFunc(framework::InferShapeContext* ctx, + const paddle::InferShapeFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs) { + std::vector> input_shapes; + std::vector>> vec_input_shapes; + + VLOG(3) << "Custom Operator: InferShape - get input ddim."; + for (auto& in_name : inputs) { + if (detail::IsDuplicableVar(in_name)) { + OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom"); + auto vec_ddim = ctx->GetInputsDim(in_name); + std::vector> vec_shape; + vec_shape.reserve(vec_ddim.size()); + std::transform(vec_ddim.begin(), vec_ddim.end(), + std::back_inserter(vec_shape), + [&](const DDim& ddim) -> std::vector { + return framework::vectorize(ddim); + }); + vec_input_shapes.emplace_back(vec_shape); + } else { + OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom"); + auto ddim = ctx->GetInputDim(in_name); + input_shapes.emplace_back(framework::vectorize(ddim)); + } + } + + std::vector custom_attrs; + for (auto& attr_str : attrs) { + auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "int") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "float") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "int64_t") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "std::string") { + custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back(ctx->Attrs().Get>(attr_name)); + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + ctx->Attrs().Get>(attr_name)); + } else if (attr_type_str == "std::vector") { + // NOTE(chenweihang): InferShape can't support std::vector + // attr type, because the input type is std::vector, only + // can use one rule to parse std::vector parameter + continue; + } else if (attr_type_str == "std::vector") { + custom_attrs.emplace_back( + ctx->Attrs().Get>(attr_name)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "Please check whether the attribute data type and " + "data type string are matched.", + attr_type_str)); + } + } + + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + auto output_shapes = func(input_shapes, vec_input_shapes, custom_attrs); + + VLOG(3) << "Custom Operator: InferShape - set output ddim."; + for (size_t i = 0; i < outputs.size(); ++i) { + auto out_name = outputs[i]; + if (detail::IsDuplicableVar(out_name)) { + std::vector vec_ddim; + vec_ddim.reserve(output_shapes.size()); + std::transform(output_shapes.begin(), output_shapes.end(), + std::back_inserter(vec_ddim), + [&](const std::vector& shape) -> DDim { + return framework::make_ddim(shape); + }); + ctx->SetOutputsDim(out_name, vec_ddim); + } else { + ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i])); + } + } +} + //////////////////// Operator Define ///////////////// class CustomOperator : public OperatorWithKernel { @@ -239,7 +328,7 @@ class CustomOperator : public OperatorWithKernel { // Dummy infershape // Because it is a pure virtual function, it must be implemented void InferShape(framework::InferShapeContext* ctx) const override { - VLOG(1) << "Custom Operator: Dummy infer shape of custom operator."; + VLOG(3) << "Custom Operator: Dummy infer shape of custom operator."; } /** @@ -381,7 +470,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { auto fwd_op_outputs = this->OutputNames(); for (auto& in_name : inputs_) { - VLOG(1) << "Custom Operator: GradOpDescMaker - input: " << in_name; + VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name; if (!detail::IsGradVar(in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) { grad_op->SetInput(in_name, this->Input(in_name)); @@ -398,7 +487,7 @@ class CustomGradOpMaker : public SingleGradOpMaker { } } for (auto& out_name : outputs_) { - VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name; + VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name; if (detail::IsDuplicableVar(out_name)) { grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name), @@ -447,7 +536,7 @@ class CustomGradOpMaker auto fwd_op_outputs = this->OutputNames(); for (auto& in_name : inputs_) { - VLOG(1) << "Custom Operator: GradOpBaseMaker - input: " << in_name; + VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name; if (!detail::IsGradVar(in_name)) { if (detail::IsMemberOf(fwd_op_inputs, in_name)) { grad_op->SetInput(in_name, this->Input(in_name)); @@ -464,7 +553,7 @@ class CustomGradOpMaker } } for (auto& out_name : outputs_) { - VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name; + VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name; grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name))); } grad_op->SetAttrMap(this->Attrs()); @@ -486,11 +575,11 @@ void RegisterOperatorKernelWithPlace(const std::string& name, const std::vector& outputs, const std::vector& attrs) { OpKernelType key(type, experimental::ConvertExtPlaceToInnerPlace(place)); - VLOG(1) << "Custom Operator: op kernel key: " << key; + VLOG(3) << "Custom Operator: op kernel key: " << key; OperatorWithKernel::AllOpKernels()[name][key] = [kernel_func, inputs, outputs, attrs](const framework::ExecutionContext& ctx) { - VLOG(1) << "Custom Operator: run custom kernel func in lambda."; + VLOG(3) << "Custom Operator: run custom kernel func in lambda."; RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); }; } @@ -500,7 +589,7 @@ void RegisterOperatorKernel(const std::string& name, const std::vector& inputs, const std::vector& outputs, const std::vector& attrs) { - VLOG(1) << "Custom Operator: op name in kernel: " << name; + VLOG(3) << "Custom Operator: op name in kernel: " << name; // NOTE [ Dummy Op Kernel Key ] // TODO(chenweihang): Because execute engine need get device context based // op_kernel_key.place_, so we should register kernel for each @@ -535,12 +624,12 @@ void RegisterOperatorWithMetaInfo( auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); - VLOG(1) << "Custom Operator: forward, op name: " << op_name; - VLOG(1) << "Custom Operator: forward, op inputs: " + VLOG(3) << "Custom Operator: forward, op name: " << op_name; + VLOG(3) << "Custom Operator: forward, op inputs: " << string::join_strings(op_inputs, ','); - VLOG(1) << "Custom Operator: forward, op outputs: " + VLOG(3) << "Custom Operator: forward, op outputs: " << string::join_strings(op_outputs, ','); - VLOG(1) << "Custom Operator: forward, op attrs: " + VLOG(3) << "Custom Operator: forward, op attrs: " << string::join_strings(op_attrs, ','); // Op @@ -588,96 +677,13 @@ void RegisterOperatorWithMetaInfo( "Please set the InferShapeFn of custom " "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); - VLOG(1) << "Custom Operator: Default InferShape - share ddim."; + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; ctx->ShareDim(op_inputs[0], op_outputs[0]); }; } else { info.infer_shape_ = [op_inputs, op_outputs, op_attrs, infer_shape_func](InferShapeContext* ctx) { - std::vector> input_shapes; - std::vector>> vec_input_shapes; - - VLOG(1) << "Custom Operator: InferShape - get input ddim."; - for (auto& in_name : op_inputs) { - if (detail::IsDuplicableVar(in_name)) { - OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom"); - auto vec_ddim = ctx->GetInputsDim(in_name); - std::vector> vec_shape; - vec_shape.reserve(vec_ddim.size()); - std::transform(vec_ddim.begin(), vec_ddim.end(), - std::back_inserter(vec_shape), - [&](const DDim& ddim) -> std::vector { - return framework::vectorize(ddim); - }); - vec_input_shapes.emplace_back(vec_shape); - } else { - OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom"); - auto ddim = ctx->GetInputDim(in_name); - input_shapes.emplace_back(framework::vectorize(ddim)); - } - } - - std::vector custom_attrs; - for (auto& attr_str : op_attrs) { - auto attr_name_and_type = detail::ParseAttrStr(attr_str); - auto attr_name = attr_name_and_type[0]; - auto attr_type_str = attr_name_and_type[1]; - if (attr_type_str == "bool") { - custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); - } else if (attr_type_str == "int") { - custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); - } else if (attr_type_str == "float") { - custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); - } else if (attr_type_str == "int64_t") { - custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); - } else if (attr_type_str == "std::string") { - custom_attrs.emplace_back(ctx->Attrs().Get(attr_name)); - } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back( - ctx->Attrs().Get>(attr_name)); - } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back( - ctx->Attrs().Get>(attr_name)); - } else if (attr_type_str == "std::vector") { - // NOTE(chenweihang): InferShape can't support std::vector - // attr type, because the input type is std::vector, only - // can use one rule to parse std::vector parameter - continue; - } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back( - ctx->Attrs().Get>(attr_name)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported `%s` type value as custom attribute now. " - "Supported data types include `bool`, `int`, `float`, " - "`int64_t`, `std::string`, `std::vector`, " - "`std::vector`, `std::vector`, " - "Please check whether the attribute data type and " - "data type string are matched.", - attr_type_str)); - } - } - - VLOG(1) << "Custom Operator: InferShape - calc output ddim."; - auto output_shapes = - infer_shape_func(input_shapes, vec_input_shapes, custom_attrs); - - VLOG(1) << "Custom Operator: InferShape - set output ddim."; - for (size_t i = 0; i < op_outputs.size(); ++i) { - auto out_name = op_outputs[i]; - if (detail::IsDuplicableVar(out_name)) { - std::vector vec_ddim; - vec_ddim.reserve(output_shapes.size()); - std::transform(output_shapes.begin(), output_shapes.end(), - std::back_inserter(vec_ddim), - [&](const std::vector& shape) -> DDim { - return framework::make_ddim(shape); - }); - ctx->SetOutputsDim(out_name, vec_ddim); - } else { - ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i])); - } - } + RunInferShapeFunc(ctx, infer_shape_func, op_inputs, op_outputs, op_attrs); }; } @@ -706,7 +712,7 @@ void RegisterOperatorWithMetaInfo( "Please set the InferDtypeFn of custom " "operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))")); - VLOG(1) << "Custom Operator: InferDtype - share dtype."; + VLOG(3) << "Custom Operator: InferDtype - share dtype."; auto dtype = ctx->GetInputDataType(op_inputs[0]); ctx->SetOutputDataType(op_outputs[0], dtype); }; @@ -716,7 +722,7 @@ void RegisterOperatorWithMetaInfo( std::vector input_dtypes; std::vector> vec_input_dtypes; - VLOG(1) << "Custom Operator: InferDtype - get input dtype."; + VLOG(3) << "Custom Operator: InferDtype - get input dtype."; for (auto& in_name : op_inputs) { if (detail::IsDuplicableVar(in_name)) { std::vector vec_custom_dtype; @@ -731,10 +737,10 @@ void RegisterOperatorWithMetaInfo( } } - VLOG(1) << "Custom Operator: InferDtype - infer output dtype."; + VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes); - VLOG(1) << "Custom Operator: InferDtype - set output dtype."; + VLOG(3) << "Custom Operator: InferDtype - set output dtype."; for (size_t i = 0; i < op_outputs.size(); ++i) { auto out_name = op_outputs[i]; if (detail::IsDuplicableVar(out_name)) { @@ -763,11 +769,12 @@ void RegisterOperatorWithMetaInfo( auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); + auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); - VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name; - VLOG(1) << "Custom Operator: backward, op inputs: " + VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name; + VLOG(3) << "Custom Operator: backward, op inputs: " << string::join_strings(grad_op_inputs, ','); - VLOG(1) << "Custom Operator: backward, op outputs: " + VLOG(3) << "Custom Operator: backward, op outputs: " << string::join_strings(grad_op_outputs, ','); // GradOpDescMaker @@ -809,40 +816,52 @@ void RegisterOperatorWithMetaInfo( }; // Grad InferShape - grad_info.infer_shape_ = [grad_op_inputs, - grad_op_outputs](InferShapeContext* ctx) { - // 1. if forward input exists, gradient's shape is same with forward input - // default - // [Suitable for most situations] - // 2. if forward input not exists, and only contains one grad input and - // output, - // use grad input shape as grad output shape - // [Suitable for the situation that forward input is not used as - // backward input] - // TODO(chenweihang): support set grad op infershape func if needed - for (auto& out_name : grad_op_outputs) { - auto fwd_name = detail::NoGrad(out_name); - if (detail::IsDuplicableVar(fwd_name)) { - // Duplicable forward var must as backward input - ctx->ShareDim(fwd_name, out_name); - } else { - if (ctx->HasInput(fwd_name)) { + if (grad_infer_shape_fn == nullptr) { + grad_info.infer_shape_ = [grad_op_inputs, + grad_op_outputs](InferShapeContext* ctx) { + // 1. if forward input exists, gradient's shape is same with forward + // input + // default + // [Suitable for most situations] + // 2. if forward input not exists, and only contains one grad input and + // output, + // use grad input shape as grad output shape + // [Suitable for the situation that forward input is not used as + // backward input] + for (auto& out_name : grad_op_outputs) { + auto fwd_name = detail::NoGrad(out_name); + if (detail::IsDuplicableVar(fwd_name)) { + // Duplicable forward var must as backward input ctx->ShareDim(fwd_name, out_name); } else { - PADDLE_ENFORCE_EQ( - grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL, - true, - platform::errors::Unavailable( - "Custom grad operator infershape error. " - "If a custom grad operator contains only one input and " - "only one output, the input shape will be directly set to " - "the output shape. Otherwise, Please set the forward input " - "as the grad operator's input.")); - ctx->ShareDim(grad_op_inputs[0], out_name); + if (ctx->HasInput(fwd_name)) { + ctx->ShareDim(fwd_name, out_name); + } else { + PADDLE_ENFORCE_EQ( + grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL, + true, + platform::errors::Unavailable( + "Custom grad operator infershape error. " + "If a custom grad operator contains only one input and " + "only one output, the input shape will be directly set " + "to " + "the output shape. Otherwise, Please set the forward " + "input " + "as the grad operator's input or set the InferShapeFn " + "of custom grad operator by " + ".SetInferShapeFn(PD_INFER_SHAPE(...))")); + ctx->ShareDim(grad_op_inputs[0], out_name); + } } } - } - }; + }; + } else { + grad_info.infer_shape_ = [grad_op_inputs, grad_op_outputs, grad_op_attrs, + grad_infer_shape_fn](InferShapeContext* ctx) { + RunInferShapeFunc(ctx, grad_infer_shape_fn, grad_op_inputs, + grad_op_outputs, grad_op_attrs); + }; + } // Kernel func RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs, @@ -860,11 +879,11 @@ void RegisterOperatorWithMetaInfo( void RegisterOperatorWithMetaInfoMap( const paddle::OpMetaInfoMap& op_meta_info_map) { auto& meta_info_map = op_meta_info_map.GetMap(); - VLOG(1) << "Custom Operator: size of op meta info map - " + VLOG(3) << "Custom Operator: size of op meta info map - " << meta_info_map.size(); // pair: {op_type, OpMetaInfo} for (auto& pair : meta_info_map) { - VLOG(1) << "Custom Operator: pair first -> op name: " << pair.first; + VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first; RegisterOperatorWithMetaInfo(pair.second); } } @@ -874,7 +893,7 @@ void RegisterOperatorWithMetaInfoMap( // load op api void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name); - VLOG(1) << "load custom_op lib: " << dso_name; + VLOG(3) << "load custom_op lib: " << dso_name; typedef OpMetaInfoMap& get_op_meta_info_map_t(); auto* get_op_meta_info_map = detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 064dfa0170bdb..9230c36a0c745 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode( // 2. Execute infer shape and choose kernel auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); - static_cast(op.get())->InferShape( - &infer_shape_ctx); + op.get()->Info().infer_shape_(&infer_shape_ctx); auto kernels_iter = all_op_kernels.find(op_type); PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), platform::errors::Unavailable( diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 41c4faa67fbeb..7ced4853c2d8f 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place, // TODO(Aurelius84): In case of control flow ops, they are NOT // inheritted // from OperatorWithKernel. - op_with_kernel->InferShape(&infer_shape_ctx); + op_with_kernel->Info().infer_shape_(&infer_shape_ctx); } auto kernels_iter = all_op_kernels.find(op->Type()); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 2d2e198ef40ec..a0c1bd44da01e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const { RuntimeInferShapeContext infer_shape_ctx(*this, ctx); - this->InferShape(&infer_shape_ctx); + this->Info().infer_shape_(&infer_shape_ctx); } void OperatorWithKernel::RunImpl(const Scope& scope, @@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::RecordEvent record_event("infer_shape", platform::EventRole::kInnerOp); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); + // TODO(chenweihang): replace this after removing `this->IsMKLDNNType()` + // in some mkldnn infershape functions, such conv2d infershape this->InferShape(&infer_shape_ctx); } diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c5623a8f4f243..29cd24a1e7793 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -491,8 +491,7 @@ static void PreparedOpRunImpl( DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, op.Type()); - static_cast(op).InferShape( - &infer_shape_ctx); + op.Info().infer_shape_(&infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs)); @@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl( const framework::AttributeMap& default_attrs) { DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, op.Type()); - static_cast(op).InferShape( - &infer_shape_ctx); + op.Info().infer_shape_(&infer_shape_ctx); BuildDygraphPtenKernelContext(pt_kernel_signature, pt_kernel, ins, outs, attrs, default_attrs, dev_ctx, diff --git a/paddle/pten/api/lib/op_meta_info.cc b/paddle/pten/api/lib/op_meta_info.cc index 586fa0cc05526..aa2e33afb94b8 100644 --- a/paddle/pten/api/lib/op_meta_info.cc +++ b/paddle/pten/api/lib/op_meta_info.cc @@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { } OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { - PADDLE_ENFORCE_EQ( - index_, - 0UL, - platform::errors::Unimplemented( - "Currently, the InferShapeFn setting of Grad Op is not supported, " - "And backward Tensor `X@GRAD` will use the shape of forward Tensor " - "`X` by default.")); info_ptr_->SetInferShapeFn(std::forward(func)); return *this; } diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index b2ef90bf87a1a..c5ec3191c1b02 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu) .Inputs({"X", "Out", paddle::Grad("Out")}) .Outputs({paddle::Grad("X")}) .SetKernelFn(PD_KERNEL(ReluBackward)); + +std::vector relu_cpu_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.mutable_data(out.place()), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out); + +std::vector ReluBackwardWithoutX( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + if (out.place() == paddle::PlaceType::kCPU) { + return relu_cpu_backward_without_x(out, grad_out); + } else if (out.place() == paddle::PlaceType::kGPU) { + return relu_cuda_backward_without_x(out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector> ReluBackwardWithoutXInferShape( + const std::vector& out_shape, + const std::vector& grad_out_shape) { + return {out_shape}; +} + +PD_BUILD_OP(custom_relu_no_x_in_backward) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward) + .Inputs({"Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackwardWithoutX)) + .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape)); diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index dda42a5c05984..637deeb90569c 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -70,3 +70,22 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, return {grad_x}; } + +std::vector relu_cuda_backward_without_x( + const paddle::Tensor& out, const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape()); + + int numel = out.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(out.place()), + numel); + })); + + return {grad_x}; +} diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 4f075066b9d93..16458841f4488 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -49,7 +49,8 @@ class TestJITLoad(unittest.TestCase): def setUp(self): self.custom_ops = [ - custom_module.custom_relu, custom_module.custom_relu_dup + custom_module.custom_relu, custom_module.custom_relu_dup, + custom_module.custom_relu_no_x_in_backward ] self.dtypes = ['float32', 'float64'] if paddle.is_compiled_with_cuda():