Skip to content

Commit

Permalink
feat(//py): add user level device class in py for embed engine
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jul 22, 2021
1 parent df87de3 commit d99169f
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 30 deletions.
110 changes: 110 additions & 0 deletions py/trtorch/Device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch

from trtorch import _types
import logging
import trtorch._C

import warnings


class Device(object):
"""
Defines a device that can be used to specify target devices for engines
Attributes:
device_type (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""

device_type = None
gpu_id = -1
dla_core = -1
allow_gpu_fallback = False

def __init__(self, *args, **kwargs):
""" __init__ Method for trtorch.Device
Device accepts one of a few construction patterns
Args:
spec (str): String with device spec e.g. "dla:0" for dla, core_id 0
Keyword Arguments:
gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided.
allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
Examples:
- Device("gpu:1")
- Device("cuda:1")
- Device("dla:0", allow_gpu_fallback=True)
- Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
- Device(dla_core=0, allow_gpu_fallback=True)
- Device(gpu_id=1)
"""
if len(args) == 1:
if not isinstance(args[0], str):
raise TypeError("When specifying Device through positional argument, argument must be str")
else:
(self.device_type, id) = Device._parse_device_str(args[0])
if self.device_type == _types.DeviceType.GPU:
self.gpu_id = id
else:
self.dla_core = id
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")

elif len(args) == 0:
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
if "dla_core" in kwargs:
self.device_type = _types.DeviceType.DLA
self.dla_core = kwargs["dla_core"]
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]
else:
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
else:
self.gpu_id = kwargs["gpu_id"]
self.device_type == _types.DeviceType.GPU

else:
raise ValueError(
"Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments"
.format(len(args)))

if "allow_gpu_fallback" in kwargs:
if not isinstance(kwargs["allow_gpu_fallback"], bool):
raise TypeError("allow_gpu_fallback must be a bool")

def __str__(self) -> str:
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \
+ ")" if self.device_type == _types.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format(self.dla_core, self.allow_gpu_fallback)

def _to_internal(self) -> trtorch._C.Device:
internal_dev = trtorch._C.Device()
internal_dev.device_type = self.device_type
internal_dev.gpu_id = self.gpu_id
internal_dev.dla_core = self.dla_core
internal_dev.allow_gpu_fallback = self.allow_gpu_fallback
return internal_dev

@classmethod
def _from_torch_device(cls, torch_dev: torch.device):
if torch_dev.type != 'cuda':
raise ValueError("Torch Device specs must have type \"cuda\"")
gpu_id = torch_dev.index
return cls(gpu_id=gpu_id)

@staticmethod
def _parse_device_str(s):
s = s.lower()
spec = s.split(':')
if spec[0] == "gpu" or spec[0] == "cuda":
return (_types.DeviceType.GPU, int(spec[1]))
elif spec[0] == "dla":
return (_types.DeviceType.DLA, int(spec[1]))
1 change: 1 addition & 0 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from trtorch._types import *
from trtorch import logging
from trtorch.Input import Input
from trtorch.Device import Device


def _register_with_torch():
Expand Down
49 changes: 29 additions & 20 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import trtorch._C
from trtorch import _types
from trtorch.Input import Input
from trtorch.Device import Device

import warnings

Expand Down Expand Up @@ -101,27 +102,35 @@ def _parse_device_type(device: Any) -> _types.DeviceType:
str(type(device)))


def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:
info = trtorch._C.Device()
if "device_type" not in device_info:
raise KeyError("Device type is required parameter")
def _parse_device(device_info: Any) -> trtorch._C.Device:
if isinstance(device_info, dict):
info = trtorch._C.Device()
if "device_type" not in device_info:
raise KeyError("Device type is required parameter")
else:
assert isinstance(device_info["device_type"], _types.DeviceType)
info.device_type = _parse_device_type(device_info["device_type"])

if "gpu_id" in device_info:
assert isinstance(device_info["gpu_id"], int)
info.gpu_id = device_info["gpu_id"]

if "dla_core" in device_info:
assert isinstance(device_info["dla_core"], int)
info.dla_core = device_info["dla_core"]

if "allow_gpu_fallback" in device_info:
assert isinstance(device_info["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = device_info["allow_gpu_fallback"]

return info
elif isinstance(device_info, Device):
return device_info._to_internal()
elif isinstance(device_info, torch.device):
return (Device._from_torch_device(device_info))._to_internal()
else:
assert isinstance(device_info["device_type"], _types.DeviceType)
info.device_type = _parse_device_type(device_info["device_type"])

if "gpu_id" in device_info:
assert isinstance(device_info["gpu_id"], int)
info.gpu_id = device_info["gpu_id"]

if "dla_core" in device_info:
assert isinstance(device_info["dla_core"], int)
info.dla_core = device_info["dla_core"]

if "allow_gpu_fallback" in device_info:
assert isinstance(device_info["allow_gpu_fallback"], bool)
info.allow_gpu_fallback = device_info["allow_gpu_fallback"]

return info
raise ValueError(
"Unsupported data for device specification. Expected either a dict, trtorch.Device or torch.Device")


def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback:
Expand Down
14 changes: 7 additions & 7 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import trtorch._C
from trtorch._compile_spec import _parse_compile_spec
from trtorch._version import __version__
from trtorch.Device import Device
from types import FunctionType


Expand Down Expand Up @@ -42,8 +43,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"input_dtypes": [torch.float32] # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
"refit": false, # enable refit
"debug": false, # enable debuggable engine
"strict_types": false, # kernels should strictly run in operating precision
Expand All @@ -61,7 +61,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
}
}
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
to select device type.
Expand Down Expand Up @@ -110,7 +110,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
"dla_core": 0, # (DLA only) Target dla core id to run engine
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": false, # enable refit
Expand All @@ -123,7 +123,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
}
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
to select device type.
Expand All @@ -137,7 +137,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))


def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
def embed_engine_in_new_module(serialized_engine: bytes, device: Device) -> torch.jit.ScriptModule:
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module
Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
Expand All @@ -153,7 +153,7 @@ def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModu
Returns:
torch.jit.ScriptModule: New TorchScript module with engine embedded
"""
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine, device._to_internal())
return torch.jit._recursive.wrap_cpp_module(cpp_mod)


Expand Down
4 changes: 4 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
}
}

core::runtime::CudaDevice Device::toInternalRuntimeDevice() {
return core::runtime::CudaDevice(gpu_id, toTRTDeviceType(device_type));
}

std::string Device::to_str() {
std::stringstream ss;
std::string fallback = allow_gpu_fallback ? "True" : "False";
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct Device : torch::CustomClassHolder {
ADD_FIELD_GET_SET(dla_core, int64_t);
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);

core::runtime::CudaDevice toInternalRuntimeDevice();
std::string to_str();
};

Expand Down
4 changes: 2 additions & 2 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
return core::CheckMethodOperatorSupport(module, method_name);
}

torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, core::runtime::CudaDevice& device) {
return core::EmbedEngineInNewModule(engine, device);
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, Device& device) {
return core::EmbedEngineInNewModule(engine, device.toInternalRuntimeDevice());
}

std::string get_build_info() {
Expand Down
2 changes: 1 addition & 1 deletion tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_pt_to_trt_to_pt(self):
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
trt_mod = trtorch.embed_engine_in_new_module(trt_engine, trtorch.Device("cuda:0"))
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

Expand Down

0 comments on commit d99169f

Please sign in to comment.