Skip to content

Commit

Permalink
Update TensorRT tests and add to CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Aug 28, 2019
1 parent 8190d9b commit 782907d
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 15 deletions.
42 changes: 42 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
//
ci_lint = "tvmai/ci-lint:v0.51"
ci_gpu = "tvmai/ci-gpu:v0.52"
ci_gpu_trt = "trevoram/tvm:latest"
ci_cpu = "tvmai/ci-cpu:v0.50"
ci_i386 = "tvmai/ci-i386:v0.50"

Expand Down Expand Up @@ -167,6 +168,36 @@ stage('Build') {
}
}
},
'BUILD: GPUTRT': {
node('GPUTRTBUILD') {
ws('workspace/tvm/build-gpu-trt') {
init_git()
sh """
mkdir -p build
cd build
cp ../cmake/config.cmake .
echo set\\(USE_CUBLAS ON\\) >> config.cmake
echo set\\(USE_CUDNN ON\\) >> config.cmake
echo set\\(USE_CUDA ON\\) >> config.cmake
echo set\\(USE_TENSORRT /usr/include/x86_64-linux-gnu/\\) >> config.cmake
echo set\\(USE_OPENGL ON\\) >> config.cmake
echo set\\(USE_MICRO ON\\) >> config.cmake
echo set\\(USE_LLVM llvm-config-6.0\\) >> config.cmake
echo set\\(USE_NNPACK ON\\) >> config.cmake
echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake
echo set\\(USE_RPC ON\\) >> config.cmake
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
"""
make(ci_gpu, 'build', '-j4')
pack_lib('gpu-trt', tvm_multilib)
}
},
'BUILD: CPU': {
node('CPU') {
ws('workspace/tvm/build-cpu') {
Expand Down Expand Up @@ -293,6 +324,17 @@ stage('Integration Test') {
}
}
},
'tensorrt: GPUTRT': {
node('GPUTRT') {
ws('workspace/tvm/tensorrt-python-gpu') {
init_git()
unpack_lib('gpu-trt', tvm_multilib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${ci_gpu_trt} ./tests/scripts/task_python_tensorrt.sh"
}
}
}
},
'docs: GPU': {
node('GPU') {
ws('workspace/tvm/docs-python-gpu') {
Expand Down
15 changes: 10 additions & 5 deletions cmake/modules/contrib/TensorRT.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@

# TensorRT Module

if(IS_DIRECTORY ${USE_TENSORRT})
set(TENSORRT_ROOT_DIR ${USE_TENSORRT})
message(STATUS "Custom TensorRT path: " ${TENSORRT_ROOT_DIR})
set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT_DIR}/include)
set(TENSORRT_LIB_DIR ${TENSORRT_ROOT_DIR}/lib)
if(USE_TENSORRT)
if(IS_DIRECTORY ${USE_TENSORRT})
set(TENSORRT_ROOT_DIR ${USE_TENSORRT})
endif()
find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES include)
find_library(TENSORRT_LIB_DIR nvinfer HINTS ${TENSORRT_ROOT_DIR} PATH_SUFFIXES lib)
find_package_handle_standard_args(TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIB_DIR)
if(NOT TENSORRT_FOUND)
message(ERROR "Could not find TensorRT.")
endif()
file(GLOB TENSORRT_SRCS src/contrib/subgraph/*.cc)
include_directories(${TENSORRT_INCLUDE_DIR})
list(APPEND RUNTIME_SRCS ${TENSORRT_SRCS})
Expand Down
119 changes: 119 additions & 0 deletions docker/Dockerfile.ci_gpu_trt
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# CI docker GPU TRT env
# tag: v0.50
FROM nvcr.io/nvidia/tensorrt:19.07-py3

# Base scripts
RUN apt-get update --fix-missing

COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh

COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh
RUN bash /install/ubuntu_install_python.sh

COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh

COPY install/ubuntu_install_opencl.sh /install/ubuntu_install_opencl.sh
RUN bash /install/ubuntu_install_opencl.sh

COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh

COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
RUN bash /install/ubuntu_install_sphinx.sh

# Fix recommonmark to latest version
RUN git clone --depth=1 https://github.com/rtfd/recommonmark
RUN cd recommonmark; python3 setup.py install

# Enable doxygen for c++ doc build
RUN apt-get update && apt-get install -y doxygen graphviz libprotobuf-dev protobuf-compiler

COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh

COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh
RUN bash /install/ubuntu_install_nodejs.sh

COPY install/ubuntu_install_rocm.sh /install/ubuntu_install_rocm.sh
RUN bash /install/ubuntu_install_rocm.sh

COPY install/ubuntu_install_opengl.sh /install/ubuntu_install_opengl.sh
RUN bash /install/ubuntu_install_opengl.sh

# DL Frameworks
COPY install/ubuntu_install_mxnet.sh /install/ubuntu_install_mxnet.sh
RUN bash /install/ubuntu_install_mxnet.sh

COPY install/ubuntu_install_gluoncv.sh /install/ubuntu_install_gluoncv.sh
RUN bash /install/ubuntu_install_gluoncv.sh

COPY install/ubuntu_install_coreml.sh /install/ubuntu_install_coreml.sh
RUN bash /install/ubuntu_install_coreml.sh

COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh
RUN bash /install/ubuntu_install_tensorflow.sh

COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh
RUN bash /install/ubuntu_install_keras.sh

COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
RUN bash /install/ubuntu_install_darknet.sh

COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh
RUN bash /install/ubuntu_install_onnx.sh

COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh
RUN bash /install/ubuntu_install_tflite.sh

COPY install/ubuntu_install_caffe2.sh /install/ubuntu_install_caffe2.sh
RUN bash /install/ubuntu_install_caffe2.sh

RUN pip3 install Pillow

COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh
RUN bash /install/ubuntu_install_vulkan.sh

# AutoTVM deps
COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh
RUN bash /install/ubuntu_install_redis.sh

COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
RUN bash /install/ubuntu_install_antlr.sh

# NNPACK deps
COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh
RUN bash /install/ubuntu_install_nnpack.sh

# Environment variables
ENV PATH=/usr/local/nvidia/bin:${PATH}
ENV PATH=/usr/local/cuda/bin:${PATH}
ENV CPLUS_INCLUDE_PATH=/usr/local/cuda/include:${CPLUS_INCLUDE_PATH}
ENV C_INCLUDE_PATH=/usr/local/cuda/include:${C_INCLUDE_PATH}
ENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LIBRARY_PATH}
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}

ENV LD_LIBRARY_PATH=/opt/rocm/lib:${LD_LIBRARY_PATH}
ENV PATH=/node_modules/.bin:${PATH}
ENV VULKAN_SDK=/usr/local/VulkanSDK/1.0.65.0/x86_64
ENV PATH=${PATH}:${VULKAN_SDK}/bin
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${VULKAN_SDK}/lib
ENV VK_LAYER_PATH=${VULKAN_SDK}/etc/explicit_layer.d
15 changes: 13 additions & 2 deletions tests/python/tensorrt/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@
import nnvm
import tvm
from tvm.contrib import graph_runtime
import json


def test_avg_pool2d():

# Generate the data
np.random.seed(0)
input_shape = [1, 1, 28, 28]
output_shape = [1, 10]
output_shape = [1, 1, 28, 28]
data = np.random.random(input_shape).astype('float32')

# Baseline model in MXNet
net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.AvgPool2D(pool_size=3, strides=1, padding=1))
net.add(gluon.nn.Dense(10))
net.collect_params().initialize(mx.init.Xavier(), ctx=mx.cpu())
net.hybridize()
baseline_input = mx.nd.array(data, ctx=mx.cpu())
Expand All @@ -48,6 +48,17 @@ def test_avg_pool2d():
graph, lib, params = nnvm.compiler.build(sym, target,
shape={'data': input_shape},
params=params)

# Verify that TRT subgraphs are partitioned
def check_trt_used(graph):
graph = json.loads(graph.json())
num_trt_subgraphs = sum([1 for n in graph['nodes'] if n['op'] == '_tensorrt_subgraph_op'])
assert num_trt_subgraphs == 1
check_trt_used(graph)

# Execute
if not tvm.module.enabled("gpu"):
return
compiled_model = graph_runtime.create(graph, lib, tvm.gpu())
compiled_input = tvm.nd.array(data, ctx=tvm.gpu())
compiled_model.set_input('data', compiled_input)
Expand Down
9 changes: 9 additions & 0 deletions tests/python/tensorrt/test_cross_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm
from tvm.contrib import graph_runtime
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
import json

batch_size = 1

Expand Down Expand Up @@ -96,6 +97,14 @@ def get_data_shape(model_name):
with nnvm.compiler.build_config(opt_level=opt_level, ext_accel=ext_accel):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params, target_host=target_host)

# Verify that TRT subgraphs are partitioned
def check_trt_used(graph):
graph = json.loads(graph.json())
num_trt_subgraphs = sum([1 for n in graph['nodes'] if n['op'] == '_tensorrt_subgraph_op'])
assert num_trt_subgraphs >= 1
check_trt_used(graph)

print("===========Compiling model %s took %.3fs" % (network, time.time() - start))

print("===========Saving lowered graph for model %s" % network)
Expand Down
23 changes: 15 additions & 8 deletions tests/python/tensorrt/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
logging.basicConfig(level=logging.INFO)
import numpy as np
import json

import nnvm.compiler
import nnvm.testing
Expand All @@ -30,15 +31,11 @@


def test_tensorrt_image_classification_models():
def compile_model(graph, params, data_shapes, subgraph_backend=None, op_names=None, **kwargs):
def compile_model(graph, params, data_shapes, **kwargs):
_, output_shapes = nnvm.compiler.graph_util.infer_shape(graph, **data_shapes)
assert len(output_shapes) == 1
flags = kwargs
if subgraph_backend is not None and op_names is not None:
graph = nnvm.subgraph._partition(graph, subgraph_backend, op_names)
flags = {}
target = tvm.target.cuda()
with nnvm.compiler.build_config(opt_level=3, **flags):
with nnvm.compiler.build_config(opt_level=3, **kwargs):
graph, lib, params = nnvm.compiler.build(
graph, target, shape=data_shapes, params=params)
return graph, lib, params, output_shapes[0]
Expand All @@ -60,7 +57,16 @@ def copy_params(params):
def check_trt_model(baseline_module, baseline_params, graph, params, data_shape,
subgraph_backend=None, op_names=None, **kwargs):
trt_graph, trt_lib, trt_params, output_shape = compile_model(graph, params, {'data': data_shape},
subgraph_backend, op_names, **kwargs)
**kwargs)
# Verify that TRT subgraphs are partitioned
def check_trt_used(graph):
graph = json.loads(graph.json())
num_trt_subgraphs = sum([1 for n in graph['nodes'] if n['op'] == '_tensorrt_subgraph_op'])
assert num_trt_subgraphs >= 1
check_trt_used(trt_graph)

if not tvm.module.enabled("gpu"):
return
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
baseline_out = get_output(baseline_module, data, baseline_params, output_shape)
trt_module = graph_runtime.create(trt_graph, trt_lib, tvm.gpu())
Expand Down Expand Up @@ -94,7 +100,8 @@ def check_trt_model(baseline_module, baseline_params, graph, params, data_shape,
shape={'data': data_shape}, params=copy_params(params))
baseline_module = graph_runtime.create(baseline_graph, baseline_lib, tvm.gpu())

# test whole graph run using tensorrt, nnvm.compiler.build_config has graph partitioning turned on
# Test whole graph run using tensorrt. nnvm.compiler.build_config has
# graph partitioning turned on when ext_accel='tensorrt'.
check_trt_model(baseline_module, baseline_params, nnvm.graph.load_json(graph_json_str),
copy_params(params), data_shape, ext_accel='tensorrt')

Expand Down
27 changes: 27 additions & 0 deletions tests/scripts/task_python_tensorrt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e
set -u

export PYTHONPATH=nnvm/python:python:topi/python
export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"

rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc

TVM_FFI=ctypes python3 -m nose -v tests/python/tensorrt

0 comments on commit 782907d

Please sign in to comment.