From 86a9ba77f02c1adc78ab0bf861a10bb5bfb831eb Mon Sep 17 00:00:00 2001
From: Trevor Morris <trevoraidanmorris@gmail.com>
Date: Tue, 4 Feb 2020 15:48:01 -0800
Subject: [PATCH] [Relay/TRT] Support clip for TRT 4 using relu + eltwise (#83)

* Support clip for TRT 4 using relu + eltwise

* Re-enable consistency check

* Invoke convertlayout properly
---
 python/tvm/relay/tensorrt.py                  |  7 ++-
 .../contrib/tensorrt/enable_tensorrt.cc       |  2 +-
 .../contrib/tensorrt/tensorrt_builder.cc      |  2 +
 src/runtime/contrib/tensorrt/tensorrt_ops.h   | 56 +++++++++++++++++++
 tests/python/relay/test_tensorrt.py           |  4 +-
 5 files changed, 65 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/tensorrt.py b/python/tvm/relay/tensorrt.py
index 7ceea2d0c18c4..449fa9e7c31b4 100644
--- a/python/tvm/relay/tensorrt.py
+++ b/python/tvm/relay/tensorrt.py
@@ -182,9 +182,10 @@ def EnableTrt(mod, params=None, trt_version=None):
     assert len(trt_version) == 3
 
     # Apply passes required for TRT
-    mod = relay.transform.RemoveUnusedFunctions()(mod)
-    mod = relay.transform.InferType()(mod)
-    mod = relay.transform.ConvertLayout('NCHW')(mod)
+    seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
+                                      relay.transform.ConvertLayout('NCHW')])
+    with relay.transform.PassContext(opt_level=3):
+        mod = seq(mod)
     mod = PreprocessForTrt(mod)
     if params:
         # Bind params so that we can use FoldConstant.
diff --git a/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc b/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc
index c2f22f7db1f20..5f36da3720796 100644
--- a/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc
+++ b/src/relay/backend/contrib/tensorrt/enable_tensorrt.cc
@@ -392,8 +392,8 @@ static const std::unordered_map<std::string, IsCompatibleFn>
         {"mean", ReduceOpChecker},
         {"contrib.adaptive_max_pool2d", AdapativePool2DOpChecker},
         {"contrib.adaptive_avg_pool2d", AdapativePool2DOpChecker},
+        {"clip", AlwaysChecker},
         // Ops which require TRT 5.1.5+
-        {"clip", TrtVersionChecker<5, 1, 5>},
         {"nn.leaky_relu", TrtVersionChecker<5, 1, 5>},
         {"sin", TrtVersionChecker<5, 1, 5>},
         {"cos", TrtVersionChecker<5, 1, 5>},
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index 169783f02b1e5..37bc4c807bd4d 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -92,6 +92,8 @@ GetOpConverters() {
   map->emplace("ceil", std::make_shared<UnaryOpConverter>());
   map->emplace("floor", std::make_shared<UnaryOpConverter>());
   map->emplace("strided_slice", std::make_shared<StridedSliceOpConverter>());
+#else
+  map->emplace("clip", std::make_shared<ClipLegacyOpConverter>());
 #endif
 #if TRT_VERSION_GE(6, 0, 1)
   map->emplace("image.resize", std::make_shared<ResizeOpConverter>());
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h
index 4aee6b9c74272..c52657f755c44 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.h
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h
@@ -147,6 +147,28 @@ class TrtOpConverter {
     // Subtract 1 for implicit batch dim.
     return axis - 1;
   }
+
+  // Create constant that is broadcastable against input.
+  /*!
+   * \brief Create constant that is broadcastable.
+   * \param params Parameters for this op.
+   * \param value Value of scalar.
+   * \param broadcast_to_dims Dims that scalar should be broadcastable against.
+   * \return Constant tensor.
+   */
+  nvinfer1::ITensor* CreateScalar(
+      AddTrtLayerParams* params, float value,
+      const nvinfer1::Dims& broadcast_to_dims) const {
+    nvinfer1::Dims dims;
+    dims.nbDims = broadcast_to_dims.nbDims;
+    std::fill_n(dims.d, dims.nbDims, 1);
+    float* values = new float[1];
+    values[0] = value;
+    nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT,
+                              static_cast<void*>(values), 1};
+    params->trt_weights->push_back(weights);
+    return params->network->addConstant(dims, weights)->getOutput(0);
+  }
 };
 
 class ActivationOpConverter : public TrtOpConverter {
@@ -185,6 +207,40 @@ class ActivationOpConverter : public TrtOpConverter {
   }
 };
 
+class ClipLegacyOpConverter : public TrtOpConverter {
+ public:
+  ClipLegacyOpConverter() : TrtOpConverter({kTensor}) {}
+
+  void Convert(AddTrtLayerParams* params) const {
+    const auto* attrs = params->call->attrs.as<ClipAttrs>();
+    CHECK_EQ(params->inputs.size(), 1) << "Activation op expects 1 input.";
+    auto input = params->inputs.at(0).tensor;
+    // relu(x)
+    nvinfer1::ITensor* output = nullptr;
+    if (attrs->a_min == 0.0f) {
+      // Use relu instead of max(x, 0) because relu can be fused.
+      nvinfer1::IActivationLayer* relu_layer = params->network->addActivation(
+          *input, nvinfer1::ActivationType::kRELU);
+      CHECK(relu_layer != nullptr);
+      output = relu_layer->getOutput(0);
+    } else {
+      // max(x, a_min)
+      nvinfer1::ITensor* a_min =
+          CreateScalar(params, attrs->a_min, input->getDimensions());
+      nvinfer1::IElementWiseLayer* max_layer = params->network->addElementWise(
+          *input, *a_min, nvinfer1::ElementWiseOperation::kMAX);
+      CHECK(max_layer != nullptr);
+      output = max_layer->getOutput(0);
+    }
+    // min(relu(x), a_max)
+    nvinfer1::ITensor* a_max =
+        CreateScalar(params, attrs->a_max, input->getDimensions());
+    nvinfer1::IElementWiseLayer* min_layer = params->network->addElementWise(
+        *output, *a_max, nvinfer1::ElementWiseOperation::kMIN);
+    params->outputs.push_back(min_layer->getOutput(0));
+  }
+};
+
 class ElementWiseBinaryOpConverter : public TrtOpConverter {
  public:
   ElementWiseBinaryOpConverter() : TrtOpConverter({kTensor, kTensor}) {}
diff --git a/tests/python/relay/test_tensorrt.py b/tests/python/relay/test_tensorrt.py
index 1b640a112d97f..cf1cffd1be131 100644
--- a/tests/python/relay/test_tensorrt.py
+++ b/tests/python/relay/test_tensorrt.py
@@ -494,8 +494,8 @@ def check_trt_used(graph):
     i_data = np.random.uniform(0, 1, input_shape).astype(dtype)
     for model in models:
         latency[model], res = test_model(model, i_data, input_shape, dtype, use_trt=True)
-        # _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1)
-        # tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3)
+        _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1)
+        tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3)
     
     for model in models:
         print(model, latency[model])