Skip to content

Commit

Permalink
fix small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Apr 16, 2024
1 parent 645513e commit 2d40091
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def embedding_bag(
weight,
)
embed = cast_trt_tensor(
ctx, embed, torch.float, f"{name}_cast_embed_to_fp16", target, source_ir
ctx, embed, torch.float, f"{name}_cast_embed_to_fp32", target, source_ir
)

# give weights to embedding
Expand Down
4 changes: 2 additions & 2 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def run_test(
):
ref_outputs = [ref_outputs]
for out, ref in zip(outputs, ref_outputs):
ref = ref.cpu() # to_dtype test has cases with gpu output
if not isinstance(ref, torch.Tensor):
ref = torch.tensor([ref])
ref = ref.cpu() # to_dtype test has cases with gpu output
if ref.dtype == torch.int64:
ref = ref.int() # convert torch.max's index output tensor to int32
torch.testing.assert_close(
out.cpu(),
ref.cpu(),
ref,
rtol=rtol,
atol=atol,
equal_nan=True,
Expand Down

0 comments on commit 2d40091

Please sign in to comment.