Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… variable-keeping-with-ensor
  • Loading branch information
CtfGo committed Jun 21, 2021
2 parents 3b1d04c + 1b0c5ef commit 58a9374
Show file tree
Hide file tree
Showing 107 changed files with 2,651 additions and 1,516 deletions.
2 changes: 1 addition & 1 deletion cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ function(version version_file)
file(APPEND ${version_file} "CXX compiler version: ${CMAKE_CXX_COMPILER_VERSION}\n")
if(TENSORRT_FOUND)
file(APPEND ${version_file}
"WITH_TENSORRT: ${TENSORRT_FOUND}\n" "TensorRT version: v${TENSORRT_MAJOR_VERSION}\n")
"WITH_TENSORRT: ${TENSORRT_FOUND}\n" "TensorRT version: v${TENSORRT_MAJOR_VERSION}.${TENSORRT_MINOR_VERSION}.${TENSORRT_PATCH_VERSION}.${TENSORRT_BUILD_VERSION}\n")
endif()
if(WITH_LITE)
file(APPEND ${version_file} "WITH_LITE: ${WITH_LITE}\n" "LITE_GIT_TAG: ${LITE_GIT_TAG}\n")
Expand Down
20 changes: 19 additions & 1 deletion cmake/tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,23 @@ if(TENSORRT_FOUND)
file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_MINOR +([0-9]+)" TENSORRT_MINOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_PATCH +([0-9]+)" TENSORRT_PATCH_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_BUILD +([0-9]+)" TENSORRT_BUILD_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")

if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_MINOR +([0-9]+)" TENSORRT_MINOR_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_PATCH +([0-9]+)" TENSORRT_PATCH_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
string(REGEX MATCH "define NV_TENSORRT_BUILD +([0-9]+)" TENSORRT_BUILD_VERSION
"${TENSORRT_VERSION_FILE_CONTENTS}")
endif()

if("${TENSORRT_MAJOR_VERSION}" STREQUAL "")
Expand All @@ -60,9 +72,15 @@ if(TENSORRT_FOUND)

string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1"
TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}")
string(REGEX REPLACE "define NV_TENSORRT_MINOR +([0-9]+)" "\\1"
TENSORRT_MINOR_VERSION "${TENSORRT_MINOR_VERSION}")
string(REGEX REPLACE "define NV_TENSORRT_PATCH +([0-9]+)" "\\1"
TENSORRT_PATCH_VERSION "${TENSORRT_PATCH_VERSION}")
string(REGEX REPLACE "define NV_TENSORRT_BUILD +([0-9]+)" "\\1"
TENSORRT_BUILD_VERSION "${TENSORRT_BUILD_VERSION}")

message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}.${TENSORRT_MINOR_VERSION}.${TENSORRT_PATCH_VERSION}.${TENSORRT_BUILD_VERSION} ")
include_directories(${TENSORRT_INCLUDE_DIR})
link_directories(${TENSORRT_LIBRARY})
add_definitions(-DPADDLE_WITH_TENSORRT)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ message DistributedStrategy {
optional bool tensor_parallel = 29 [ default = false ];
optional bool without_graph_optimization = 30 [ default = false ];
optional int32 fuse_grad_size_in_num = 31 [ default = 1 ];
optional bool calc_comm_same_stream = 32 [ default = false ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand Down
259 changes: 256 additions & 3 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope,
}
}

ConvBNFusePass::ConvBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();

AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();

AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}

void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
Expand All @@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse";

// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
Expand Down Expand Up @@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
if (!IsCompat(*conv->Op())) {
LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
return;
}
GraphSafeRemoveNodes(
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
Expand All @@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
if (!IsCompat(desc)) {
LOG(WARNING)
<< "conv_bn fuse pass in out elementwise_add op compat failed.";
return;
}
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.

GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
Expand All @@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count);
}

ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();

AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();

AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}

void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
Expand All @@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse";

// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
Expand Down Expand Up @@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));

if (!IsCompat(*eltwise->Op())) {
LOG(WARNING)
<< "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
return;
}
GraphSafeRemoveNodes(
graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
Expand All @@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count);
}

ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}

ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
Loading

0 comments on commit 58a9374

Please sign in to comment.