Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed May 1, 2024
1 parent a95a217 commit 1f58d47
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
to_numpy,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name

import tensorrt as trt
from torch_tensorrt.fx.types import TRTTensor


def embedding(
Expand All @@ -31,6 +30,10 @@ def embedding(
) -> TRTTensor:
indices_tensor = input
embedding_tensor = weight
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
raise RuntimeError(
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
)
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
# unsupported parameters
Expand Down
6 changes: 2 additions & 4 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def run_test(
cuda_inputs.append(i.cuda())

mod.eval()
mod = mod.cuda()
start = time.perf_counter()
interpreter_result = interpreter.run()
sec = time.perf_counter() - start
Expand All @@ -73,6 +72,7 @@ def run_test(
interpreter_result.output_names,
)

mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)

torch.cuda.synchronize()
Expand All @@ -96,11 +96,9 @@ def run_test(
):
ref_outputs = [ref_outputs]
for out, ref in zip(outputs, ref_outputs):
ref = ref.cpu() # to_dtype test has cases with gpu output
if not isinstance(ref, torch.Tensor):
ref = torch.tensor([ref])
if ref.dtype == torch.int64:
ref = ref.int() # convert torch.max's index output tensor to int32
ref = ref.cpu() # to_dtype test has cases with gpu output
torch.testing.assert_close(
out.cpu(),
ref,
Expand Down

0 comments on commit 1f58d47

Please sign in to comment.