diff --git a/py/trtorch/Device.py b/py/trtorch/Device.py index 41a1308518..4a9395415e 100644 --- a/py/trtorch/Device.py +++ b/py/trtorch/Device.py @@ -107,7 +107,11 @@ def _from_torch_device(cls, torch_dev: torch.device): @classmethod def _current_device(cls): - dev = trtorch._C._get_current_device() + try: + dev = trtorch._C._get_current_device() + except RuntimeError: + trtorch.logging.log(trtorch.logging.Level.Error, "Cannot get current device") + return None return cls(gpu_id=dev.gpu_id) @staticmethod