Skip to content

Commit

Permalink
[tensorrt] [byoc] [plugin] allows save external data
Browse files Browse the repository at this point in the history
  • Loading branch information
Civitasv committed Aug 21, 2023
1 parent dcd46ca commit b1653c0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 3 additions & 1 deletion python/tvm/tpat/cuda/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import os

import tvm
import tvm.contrib.graph_executor as runtime
import tvm.relay as relay
Expand Down Expand Up @@ -86,7 +88,7 @@ def run(self):
mod, params = relay.frontend.from_onnx(self._config.onnx_model)

# 2. Tune it
if self._enable_tunning:
if self._enable_tunning and not os.path.exists(self._config.work_dir):
tunning_option = self._config._tune_option()
ms.relay_integration.tune_relay(mod=mod, params=params, **tunning_option)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/tpat/cuda/onnx_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def _handle_trt_not_support_type(
_remove_unnecessary_cast_nodes(graph)

try:
onnx.save(gs.export_onnx(graph), output_model_path)
onnx.save(gs.export_onnx(graph), output_model_path["name"])
except:
onnx.save(gs.export_onnx(graph), output_model_path, save_as_external_data=True)
onnx.save(
gs.export_onnx(graph),
output_model_path["name"],
save_as_external_data=True,
location=output_model_path["weights"],
)


def _remove_unnecessary_cast_nodes(graph):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/tpat/cuda/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def pipeline(
node_names: list[str],
enable_tunning: bool,
tunning_option: object,
output_onnx: str,
output_onnx: object,
) -> Tuple[str, list[str]]:
"""Generate plugins for specified nodes in an ONNX model.
Expand All @@ -135,8 +135,11 @@ def pipeline(
Flag indicating whether tunning is enabled.
tunning_option : object
Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target.
output_onnx : str
output_onnx : object
{ "name": xx, "weights": xx }
Path to the output ONNX file where the modified model will be saved.
It will firstly try to save without weights, if it fails, it will then
save it with weights.
Returns
-------
Expand Down

0 comments on commit b1653c0

Please sign in to comment.