Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1172

Merged
merged 1 commit into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fx/quantized_resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def build_int8_trt(rn18):
# uncomment to check per channel quant works
weight=torch.quantization.default_per_channel_weight_observer,
)
prepared = prepare_fx(rn18, {"": qconfig})
prepared = prepare_fx(rn18, {"": qconfig}, data)
for _ in range(10):
prepared(data)
quantized_rn18 = convert_to_reference(prepared)
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
),
self.lower_setting.opt_profile_replica,
)
if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch
if self.lower_setting.explicit_batch_dimension
and self.lower_setting.dynamic_batch
else InputTensorSpec.from_tensors(input)
)
)
logger.info(f"{split_name=} {input_specs_val=}")

# Prepare algorithm selector and timing_cache for TRTInterpreter
algo_selector = None
if self.lower_setting.algo_selector:
Expand Down
6 changes: 1 addition & 5 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ class LowerSetting(LowerSettingBasic):
cache file is provided.
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
preset_lowerer (str): when specified, use a preset logic to build the
instance of Lowerer. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
how presets are applied. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
to add a preset.
instance of Lowerer.
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
only used by explicit batch dim with dynamic shape mode.
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ def pass_with_validation(
y = y.cpu()
accuracy_check = torch.allclose(x, y, **kwargs)
if not accuracy_check:
_LOGGER.error(
f"Pass {pass_} failed correctness check, get original model output as {x} and processed model output as {y} for output {kk}."
)
if suppress_accuracy_check_failure:
_LOGGER.error(
f"pass {pass_} failed correctness check due to output {kk}, escape current pass."
f"Pass {pass_} failed correctness check due to output {kk}."
)
return processed_module
else:
raise RuntimeError(
f"pass {pass_} failed correctness check due to output {kk}"
f"Pass {pass_} failed correctness check due to output {kk}"
)
return processed_module

Expand Down
18 changes: 18 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.dequantize}
)

def test_dequantize_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
x = torch.quantize_per_tensor(x, 1, 0, torch.quint8)
return x.dequantize()

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.dequantize}
)


if __name__ == "__main__":
run_tests()
33 changes: 32 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestConverter(AccTestCase):
Expand All @@ -30,6 +30,37 @@ def forward(self, x, y):
test_implicit_batch_dim=False,
)

@parameterized.expand(
[
("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)),
("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)),
# TRT does not support ellipsis or diagonal operations
]
)
def test_einsum_with_dynamic_shape_four_dimensions(
self, _, equation, x_size, y_size
):
class Einsum(nn.Module):
def forward(self, x, y):
return torch.einsum(equation, x, y)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))],
),
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(
Einsum(), input_specs, expected_ops={acc_ops.einsum}
)


if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.elu}
)

def test_elu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.elu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.elu}
)


if __name__ == "__main__":
run_tests()
42 changes: 41 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


@unittest.skip(
Expand Down Expand Up @@ -62,6 +62,46 @@ def forward(self, indices, weights):
test_explicit_batch_dim=True,
)

def test_embedding_with_dynamic_shape_four_dimensions(
self,
test_name,
indices_tensor,
weights_tensor,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
):
class TestEmbedding(torch.nn.Module):
def forward(self, indices, weights):
return torch.nn.functional.embedding(
input=indices,
weight=weights,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]

self.run_test_with_dynamic_shape(
TestEmbedding(), input_specs, expected_ops={acc_ops.embedding}
)


if __name__ == "__main__":
run_tests()
44 changes: 43 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestEqConverter(AccTestCase):
Expand Down Expand Up @@ -184,6 +184,28 @@ def forward(self, x, y):
)


class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase):
def test_eq(self):
class Eq(torch.nn.Module):
def forward(self, x, y):
return x == y

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]

self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq})


class TestEqOperatorConstantConverter(AccTestCase):
@parameterized.expand(
[
Expand Down Expand Up @@ -243,5 +265,25 @@ def forward(self, x):
)


class TestConstInputConverterWithDynamicShape(AccTestCase):
def test_eq(self):
class Eq(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.shape[0] == 4

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq})


if __name__ == "__main__":
run_tests()
17 changes: 17 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.gelu}
)

def test_gelu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.gelu(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.gelu}
)


if __name__ == "__main__":
run_tests()
46 changes: 46 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,52 @@ def forward(self, x):
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
)

# Testing with following parameters results into Error:
# AssertionError: We don't support slicing tensor on dynamic shape.

"""
("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))),
(
"slice_end_none",
(slice(None, None, None), slice(None, None, None), slice(1, None, 1)),
),
(
"slice_step_none",
(slice(None, None, None), slice(None, None, None), slice(0, 3, None)),
),
"""

@parameterized.expand(
[
("slice_batch_dim", slice(None, None, None)),
(
"slice_all_none",
(slice(None, None, None), slice(None, None, None)),
),
]
)
def test_getitem_with_dynamic_shape_four_dimensions(self, _, idx):
class Getitem(nn.Module):
def __init__(self, idx):
super().__init__()
self.idx = idx

def forward(self, x):
x = x + x
return x[self.idx]

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]

self.run_test_with_dynamic_shape(
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
)


if __name__ == "__main__":
run_tests()
Loading