Skip to content

Commit

Permalink
empty tensor moving to default device (#2948)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored and cehongwang committed Jul 8, 2024
1 parent 0352a5f commit e2019c9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.utils import to_torch_device

from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
Expand Down Expand Up @@ -172,6 +174,7 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
perm = [0] * len(empty_size)
for permute_index, permute_element in enumerate(empty_permute):
perm[permute_element] = permute_index
kwargs["device"] = to_torch_device(default_device())
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


Expand Down Expand Up @@ -233,7 +236,11 @@ def select_scatter_decomposition(
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
empty_size = args[0]
empty_stride = args[1]
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
return torch.as_strided(
torch.empty(empty_size, device=to_torch_device(default_device())),
empty_size,
empty_stride,
)


def get_decompositions(
Expand Down

0 comments on commit e2019c9

Please sign in to comment.