From b0b14c52b31d827e431bde0a545a418852fe77bb Mon Sep 17 00:00:00 2001 From: Maxim Berman Date: Tue, 25 Jun 2019 16:11:13 -0700 Subject: [PATCH] [MXNET-1086] added sub and mul to ONNX->TensorRT conversion (#15344) * added sub and mul to ONNX->TensorRT conversion * add test for elementwise ops in TRT --- CMakeLists.txt | 2 +- .../subgraph/tensorrt/nnvm_to_onnx-inl.h | 12 ++++ .../subgraph/tensorrt/nnvm_to_onnx.cc | 12 ++++ tests/python/tensorrt/test_ops.py | 68 +++++++++++++++++++ 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 tests/python/tensorrt/test_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2142a09d6d2e..0148ac302d54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,7 +47,7 @@ mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON) mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF) mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." ON) -mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF) +mxnet_option(USE_TENSORRT "Enable inference optimization with TensorRT." OFF) mxnet_option(USE_ASAN "Enable Clang/GCC ASAN sanitizers." OFF) mxnet_option(ENABLE_TESTCOVERAGE "Enable compilation with test coverage metric output" OFF) mxnet_option(USE_INT64_TENSOR_SIZE "Use int64_t to represent the total number of elements in a tensor" OFF) diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h index 4a88aee886db..edf4d357e922 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h @@ -126,6 +126,16 @@ void ConvertElementwiseAdd(NodeProto *node_proto, const nnvm::IndexedGraph &ig, const array_view &inputs); +void ConvertElementwiseSub(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertElementwiseMul(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + void ConvertConcatenate(NodeProto *node_proto, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, @@ -152,6 +162,8 @@ static const std::unordered_map converter_map = {"Concat", ConvertConcatenate}, {"Dropout", ConvertDropout}, {"elemwise_add", ConvertElementwiseAdd}, + {"elemwise_sub", ConvertElementwiseSub}, + {"elemwise_mul", ConvertElementwiseMul}, {"Flatten", ConvertFlatten}, {"FullyConnected", ConvertFullyConnected}, {"Pad", ConvertPad}, diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc index da89c2b476ee..9d98a48c2ec2 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc @@ -393,6 +393,18 @@ void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/, node_proto->set_op_type("Add"); } +void ConvertElementwiseSub(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Sub"); +} + +void ConvertElementwiseMul(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Mul"); +} + void ConvertConcatenate(NodeProto* node_proto, const NodeAttrs& attrs, const nnvm::IndexedGraph& /*ig*/, const array_view& /*inputs*/) { diff --git a/tests/python/tensorrt/test_ops.py b/tests/python/tensorrt/test_ops.py new file mode 100644 index 000000000000..2df9104aa06c --- /dev/null +++ b/tests/python/tensorrt/test_ops.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from mxnet.test_utils import assert_almost_equal +import mxnet as mx +import numpy as np +import os + +def check_elementwise_random(op='sum', shape=(1, 3, 224, 224)): + """ + Check elementwise operators with vanilla/TensorRT executors with uniform random tensors + """ + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + if op == 'sum': + sym = a + b + elif op == 'sub': + sym = a - b + elif op == 'mul': + sym = a * b + + a_data = mx.ndarray.random.uniform(shape=shape, ctx=mx.gpu()) + b_data = mx.ndarray.random.uniform(shape=shape, ctx=mx.gpu()) + + executor = sym.simple_bind(ctx=mx.gpu(), a=shape, b=shape, + grad_req='null', force_rebind=True) + y = executor.forward(is_train=False, a=a_data, b=b_data) + trt_sym = sym.get_backend_symbol('TensorRT') + original_precision_value = mx.contrib.tensorrt.get_use_fp16() + try: + mx.contrib.tensorrt.set_use_fp16(True) + executor = trt_sym.simple_bind(ctx=mx.gpu(), a=shape, b=shape, + grad_req='null', force_rebind=True) + y_trt = executor.forward(is_train=False, a=a_data, b=b_data) + mx.contrib.tensorrt.set_use_fp16(False) + executor = trt_sym.simple_bind(ctx=mx.gpu(), a=shape, b=shape, + grad_req='null', force_rebind=True) + y_trt_fp32 = executor.forward(is_train=False, a=a_data, b=b_data) + assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-1, 1e-2) + assert_almost_equal(y[0].asnumpy(), y_trt_fp32[0].asnumpy(), 1e-4, 1e-4) + finally: + mx.contrib.tensorrt.set_use_fp16(original_precision_value) + + +def test_elementwise(): + for op in ['sum', 'sub', 'mul']: + for shape in [(20, 25), (3, 4, 20), (1, 3, 20, 25), (10, 10, 100, 100)]: + for itry in range(10): + check_elementwise_random(op, shape) + + +if __name__ == '__main__': + import nose + nose.runmodule()