-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//py): Inital introduction of the Python API
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
- Loading branch information
1 parent
83e0ed6
commit 7088245
Showing
8 changed files
with
523 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from setuptools import setup, Extension, find_packages | ||
from setuptools.command.build_ext import build_ext | ||
import sys | ||
import setuptools | ||
import os | ||
from torch.utils import cpp_extension | ||
from shutil import copyfile | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
__version__ = '0.0.1' | ||
|
||
def gen_version_file(): | ||
if not os.path.exists(dir_path + '/trtorch/version.py'): | ||
os.mknod(dir_path + '/trtorch/version.py') | ||
|
||
with open(dir_path + '/trtorch/version.py', 'w') as f: | ||
print("creating version file") | ||
f.write("__version__ = \"" + __version__ + '\"') | ||
|
||
def copy_libtrtorch(): | ||
if not os.path.exists(dir_path + '/trtorch/lib'): | ||
os.makedirs(dir_path + '/trtorch/lib') | ||
|
||
print("copying library into module") | ||
copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so') | ||
|
||
class DevelopCommand(develop): | ||
description = "Builds the package and symlinks it into the PYTHONPATH" | ||
user_options = develop.user_options + plugins_user_options | ||
|
||
def initialize_options(self): | ||
develop.initialize_options(self) | ||
|
||
def finalize_options(self): | ||
develop.finalize_options(self) | ||
|
||
def run(self): | ||
gen_version_file() | ||
copy_libtrtorch() | ||
develop.run(self) | ||
|
||
|
||
class InstallCommand(install): | ||
description = "Builds the package" | ||
user_options = install.user_options + plugins_user_options | ||
|
||
def initialize_options(self): | ||
install.initialize_options(self) | ||
|
||
def finalize_options(self): | ||
install.finalize_options(self) | ||
|
||
def run(self): | ||
gen_version_file() | ||
copy_libtrtorch() | ||
install.run(self) | ||
|
||
class CleanCommand(Command): | ||
"""Custom clean command to tidy up the project root.""" | ||
PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './*.pyc', './*.tgz', './*.egg-info'] | ||
description = "Command to tidy up the project root" | ||
user_options = [] | ||
|
||
def initialize_options(self): | ||
pass | ||
|
||
def finalize_options(self): | ||
pass | ||
|
||
def run(self): | ||
for path_spec in self.PY_CLEAN_FILES: | ||
# Make paths absolute and relative to this path | ||
abs_paths = glob.glob(os.path.normpath(os.path.join(dir_path, path_spec))) | ||
for path in [str(p) for p in abs_paths]: | ||
if not path.startswith(root_dir): | ||
# Die if path in CLEAN_FILES is absolute + outside this directory | ||
raise ValueError("%s is not a path inside %s" % (path, root_dir)) | ||
print('Removing %s' % os.path.relpath(path)) | ||
shutil.rmtree(path) | ||
|
||
ext_modules = [ | ||
cpp_extension.CUDAExtension('trtorch._C', | ||
['trtorch/csrc/trtorch_py.cpp'], | ||
library_dirs=[ | ||
dir_path + '/trtorch/lib/libtrtorch.so', | ||
dir_path + '/trtorch/lib/' | ||
], | ||
libraries=[ | ||
"trtorch" | ||
], | ||
include_dirs=[ | ||
dir_path + "/../", | ||
], | ||
extra_compile_args=[ | ||
"-D_GLIBCXX_USE_CXX11_ABI=0" | ||
], | ||
extra_link_args=[ | ||
"-D_GLIBCXX_USE_CXX11_ABI=0" | ||
"-Wl,--no-as-needed", | ||
"-ltrtorch" | ||
], | ||
undef_macros=[ "NDEBUG" ] | ||
) | ||
] | ||
|
||
setup( | ||
name='trtorch', | ||
version=__version__, | ||
author='NVIDIA Corporation.', | ||
author_email='[email protected]', | ||
url='https://github.com/nvidia/trtorch', | ||
description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs', | ||
long_description='', | ||
ext_modules=ext_modules, | ||
install_requires=['pybind11>=2.4'], | ||
setup_requires=['pybind11>=2.4'], | ||
cmdclass={ | ||
'install': InstallCommand, | ||
'clean': CleanCommand, | ||
'develop': DevelopCommand, | ||
'build_ext': cpp_extension.BuildExtension | ||
}, | ||
zip_safe=False, | ||
license="BSD-3", | ||
packages=find_packages(), | ||
classifiers=["Intended Audience :: Developers", | ||
"Intended Audience :: Science/Research", | ||
"Operating System :: POSIX :: Linux", | ||
"Programming Language :: C++", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artifical Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Developement :: Libraries"], | ||
|
||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import sys | ||
|
||
if sys.version_info < (3,): | ||
raise Exception("Python 2 has reached end-of-life and is not supported by TRTorch") | ||
|
||
import ctypes | ||
import torch | ||
|
||
def _load_trtorch_lib(): | ||
lib_name = 'libtrtorch.so' | ||
here = os.path.abspath(__file__) | ||
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) | ||
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) | ||
|
||
_load_trtorch_lib() | ||
|
||
from .version import __version__ | ||
#from trtorch import _C | ||
from trtorch.compiler import * | ||
from trtorch.types import * | ||
|
||
def test(mod, data): | ||
_C._test(mod._c, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from typing import List, Dict, Any | ||
import torch | ||
import tensorrt as trt | ||
import trtorch._C | ||
from trtorch import types | ||
from .version import __version__ | ||
|
||
def _supported_input_size_type(input_size: Any) -> bool: | ||
if isinstance(input_size, torch.Size): | ||
return True | ||
elif isinstance(input_size, tuple): | ||
return True | ||
elif isinstance(input_size, list): | ||
return True | ||
else: | ||
raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size))) | ||
|
||
def _parse_input_sizes(input_sizes: List) -> List: | ||
|
||
if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes): | ||
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict") | ||
|
||
parsed_input_sizes = [] | ||
for i in input_sizes: | ||
if isinstance(i, dict): | ||
if all (k in i for k in ["min", "opt", "min"]): | ||
in_range = trtorch._C.InputRange() | ||
in_range.min = i["min"] | ||
in_range.opt = i["opt"] | ||
in_range.max = i["max"] | ||
|
||
parsed_input_sizes.append(in_range.to_internal_input_range()) | ||
|
||
elif "opt" in i: | ||
in_range = trtorch._C.InputRange() | ||
in_range.min = i["opt"] | ||
in_range.opt = i["opt"] | ||
in_range.max = i["opt"] | ||
|
||
parsed_input_sizes.append(in_range.to_internal_input_range()) | ||
|
||
else: | ||
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict") | ||
|
||
elif isinstance(i, list): | ||
in_range = trtorch._C.InputRange() | ||
in_range.min = i | ||
in_range.opt = i | ||
in_range.max = i | ||
|
||
parsed_input_sizes.append(in_range.to_internal_input_range()) | ||
|
||
return parsed_input_sizes | ||
|
||
def _parse_op_precision(precision: Any) -> types.dtype: | ||
if isinstance(precision, torch.dtype): | ||
if precision == torch.int8: | ||
return types.dtype.int8 | ||
elif precision == torch.half: | ||
return types.dtype.half | ||
elif precision == torch.float: | ||
return types.dtype.float | ||
else: | ||
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision)) | ||
|
||
elif isinstance(precision, types.DataTypes): | ||
return precision | ||
|
||
else: | ||
raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision))) | ||
|
||
def _parse_device_type(device: Any) -> types.DeviceType: | ||
if isinstance(device, torch.device): | ||
if torch.device.type == 'cuda': | ||
return types.DeviceType.gpu | ||
else: | ||
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type)) | ||
|
||
elif isinstance(device, types.DeviceType): | ||
return device | ||
|
||
else: | ||
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device))) | ||
|
||
def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo: | ||
info = trtorch._C._ExtraInfo() | ||
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list): | ||
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict") | ||
|
||
info.input_ranges = _parse_input_sizes(extra_info["input_shapes"]) | ||
|
||
if "op_precision" in extra_info: | ||
info.op_precision = _parse_op_precision(extra_info["op_precision"]) | ||
|
||
if "refit" in extra_info: | ||
assert isinstance(extra_info["refit"], bool) | ||
info.refit = extra_info["refit"] | ||
|
||
if "debug" in extra_info: | ||
assert isinstance(extra_info["debug"], bool) | ||
info.debug = extra_info["debug"] | ||
|
||
if "strict_types" in extra_info: | ||
assert isinstance(extra_info["strict_types"], bool) | ||
info.strict_types = extra_info["strict_types"] | ||
|
||
if "allow_gpu_fallback" in extra_info: | ||
assert isinstance(extra_info["allow_gpu_fallback"], bool) | ||
info.allow_gpu_fallback = extra_info["allow_gpu_fallback"] | ||
|
||
if "device" in extra_info: | ||
info.device = _parse_device_type(extra_info["device"]) | ||
|
||
if "capability" in extra_info: | ||
assert isinstance(extra_info["capability"], type.EngineCapability) | ||
info.capability = extra_info["capability"] | ||
|
||
|
||
if "num_min_timing_iters" in extra_info: | ||
assert type(extra_info["num_min_timing_iters"]) is int | ||
info.num_min_timing_iters = extra_info["num_min_timing_iters"] | ||
|
||
if "num_avg_timing_iters" in extra_info: | ||
assert type(extra_info["num_avg_timing_iters"]) is int | ||
info.num_avg_timing_iters = extra_info["num_avg_timing_iters"] | ||
|
||
if "workspace_size" in extra_info: | ||
assert type(extra_info["workspace_size"]) is int | ||
info.workspace_size = extra_info["workspace_size"] | ||
|
||
if "max_batch_size" in extra_info: | ||
assert type(extra_info["max_batch_size"]) is int | ||
info.max_batch_size = extra_info["max_batch_size"] | ||
|
||
return info | ||
|
||
def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule: | ||
return module | ||
|
||
def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str: | ||
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info)) | ||
|
||
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: | ||
return trtorch._C._check_method_op_support(module._c, method_name) | ||
|
||
def dump_build_info(): | ||
print(get_build_info()) | ||
|
||
def get_build_info() -> str: | ||
build_info = trtorch._C._get_build_info() | ||
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info | ||
return build_info | ||
|
Oops, something went wrong.