diff --git a/python/tvm/tpat/cuda/kernel.py b/python/tvm/tpat/cuda/kernel.py index a1a6c57f57adb..80877d4892e9f 100644 --- a/python/tvm/tpat/cuda/kernel.py +++ b/python/tvm/tpat/cuda/kernel.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import os + import tvm import tvm.contrib.graph_executor as runtime import tvm.relay as relay @@ -86,7 +88,7 @@ def run(self): mod, params = relay.frontend.from_onnx(self._config.onnx_model) # 2. Tune it - if self._enable_tunning: + if self._enable_tunning and not os.path.exists(self._config.work_dir): tunning_option = self._config._tune_option() ms.relay_integration.tune_relay(mod=mod, params=params, **tunning_option) diff --git a/python/tvm/tpat/cuda/onnx_util.py b/python/tvm/tpat/cuda/onnx_util.py index 2c2fa5b702f29..dd2ef1ab0c333 100644 --- a/python/tvm/tpat/cuda/onnx_util.py +++ b/python/tvm/tpat/cuda/onnx_util.py @@ -90,9 +90,14 @@ def _handle_trt_not_support_type( _remove_unnecessary_cast_nodes(graph) try: - onnx.save(gs.export_onnx(graph), output_model_path) + onnx.save(gs.export_onnx(graph), output_model_path["name"]) except: - onnx.save(gs.export_onnx(graph), output_model_path, save_as_external_data=True) + onnx.save( + gs.export_onnx(graph), + output_model_path["name"], + save_as_external_data=True, + location=output_model_path["weights"], + ) def _remove_unnecessary_cast_nodes(graph): diff --git a/python/tvm/tpat/cuda/pipeline.py b/python/tvm/tpat/cuda/pipeline.py index 45ca7747d9e44..5bdcf31ed6234 100644 --- a/python/tvm/tpat/cuda/pipeline.py +++ b/python/tvm/tpat/cuda/pipeline.py @@ -119,7 +119,7 @@ def pipeline( node_names: list[str], enable_tunning: bool, tunning_option: object, - output_onnx: str, + output_onnx: object, ) -> Tuple[str, list[str]]: """Generate plugins for specified nodes in an ONNX model. @@ -135,8 +135,11 @@ def pipeline( Flag indicating whether tunning is enabled. tunning_option : object Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target. - output_onnx : str + output_onnx : object + { "name": xx, "weights": xx } Path to the output ONNX file where the modified model will be saved. + It will firstly try to save without weights, if it fails, it will then + save it with weights. Returns -------