Skip to content

Commit

Permalink
refactor: Reorder the API since everything but the engine is optional
Browse files Browse the repository at this point in the history
Also new destructor to order cleanup

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Nov 21, 2022
1 parent 71082d3 commit 4ab2856
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 8 deletions.
6 changes: 6 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ TRTEngine::TRTEngine(
LOG_DEBUG(*this);
}

TRTEngine::~TRTEngine() {
exec_ctx.reset();
cuda_engine.reset();
rt.reset();
}

void TRTEngine::disable_profiling() {
torch::cuda::synchronize(device_info.id);
profile_execution = false;
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX

~TRTEngine() = default;
~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
Expand Down
17 changes: 13 additions & 4 deletions py/torch_tensorrt/_TRTModuleNext.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TRTModuleNext(torch.nn.Module):

def __init__(
self,
serialized_engine: bytearray,
name: str = "",
serialized_engine: bytearray = bytearray(),
input_binding_names: List[str] = [],
output_binding_names: List[str] = [],
target_device: Device = Device._current_device(),
Expand All @@ -42,6 +42,11 @@ def __init__(
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it.
If binding names are not provided, it is assumed that the engine binding names follow the following convention:
- [symbol].[index in input / output array]
- ex. [x.0, x.1, x.2] -> [y.0]
Args:
name (str): Name for module
serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray
Expand All @@ -51,15 +56,15 @@ def __init__(
Example:
..code-block:: python
..code-block:: py
with io.BytesIO() as engine_bytes:
engine_bytes.write(trt_engine.serialize())
engine_str = engine_bytes.getvalue()
trt_module = TRTModule(
engine_name="my_engine",
serialized_engine=engine_str,
engine_str,
engine_name="my_module",
input_names=["x"],
output_names=["output"],
)
Expand All @@ -69,6 +74,10 @@ def __init__(
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
)
super(TRTModuleNext, self).__init__()

if not isinstance(serialized_engine, bytearray):
ValueError("Expected serialized engine as bytearray")

self.input_binding_names = input_binding_names
self.output_binding_names = output_binding_names
self.name = name
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def lower_pass(
engine_str = engine_bytes.getvalue()

trt_module = TRTModuleNext(
engine_str,
name=module_name,
serialized_engine=engine_str,
input_binding_names=interp_res.input_names,
output_binding_names=interp_res.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/tools/trt_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def lower_mod_default(
engine_str = engine_bytes.getvalue()

res_mod = TRTModuleNext(
engine_str,
name=str(type(mod)),
serialized_engine=engine_str,
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/tools/trt_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _lower_model_to_backend(
engine_str = engine_bytes.getvalue()

return TRTModuleNext(
engine_str,
name=str(type(mod)),
serialized_engine=engine_str,
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
Expand Down

0 comments on commit 4ab2856

Please sign in to comment.