From 86a9ba77f02c1adc78ab0bf861a10bb5bfb831eb Mon Sep 17 00:00:00 2001 From: Trevor Morris <trevoraidanmorris@gmail.com> Date: Tue, 4 Feb 2020 15:48:01 -0800 Subject: [PATCH] [Relay/TRT] Support clip for TRT 4 using relu + eltwise (#83) * Support clip for TRT 4 using relu + eltwise * Re-enable consistency check * Invoke convertlayout properly --- python/tvm/relay/tensorrt.py | 7 ++- .../contrib/tensorrt/enable_tensorrt.cc | 2 +- .../contrib/tensorrt/tensorrt_builder.cc | 2 + src/runtime/contrib/tensorrt/tensorrt_ops.h | 56 +++++++++++++++++++ tests/python/relay/test_tensorrt.py | 4 +- 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/tensorrt.py b/python/tvm/relay/tensorrt.py index 7ceea2d0c18c4..449fa9e7c31b4 100644 --- a/python/tvm/relay/tensorrt.py +++ b/python/tvm/relay/tensorrt.py @@ -182,9 +182,10 @@ def EnableTrt(mod, params=None, trt_version=None): assert len(trt_version) == 3 # Apply passes required for TRT - mod = relay.transform.RemoveUnusedFunctions()(mod) - mod = relay.transform.InferType()(mod) - mod = relay.transform.ConvertLayout('NCHW')(mod) + seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout('NCHW')]) + with relay.transform.PassContext(opt_level=3): + mod = seq(mod) mod = PreprocessForTrt(mod) if params: # Bind params so that we can use FoldConstant. diff --git a/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc b/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc index c2f22f7db1f20..5f36da3720796 100644 --- a/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc +++ b/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc @@ -392,8 +392,8 @@ static const std::unordered_map<std::string, IsCompatibleFn> {"mean", ReduceOpChecker}, {"contrib.adaptive_max_pool2d", AdapativePool2DOpChecker}, {"contrib.adaptive_avg_pool2d", AdapativePool2DOpChecker}, + {"clip", AlwaysChecker}, // Ops which require TRT 5.1.5+ - {"clip", TrtVersionChecker<5, 1, 5>}, {"nn.leaky_relu", TrtVersionChecker<5, 1, 5>}, {"sin", TrtVersionChecker<5, 1, 5>}, {"cos", TrtVersionChecker<5, 1, 5>}, diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 169783f02b1e5..37bc4c807bd4d 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -92,6 +92,8 @@ GetOpConverters() { map->emplace("ceil", std::make_shared<UnaryOpConverter>()); map->emplace("floor", std::make_shared<UnaryOpConverter>()); map->emplace("strided_slice", std::make_shared<StridedSliceOpConverter>()); +#else + map->emplace("clip", std::make_shared<ClipLegacyOpConverter>()); #endif #if TRT_VERSION_GE(6, 0, 1) map->emplace("image.resize", std::make_shared<ResizeOpConverter>()); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h index 4aee6b9c74272..c52657f755c44 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -147,6 +147,28 @@ class TrtOpConverter { // Subtract 1 for implicit batch dim. return axis - 1; } + + // Create constant that is broadcastable against input. + /*! + * \brief Create constant that is broadcastable. + * \param params Parameters for this op. + * \param value Value of scalar. + * \param broadcast_to_dims Dims that scalar should be broadcastable against. + * \return Constant tensor. + */ + nvinfer1::ITensor* CreateScalar( + AddTrtLayerParams* params, float value, + const nvinfer1::Dims& broadcast_to_dims) const { + nvinfer1::Dims dims; + dims.nbDims = broadcast_to_dims.nbDims; + std::fill_n(dims.d, dims.nbDims, 1); + float* values = new float[1]; + values[0] = value; + nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, + static_cast<void*>(values), 1}; + params->trt_weights->push_back(weights); + return params->network->addConstant(dims, weights)->getOutput(0); + } }; class ActivationOpConverter : public TrtOpConverter { @@ -185,6 +207,40 @@ class ActivationOpConverter : public TrtOpConverter { } }; +class ClipLegacyOpConverter : public TrtOpConverter { + public: + ClipLegacyOpConverter() : TrtOpConverter({kTensor}) {} + + void Convert(AddTrtLayerParams* params) const { + const auto* attrs = params->call->attrs.as<ClipAttrs>(); + CHECK_EQ(params->inputs.size(), 1) << "Activation op expects 1 input."; + auto input = params->inputs.at(0).tensor; + // relu(x) + nvinfer1::ITensor* output = nullptr; + if (attrs->a_min == 0.0f) { + // Use relu instead of max(x, 0) because relu can be fused. + nvinfer1::IActivationLayer* relu_layer = params->network->addActivation( + *input, nvinfer1::ActivationType::kRELU); + CHECK(relu_layer != nullptr); + output = relu_layer->getOutput(0); + } else { + // max(x, a_min) + nvinfer1::ITensor* a_min = + CreateScalar(params, attrs->a_min, input->getDimensions()); + nvinfer1::IElementWiseLayer* max_layer = params->network->addElementWise( + *input, *a_min, nvinfer1::ElementWiseOperation::kMAX); + CHECK(max_layer != nullptr); + output = max_layer->getOutput(0); + } + // min(relu(x), a_max) + nvinfer1::ITensor* a_max = + CreateScalar(params, attrs->a_max, input->getDimensions()); + nvinfer1::IElementWiseLayer* min_layer = params->network->addElementWise( + *output, *a_max, nvinfer1::ElementWiseOperation::kMIN); + params->outputs.push_back(min_layer->getOutput(0)); + } +}; + class ElementWiseBinaryOpConverter : public TrtOpConverter { public: ElementWiseBinaryOpConverter() : TrtOpConverter({kTensor, kTensor}) {} diff --git a/tests/python/relay/test_tensorrt.py b/tests/python/relay/test_tensorrt.py index 1b640a112d97f..cf1cffd1be131 100644 --- a/tests/python/relay/test_tensorrt.py +++ b/tests/python/relay/test_tensorrt.py @@ -494,8 +494,8 @@ def check_trt_used(graph): i_data = np.random.uniform(0, 1, input_shape).astype(dtype) for model in models: latency[model], res = test_model(model, i_data, input_shape, dtype, use_trt=True) - # _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1) - # tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3) + _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1) + tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3) for model in models: print(model, latency[model])