From c67cad53167d8e4f4bb5e07f2fbec68ecaeb0b65 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Jun 2024 14:28:24 -0700 Subject: [PATCH 1/2] empty tensor moving to default device --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b96d912897..4b99c7f6e4 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: perm = [0] * len(empty_size) for permute_index, permute_element in enumerate(empty_permute): perm[permute_element] = permute_index + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + kwargs["device"] = default_device return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) @@ -233,7 +235,10 @@ def select_scatter_decomposition( def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: empty_size = args[0] empty_stride = args[1] - return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride) + default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.as_strided( + torch.empty(empty_size, device=default_device), empty_size, empty_stride + ) def get_decompositions( From c24d169d136caa9a04e64d5e7d698268f8860f3c Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Jun 2024 17:12:47 -0700 Subject: [PATCH 2/2] addressing review comments --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 4b99c7f6e4..8ec5a95da2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -4,7 +4,9 @@ import torch from torch._decomp import register_decomposition from torch._ops import OpOverload +from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.utils import to_torch_device from ._decomposition_groups import ( ENABLED_TORCH_DECOMPOSITIONS, @@ -172,8 +174,7 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: perm = [0] * len(empty_size) for permute_index, permute_element in enumerate(empty_permute): perm[permute_element] = permute_index - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - kwargs["device"] = default_device + kwargs["device"] = to_torch_device(default_device()) return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) @@ -235,9 +236,10 @@ def select_scatter_decomposition( def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: empty_size = args[0] empty_stride = args[1] - default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.as_strided( - torch.empty(empty_size, device=default_device), empty_size, empty_stride + torch.empty(empty_size, device=to_torch_device(default_device())), + empty_size, + empty_stride, )