Skip to content

Commit

Permalink
empty tensor moving to default device
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Jun 24, 2024
1 parent 6aa439b commit 7f8bb4f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
perm = [0] * len(empty_size)
for permute_index, permute_element in enumerate(empty_permute):
perm[permute_element] = permute_index
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs[device] = default_device
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


Expand Down Expand Up @@ -233,7 +235,10 @@ def select_scatter_decomposition(
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
empty_size = args[0]
empty_stride = args[1]
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.as_strided(
torch.empty(empty_size, device=default_device), empty_size, empty_stride
)


def get_decompositions(
Expand Down

0 comments on commit 7f8bb4f

Please sign in to comment.