Skip to content

Commit

Permalink
More precise mkldnn kernel rules in GetExpectedKernelType (#29840)
Browse files Browse the repository at this point in the history
* More precise mkldnn kernel choice in GetExpectedKernelType

* Fixes after review

* Refresh develop for CI

* CI experiment

* get back from CI exper
  • Loading branch information
arlesniak authored Jan 25, 2021
1 parent a28a202 commit 5bf25d1
Show file tree
Hide file tree
Showing 25 changed files with 111 additions and 114 deletions.
14 changes: 8 additions & 6 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,21 +1040,23 @@ static void CheckTensorNANOrInf(const std::string& op_type,
op_type, name));
}

bool OperatorWithKernel::SupportsMKLDNN() const {
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN;
LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
});
}

bool OperatorWithKernel::CanMKLDNNBeUsed(
const framework::ExecutionContext& ctx) const {
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
bool use_mkldnn_ctx =
ctx.Attr<bool>("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN();
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}

void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ class OperatorBase {

virtual bool SupportGPU() const { return false; }

virtual bool SupportsMKLDNN() const { return false; }

const std::string& Type() const { return type_; }

bool HasAttr(const std::string& name) const { return attrs_.count(name); }
Expand Down Expand Up @@ -492,9 +490,10 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const;

bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;

virtual void InferShape(InferShapeContext* ctx) const = 0;

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
Expand All @@ -106,13 +107,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
oper.CanMKLDNNBeUsed(ctx)) {
oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

class ActivationOp : public framework::OperatorWithKernel {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/addmm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;

Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
Expand Down Expand Up @@ -524,17 +525,17 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"All Inputs of Concat OP are Empty!"));
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
Expand Down Expand Up @@ -556,6 +557,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");

#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
Expand All @@ -564,17 +566,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
}
#endif

auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
auto type = framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
return type;
}

Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Expand All @@ -193,15 +194,13 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
}

framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/operators/data_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
Expand Down Expand Up @@ -483,18 +483,17 @@ class DataNormGradOp : public framework::OperatorWithKernel {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detection/prior_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_input_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ElementwiseMulOp : public ElementwiseOp {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -280,8 +280,9 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};

if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -331,7 +332,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -384,7 +385,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
}

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/fused/fusion_gru_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,14 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

void FusionGRUOpMaker::Make() {
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context(), layout, library);
return framework::OpKernelType(data_type, ctx.device_context(), layout,
library);
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
14 changes: 6 additions & 8 deletions paddle/fluid/operators/gelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down Expand Up @@ -86,17 +85,16 @@ class GeluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/interpolate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,19 @@ class InterpolateOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx) &&
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
Loading

0 comments on commit 5bf25d1

Please sign in to comment.