Skip to content

Commit

Permalink
fix(//py): Fix trtorch.Device alternate contructor options
Browse files Browse the repository at this point in the history
There were issues setting fields of trtorch.Device via
kwargs, this patch should resolve those and verify that they
work

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Sep 24, 2021
1 parent 0a39189 commit fa08311
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
19 changes: 12 additions & 7 deletions py/trtorch/Device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from trtorch import _types
import logging
import trtorch.logging
import trtorch._C

import warnings
Expand Down Expand Up @@ -54,23 +54,27 @@ def __init__(self, *args, **kwargs):
else:
self.dla_core = id
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
trtorch.logging.log(trtorch.logging.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")

elif len(args) == 0:
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
if "gpu_id" in kwargs or "dla_core" in kwargs:
if "dla_core" in kwargs:
self.device_type = _types.DeviceType.DLA
self.dla_core = kwargs["dla_core"]
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]
else:
self.gpu_id = 0
logging.log(logging.log.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
trtorch.logging.log(trtorch.logging.Level.Warning,
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
else:
self.gpu_id = kwargs["gpu_id"]
self.device_type == _types.DeviceType.GPU
self.device_type = _types.DeviceType.GPU
else:
raise ValueError(
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
)

else:
raise ValueError(
Expand All @@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs):
if "allow_gpu_fallback" in kwargs:
if not isinstance(kwargs["allow_gpu_fallback"], bool):
raise TypeError("allow_gpu_fallback must be a bool")
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]

def __str__(self) -> str:
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \
Expand Down
48 changes: 48 additions & 0 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,53 @@ def test_is_colored_output_on(self):
self.assertTrue(color)


class TestDevice(unittest.TestCase):

def test_from_string_constructor(self):
device = trtorch.Device("cuda:0")
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)

device = trtorch.Device("gpu:1")
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 1)

def test_from_string_constructor_dla(self):
device = trtorch.Device("dla:0")
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 0)

device = trtorch.Device("dla:1", allow_gpu_fallback=True)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 1)
self.assertEqual(device.allow_gpu_fallback, True)

def test_kwargs_gpu(self):
device = trtorch.Device(gpu_id=0)
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)

def test_kwargs_dla_and_settings(self):
device = trtorch.Device(dla_core=1, allow_gpu_fallback=False)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 0)
self.assertEqual(device.dla_core, 1)
self.assertEqual(device.allow_gpu_fallback, False)

device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
self.assertEqual(device.gpu_id, 1)
self.assertEqual(device.dla_core, 0)
self.assertEqual(device.allow_gpu_fallback, True)

def test_from_torch(self):
device = trtorch.Device._from_torch_device(torch.device("cuda:0"))
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
self.assertEqual(device.gpu_id, 0)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
Expand All @@ -231,6 +278,7 @@ def test_suite():
suite.addTest(
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
suite.addTest(unittest.makeSuite(TestDevice))

return suite

Expand Down

0 comments on commit fa08311

Please sign in to comment.