Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1208

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions py/torch_tensorrt/fx/Dynamic_Shape_Support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# PyTorch Operations Dynamic Shape Support Summary



| Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason |
| --- | --- | --- | --- | --- | --- |
| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. |
| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 |
| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | |
| | avg_pool1d | partially | (-1, 3, 3) | 1 | |
| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." |
| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | |
| cat | | yes | (-1,-,1,-1,-1) | 4 | |
| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! |
| clamp | | yes | (-1,-,1,-1,-1) | | |
| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. |
| | conv1d | partially | (-1, 3, 3) | 1 | |
| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. |
| dequantize | | yes | (-1,-,1,-1,-1) | 4 | |
| eimsum | | yes | (-1,-,1,-1,-1) | 4 | |
| elu | | yes | (-1,-,1,-1,-1) | 4 | |
| embedding | | yes | (-1,-,1,-1,-1) | 4 | |
| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperatorConstant | partially | (3,-1) | 1 | |
| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| expand | | no | | | Dynamic shape is not suitable for the expand operation. |
| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | |
| gelu | | yes | (-1,-,1,-1,-1) | 4 | |
| getitem | | yes | (-1,-,1,-1,-1) | 4 | |
| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | |
| interpolate | | yes | (-1,-,1,-1,-1) | 4 | |
| isinf | | yes | (-1,-,1,-1,-1) | 4 | |
| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | |
| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. |
| logical_and | | yes | (-1, -1, -1, -1) | 4 | |
| logical_or | | yes | (-1, -1, -1, -1) | 4 | |
| logical_xor | | yes | (-1, -1, -1, -1) | 4 | |
| lt | | yes | (-1, -1, -1, -1) | 4 | |
| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| mat_mul | | yes | batch dim | | |
| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | |
| maximum | | yes | (-1, -1, -1, -1) | 4 | |
| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer |
| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | |
| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | |
| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MinMethod | yes | (-1, -1, -1, -1) | 4 | |
| minimum | | yes | (-1, -1, -1, -1) | 4 | |
| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! |
| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | |
| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeOperatorConstantConverter | partially | (3, -1) | 1 | |
| new_ones | | yes | (-1, -1, -1, -1) | 4 | |
| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. |
| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. |
| permute | | yes | (-1, -1, -1, -1) | 4 | |
| prod | | yes | (-1, -1, -1, -1) | 4 | |
| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | |
| reduce op | | yes | (-1, -1, -1, -1) | 4 | |
| relu | | yes | (-1, -1, -1, -1) | 4 | |
| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
| reshape | | yes | (-1, -1, -1, -1) | 4 | |
| selu | | yes | (-1, -1, -1, -1) | 4 | |
| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
| silu | | yes | (-1,-,1,-1,-1) | 4 | |
| size | | yes | (-1, -1, -1, -1) | 4 | |
| softmax | | yes | (-1, -1, -1, -1) | 4 | |
| softsign | | yes | (-1, -1, -1, -1) | 4 | |
| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! |
| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. |
| std | | yes | (-1, -1, -1, -1) | 4 | |
| tanh | | yes | (-1, -1, -1, -1) | 4 | |
| tile | | yes | (-1, -1, -1, -1) | 4 | |
| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | |
| | float | yes | (-1, -1, -1, -1) | 4 | |
| topk | | yes | (-1, -1, -1, -1) | 4 | |
| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | |
| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | |
| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} |
| unary ops | | yes | (-1, -1, -1, -1) | 4 | |
| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] |



Binary Ops Include following operations:
|Binary Ops |
|----------|
|add |
|sub |
|div |
|mul |
|floor_div |
|fmod |
|floor_divide|
|pow |


Unary Ops Include following operations:
|Unary Ops |
|----------|
|rsqrt |
|sin |
|cos |
|tan |
|sinh |
|cosh |
|asin |
|acos |
|atan |
|abs |
|neg |
|reciprocal|
|sqrt |
|log |
|exp |
|floor |
|ceil |
|sign |

Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing.
10 changes: 8 additions & 2 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
import logging
import math
import operator
import warnings
Expand All @@ -22,6 +23,9 @@
from .converter_utils import * # noqa: F403


_LOGGER: logging.Logger = logging.getLogger(__name__)


@tensorrt_converter(acc_ops.conv1d)
def acc_ops_conv1d(
network: TRTNetwork,
Expand Down Expand Up @@ -641,7 +645,7 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
except TypeError:
print("Unable to convert normalized_shape to a field, fall back to []")
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
normalized_shape = np.array([], dtype=np.int32)

normalized_shape_filed = trt.PluginField(
Expand All @@ -657,7 +661,9 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
else:
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
except AssertionError:
print("Unable to find layer norm plugin, fall back to TensorRT implementation.")
_LOGGER.error(
"Unable to find layer norm plugin, fall back to TensorRT implementation."
)
return layer_norm(network, target, args, kwargs, name)
layer = network.add_plugin_v2([input_val], plugin)
layer.name = name
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def create(
cls,
lower_setting: LowerSetting,
interpreter_builder: Callable = create_lower_trt_interpreter,
split_func: Callable = default_split_function,
) -> "Lowerer":
"""Instantiate a `Lowerer` instance."""

Expand All @@ -209,7 +210,7 @@ def create(
ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
leaf_module_list=lower_setting.leaf_module_list,
),
split_func=default_split_function,
split_func=split_func,
lower_func=default_lower_pass(interpreter_builder),
)
)
Expand Down
13 changes: 9 additions & 4 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import logging
from functools import partial, wraps
from typing import Any, Callable, Optional, Sequence

Expand All @@ -17,6 +18,10 @@

from .lower_basic_pass import run_const_fold


_LOGGER: logging.Logger = logging.getLogger(__name__)


Input = Sequence[Any]


Expand Down Expand Up @@ -143,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
_LOGGER.info("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

self.lower_setting.input_specs = generate_input_specs(
Expand All @@ -160,7 +165,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)
Expand All @@ -179,7 +184,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
_LOGGER.info("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

lowered_module = self._lower_func(
Expand All @@ -189,7 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def pass_with_before_after_log(
encoding="utf-8",
delete=False,
) as f:
print(f"== Log pass {pass_} before/after graph to {f.name}")
_LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}")
print(f"[{pass_}] Before:\n{module.graph}", file=f)
module = pass_(module, input)
print(f"[{pass_}] After:\n{module.graph}", file=f)
Expand Down
12 changes: 8 additions & 4 deletions py/torch_tensorrt/fx/test/passes/test_graph_opts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import unittest
from collections import Counter
from typing import Callable, Dict, List
Expand All @@ -8,13 +9,16 @@
from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination


_LOGGER: logging.Logger = logging.getLogger(__name__)


def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None:
"""
Helper func to print model's graph in plain and tabular format, also print code.
"""
print(mod_graph.graph)
_LOGGER.info(mod_graph.graph)
mod_graph.graph.print_tabular()
print(mod_graph.code)
_LOGGER.info(mod_graph.code)


@torch.fx.wrap
Expand Down Expand Up @@ -46,7 +50,7 @@ def _test_opt_with_module(
before_results = module(*inputs)
mod_traced = acc_tracer.trace(module, inputs)
before_node_list = list(mod_traced.graph.nodes)
print("Model before opt.")
_LOGGER.info("Model before opt.")
debug_print_graph_module(mod_traced)

# Apply Opt
Expand All @@ -55,7 +59,7 @@ def _test_opt_with_module(
# After Opt
after_results = mod_traced(*inputs)
after_node_list = list(mod_traced.graph.nodes)
print("Model after opt.")
_LOGGER.info("Model after opt.")
mod_traced.recompile()
debug_print_graph_module(mod_traced)

Expand Down
8 changes: 5 additions & 3 deletions py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Owner(s): ["oncall: fx"]

import logging
import unittest
from typing import Callable, List

Expand All @@ -16,6 +16,8 @@

torch.manual_seed(0)

_LOGGER: logging.Logger = logging.getLogger(__name__)


class AccTracerTest(unittest.TestCase):
def _make_model_unit_test(
Expand Down Expand Up @@ -258,7 +260,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
print(traced.graph)
_LOGGER.info(traced.graph)
ph = weight_attr = bias_attr = conv = None
for node in traced.graph.nodes:
if node.op == "placeholder":
Expand Down Expand Up @@ -626,7 +628,7 @@ def run_embedding_bag_test(is_4bit, use_weights):
)

traced = acc_tracer.trace(m, inputs)
print(traced.graph)
_LOGGER.info(traced.graph)

expected_target = (
acc_ops.embedding_bag_4bit_rowwise_offsets
Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: gpu_enablement"]
import functools
import glob
import logging
import os
import shutil
import tempfile
Expand All @@ -10,6 +11,9 @@
import torch_tensorrt.fx.diagnostics as diag


_LOGGER: logging.Logger = logging.getLogger(__name__)


def reset_diag(fn):
@functools.wraps(fn)
def reset(*a, **kw):
Expand Down Expand Up @@ -53,7 +57,7 @@ def boom() -> str:
zip_fn = collector._last_zip_path_for_test
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")
_check_file(tempdir, "bbb", "world")
Expand All @@ -78,7 +82,7 @@ def test_condition_func_name(self):
zip_fn = collector._last_zip_path_for_test
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")

Expand Down Expand Up @@ -160,7 +164,7 @@ def _test_cond(
if should_collect:
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")
else:
Expand Down
Loading