Skip to content

Commit

Permalink
Merge branch 'release/2.1' into release/2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Jun 9, 2021
2 parents 837e7d0 + bad3beb commit 7d0025d
Show file tree
Hide file tree
Showing 23 changed files with 481 additions and 34 deletions.
18 changes: 15 additions & 3 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,23 @@ function(select_nvcc_arch_flags out_variable)
if(${CUDA_ARCH_NAME} STREQUAL "Kepler")
set(cuda_arch_bin "30 35")
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50")
if (WITH_NV_JETSON)
set(cuda_arch_bin "53")
else()
set(cuda_arch_bin "50")
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
set(cuda_arch_bin "60 61")
if (WITH_NV_JETSON)
set(cuda_arch_bin "62")
else()
set(cuda_arch_bin "60 61")
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
set(cuda_arch_bin "70")
if (WITH_NV_JETSON)
set(cuda_arch_bin "72")
else()
set(cuda_arch_bin "70")
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,12 @@ Scope* OperatorWithKernel::PrepareData(
// the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error.
if (pre_scope_ == &scope && new_scope == nullptr) {

// For inference, ops that behind conditional branch aren't supported well,
// so disable prepare optimization conservatively.
bool force_prepare_data = HasAttr("inference_force_prepare_data") &&
Attr<bool>("inference_force_prepare_data");
if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) {
need_prepare_data_ = false;
}

Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,46 @@ bool AnalysisPredictor::CreateExecutor() {
executor_.reset(new paddle::framework::NaiveExecutor(place_));
return true;
}

static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
// here is prepare data optimization related bad cases:
// let's assume an op behind conditional_block and if conditional_block
// chooses branch 1, the op need to call prepare data. else the op don't need
// to call prepare data. In running, if predictor chooses branch 2, then
// optimization takes effect, later issue is followed if predictor chooses
// branch 1, because the op lost chance to prepare data.
std::vector<std::string> op_type = {"conditional_block_infer",
"select_input"};
for (const auto &type : op_type) {
if (op->Type() == type) {
return true;
}
}
return false;
}

static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
bool pre_disable_opt) {
bool disable_opt = false;
auto &infer_block = inference_program->Block(block);
for (auto *op : infer_block.AllOps()) {
if (disable_opt || pre_disable_opt) {
op->SetAttr("inference_force_prepare_data", true);
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID,
disable_opt || pre_disable_opt);
}
// disable prepare data if unfriendly op is found
disable_opt = IsPrepareDataOptTargetOp(op);
}
}

bool AnalysisPredictor::PrepareExecutor() {
DisablePrepareDataOpt(inference_program_, 0, false);

executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_);

Expand Down Expand Up @@ -1197,6 +1236,7 @@ USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(nearest_interp);
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reshape);
#endif

namespace paddle_infer {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ nv_library(tensorrt_converter
nearest_interp_op.cc
reduce_op.cc
gather_nd_op.cc
reshape_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)

nv_test(test_op_converter SRCS test_op_converter.cc DEPS
Expand Down
19 changes: 12 additions & 7 deletions paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,

TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), bias_size};
auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input,
nv_ksize, weight, bias);
PADDLE_ENFORCE_NOT_NULL(layer,
platform::errors::Fatal("TensorRT create conv2d"
" layer error."));
// In conv2d_transpose and depthwise_conv2d_transpose,
// output channels = filter_dims[1] * groups
auto* layer = (op_desc.Type() == "conv2d_transpose" ||
op_desc.Type() == "depthwise_conv2d_transpose")
? fadd_layer(const_cast<nvinfer1::ITensor*>(X),
n_input * groups, nv_ksize, weight, bias)
: fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output,
nv_ksize, weight, bias);

PADDLE_ENFORCE_NOT_NULL(
layer, platform::errors::Fatal("TensorRT create conv2d/conv2d_transpose"
" layer failed."));
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setNbGroups(groups);
Expand All @@ -134,7 +141,6 @@ class Conv2dOpConverter : public OpConverter {
ConvertConv2d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
int n_input, /* Conv input maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer =
Expand All @@ -156,7 +162,6 @@ class Deconv2dOpConverter : public OpConverter {
ConvertConv2d(
engine_, op, scope, test_mode,
[&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
int n_input, /* Deconv output maps */
nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
auto* layer =
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class OpConverter {
it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
// reshape2 == reshape
if (op_desc.Type() == "reshape2") {
it = Registry<OpConverter>::Global().Lookup("reshape");
PADDLE_ENFORCE_NOT_NULL(
it, platform::errors::Unimplemented("no OpConverter for optype [%s]",
op_desc.Type()));
}
if (!it) {
it = Registry<OpConverter>::Global().Lookup(op_desc.Type());
}
Expand Down
63 changes: 63 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/reshape_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* ReshapeOp
*/
class ReshapeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const std::vector<int>& shape =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("shape"));
int nbDims_num = shape.size();
nvinfer1::Dims reshape_dim;
if (engine_->with_dynamic_shape()) { // running the TRT Dynamic Shape mode
reshape_dim.nbDims = nbDims_num;
for (int i = 0; i < nbDims_num; ++i) {
reshape_dim.d[i] = shape[i];
}
} else { // running the TRT Static Shape mode
reshape_dim.nbDims = nbDims_num - 1;
for (int i = 0; i < nbDims_num - 1; ++i) {
reshape_dim.d[i] = shape[i + 1];
}
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(reshape_dim);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "reshape", {output_name}, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(reshape, ReshapeOpConverter);
17 changes: 17 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
#if IS_TRT_VERSION_GE(7130)
teller_set.insert("group_norm");
#endif
#if CUDA_VERSION >= 10200
teller_set.insert("reshape");
teller_set.insert("reshape2");
#endif
}

Expand Down Expand Up @@ -695,6 +699,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}

if (op_type == "reshape" || op_type == "reshape2") {
if (!desc.HasAttr("shape") || with_dynamic_shape) {
return false;
// Paddle-TRT does not support the input tensors: Shape and ShapeTensor
} else if (desc.Input("Shape").size() >= 1 ||
desc.Input("ShapeTensor").size() >= 1) {
return false;
} else {
std::vector<int> shape =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("shape"));
if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false;
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/controlflow/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,18 @@ class CompareOp : public framework::OperatorWithKernel {

REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"

REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/strided_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,

REGISTER_OP_CPU_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
Expand All @@ -335,6 +336,7 @@ REGISTER_OP_CPU_KERNEL(

REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/strided_slice_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
Expand All @@ -30,7 +31,8 @@ REGISTER_OP_CUDA_KERNEL(

REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11075,7 +11075,7 @@ def strided_slice(input, axes, starts, ends, strides):
Then:
result = [ [2], ]
Args:
input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``bool``, ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to.
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Variable): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of
Expand Down Expand Up @@ -11126,7 +11126,7 @@ def strided_slice(input, axes, starts, ends, strides):
helper = LayerHelper('strided_slice', **locals())

check_variable_and_dtype(input, 'input',
['float32', 'float64', 'int32', 'int64'],
['bool', 'float32', 'float64', 'int32', 'int64'],
'strided_slice')
check_type(axes, 'axes', (list, tuple), 'strided_slice')
check_type(starts, 'starts', (list, tuple, Variable), 'strided_slice')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ foreach(TEST_INFERENCE_IR_PASS ${TEST_TRT_IR_PASSES})
endforeach()

if(WITH_GPU AND TENSORRT_FOUND)
list(REMOVE_ITEM TEST_TRT_IR_PASSES test_trt_multiclass_nms_op)
foreach(target ${TEST_TRT_IR_PASSES})
py_test_modules(${target} MODULES ${target})
endforeach()
Expand All @@ -32,6 +33,6 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_trt_subgraph_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_activation_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200)
#set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120)
endif()
Loading

1 comment on commit 7d0025d

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.