Skip to content

Commit

Permalink
feat(//py): Inital introduction of the Python API
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 14, 2020
1 parent 83e0ed6 commit 7088245
Show file tree
Hide file tree
Showing 8 changed files with 523 additions and 0 deletions.
Empty file added py/BUILD
Empty file.
Empty file added py/requirements.txt
Empty file.
138 changes: 138 additions & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
import sys
import setuptools
import os
from torch.utils import cpp_extension
from shutil import copyfile

dir_path = os.path.dirname(os.path.realpath(__file__))

__version__ = '0.0.1'

def gen_version_file():
if not os.path.exists(dir_path + '/trtorch/version.py'):
os.mknod(dir_path + '/trtorch/version.py')

with open(dir_path + '/trtorch/version.py', 'w') as f:
print("creating version file")
f.write("__version__ = \"" + __version__ + '\"')

def copy_libtrtorch():
if not os.path.exists(dir_path + '/trtorch/lib'):
os.makedirs(dir_path + '/trtorch/lib')

print("copying library into module")
copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so')

class DevelopCommand(develop):
description = "Builds the package and symlinks it into the PYTHONPATH"
user_options = develop.user_options + plugins_user_options

def initialize_options(self):
develop.initialize_options(self)

def finalize_options(self):
develop.finalize_options(self)

def run(self):
gen_version_file()
copy_libtrtorch()
develop.run(self)


class InstallCommand(install):
description = "Builds the package"
user_options = install.user_options + plugins_user_options

def initialize_options(self):
install.initialize_options(self)

def finalize_options(self):
install.finalize_options(self)

def run(self):
gen_version_file()
copy_libtrtorch()
install.run(self)

class CleanCommand(Command):
"""Custom clean command to tidy up the project root."""
PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './*.pyc', './*.tgz', './*.egg-info']
description = "Command to tidy up the project root"
user_options = []

def initialize_options(self):
pass

def finalize_options(self):
pass

def run(self):
for path_spec in self.PY_CLEAN_FILES:
# Make paths absolute and relative to this path
abs_paths = glob.glob(os.path.normpath(os.path.join(dir_path, path_spec)))
for path in [str(p) for p in abs_paths]:
if not path.startswith(root_dir):
# Die if path in CLEAN_FILES is absolute + outside this directory
raise ValueError("%s is not a path inside %s" % (path, root_dir))
print('Removing %s' % os.path.relpath(path))
shutil.rmtree(path)

ext_modules = [
cpp_extension.CUDAExtension('trtorch._C',
['trtorch/csrc/trtorch_py.cpp'],
library_dirs=[
dir_path + '/trtorch/lib/libtrtorch.so',
dir_path + '/trtorch/lib/'
],
libraries=[
"trtorch"
],
include_dirs=[
dir_path + "/../",
],
extra_compile_args=[
"-D_GLIBCXX_USE_CXX11_ABI=0"
],
extra_link_args=[
"-D_GLIBCXX_USE_CXX11_ABI=0"
"-Wl,--no-as-needed",
"-ltrtorch"
],
undef_macros=[ "NDEBUG" ]
)
]

setup(
name='trtorch',
version=__version__,
author='NVIDIA Corporation.',
author_email='[email protected]',
url='https://github.com/nvidia/trtorch',
description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs',
long_description='',
ext_modules=ext_modules,
install_requires=['pybind11>=2.4'],
setup_requires=['pybind11>=2.4'],
cmdclass={
'install': InstallCommand,
'clean': CleanCommand,
'develop': DevelopCommand,
'build_ext': cpp_extension.BuildExtension
},
zip_safe=False,
license="BSD-3",
packages=find_packages(),
classifiers=["Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"Programming Language :: C++",
"Programming Language :: Python",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artifical Intelligence",
"Topic :: Software Development",
"Topic :: Software Developement :: Libraries"],

)
24 changes: 24 additions & 0 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import sys

if sys.version_info < (3,):
raise Exception("Python 2 has reached end-of-life and is not supported by TRTorch")

import ctypes
import torch

def _load_trtorch_lib():
lib_name = 'libtrtorch.so'
here = os.path.abspath(__file__)
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)

_load_trtorch_lib()

from .version import __version__
#from trtorch import _C
from trtorch.compiler import *
from trtorch.types import *

def test(mod, data):
_C._test(mod._c, data)
153 changes: 153 additions & 0 deletions py/trtorch/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import List, Dict, Any
import torch
import tensorrt as trt
import trtorch._C
from trtorch import types
from .version import __version__

def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
return True
elif isinstance(input_size, tuple):
return True
elif isinstance(input_size, list):
return True
else:
raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size)))

def _parse_input_sizes(input_sizes: List) -> List:

if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes):
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")

parsed_input_sizes = []
for i in input_sizes:
if isinstance(i, dict):
if all (k in i for k in ["min", "opt", "min"]):
in_range = trtorch._C.InputRange()
in_range.min = i["min"]
in_range.opt = i["opt"]
in_range.max = i["max"]

parsed_input_sizes.append(in_range.to_internal_input_range())

elif "opt" in i:
in_range = trtorch._C.InputRange()
in_range.min = i["opt"]
in_range.opt = i["opt"]
in_range.max = i["opt"]

parsed_input_sizes.append(in_range.to_internal_input_range())

else:
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")

elif isinstance(i, list):
in_range = trtorch._C.InputRange()
in_range.min = i
in_range.opt = i
in_range.max = i

parsed_input_sizes.append(in_range.to_internal_input_range())

return parsed_input_sizes

def _parse_op_precision(precision: Any) -> types.dtype:
if isinstance(precision, torch.dtype):
if precision == torch.int8:
return types.dtype.int8
elif precision == torch.half:
return types.dtype.half
elif precision == torch.float:
return types.dtype.float
else:
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision))

elif isinstance(precision, types.DataTypes):
return precision

else:
raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision)))

def _parse_device_type(device: Any) -> types.DeviceType:
if isinstance(device, torch.device):
if torch.device.type == 'cuda':
return types.DeviceType.gpu
else:
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))

elif isinstance(device, types.DeviceType):
return device

else:
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))

def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
info = trtorch._C._ExtraInfo()
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")

info.input_ranges = _parse_input_sizes(extra_info["input_shapes"])

if "op_precision" in extra_info:
info.op_precision = _parse_op_precision(extra_info["op_precision"])

if "refit" in extra_info:
assert isinstance(extra_info["refit"], bool)
info.refit = extra_info["refit"]

if "debug" in extra_info:
assert isinstance(extra_info["debug"], bool)
info.debug = extra_info["debug"]

if "strict_types" in extra_info:
assert isinstance(extra_info["strict_types"], bool)
info.strict_types = extra_info["strict_types"]

if "allow_gpu_fallback" in extra_info:
assert isinstance(extra_info["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = extra_info["allow_gpu_fallback"]

if "device" in extra_info:
info.device = _parse_device_type(extra_info["device"])

if "capability" in extra_info:
assert isinstance(extra_info["capability"], type.EngineCapability)
info.capability = extra_info["capability"]


if "num_min_timing_iters" in extra_info:
assert type(extra_info["num_min_timing_iters"]) is int
info.num_min_timing_iters = extra_info["num_min_timing_iters"]

if "num_avg_timing_iters" in extra_info:
assert type(extra_info["num_avg_timing_iters"]) is int
info.num_avg_timing_iters = extra_info["num_avg_timing_iters"]

if "workspace_size" in extra_info:
assert type(extra_info["workspace_size"]) is int
info.workspace_size = extra_info["workspace_size"]

if "max_batch_size" in extra_info:
assert type(extra_info["max_batch_size"]) is int
info.max_batch_size = extra_info["max_batch_size"]

return info

def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
return module

def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str:
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))

def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
return trtorch._C._check_method_op_support(module._c, method_name)

def dump_build_info():
print(get_build_info())

def get_build_info() -> str:
build_info = trtorch._C._get_build_info()
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
return build_info

Loading

0 comments on commit 7088245

Please sign in to comment.