Skip to content

Commit

Permalink
[tensorrt] [byoc] [plugin] Allow users to specify tunning option
Browse files Browse the repository at this point in the history
  • Loading branch information
Civitasv committed Aug 11, 2023
1 parent 2e68c8b commit 0d8cddb
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
14 changes: 9 additions & 5 deletions python/tvm/tpat/cuda/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@


class Config(object):
def __init__(self, onnx_model, input_shapes, target, work_dir) -> None:
def __init__(self, onnx_model, input_shapes, target, tunning_option) -> None:
self.onnx_model = onnx_model
self.input_shapes = input_shapes
self.work_dir = work_dir
self.tunning_option = tunning_option
self.work_dir = tunning_option["work_dir"] or "./log_db"

if target == "gpu":
self.target = self._detect_cuda_target()

def tune_option(self):
return {
default = {
"target": self.target,
"builder": ms.builder.LocalBuilder(),
"runner": ms.runner.LocalRunner(),
Expand All @@ -41,6 +42,9 @@ def tune_option(self):
"work_dir": self.work_dir,
}

default.update(self.tunning_option)
return default

def _detect_cuda_target(self):
dev = tvm.cuda()
if not dev.exist:
Expand All @@ -59,10 +63,10 @@ def _detect_cuda_target(self):


class Kernel(object):
def __init__(self, name, onnx_model, input_shapes, enable_tunning, work_dir):
def __init__(self, name, onnx_model, input_shapes, enable_tunning, tunning_option):
self._name = name
self._enable_tunning = enable_tunning
self._config = Config(onnx_model, input_shapes, "gpu", work_dir)
self._config = Config(onnx_model, input_shapes, "gpu", tunning_option)

self._lib = None
self._module = None
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tpat/cuda/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _extract_target_onnx_node(model, tunning_node):


def pipeline(
onnx_file: str, node_names: list[str], enable_tunning: bool, work_dir: str, output_onnx: str
onnx_file: str, node_names: list[str], enable_tunning: bool, tunning_option: object, output_onnx: str
) -> Tuple[str, list[str]]:
"""Generate plugins for specified nodes in an ONNX model.
Expand All @@ -73,8 +73,8 @@ def pipeline(
Names of the nodes to be generated as TensorRT plugins.
enable_tunning : bool
Flag indicating whether tunning is enabled.
work_dir : str
Path to the tunning log file where the records will be saved.
tunning_option : object
Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target.
output_onnx : str
Path to the output ONNX file where the modified model will be saved.
Expand Down Expand Up @@ -106,7 +106,7 @@ def pipeline(

subgraph, submodel, shapes = _extract_target_onnx_node(inferred_model, node)

kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, work_dir)
kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, tunning_option)
kernel.run()

## 3.1 fill in template
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/tpat/cuda/plugin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# limitations under the License.
#

# Variables need to be defined by Users
CUDA_PATH = /path/to/cuda
CUDNN_PATH = /path/to/cudnn
TRT_PATH = /path/to/TensorRT
ARCH = sm_86
########################################

CUDA_INC_PATH = $(CUDA_PATH)/include
CUDA_LIB_PATH = $(CUDA_PATH)/lib
Expand All @@ -28,13 +31,9 @@ CUDNN_LIB_PATH = $(CUDNN_PATH)/lib
TRT_INC_PATH = $(TRT_PATH)/include
TRT_LIB_PATH = $(TRT_PATH)/lib


ARCH = sm_86
GCC = g++
NVCC = $(CUDA_PATH)/bin/nvcc
# CCFLAGS = -g -std=c++11 -DNDEBUG
CCFLAGS = -w -std=c++11
# CCFLAGS+= -DDEBUG_ME
INCLUDES := -I. -I$(CUDA_COM_PATH) -I$(CUDA_INC_PATH) -I$(CUDNN_INC_PATH) -I$(TRT_INC_PATH) -I/usr/include

LDFLAGS := -L$(CUDA_LIB_PATH) -L$(CUDNN_LIB_PATH) -L$(TRT_LIB_PATH)
Expand Down
12 changes: 10 additions & 2 deletions tests/python/tpat/cuda/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def name_without_num(name):
ops_name = [op_name]

_, trt_plugin_names = tpat.cuda.pipeline(
INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE
INPUT_MODEL_FILE,
ops_name,
False,
{"work_dir": "./log_db", "max_trials_per_task": 500},
OUTPUT_MODEL_FILE,
)

load_plugin(trt_plugin_names)
Expand Down Expand Up @@ -197,7 +201,11 @@ def verify_with_ort_with_trt(
ops_name = [op_name]

_, trt_plugin_names = tpat.cuda.pipeline(
INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE
INPUT_MODEL_FILE,
ops_name,
False,
{"work_dir": "./log_db", "max_trials_per_task": 500},
OUTPUT_MODEL_FILE,
)

load_plugin(trt_plugin_names)
Expand Down

0 comments on commit 0d8cddb

Please sign in to comment.