Skip to content

Commit

Permalink
feat(//py): Working portable package
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 16, 2020
1 parent a71bca9 commit 482ef2c
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 28 deletions.
3 changes: 2 additions & 1 deletion py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def run(self):
extra_link_args=[
"-D_GLIBCXX_USE_CXX11_ABI=0"
"-Wl,--no-as-needed",
"-ltrtorch"
"-ltrtorch",
"-Wl,-rpath,$ORIGIN/lib"
],
undef_macros=[ "NDEBUG" ]
)
Expand Down
8 changes: 0 additions & 8 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
import ctypes
import torch

def _load_trtorch_lib():
lib_name = 'libtrtorch.so'
here = os.path.abspath(__file__)
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)

_load_trtorch_lib()

from trtorch._version import __version__
from trtorch._compiler import *
from trtorch._types import *
Expand Down
19 changes: 15 additions & 4 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import List, Dict, Any
import torch
from torch import nn

import trtorch._C
from trtorch._extra_info import _parse_extra_info
from trtorch._version import __version__
from types import FunctionType


def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -50,7 +54,11 @@ def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.Script
Returns:
torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT
"""
compiled_cpp_mod = trtorch._C._compile_graph(module._c, _parse_extra_info(extra_info))

if isinstance(module, torch.jit.ScriptFunction):
raise TypeError("torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile")

compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_extra_info(extra_info))
compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
return compiled_module

Expand Down Expand Up @@ -98,7 +106,10 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
if isinstance(module, torch.jit.ScriptFunction):
raise TypeError("torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile")

return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))

def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
"""Checks to see if a method is fully supported by TRTorch
Expand All @@ -114,7 +125,7 @@ def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) ->
Returns:
bool: True if supported Method
"""
return trtorch._C._check_method_op_support(module._c, method_name)
return trtorch._C.check_method_op_support(module._c, method_name)

def dump_build_info():
"""Prints build information about the TRTorch distribution to stdout
Expand All @@ -127,7 +138,7 @@ def get_build_info() -> str:
Returns:
str: String containing the build information for TRTorch distribution
"""
build_info = trtorch._C._get_build_info()
build_info = trtorch._C.get_build_info()
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
return build_info

7 changes: 3 additions & 4 deletions py/trtorch/_extra_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ def _parse_device_type(device: Any) -> _types.DeviceType:
else:
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))

def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
info = trtorch._C._ExtraInfo()
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):
def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C.ExtraInfo:
info = trtorch._C.ExtraInfo()
if "input_shapes" not in extra_info:
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")

info.input_ranges = _parse_input_ranges(extra_info["input_shapes"])
print(info.input_ranges)

if "op_precision" in extra_info:
info.op_precision = _parse_op_precision(extra_info["op_precision"])
Expand Down
14 changes: 5 additions & 9 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ struct InputRange {
std::vector<int64_t> max;

core::conversion::InputRange toInternalInputRange() {
for (auto o : opt) {
std::cout << o << std::endl;
}
return core::conversion::InputRange(min, opt, max);
}
};
Expand Down Expand Up @@ -79,7 +76,6 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
struct ExtraInfo {

core::ExtraInfo toInternalExtraInfo() {
std::cout << "HELLO" << input_ranges.size() << std::endl;
for (auto i : input_ranges) {
internal_input_ranges.push_back(i.toInternalInputRange());
}
Expand Down Expand Up @@ -193,7 +189,7 @@ PYBIND11_MODULE(_C, m) {
.value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only")
.value("default", EngineCapability::kDEFAULT, "Use default behavior");

py::class_<ExtraInfo>(m, "_ExtraInfo")
py::class_<ExtraInfo>(m, "ExtraInfo")
.def(py::init<>())
.def_readwrite("input_ranges", &ExtraInfo::input_ranges)
.def_readwrite("op_precision", &ExtraInfo::op_precision)
Expand All @@ -209,10 +205,10 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("max_batch_size", &ExtraInfo::max_batch_size);

m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
m.def("_convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine");
m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string");
m.def("compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded");
m.def("convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine");
m.def("check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");

m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");
m.def("_set_logging_prefix", &logging::set_logging_prefix, "Set the logging prefix for logging output");
Expand Down
4 changes: 2 additions & 2 deletions py/trtorch/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def set_logging_prefix(prefix: str):
Args:
prefix (str): Prefix to use for logging messages
"""
_set_logging_prefix(str)
_set_logging_prefix(prefix)

def get_reportable_log_level() -> Level:
"""Get the level required for a message to be printed in the log
Expand Down Expand Up @@ -84,4 +84,4 @@ def log(level: Level, msg: str):
level (trtorch.logging.Level): Severity of the message
msg (str): Actual message text
"""
_log(level, msg)
_log(Level._to_internal_level(level), msg)

0 comments on commit 482ef2c

Please sign in to comment.