Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug in elementwise base for static inputs #2819

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
Expand All @@ -15,11 +16,12 @@
ConverterRegistry,
DynamoConverterImplSignature,
)
from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -205,6 +207,72 @@ def broadcastable(
return True


def broadcast_to_same_shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
lhs_val: TRTTensor,
rhs_val: TRTTensor,
) -> Tuple[TRTTensor, TRTTensor]:
"""Broadcast ITensors `lhs_val` and `rhs_val` to the same shape. If the shapes are already the same, return the
original tensors. If the shapes are different, broadcast the tensors to the same shape.

This helper function is different from fx/converter_utils.broadcast.
fx/converter_utils.broadcast only broadcasts two ITensors to the same number of dimensions (ranks)
by prepending 1s, while this function broadcasts two ITensors to the same shape.

For example, we have original ITensors: lhs_val.shape: (2, 3) rhs_val.shape: (2, 2, 1, 3)
If calling fx/converter_utils.broadcast, lhs_val.shape: (1, 1, 2, 3) lhs_val.shape: (2, 2, 1, 3).
If calling this function broadcast_to_same_shape, lhs_val.shape: (2, 2, 2, 3) lhs_val.shape: (2, 2, 2, 3).

Args:
lhs_val (TRTTensor): A TensorRT ITensor.
rhs_val (TRTTensor): A TensorRT ITensor.

Returns:
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape

"""
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
)

lhs_val_shape = lhs_val.shape
rhs_val_shape = rhs_val.shape

if tuple(lhs_val_shape) != tuple(rhs_val_shape):
rank = len(lhs_val_shape)
expanded_dims = [-1] * len(lhs_val_shape)

for dim in range(rank):
expanded_dims[dim] = max(lhs_val_shape[dim], rhs_val_shape[dim])

expanded_shape = tuple(expanded_dims)

if lhs_val_shape != expanded_shape:
lhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_lhs_val",
lhs_val,
expanded_shape,
)

if rhs_val_shape != expanded_shape:
rhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_rhs_val",
rhs_val,
expanded_shape,
)

return lhs_val, rhs_val


get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)
Expand Down
39 changes: 5 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcast_to_same_shape,
cast_trt_tensor,
get_trt_tensor,
)
Expand Down Expand Up @@ -152,41 +152,12 @@ def convert_binary_elementwise(

if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
)
else:
lhs_val_shape = lhs_val.shape
rhs_val_shape = rhs_val.shape
rank_diff = len(lhs_val_shape) - len(rhs_val_shape)
if rank_diff > 0:
rhs_val = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape
)
elif rank_diff < 0:
lhs_val = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape
)
else:
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape)
if sum_diff > 0:
rhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_rhs_val",
rhs_val,
lhs_val_shape,
)
elif sum_diff < 0:
lhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_lhs_val",
lhs_val,
rhs_val_shape,
)
lhs_val, rhs_val = broadcast_to_same_shape(
ctx, target, source_ir, f"{name}_broadcast_to_same_shape", lhs_val, rhs_val
)

layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type)
set_layer_name(layer, target, name, source_ir)
Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/conversion/test_binary_ops_aten.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
from typing import Callable

import torch
Expand Down Expand Up @@ -59,7 +58,6 @@ def forward(self, x):
self.run_test(m, inputs)

@parameterized.expand([(op[0].__name__, op[0]) for op in elementwise_ops])
@unittest.skip("Pending reimplementation of all binary converters in Dynamo")
def test_elementwise_ops_mismatched_dtypes(self, name, orig_op: Callable):
class TestModule(nn.Module):
def __init__(self, orig_op):
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_and_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseAndConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_and_tensor(self, _, shape):
def test_bitwise_and_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_and(),
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_or_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseOrConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_or_tensor(self, _, shape):
def test_bitwise_or_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_or(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_or(),
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_xor_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseXorConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_xor_tensor(self, _, shape):
def test_bitwise_xor_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_xor(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_xor(),
Expand Down
Loading