From 7088245620661e0390b7357e84bcd4e4d2a11db5 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 13 May 2020 20:12:01 -0700 Subject: [PATCH] feat(//py): Inital introduction of the Python API Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- py/BUILD | 0 py/requirements.txt | 0 py/setup.py | 138 ++++++++++++++++++++++ py/trtorch/__init__.py | 24 ++++ py/trtorch/compiler.py | 153 ++++++++++++++++++++++++ py/trtorch/csrc/trtorch_py.cpp | 206 +++++++++++++++++++++++++++++++++ py/trtorch/types.py | 1 + py/trtorch/version.py | 1 + 8 files changed, 523 insertions(+) create mode 100644 py/BUILD create mode 100644 py/requirements.txt create mode 100644 py/setup.py create mode 100644 py/trtorch/__init__.py create mode 100644 py/trtorch/compiler.py create mode 100644 py/trtorch/csrc/trtorch_py.cpp create mode 100644 py/trtorch/types.py create mode 100644 py/trtorch/version.py diff --git a/py/BUILD b/py/BUILD new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/requirements.txt b/py/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/setup.py b/py/setup.py new file mode 100644 index 0000000000..7dc86f09e0 --- /dev/null +++ b/py/setup.py @@ -0,0 +1,138 @@ +from setuptools import setup, Extension, find_packages +from setuptools.command.build_ext import build_ext +import sys +import setuptools +import os +from torch.utils import cpp_extension +from shutil import copyfile + +dir_path = os.path.dirname(os.path.realpath(__file__)) + +__version__ = '0.0.1' + +def gen_version_file(): + if not os.path.exists(dir_path + '/trtorch/version.py'): + os.mknod(dir_path + '/trtorch/version.py') + + with open(dir_path + '/trtorch/version.py', 'w') as f: + print("creating version file") + f.write("__version__ = \"" + __version__ + '\"') + +def copy_libtrtorch(): + if not os.path.exists(dir_path + '/trtorch/lib'): + os.makedirs(dir_path + '/trtorch/lib') + + print("copying library into module") + copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so') + +class DevelopCommand(develop): + description = "Builds the package and symlinks it into the PYTHONPATH" + user_options = develop.user_options + plugins_user_options + + def initialize_options(self): + develop.initialize_options(self) + + def finalize_options(self): + develop.finalize_options(self) + + def run(self): + gen_version_file() + copy_libtrtorch() + develop.run(self) + + +class InstallCommand(install): + description = "Builds the package" + user_options = install.user_options + plugins_user_options + + def initialize_options(self): + install.initialize_options(self) + + def finalize_options(self): + install.finalize_options(self) + + def run(self): + gen_version_file() + copy_libtrtorch() + install.run(self) + +class CleanCommand(Command): + """Custom clean command to tidy up the project root.""" + PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './*.pyc', './*.tgz', './*.egg-info'] + description = "Command to tidy up the project root" + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + for path_spec in self.PY_CLEAN_FILES: + # Make paths absolute and relative to this path + abs_paths = glob.glob(os.path.normpath(os.path.join(dir_path, path_spec))) + for path in [str(p) for p in abs_paths]: + if not path.startswith(root_dir): + # Die if path in CLEAN_FILES is absolute + outside this directory + raise ValueError("%s is not a path inside %s" % (path, root_dir)) + print('Removing %s' % os.path.relpath(path)) + shutil.rmtree(path) + +ext_modules = [ + cpp_extension.CUDAExtension('trtorch._C', + ['trtorch/csrc/trtorch_py.cpp'], + library_dirs=[ + dir_path + '/trtorch/lib/libtrtorch.so', + dir_path + '/trtorch/lib/' + ], + libraries=[ + "trtorch" + ], + include_dirs=[ + dir_path + "/../", + ], + extra_compile_args=[ + "-D_GLIBCXX_USE_CXX11_ABI=0" + ], + extra_link_args=[ + "-D_GLIBCXX_USE_CXX11_ABI=0" + "-Wl,--no-as-needed", + "-ltrtorch" + ], + undef_macros=[ "NDEBUG" ] + ) +] + +setup( + name='trtorch', + version=__version__, + author='NVIDIA Corporation.', + author_email='narens@nvidia.com', + url='https://github.com/nvidia/trtorch', + description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs', + long_description='', + ext_modules=ext_modules, + install_requires=['pybind11>=2.4'], + setup_requires=['pybind11>=2.4'], + cmdclass={ + 'install': InstallCommand, + 'clean': CleanCommand, + 'develop': DevelopCommand, + 'build_ext': cpp_extension.BuildExtension + }, + zip_safe=False, + license="BSD-3", + packages=find_packages(), + classifiers=["Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "Programming Language :: C++", + "Programming Language :: Python", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artifical Intelligence", + "Topic :: Software Development", + "Topic :: Software Developement :: Libraries"], + +) diff --git a/py/trtorch/__init__.py b/py/trtorch/__init__.py new file mode 100644 index 0000000000..e7f690effd --- /dev/null +++ b/py/trtorch/__init__.py @@ -0,0 +1,24 @@ +import os +import sys + +if sys.version_info < (3,): + raise Exception("Python 2 has reached end-of-life and is not supported by TRTorch") + +import ctypes +import torch + +def _load_trtorch_lib(): + lib_name = 'libtrtorch.so' + here = os.path.abspath(__file__) + lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + +_load_trtorch_lib() + +from .version import __version__ +#from trtorch import _C +from trtorch.compiler import * +from trtorch.types import * + +def test(mod, data): + _C._test(mod._c, data) diff --git a/py/trtorch/compiler.py b/py/trtorch/compiler.py new file mode 100644 index 0000000000..053da6e589 --- /dev/null +++ b/py/trtorch/compiler.py @@ -0,0 +1,153 @@ +from typing import List, Dict, Any +import torch +import tensorrt as trt +import trtorch._C +from trtorch import types +from .version import __version__ + +def _supported_input_size_type(input_size: Any) -> bool: + if isinstance(input_size, torch.Size): + return True + elif isinstance(input_size, tuple): + return True + elif isinstance(input_size, list): + return True + else: + raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size))) + +def _parse_input_sizes(input_sizes: List) -> List: + + if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes): + raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict") + + parsed_input_sizes = [] + for i in input_sizes: + if isinstance(i, dict): + if all (k in i for k in ["min", "opt", "min"]): + in_range = trtorch._C.InputRange() + in_range.min = i["min"] + in_range.opt = i["opt"] + in_range.max = i["max"] + + parsed_input_sizes.append(in_range.to_internal_input_range()) + + elif "opt" in i: + in_range = trtorch._C.InputRange() + in_range.min = i["opt"] + in_range.opt = i["opt"] + in_range.max = i["opt"] + + parsed_input_sizes.append(in_range.to_internal_input_range()) + + else: + raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict") + + elif isinstance(i, list): + in_range = trtorch._C.InputRange() + in_range.min = i + in_range.opt = i + in_range.max = i + + parsed_input_sizes.append(in_range.to_internal_input_range()) + + return parsed_input_sizes + +def _parse_op_precision(precision: Any) -> types.dtype: + if isinstance(precision, torch.dtype): + if precision == torch.int8: + return types.dtype.int8 + elif precision == torch.half: + return types.dtype.half + elif precision == torch.float: + return types.dtype.float + else: + raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision)) + + elif isinstance(precision, types.DataTypes): + return precision + + else: + raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision))) + +def _parse_device_type(device: Any) -> types.DeviceType: + if isinstance(device, torch.device): + if torch.device.type == 'cuda': + return types.DeviceType.gpu + else: + raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type)) + + elif isinstance(device, types.DeviceType): + return device + + else: + raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device))) + +def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo: + info = trtorch._C._ExtraInfo() + if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list): + raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict") + + info.input_ranges = _parse_input_sizes(extra_info["input_shapes"]) + + if "op_precision" in extra_info: + info.op_precision = _parse_op_precision(extra_info["op_precision"]) + + if "refit" in extra_info: + assert isinstance(extra_info["refit"], bool) + info.refit = extra_info["refit"] + + if "debug" in extra_info: + assert isinstance(extra_info["debug"], bool) + info.debug = extra_info["debug"] + + if "strict_types" in extra_info: + assert isinstance(extra_info["strict_types"], bool) + info.strict_types = extra_info["strict_types"] + + if "allow_gpu_fallback" in extra_info: + assert isinstance(extra_info["allow_gpu_fallback"], bool) + info.allow_gpu_fallback = extra_info["allow_gpu_fallback"] + + if "device" in extra_info: + info.device = _parse_device_type(extra_info["device"]) + + if "capability" in extra_info: + assert isinstance(extra_info["capability"], type.EngineCapability) + info.capability = extra_info["capability"] + + + if "num_min_timing_iters" in extra_info: + assert type(extra_info["num_min_timing_iters"]) is int + info.num_min_timing_iters = extra_info["num_min_timing_iters"] + + if "num_avg_timing_iters" in extra_info: + assert type(extra_info["num_avg_timing_iters"]) is int + info.num_avg_timing_iters = extra_info["num_avg_timing_iters"] + + if "workspace_size" in extra_info: + assert type(extra_info["workspace_size"]) is int + info.workspace_size = extra_info["workspace_size"] + + if "max_batch_size" in extra_info: + assert type(extra_info["max_batch_size"]) is int + info.max_batch_size = extra_info["max_batch_size"] + + return info + +def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule: + return module + +def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str: + return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info)) + +def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: + return trtorch._C._check_method_op_support(module._c, method_name) + +def dump_build_info(): + print(get_build_info()) + +def get_build_info() -> str: + build_info = trtorch._C._get_build_info() + build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info + return build_info + diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp new file mode 100644 index 0000000000..a0204461f4 --- /dev/null +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -0,0 +1,206 @@ +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "core/compiler.h" +#include "core/conversion/conversion.h" +#include "torch/torch.h" +#include "torch/script.h" +#include "torch/csrc/jit/python/pybind_utils.h" +#include "Python.h" + +namespace py = pybind11; + +namespace trtorch { +namespace pyapi { + +struct InputRange { + std::vector min; + std::vector opt; + std::vector max; + + core::conversion::InputRange toInternalInputRange() { + return core::conversion::InputRange(min, opt, max); + } +}; + +enum class DataType : int8_t { + kFloat, + kHalf, + kChar, +}; + +nvinfer1::DataType toTRTDataType(DataType value) { + switch (value) { + case DataType::kChar: + return nvinfer1::DataType::kINT8; + case DataType::kHalf: + return nvinfer1::DataType::kHALF; + case DataType::kFloat: + default: + return nvinfer1::DataType::kFLOAT; + } +} + +enum DeviceType : int8_t { + kGPU, + kDLA, +}; + +nvinfer1::DeviceType toTRTDeviceType(DeviceType value) { + switch (value) { + case DeviceType::kDLA: + return nvinfer1::DeviceType::kDLA; + case DeviceType::kGPU: + default: + return nvinfer1::DeviceType::kDLA; + } +} + +enum class EngineCapability : int8_t { + kDEFAULT, + kSAFE_GPU, + kSAFE_DLA, +}; + +nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { + switch (value) { + case EngineCapability::kSAFE_DLA: + return nvinfer1::EngineCapability::kSAFE_DLA; + case EngineCapability::kSAFE_GPU: + return nvinfer1::EngineCapability::kSAFE_GPU; + case EngineCapability::kDEFAULT: + default: + return nvinfer1::EngineCapability::kDEFAULT; + } +} + +struct ExtraInfo { + + core::ExtraInfo toInternalExtraInfo() { + auto info = core::ExtraInfo(input_ranges); + info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision); + info.convert_info.engine_settings.refit = refit; + info.convert_info.engine_settings.debug = debug; + info.convert_info.engine_settings.strict_types = strict_types; + info.convert_info.engine_settings.allow_gpu_fallback = allow_gpu_fallback; + info.convert_info.engine_settings.device = toTRTDeviceType(device); + info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); + info.convert_info.engine_settings.num_min_timing_iters = num_min_timing_iters; + info.convert_info.engine_settings.num_avg_timing_iters = num_avg_timing_iters; + info.convert_info.engine_settings.workspace_size = workspace_size; + info.convert_info.engine_settings.max_batch_size = max_batch_size; + return info; + } + + std::vector input_ranges; + DataType op_precision = DataType::kFloat; + bool refit = false; + bool debug = false; + bool strict_types = false; + bool allow_gpu_fallback = true; + DeviceType device = DeviceType::kGPU; + EngineCapability capability = EngineCapability::kDEFAULT; + uint64_t num_min_timing_iters = 2; + uint64_t num_avg_timing_iters = 1; + uint64_t workspace_size = 0; + uint64_t max_batch_size = 0; +}; + + +torch::jit::Module CompileGraph(const torch::jit::Module& mod, ExtraInfo& info) { + py::gil_scoped_acquire gil; + auto trt_mod = trtorch::CompileGraph(mod, info.toInternalExtraInfo()); + return trt_mod; +} + +std::string ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) { + py::gil_scoped_acquire gil; + auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalExtraInfo()); + return trt_engine; +} + +bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::string& method_name) { + return core::CheckMethodOperatorSupport(module, method_name); +} + +void test(torch::jit::Module& mod, torch::Tensor data) { + std::cout << mod.forward({data}) << std::endl; +} + +std::string get_build_info() { + auto info = core::util::get_build_info(); + return info; +} + +PYBIND11_MODULE(_C, m) { + py::class_(m, "InputRange") + .def(py::init<>()) + .def_readwrite("min", &InputRange::min) + .def_readwrite("opt", &InputRange::opt) + .def_readwrite("max", &InputRange::max) + .def("_to_internal_input_range", &InputRange::toInternalInputRange); + + py::class_(m, "_InternalInputRange") + .def(py::init<>()); + + py::enum_(m, "dtype") + .value("float", DataType::kFloat) + .value("float32", DataType::kFloat) + .value("half", DataType::kHalf) + .value("float16", DataType::kHalf) + .value("int8", DataType::kChar) + .export_values(); + + py::enum_(m, "DeviceType") + .value("gpu", DeviceType::kGPU) + .value("dla", DeviceType::kDLA) + .export_values(); + + py::enum_(m, "EngineCapability") + .value("safe_gpu", EngineCapability::kSAFE_GPU) + .value("safe_dla", EngineCapability::kSAFE_DLA) + .value("default", EngineCapability::kDEFAULT); + + py::class_(m, "_ExtraInfo") + .def(py::init<>()) + .def_readwrite("input_ranges", &ExtraInfo::input_ranges) + .def_readwrite("op_precision", &ExtraInfo::op_precision) + .def_readwrite("refit", &ExtraInfo::refit) + .def_readwrite("debug", &ExtraInfo::debug) + .def_readwrite("strict_types", &ExtraInfo::strict_types) + .def_readwrite("allow_gpu_fallback", &ExtraInfo::allow_gpu_fallback) + .def_readwrite("device", &ExtraInfo::device) + .def_readwrite("capability", &ExtraInfo::capability) + .def_readwrite("num_min_timing_iters", &ExtraInfo::num_min_timing_iters) + .def_readwrite("num_avg_timing_iters", &ExtraInfo::num_avg_timing_iters) + .def_readwrite("workspace_size", &ExtraInfo::workspace_size) + .def_readwrite("max_batch_size", &ExtraInfo::max_batch_size); + + m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT"; + m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded"); + m.def("_convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine"); + m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators"); + m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string"); + m.def("_test", &test); +} + +// namespace logging { +// PYBIND11_MODULE(logging, m) { +// m.attr("__name__") = "trtorch.logging"; +// m.def("get_logging_prefix", &trtorch::logging::get_logging_prefix, "Get the current prefix for the logging output"); +// m.def("set_logging_prefix", &trtorch::logging::set_logging_prefix, "Set the logging prefix for logging output"); +// m.def("get_reportable_log_level", &trtorch::logging::get_reportable_log_level, "Get the current log level"); +// m.def("set_reportable_log_level", &trtorch::logging::set_reportable_log_level, "Set the level required to be met for a log message to be printed"); +// m.def("get_is_colored_output_on", &trtorch::logging::get_is_colored_output_on, "Get if the logging output will be colored"); +// m.def("set_is_colored_output_on", &trtorch::logging::set_is_colored_output_on, "Set if the logging output should be colored"); +// m.def("log", &trtorch::logging::log, "Add a message to the logger"); +// py::enum_(m, "Level", py::arithmetic()) +// .value("INTERNAL_ERROR", trtorch::logging::Level::kINTERNAL_ERROR) +// .value("ERROR", trtorch::logging::Level::kERROR) +// .value("WARNING", trtorch::logging::Level::kWARNING) +// .value("INFO", trtorch::logging::Level::kINFO) +// .value("DEBUG", trtorch::logging::Level::kDEBUG) +// .export_values(); +// } +//} // namespace logging +} // namespace py +} // namespace trtorch diff --git a/py/trtorch/types.py b/py/trtorch/types.py new file mode 100644 index 0000000000..48244c3e85 --- /dev/null +++ b/py/trtorch/types.py @@ -0,0 +1 @@ +from trtorch._C import dtype, DeviceType, EngineCapability diff --git a/py/trtorch/version.py b/py/trtorch/version.py new file mode 100644 index 0000000000..b3c06d4883 --- /dev/null +++ b/py/trtorch/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" \ No newline at end of file