Skip to content

Commit

Permalink
add shape, sum trt layer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoutianzi666 committed Jul 19, 2022
1 parent 50c0257 commit a7c2e98
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 14 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,8 @@ USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2)
USE_TRT_CONVERTER(squeeze2)
USE_TRT_CONVERTER(unsqueeze2)
USE_TRT_CONVERTER(sum)
USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ list(
top_k_op.cc
squeeze2_op.cc
unsqueeze2_op.cc
sum_op.cc
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc)

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace tensorrt {
class ShapeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid shape op to tensorrt shape layer";

framework::OpDesc op_desc(op, nullptr);
Expand Down
25 changes: 14 additions & 11 deletions paddle/fluid/inference/tensorrt/convert/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,26 @@ namespace tensorrt {
class SumOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid sum op to tensorrt sum layer";

framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
// Declare the first input
// Declare the first input
auto* sum_tmp = engine_->GetITensor(op_desc.Input("X")[0]);
if (op_desc.Input("X").size() == 1) {
layer = TRT_ENGINE_ADD_LAYER(engine_, addIdentity, *sum_tmp);
}
else {
for (size_t i = 1; i < op_desc.Input("X").size(); i++) {
auto* input_i = engine_->GetITensor(op_desc.Input("X")[i]);
layer = TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *input_i, *sum_tmp,
nvinfer1::ElementWiseOperation::kSUM);
sum_tmp = layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *sum_tmp);
} else {
for (size_t i = 1; i < op_desc.Input("X").size(); i++) {
auto* input_i = engine_->GetITensor(op_desc.Input("X")[i]);
layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*input_i,
*sum_tmp,
nvinfer1::ElementWiseOperation::kSUM);
sum_tmp = layer->getOutput(0);
}
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "sum", {output_name}, test_mode);
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2"};
std::unordered_set<std::string> teller_set{
Expand Down Expand Up @@ -276,6 +278,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"recover_padding",
"remove_padding",
"fill_constant",
"sum",
"shape",
"squeeze2",
"unsqueeze2",
"fused_token_prune"};
Expand Down Expand Up @@ -1208,6 +1212,11 @@ bool OpTeller::Tell(const framework::ir::Node* node,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
auto dtype = x_var_desc->GetDataType();
// At present, only support float32 or float16 into trt.
if (!(dtype == 5 || dtype == 4)) {
return false;
}
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt";
return false;
Expand Down Expand Up @@ -1361,9 +1370,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
}

// remember that 1D input in static shape mode is filtered at the beginning
if (op_type == "sum") {
return true;
return true;
}

if (op_type == "shape" && !with_dynamic_shape) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest


class TrtConvertSumTest(TrtLayerAutoScanTest):

def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True

def sample_program_configs(self):

def generate_input1(batch):
if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32)
elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32)
elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for batch in [1, 4]:
self.dims = dims
ops_config = [{
"op_type": "shape",
"op_inputs": {
"Input": ["input1"]
},
"op_outputs": {
"Out": ["output"]
},
"op_attrs": {}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input1":
TensorConfig(data_gen=partial(generate_input1, batch))
},
outputs=["output"])

yield program_config

def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):

def generate_dynamic_shape():
if self.dims == 4:
self.dynamic_shape.min_input_shape = {"input1": [1, 3, 24, 24]}
self.dynamic_shape.max_input_shape = {"input1": [4, 3, 48, 48]}
self.dynamic_shape.opt_input_shape = {"input1": [1, 3, 24, 24]}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {"input1": [1, 3, 24]}
self.dynamic_shape.max_input_shape = {"input1": [4, 3, 48]}
self.dynamic_shape.opt_input_shape = {"input1": [1, 3, 24]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input1": [1, 24]}
self.dynamic_shape.max_input_shape = {"input1": [4, 48]}
self.dynamic_shape.opt_input_shape = {"input1": [1, 24]}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input1": [24]}
self.dynamic_shape.max_input_shape = {"input1": [48]}
self.dynamic_shape.opt_input_shape = {
"input1": [24],
}

def generate_trt_nodes_num(dynamic_shape):
if (not dynamic_shape):
return 0, 3
return 1, 2

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
False), 1e-5

# for dynamic_shape
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5

def test(self):
self.run_test()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit a7c2e98

Please sign in to comment.