diff --git a/python/tvm/tpat/cuda/kernel.py b/python/tvm/tpat/cuda/kernel.py index b9a543acb33d9..dc4b68964f697 100644 --- a/python/tvm/tpat/cuda/kernel.py +++ b/python/tvm/tpat/cuda/kernel.py @@ -23,16 +23,17 @@ class Config(object): - def __init__(self, onnx_model, input_shapes, target, work_dir) -> None: + def __init__(self, onnx_model, input_shapes, target, tunning_option) -> None: self.onnx_model = onnx_model self.input_shapes = input_shapes - self.work_dir = work_dir + self.tunning_option = tunning_option + self.work_dir = tunning_option["work_dir"] or "./log_db" if target == "gpu": self.target = self._detect_cuda_target() def tune_option(self): - return { + default = { "target": self.target, "builder": ms.builder.LocalBuilder(), "runner": ms.runner.LocalRunner(), @@ -41,6 +42,9 @@ def tune_option(self): "work_dir": self.work_dir, } + default.update(self.tunning_option) + return default + def _detect_cuda_target(self): dev = tvm.cuda() if not dev.exist: @@ -59,10 +63,10 @@ def _detect_cuda_target(self): class Kernel(object): - def __init__(self, name, onnx_model, input_shapes, enable_tunning, work_dir): + def __init__(self, name, onnx_model, input_shapes, enable_tunning, tunning_option): self._name = name self._enable_tunning = enable_tunning - self._config = Config(onnx_model, input_shapes, "gpu", work_dir) + self._config = Config(onnx_model, input_shapes, "gpu", tunning_option) self._lib = None self._module = None diff --git a/python/tvm/tpat/cuda/pipeline.py b/python/tvm/tpat/cuda/pipeline.py index 5e1d112626df7..8302fd0cb769f 100644 --- a/python/tvm/tpat/cuda/pipeline.py +++ b/python/tvm/tpat/cuda/pipeline.py @@ -59,7 +59,7 @@ def _extract_target_onnx_node(model, tunning_node): def pipeline( - onnx_file: str, node_names: list[str], enable_tunning: bool, work_dir: str, output_onnx: str + onnx_file: str, node_names: list[str], enable_tunning: bool, tunning_option: object, output_onnx: str ) -> Tuple[str, list[str]]: """Generate plugins for specified nodes in an ONNX model. @@ -73,8 +73,8 @@ def pipeline( Names of the nodes to be generated as TensorRT plugins. enable_tunning : bool Flag indicating whether tunning is enabled. - work_dir : str - Path to the tunning log file where the records will be saved. + tunning_option : object + Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target. output_onnx : str Path to the output ONNX file where the modified model will be saved. @@ -106,7 +106,7 @@ def pipeline( subgraph, submodel, shapes = _extract_target_onnx_node(inferred_model, node) - kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, work_dir) + kernel = Kernel(plugin_name, submodel, shapes, enable_tunning, tunning_option) kernel.run() ## 3.1 fill in template diff --git a/python/tvm/tpat/cuda/plugin/Makefile b/python/tvm/tpat/cuda/plugin/Makefile index d90f15f1bd77f..1aa97fcb7b627 100644 --- a/python/tvm/tpat/cuda/plugin/Makefile +++ b/python/tvm/tpat/cuda/plugin/Makefile @@ -14,9 +14,12 @@ # limitations under the License. # +# Variables need to be defined by Users CUDA_PATH = /path/to/cuda CUDNN_PATH = /path/to/cudnn TRT_PATH = /path/to/TensorRT +ARCH = sm_86 +######################################## CUDA_INC_PATH = $(CUDA_PATH)/include CUDA_LIB_PATH = $(CUDA_PATH)/lib @@ -28,13 +31,9 @@ CUDNN_LIB_PATH = $(CUDNN_PATH)/lib TRT_INC_PATH = $(TRT_PATH)/include TRT_LIB_PATH = $(TRT_PATH)/lib - -ARCH = sm_86 GCC = g++ NVCC = $(CUDA_PATH)/bin/nvcc -# CCFLAGS = -g -std=c++11 -DNDEBUG CCFLAGS = -w -std=c++11 -# CCFLAGS+= -DDEBUG_ME INCLUDES := -I. -I$(CUDA_COM_PATH) -I$(CUDA_INC_PATH) -I$(CUDNN_INC_PATH) -I$(TRT_INC_PATH) -I/usr/include LDFLAGS := -L$(CUDA_LIB_PATH) -L$(CUDNN_LIB_PATH) -L$(TRT_LIB_PATH) diff --git a/tests/python/tpat/cuda/common.py b/tests/python/tpat/cuda/common.py index 250535015d1fc..019a0cf366b00 100644 --- a/tests/python/tpat/cuda/common.py +++ b/tests/python/tpat/cuda/common.py @@ -94,7 +94,11 @@ def name_without_num(name): ops_name = [op_name] _, trt_plugin_names = tpat.cuda.pipeline( - INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE + INPUT_MODEL_FILE, + ops_name, + False, + {"work_dir": "./log_db", "max_trials_per_task": 500}, + OUTPUT_MODEL_FILE, ) load_plugin(trt_plugin_names) @@ -197,7 +201,11 @@ def verify_with_ort_with_trt( ops_name = [op_name] _, trt_plugin_names = tpat.cuda.pipeline( - INPUT_MODEL_FILE, ops_name, False, "./log_db", OUTPUT_MODEL_FILE + INPUT_MODEL_FILE, + ops_name, + False, + {"work_dir": "./log_db", "max_trials_per_task": 500}, + OUTPUT_MODEL_FILE, ) load_plugin(trt_plugin_names)