diff --git a/py/trtorch/Device.py b/py/trtorch/Device.py new file mode 100644 index 0000000000..a435e95889 --- /dev/null +++ b/py/trtorch/Device.py @@ -0,0 +1,110 @@ +import torch + +from trtorch import _types +import logging +import trtorch._C + +import warnings + + +class Device(object): + """ + Defines a device that can be used to specify target devices for engines + + Attributes: + device_type (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + gpu_id (int): Device ID for target GPU + dla_core (int): Core ID for target DLA core + allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed + """ + + device_type = None + gpu_id = -1 + dla_core = -1 + allow_gpu_fallback = False + + def __init__(self, *args, **kwargs): + """ __init__ Method for trtorch.Device + + Device accepts one of a few construction patterns + + Args: + spec (str): String with device spec e.g. "dla:0" for dla, core_id 0 + + Keyword Arguments: + gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided + dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided. + allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA) + + Examples: + - Device("gpu:1") + - Device("cuda:1") + - Device("dla:0", allow_gpu_fallback=True) + - Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True) + - Device(dla_core=0, allow_gpu_fallback=True) + - Device(gpu_id=1) + """ + if len(args) == 1: + if not isinstance(args[0], str): + raise TypeError("When specifying Device through positional argument, argument must be str") + else: + (self.device_type, id) = Device._parse_device_str(args[0]) + if self.device_type == _types.DeviceType.GPU: + self.gpu_id = id + else: + self.dla_core = id + self.gpu_id = 0 + logging.log(logging.log.Level.Warning, + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") + + elif len(args) == 0: + if not "gpu_id" in kwargs or not "dla_core" in kwargs: + if "dla_core" in kwargs: + self.device_type = _types.DeviceType.DLA + self.dla_core = kwargs["dla_core"] + if "gpu_id" in kwargs: + self.gpu_id = kwargs["gpu_id"] + else: + self.gpu_id = 0 + logging.log(logging.log.Level.Warning, + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier") + else: + self.gpu_id = kwargs["gpu_id"] + self.device_type == _types.DeviceType.GPU + + else: + raise ValueError( + "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments" + .format(len(args))) + + if "allow_gpu_fallback" in kwargs: + if not isinstance(kwargs["allow_gpu_fallback"], bool): + raise TypeError("allow_gpu_fallback must be a bool") + + def __str__(self) -> str: + return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \ + + ")" if self.device_type == _types.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format(self.dla_core, self.allow_gpu_fallback) + + def _to_internal(self) -> trtorch._C.Device: + internal_dev = trtorch._C.Device() + internal_dev.device_type = self.device_type + internal_dev.gpu_id = self.gpu_id + internal_dev.dla_core = self.dla_core + internal_dev.allow_gpu_fallback = self.allow_gpu_fallback + return internal_dev + + @classmethod + def _from_torch_device(cls, torch_dev: torch.device): + if torch_dev.type != 'cuda': + raise ValueError("Torch Device specs must have type \"cuda\"") + gpu_id = torch_dev.index + return cls(gpu_id=gpu_id) + + @staticmethod + def _parse_device_str(s): + s = s.lower() + spec = s.split(':') + if spec[0] == "gpu" or spec[0] == "cuda": + return (_types.DeviceType.GPU, int(spec[1])) + elif spec[0] == "dla": + return (_types.DeviceType.DLA, int(spec[1])) diff --git a/py/trtorch/__init__.py b/py/trtorch/__init__.py index a31aafbd1f..26cebd2475 100644 --- a/py/trtorch/__init__.py +++ b/py/trtorch/__init__.py @@ -14,6 +14,7 @@ from trtorch._types import * from trtorch import logging from trtorch.Input import Input +from trtorch.Device import Device def _register_with_torch(): diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 63ebc288bc..c1a534d35b 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -3,6 +3,7 @@ import trtorch._C from trtorch import _types from trtorch.Input import Input +from trtorch.Device import Device import warnings @@ -101,27 +102,35 @@ def _parse_device_type(device: Any) -> _types.DeviceType: str(type(device))) -def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device: - info = trtorch._C.Device() - if "device_type" not in device_info: - raise KeyError("Device type is required parameter") +def _parse_device(device_info: Any) -> trtorch._C.Device: + if isinstance(device_info, dict): + info = trtorch._C.Device() + if "device_type" not in device_info: + raise KeyError("Device type is required parameter") + else: + assert isinstance(device_info["device_type"], _types.DeviceType) + info.device_type = _parse_device_type(device_info["device_type"]) + + if "gpu_id" in device_info: + assert isinstance(device_info["gpu_id"], int) + info.gpu_id = device_info["gpu_id"] + + if "dla_core" in device_info: + assert isinstance(device_info["dla_core"], int) + info.dla_core = device_info["dla_core"] + + if "allow_gpu_fallback" in device_info: + assert isinstance(device_info["allow_gpu_fallback"], bool) + info.allow_gpu_fallback = device_info["allow_gpu_fallback"] + + return info + elif isinstance(device_info, Device): + return device_info._to_internal() + elif isinstance(device_info, torch.device): + return (Device._from_torch_device(device_info))._to_internal() else: - assert isinstance(device_info["device_type"], _types.DeviceType) - info.device_type = _parse_device_type(device_info["device_type"]) - - if "gpu_id" in device_info: - assert isinstance(device_info["gpu_id"], int) - info.gpu_id = device_info["gpu_id"] - - if "dla_core" in device_info: - assert isinstance(device_info["dla_core"], int) - info.dla_core = device_info["dla_core"] - - if "allow_gpu_fallback" in device_info: - assert isinstance(device_info["allow_gpu_fallback"], bool) - info.allow_gpu_fallback = device_info["allow_gpu_fallback"] - - return info + raise ValueError( + "Unsupported data for device specification. Expected either a dict, trtorch.Device or torch.Device") def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback: diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index dc4a93f1ac..0b769b6c0a 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -5,6 +5,7 @@ import trtorch._C from trtorch._compile_spec import _parse_compile_spec from trtorch._version import __version__ +from trtorch.Device import Device from types import FunctionType @@ -42,8 +43,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri "dla_core": 0, # (DLA only) Target dla core id to run engine "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU }, - "op_precision": torch.half, # Operating precision set to FP16 - "input_dtypes": [torch.float32] # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}. + "enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels "refit": false, # enable refit "debug": false, # enable debuggable engine "strict_types": false, # kernels should strictly run in operating precision @@ -61,7 +61,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri } } - Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using + Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type. @@ -110,7 +110,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st "dla_core": 0, # (DLA only) Target dla core id to run engine "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU }, - "op_precision": torch.half, # Operating precision set to FP16 + "enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}. "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "refit": false, # enable refit @@ -123,7 +123,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set) } - Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using + Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type. @@ -137,7 +137,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec)) -def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule: +def embed_engine_in_new_module(serialized_engine: bytes, device: Device) -> torch.jit.ScriptModule: """Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module. @@ -153,7 +153,7 @@ def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModu Returns: torch.jit.ScriptModule: New TorchScript module with engine embedded """ - cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine) + cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine, device._to_internal()) return torch.jit._recursive.wrap_cpp_module(cpp_mod) diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 19b921dc8a..3e668a5086 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -115,6 +115,10 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) { } } +core::runtime::CudaDevice Device::toInternalRuntimeDevice() { + return core::runtime::CudaDevice(gpu_id, toTRTDeviceType(device_type)); +} + std::string Device::to_str() { std::stringstream ss; std::string fallback = allow_gpu_fallback ? "True" : "False"; diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index f5009fa777..6b3674c527 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -79,6 +79,7 @@ struct Device : torch::CustomClassHolder { ADD_FIELD_GET_SET(dla_core, int64_t); ADD_FIELD_GET_SET(allow_gpu_fallback, bool); + core::runtime::CudaDevice toInternalRuntimeDevice(); std::string to_str(); }; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 52fb2031b2..5cee8952f8 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -119,8 +119,8 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str return core::CheckMethodOperatorSupport(module, method_name); } -torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, core::runtime::CudaDevice& device) { - return core::EmbedEngineInNewModule(engine, device); +torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, Device& device) { + return core::EmbedEngineInNewModule(engine, device.toInternalRuntimeDevice()); } std::string get_build_info() { diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 42ecf6316a..b7ebc57389 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -162,7 +162,7 @@ def test_pt_to_trt_to_pt(self): } trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec) - trt_mod = trtorch.embed_engine_in_new_module(trt_engine) + trt_mod = trtorch.embed_engine_in_new_module(trt_engine, trtorch.Device("cuda:0")) same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max() self.assertTrue(same < 2e-3)