From 2e21ce67ab2df6bb0209a8a5f7c7038c5b513308 Mon Sep 17 00:00:00 2001 From: Wei Date: Sun, 22 Jan 2023 14:13:34 -0800 Subject: [PATCH] [FX] Changes done internally at Facebook (#1603) Co-authored-by: Wei Wei --- .circleci/config.yml | 4 +- py/torch_tensorrt/fx/README.md | 4 +- .../fx/converters/acc_ops_converters.py | 2 +- .../fx/converters/aten_ops_converters.py | 53 +- py/torch_tensorrt/fx/fx2trt.py | 7 +- py/torch_tensorrt/fx/lower.py | 5 +- py/torch_tensorrt/fx/lower_setting.py | 16 +- .../fx/passes/lower_basic_pass.py | 22 + .../fx/passes/lower_basic_pass_aten.py | 514 ++++++++++++++++++ .../fx/passes/lower_pass_manager_builder.py | 12 +- py/torch_tensorrt/fx/passes/pass_utils.py | 6 +- .../aten_op/test_binary_ops_aten.py | 23 - .../test/converters/aten_op/test_cat_aten.py | 58 ++ .../converters/aten_op/test_expand_aten.py | 31 ++ .../converters/aten_op/test_flatten_aten.py | 4 +- .../converters/aten_op/test_linear_aten.py | 14 +- .../converters/aten_op/test_maxpool_aten.py | 4 + .../converters/aten_op/test_reshape_aten.py | 3 + .../fx/test/core/test_trt_module.py | 170 +++--- ...test_fix_clamp_numerical_limits_to_fp16.py | 72 +++ .../fx/test/quant/test_quant_trt.py | 14 +- .../fx/test/tracer/test_acc_tracer.py | 48 ++ .../fx/test/tracer/test_dispatch_tracer.py | 15 +- py/torch_tensorrt/fx/tools/common_fx2trt.py | 80 ++- py/torch_tensorrt/fx/tools/trt_minimizer.py | 3 +- .../fx/tools/trt_profiler_sorted.py | 3 +- py/torch_tensorrt/fx/tools/trt_splitter.py | 3 +- .../fx/tracer/acc_tracer/acc_ops.py | 28 +- .../fx/tracer/acc_tracer/acc_shape_prop.py | 75 ++- .../fx/tracer/acc_tracer/acc_tracer.py | 77 +++ .../fx/tracer/dispatch_tracer/aten_tracer.py | 114 ++++ 31 files changed, 1277 insertions(+), 207 deletions(-) create mode 100644 py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_expand_aten.py create mode 100644 py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py create mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 19549721a2..5babd6a280 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -263,7 +263,7 @@ commands: parameters: torch-build: type: string - default: "2.0.0.dev20230103+cu117" + default: "2.0.0.dev20230120+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" @@ -992,7 +992,7 @@ parameters: # Nightly platform config torch-build: type: string - default: "2.0.0.dev20230103+cu117" + default: "2.0.0.dev20230120+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" diff --git a/py/torch_tensorrt/fx/README.md b/py/torch_tensorrt/fx/README.md index 916381976d..4ad69ea869 100644 --- a/py/torch_tensorrt/fx/README.md +++ b/py/torch_tensorrt/fx/README.md @@ -5,7 +5,7 @@ FX2TRT is merged as FX module in Torch-TensorRT * Method 1. Follow the instrucions for Torch-TensorRT * Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path -``` +` $ conda create --name python_env python=3.8 $ conda activate python_env # Recommend to install PyTorch 1.12 and later @@ -18,4 +18,4 @@ FX2TRT is merged as FX module in Torch-TensorRT $ pyton -c "import torch_tensorrt.fx" # Test an example by $ python py/torch_tensorrt/fx/example/lower_example.py -``` +` diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 01b15aa533..77a9b92dfe 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -2802,7 +2802,7 @@ def acc_ops_linear( if isinstance(kwargs["weight"], torch.Tensor): weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") - if target is not acc_ops.linear: + if target not in (acc_ops.linear, torch.ops.aten.linear): weight_op = trt.MatrixOperation.TRANSPOSE else: weight_op = trt.MatrixOperation.NONE diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 40e9c5b716..c79f618be3 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -187,8 +187,7 @@ def aten_ops_fmod( return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.mm.default) -@tensorrt_converter(torch.ops.aten.addmm.default) +@tensorrt_converter(torch.ops.aten.linear) def aten_ops_linear( network: TRTNetwork, target: Target, @@ -196,18 +195,12 @@ def aten_ops_linear( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if target == torch.ops.aten.addmm.default: - kwargs_new = { - "bias": args[0], - "input": args[1], - "weight": args[2], - } - elif target == torch.ops.aten.mm.default: - kwargs_new = { - "bias": None, - "input": args[0], - "weight": args[1], - } + kwargs_new = { + "input": args[0], + "weight": args[1], + "bias": args[2], + } + return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name) @@ -320,3 +313,35 @@ def aten_ops_reshape( "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]), } return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.cat.default) +def aten_ops_cat( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "tensors": args[0], + "dim": args[1], + } + return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.expand.default) +def aten_ops_expand( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "sizes": args[1], + } + return acc_ops_converters.acc_ops_expand_tensor( + network, target, None, kwargs_new, name + ) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 96f1f1cadd..d0a6bdf0a1 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -341,7 +341,12 @@ def call_method(self, target, args, kwargs): def output(self, target, args, kwargs): assert len(args) == 1 - outputs = args[0] if isinstance(args[0], tuple) else (args[0],) + if isinstance(args[0], tuple): + outputs = args[0] + elif isinstance(args[0], list): + outputs = tuple(args[0]) + else: + outputs = (args[0],) if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 61bd232421..2541143fb6 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -1,7 +1,5 @@ import dataclasses as dc import logging -import dataclasses as dc -import logging from typing import Any, Callable, Optional, Sequence # @manual=//deeplearning/trt/python:py_tensorrt @@ -180,8 +178,9 @@ def lower_pass( interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) if lower_setting.use_experimental_rt: import io - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext with io.BytesIO() as engine_bytes: engine_bytes.write(interp_res.engine.serialize()) diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index a47f8c77c5..07e7bf0dac 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -66,8 +66,8 @@ class LowerSetting(LowerSettingBasic): cuda_graph_batch_size (int): Cuda graph batch size, default to be -1. preset_lowerer (str): when specified, use a preset logic to build the instance of Lowerer. - opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is - only used by explicit batch dim with dynamic shape mode. + only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with + 2 stream on each. Set total number to 8 as a safe default value. dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension. tactic_sources: tactic sources for TensorRT kernel selection. Default to None, meaning all possible tactic sources. @@ -81,9 +81,13 @@ class LowerSetting(LowerSettingBasic): explicit_precision: bool = False max_workspace_size: int = 1 << 30 strict_type_constraints: bool = False - customized_fuse_pass: PassManager = PassManager.build_from_passlist([]) - lower_basic_fuse_pass: PassManager = PassManager.build_from_passlist( - [fuse_permute_matmul, fuse_permute_linear] + customized_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist([]) + ) + lower_basic_fuse_pass: PassManager = dc.field( + default_factory=lambda: PassManager.build_from_passlist( + [fuse_permute_matmul, fuse_permute_linear] + ) ) verbose_log: bool = False algo_selector = None @@ -91,7 +95,7 @@ class LowerSetting(LowerSettingBasic): save_timing_cache: bool = False cuda_graph_batch_size: int = -1 preset_lowerer: str = "" - opt_profile_replica: int = 1 + opt_profile_replica: int = 8 dynamic_batch: bool = True tactic_sources: Optional[int] = None correctness_atol: float = 0.1 diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index 844fa24238..e753d6e227 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -608,3 +608,25 @@ def _get_shape(node: fx.Node) -> Optional[torch.Size]: # shape info not available return None return node.meta["tensor_meta"].shape + + +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def fix_clamp_numerical_limits_to_fp16( + mod: torch.fx.GraphModule, input: Input +) -> torch.fx.GraphModule: + MIN_FP16 = -65504.0 + MAX_FP16 = 65504.0 + for node in mod.graph.nodes: + if node.op == "call_function" and "clamp" in str(node.target): + input_kwargs = node.kwargs + if input_kwargs["min"] < MIN_FP16 and input_kwargs["max"] > MAX_FP16: + new_kwargs = { + "input": input_kwargs["input"], + "min": MIN_FP16, + "max": MAX_FP16, + } + node.kwargs = new_kwargs + + mod.recompile() + return mod diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py new file mode 100644 index 0000000000..0ca4383f6e --- /dev/null +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -0,0 +1,514 @@ +import logging +import operator +from typing import Any + +import torch +import torch.fx +from torch.fx.experimental.const_fold import split_const_subgraphs +from torch.fx.passes.infra.pass_base import PassResult + +_LOGGER = logging.getLogger(__name__) + +# Create an alias for module input type to avoid littering pyre-ignore for Any +# throughout the file. +Input = Any + + +def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Now we do constant folding on traced module. + def skip_folding(node: torch.fx.Node): + if node.target == torch.ops.aten.sym_size: + return True + + const_split_mod = split_const_subgraphs( + traced_mod, skip_folding_node_fn=skip_folding + ) + const_split_mod.run_folding() + return const_split_mod + + +def replace_inplace_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + Remove this func after functionalization is workable + """ + modified = False + map_func = { + torch.ops.aten.relu_.default: torch.ops.aten.relu.default, + torch.ops.aten.hardtanh_.default: torch.ops.aten.hardtanh.default, + torch.ops.aten.add_.Tensor: torch.ops.aten.add.Tensor, + } + for n in module.graph.nodes: + if n.op == "call_function" and n.target in map_func.keys(): + modified = True + node = n + with module.graph.inserting_after(node): + new_args = node.args + new_node = module.graph.create_node( + "call_function", + map_func[node.target], + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_native_layernorm_with_layernorm( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if ( + n.op == "call_function" + and n.target == torch.ops.aten.native_layer_norm.default + ): + for v in n.users: + if v.op == "call_function" and v.target == operator.getitem: + if v.args[1] != 0: + raise RuntimeError( + f"Got args[{v.args[1]}]!!\n" + "layernorm can only generate output (args[0]), " + "not mean (args[1]) or std (args[2])!" + ) + new_op = torch.ops.aten.layer_norm.default + new_args = (*n.args, True) # cudnn_enable=True + modified = True + else: + continue + + with module.graph.inserting_after(v): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=v.kwargs, + ) + v.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_transpose_mm_op_with_linear( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.t.default: + to_erase = [] + for v in n.users: + if v.op == "call_function" and v.target == torch.ops.aten.addmm.default: + new_op = torch.ops.aten.linear + bias, inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, bias) + modified = True + elif v.op == "call_function" and v.target == torch.ops.aten.mm.default: + new_op = torch.ops.aten.linear + inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, None) + modified = True + # this pass should be after `compose_bmm` + elif v.op == "call_function" and v.target == aten_compose_bmm_2d: + new_op = torch.ops.aten.linear + inp, _ = list(v.args) + weight = list(n.args)[0] + new_args = (inp, weight, None) + modified = True + else: + continue + + with module.graph.inserting_after(v): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=v.kwargs, + ) + v.replace_all_uses_with(new_node) + to_erase.append(v) + for v in to_erase: + module.graph.erase_node(v) + module.graph.eliminate_dead_code() + module.recompile() + # handle the linear with multiple dim, remove the extra reshape + for n in module.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear: + before = n.args[0] + after = next(iter(n.users)) + if (len(n.users) == 1 and after.target == torch.ops.aten.view.default) and ( + before.target == torch.ops.aten.view.default and len(before.users) == 1 + ): + real_input = before.args[0] + new_args = list(n.args) + new_args[0] = real_input + n.args = tuple(new_args) + after.replace_all_uses_with(n) + module.graph.eliminate_dead_code() + module.recompile() + + return PassResult(module, modified) + + +def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool3d_with_indices.default, + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + ): + modified = True + if len(n.users) != 1: + raise RuntimeError( + f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" + ) + if n.target == torch.ops.aten.max_pool2d_with_indices.default: + new_op = torch.ops.aten.max_pool2d + new_args = n.args + elif n.target == torch.ops.aten.max_pool3d_with_indices.default: + new_op = torch.ops.aten.max_pool3d + new_args = n.args + elif ( + n.target == torch.ops.aten.native_batch_norm.default + or n.target == torch.ops.aten._native_batch_norm_legit.default + ): + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + new_args = tuple(new_args) + + getitem_node = next(iter(n.users)) + with module.graph.inserting_after(getitem_node): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=n.kwargs, + ) + getitem_node.replace_all_uses_with(new_node) + module.graph.erase_node(getitem_node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def replace_aten_reshape_alias_with_replace( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + # The stride parameter is not used. Replace with reshape without stride + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten._reshape_alias.default, + ): + modified = True + node = n + with module.graph.inserting_after(node): + new_args = (node.args[0], node.args[1]) + new_node = module.graph.create_node( + "call_function", + torch.ops.aten.reshape, + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + break + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def remove_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + 1. Remove clone, _unsafe_view node. #TODO Remove this func after functionalization is workable + 2. Remove inefficient op getitem(index=slice) P561572458 + """ + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (torch.ops.aten.clone.default,): + modified = True + node = n + input_n = node.all_input_nodes[0] + node.replace_all_uses_with(input_n) + module.graph.eliminate_dead_code() + module.recompile() + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten._unsafe_view.default, + ): + modified = True + node = n + with module.graph.inserting_after(node): + new_node = module.graph.create_node( + "call_function", + torch.ops.aten.reshape, + args=node.args, + kwargs=node.kwargs, + ) + node.replace_all_uses_with(new_node) + module.graph.erase_node(node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_operator_getitem(*args): + return operator.getitem(*args) + + +def replace_builtin_ops( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + To differential the same op in fx2ait as they are registered in the same dictionary + """ + + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (operator.getitem,): + modified = True + n.target = aten_operator_getitem + module.graph.eliminate_dead_code() + module.recompile() + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +############### +""" +Trace compose. For some ops, we do not want to decompose further but want coarse granularity +For ex: +1. bmm +2. chunk +3. getitem(input, idx=(slice(),slice()...)) +""" + + +def aten_compose_getitem_slice(input, list_args): + for _, args in enumerate(list_args): + input = torch.ops.aten.slice.Tensor(input, *args) + return input + + +def compose_getitem_slice( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed getitem(input, idx=(slice(),slice()...)) + """ + + def match_pattern(module, node): + if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor: + holder = [] + holder.append(node) + while ( + len(node.users.keys()) == 1 + and next(iter(node.users)).target == torch.ops.aten.slice.Tensor + and node.args[1] + 1 == next(iter(node.users)).args[1] + ): + node = next(iter(node.users)) + holder.append(node) + if len(holder) == 1: + return (False,) + else: + return (True, holder) + return (False,) + + modified = False + for node in module.graph.nodes: + res = match_pattern(module, node) + if res[0]: + modified = True + holder = res[1] + input_n = holder[0].args[0] + last_n = holder[-1] + list_args = [] + for h_n in holder: + list_args.append(h_n.args[1:]) + + with module.graph.inserting_after(last_n): + new_args = (input_n, list_args) + new_node = module.graph.create_node( + "call_function", + aten_compose_getitem_slice, + args=new_args, + kwargs=None, + ) + last_n.replace_all_uses_with(new_node) + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_compose_bmm_2d(flat_args_1, flat_args_2): + sym_size = torch.ops.aten.sym_size(flat_args_1, 0) + sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) + sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) + expand = torch.ops.aten.expand.default( + flat_args_1, [sym_size, sym_size_1, sym_size_2] + ) + view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) + sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 0) + sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 1) + expand_1 = torch.ops.aten.expand.default( + flat_args_2, [sym_size, sym_size_3, sym_size_4] + ) + view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) + bmm = torch.ops.aten.bmm.default(view, view_1) + view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) + return view_2 + + +def aten_compose_bmm_3d(flat_args_1, flat_args_2): + sym_size = torch.ops.aten.sym_size(flat_args_1, 0) + sym_size_1 = torch.ops.aten.sym_size(flat_args_1, 1) + sym_size_2 = torch.ops.aten.sym_size(flat_args_1, 2) + expand = torch.ops.aten.expand.default( + flat_args_1, [sym_size, sym_size_1, sym_size_2] + ) + view = torch.ops.aten.view.default(expand, [sym_size, sym_size_1, sym_size_2]) + sym_size_3 = torch.ops.aten.sym_size(flat_args_2, 1) + sym_size_4 = torch.ops.aten.sym_size(flat_args_2, 2) + expand_1 = torch.ops.aten.expand.default( + flat_args_2, [sym_size, sym_size_3, sym_size_4] + ) + view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) + bmm = torch.ops.aten.bmm.default(view, view_1) + view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) + return view_2 + + +def compose_bmm( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed bmm (matmul) + """ + modified = False + for n in module.graph.nodes: + if n.op == "call_function" and n.target in (torch.ops.aten.bmm.default,): + modified = True + node = n + input_n = node.all_input_nodes[0] + other_n = node.all_input_nodes[1] + output = next(iter(node.users)) + input_input_n = input_n.all_input_nodes[0] + if ( + input_input_n.target != torch.ops.aten.expand.default + and input_n.target != torch.ops.aten.view.default + ): + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" + ) + real_input = input_input_n.all_input_nodes[0] + input_other_n = other_n.all_input_nodes[0] + if ( + input_other_n.target != torch.ops.aten.expand.default + and other_n.target != torch.ops.aten.view.default + ): + raise RuntimeError( + "Bmm is addressed in fixed pattern. A new pattern is met!" + ) + real_other = input_other_n.all_input_nodes[0] + if len(real_other.meta["val"].size()) == 2: + new_func = aten_compose_bmm_2d + if len(real_other.meta["val"].size()) == 3: + new_func = aten_compose_bmm_3d + + with module.graph.inserting_after(node): + new_args = (real_input, real_other) + new_node = module.graph.create_node( + "call_function", + new_func, + args=new_args, + kwargs=None, + ) + output.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) + + +def aten_compose_chunk(flat_args_1, chunk, dim): + sym_size = torch.ops.aten.sym_size(flat_args_1, dim) + add = operator.add(sym_size, chunk) + sub = operator.sub(add, 1) + floordiv = operator.floordiv(sub, chunk) + split = torch.ops.aten.split.Tensor(flat_args_1, floordiv, dim) + return split + + +def compose_chunk( + module: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """ + combine decomposed chunk + """ + + def match_pattern(module, node): + if node.op == "call_function" and node.target in (torch.ops.aten.split.Tensor,): + div = node.args[1] + input = node.args[0] + if isinstance(div, int): + return (False,) + if div.target != operator.floordiv: + return (False,) + else: + div_const = div.args[1] + sub = div.args[0] + if sub.target != operator.sub: + return (False,) + else: + add = sub.args[0] + if add.target != operator.add: + return (False,) + else: + add_const = add.args[1] + if add_const != div_const: + return (False,) + symsize = add.args[0] + if symsize.target != torch.ops.aten.sym_size: + return (False,) + else: + symsize_input = symsize.args[0] + dim = symsize.args[1] + if symsize_input != input: + return (False,) + + return (True, div_const, dim) + else: + return (False,) + + modified = False + for node in module.graph.nodes: + res = match_pattern(module, node) + if res[0]: + modified = True + with module.graph.inserting_after(node): + new_args = (node.args[0], res[1], res[2]) + new_node = module.graph.create_node( + "call_function", + aten_compose_chunk, + args=new_args, + kwargs=None, + ) + node.replace_all_uses_with(new_node) + + module.graph.eliminate_dead_code() + module.recompile() + return PassResult(module, modified) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index c4bb927b85..514a52fab8 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -8,6 +8,7 @@ from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch_tensorrt.fx.utils import LowerPrecision from ..input_tensor_spec import generate_input_specs @@ -16,7 +17,8 @@ from ..passes.remove_duplicate_output_args import remove_duplicate_output_args from .graph_opts import common_subexpression_elimination -from .lower_basic_pass import ( +from .lower_basic_pass import ( # noqa + fix_clamp_numerical_limits_to_fp16, fix_reshape_batch_dim, replace_mutable_op, replace_op_with_indices, @@ -108,6 +110,14 @@ def graph_optimization_pass(self) -> PassManager: passes.append(wrapper(p, self._input)) for p in self.lower_setting.lower_basic_fuse_pass.passes: passes.append(wrapper(p, self._input)) + if ( + hasattr(self.lower_setting, "lower_precision") + and self.lower_setting.lower_precision is LowerPrecision.FP16 + ) or ( + hasattr(self.lower_setting, "precision") + and self.lower_setting.precision is LowerPrecision.FP16 + ): + passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) passes.append(inplace_wrapper(common_subexpression_elimination)) passes.append( diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 78e9ec1b22..fabc92881d 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,6 +1,7 @@ import io import logging import tempfile +from datetime import datetime from functools import wraps from typing import Any, Callable, List, Optional @@ -243,13 +244,14 @@ def pass_with_before_after_log( ) as f: print(f"[{pass_}] Before:\n{module.graph}", file=f) print(module.graph, file=before_io) - + start_time = datetime.now() module = pass_(module, input) + t_elapsed = datetime.now() - start_time print(f"[{pass_}] After:\n{module.graph}", file=f) print(module.graph, file=after_io) t = before_io.getvalue() == after_io.getvalue() _LOGGER.info( - f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}" + f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}, time elapsed = {t_elapsed}" ) return module diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py index 028510b472..b91ff301ef 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py @@ -200,29 +200,6 @@ def forward(self, x, y): ] self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) - def test_elementwise_ops_with_scalar_lhs(self): - def orig_op(x, y): - return x + y - - class TestModule(nn.Module): - def __init__(self, orig_op): - super().__init__() - self.constant = torch.randn(1) - self.orig_op = orig_op - - def forward(self, x): - return self.orig_op(x, self.constant) - - m = TestModule(orig_op) - inputs = [torch.randn(10)] - self.run_test( - m, - inputs, - expected_ops={torch.ops.aten.add.Tensor}, - test_explicit_batch_dim=False, - test_implicit_batch_dim=True, - ) - if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py new file mode 100644 index 0000000000..cfeb235af3 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestCatConverter(DispatchTestCase): + @parameterized.expand( + [ + ("pos", 1), + # ("neg", -2), #Dynamo tracer issue + ] + ) + def test_cat(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z), dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + @parameterized.expand( + [ + ("pos", 1), + # ("neg", -2), #Dynamo tracer issue + ] + ) + def test_cat_dynamic_shape(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y), dim) + + input_specs = [ + InputTensorSpec( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))], + ), + InputTensorSpec( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_expand_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_expand_aten.py new file mode 100644 index 0000000000..e1e5eb356f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_expand_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase + + +class TestExpandConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (2, 3), (2, 1)), + ("3d_dim", (2, 3, 4), (2, 1, 1)), + ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), + ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ] + ) + def test_expand(self, _, sizes, init_size): + class Expand(nn.Module): + def forward(self, x): + return x.expand(*sizes) + + inputs = [torch.randn(*init_size)] + self.run_test( + Expand(), + inputs, + expected_ops={torch.ops.aten.expand.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py index 69dea57efb..54f9e1e53e 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch.nn as nn from parameterized import parameterized @@ -13,6 +15,7 @@ class TestFlattenConverter(DispatchTestCase): ("flatten_all", 0, 3), ] ) + @unittest.skip("Not support yet") def test_flatten(self, _, start_dim, end_dim): class Flatten(nn.Module): def __init__(self, start, end): @@ -28,7 +31,6 @@ def forward(self, x): Flatten(start_dim, end_dim), inputs, expected_ops={torch.ops.aten.view.default}, - test_implicit_batch_dim=(start_dim != 0), ) ## Dynamic shape does not work due to flatten converts to reshape in tracing. And batch or dynamic dimension is converted to fixed integer and loose dynamic diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py index 408361f31b..1870490e85 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py @@ -7,20 +7,20 @@ class TestLinearConverter(DispatchTestCase): @parameterized.expand( [ - ("default", [1, 512], True, torch.ops.aten.addmm.default), - ("matrix", [5, 512], True, torch.ops.aten.addmm.default), - ("no_bias", [1, 512], False, torch.ops.aten.mm.default), + ("default", [1, 512], True, torch.ops.aten.linear), + ("matrix", [5, 512], True, torch.ops.aten.linear), + ("no_bias", [1, 512], False, torch.ops.aten.linear), ( "multi_dim_matrix", [4, 5, 512], True, - torch.ops.aten.addmm.default, + torch.ops.aten.linear, ), ( "multi_dim_matrix", [4, 5, 512], False, - torch.ops.aten.mm.default, + torch.ops.aten.linear, ), ] ) @@ -34,9 +34,7 @@ def forward(self, x): return self.linear(x) inputs = [torch.randn(shape)] - self.run_test( - TestModule(), inputs, expected_ops={op}, test_implicit_batch_dim=False - ) + self.run_test(TestModule(), inputs, expected_ops={op}) # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py index 95a86f4827..fac55ad46a 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized @@ -71,6 +73,7 @@ def forward(self, x): # param("ceil_mode", 1, ceil_mode=True), ] ) + @unittest.skip("PT tracer issue") def test_max_pool3d( self, test_name, @@ -144,6 +147,7 @@ def forward(self, x): param("stride", 2, stride=()), ] ) + @unittest.skip("PT tracer issue") def test_stride_none_max_pool3d( self, test_name, diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 22a254e407..96c8fe7423 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized @@ -12,6 +14,7 @@ class TestReshapeConverter(DispatchTestCase): ((1, 10, -1),), ] ) + @unittest.skip("Need support") def test_reshape(self, target_shape): class TestModule(torch.nn.Module): def __init__(self, target_shape): diff --git a/py/torch_tensorrt/fx/test/core/test_trt_module.py b/py/torch_tensorrt/fx/test/core/test_trt_module.py index 71855e1299..df4de754ba 100644 --- a/py/torch_tensorrt/fx/test/core/test_trt_module.py +++ b/py/torch_tensorrt/fx/test/core/test_trt_module.py @@ -9,8 +9,9 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule -from torch_tensorrt import TRTModuleNext -from torch_tensorrt import Device + +# from torch_tensorrt import TRTModuleNext +# from torch_tensorrt import Device from torch_tensorrt.fx.utils import LowerPrecision @@ -57,88 +58,89 @@ def forward(self, x): ) -class TestTRTModuleNext(TestCase): - def test_save_and_load_trt_module(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x + x - - inputs = [torch.randn(1, 1)] - mod = TestModule().eval() - ref_output = mod(*inputs) - - mod = acc_tracer.trace(mod, inputs) - - interp = TRTInterpreter( - mod, - input_specs=InputTensorSpec.from_tensors(inputs), - explicit_batch_dimension=True, - ) - interp_res = interp.run(lower_precision=LowerPrecision.FP32) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interp_res.engine.serialize()) - engine_str = engine_bytes.getvalue() - - trt_mod = TRTModuleNext( - name="TestModule", - serialized_engine=engine_str, - input_binding_names=interp_res.input_names, - output_binding_names=interp_res.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - ) - - torch.save(trt_mod, "trt.pt") - reload_trt_mod = torch.load("trt.pt") - - torch.testing.assert_allclose( - reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), - ref_output, - rtol=1e-04, - atol=1e-04, - ) - os.remove(f"{os.getcwd()}/trt.pt") - - def test_save_and_load_state_dict(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x + x - - inputs = [torch.randn(1, 1)] - mod = TestModule().eval() - ref_output = mod(*inputs) - - mod = acc_tracer.trace(mod, inputs) - interp = TRTInterpreter( - mod, - input_specs=InputTensorSpec.from_tensors(inputs), - explicit_batch_dimension=True, - ) - interp_res = interp.run(lower_precision=LowerPrecision.FP32) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interp_res.engine.serialize()) - engine_str = engine_bytes.getvalue() - - trt_mod = TRTModuleNext( - name="TestModule", - serialized_engine=engine_str, - input_binding_names=interp_res.input_names, - output_binding_names=interp_res.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - ) - - st = trt_mod.state_dict() - - new_trt_mod = TRTModuleNext() - new_trt_mod.load_state_dict(st) - - torch.testing.assert_allclose( - new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), - ref_output, - rtol=1e-04, - atol=1e-04, - ) +# TODO add unittest.skip later +# class TestTRTModuleNext(TestCase): +# def test_save_and_load_trt_module(self): +# class TestModule(torch.nn.Module): +# def forward(self, x): +# return x + x + +# inputs = [torch.randn(1, 1)] +# mod = TestModule().eval() +# ref_output = mod(*inputs) + +# mod = acc_tracer.trace(mod, inputs) + +# interp = TRTInterpreter( +# mod, +# input_specs=InputTensorSpec.from_tensors(inputs), +# explicit_batch_dimension=True, +# ) +# interp_res = interp.run(lower_precision=LowerPrecision.FP32) + +# with io.BytesIO() as engine_bytes: +# engine_bytes.write(interp_res.engine.serialize()) +# engine_str = engine_bytes.getvalue() + +# trt_mod = TRTModuleNext( +# name="TestModule", +# serialized_engine=engine_str, +# input_binding_names=interp_res.input_names, +# output_binding_names=interp_res.output_names, +# target_device=Device(f"cuda:{torch.cuda.current_device()}"), +# ) + +# torch.save(trt_mod, "trt.pt") +# reload_trt_mod = torch.load("trt.pt") + +# torch.testing.assert_allclose( +# reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), +# ref_output, +# rtol=1e-04, +# atol=1e-04, +# ) +# os.remove(f"{os.getcwd()}/trt.pt") + +# def test_save_and_load_state_dict(self): +# class TestModule(torch.nn.Module): +# def forward(self, x): +# return x + x + +# inputs = [torch.randn(1, 1)] +# mod = TestModule().eval() +# ref_output = mod(*inputs) + +# mod = acc_tracer.trace(mod, inputs) +# interp = TRTInterpreter( +# mod, +# input_specs=InputTensorSpec.from_tensors(inputs), +# explicit_batch_dimension=True, +# ) +# interp_res = interp.run(lower_precision=LowerPrecision.FP32) + +# with io.BytesIO() as engine_bytes: +# engine_bytes.write(interp_res.engine.serialize()) +# engine_str = engine_bytes.getvalue() + +# trt_mod = TRTModuleNext( +# name="TestModule", +# serialized_engine=engine_str, +# input_binding_names=interp_res.input_names, +# output_binding_names=interp_res.output_names, +# target_device=Device(f"cuda:{torch.cuda.current_device()}"), +# ) + +# st = trt_mod.state_dict() + +# new_trt_mod = TRTModuleNext() +# new_trt_mod.load_state_dict(st) + +# torch.testing.assert_allclose( +# new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output), +# ref_output, +# rtol=1e-04, +# atol=1e-04, +# ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py new file mode 100644 index 0000000000..457a9e415a --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py @@ -0,0 +1,72 @@ +import logging +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch_tensorrt.fx.passes.lower_basic_pass import fix_clamp_numerical_limits_to_fp16 + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: + """ + Helper func to print model's graph in plain and tabular format, also print code. + """ + _LOGGER.info(mod_graph.graph) + mod_graph.graph.print_tabular() + _LOGGER.info(mod_graph.code) + + +class ClampNumericalLimitsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def test_clamp_numerical_limits_to_fp16(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.clamp(x + x, min=-1e8, max=1e8) + return y + + module = TestModule() + inputs = [torch.rand(3, 2, 1)] + + module.eval() + + # Before Opt + before_results = module(*inputs) + mod_traced = acc_tracer.trace(module, inputs) + before_node_list = list(mod_traced.graph.nodes) + clamp_node_before = [node for node in before_node_list if "clamp" in str(node)] + min_val_before = clamp_node_before[0].kwargs["min"] + max_val_before = clamp_node_before[0].kwargs["max"] + _LOGGER.info("Model before opt.") + debug_print_graph_module(mod_traced) + + # Apply Opt + module_after_pass = fix_clamp_numerical_limits_to_fp16(mod_traced, inputs) + + # After Opt + after_results = module_after_pass(*inputs) + after_node_list = list(mod_traced.graph.nodes) + clamp_node_after = [node for node in after_node_list if "clamp" in str(node)] + min_val_after = clamp_node_after[0].kwargs["min"] + max_val_after = clamp_node_after[0].kwargs["max"] + _LOGGER.info("Model after opt.") + mod_traced.recompile() + debug_print_graph_module(mod_traced) + + # Tests + # * Numerics + tol_args = {"rtol": 1e-2, "atol": 1e-2} + torch.testing.assert_close(before_results, after_results, **tol_args) + + # graph should not change + self.assertTrue(before_node_list == after_node_list) + + # values of clamp node changed + self.assertTrue(min_val_before != min_val_after) + self.assertTrue(max_val_before != max_val_after) diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index 10ff886d33..5d5e747505 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -696,7 +696,7 @@ def conv_add_extra_inputs_getter(pattern): return [extra_input] conv_add_config = { - "pattern": (operator.add, torch.nn.Conv2d, MatchAllNode), + "pattern_complex_format": (operator.add, torch.nn.Conv2d, MatchAllNode), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_qint8_dtype_config, @@ -728,9 +728,6 @@ def conv_add_extra_inputs_getter(pattern): } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - @unittest.skip( - "This is not stable. We can enable the test after it becomes stable." - ) def test_conv_add_standalone_module(self): class Standalone(torch.nn.Module): def __init__(self): @@ -751,9 +748,7 @@ def forward(self, x): y = self.conv(x) return self.standalone(x, y) - from torch.ao.quantization.backend_config.observation_type import ( - ObservationType, - ) + from torch.ao.quantization.backend_config import ObservationType weighted_op_quint8_dtype_config = { # optional, input activation dtype @@ -769,7 +764,10 @@ def forward(self, x): } conv_add_config = { - "pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)), + "pattern_complex_format": ( + torch.nn.ReLU, + (operator.add, torch.nn.Conv2d, MatchAllNode), + ), "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, "dtype_configs": [ weighted_op_quint8_dtype_config, diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 8709661dfc..709df1cd2f 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -19,6 +19,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) +torch.fx.wrap("len") + class AccTracerTest(unittest.TestCase): def _make_model_unit_test( @@ -1393,6 +1395,37 @@ def test_transpose(self): acc_ops.permute, lambda x: torch.transpose(x, 1, 0) ) + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + x = len(a.shape) - 2 + y = len(a.shape) - 1 + return a.transpose(x, y) + + m = TestModule() + m.eval() + + a = torch.randn(2, 3, 4, 5) + traced = acc_tracer.trace(m, [a]) + + ph_a = permute = None + for node in traced.graph.nodes: + if node.op == "placeholder": + ph_a = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.permute) + self.assertEqual(node.kwargs["input"], ph_a) + self.assertEqual(node.kwargs["permutation"], [0, 1, 3, 2]) + permute = node + elif node.op == "output": + self.assertEqual(permute, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + self.assertTrue(torch.equal(m(a), traced(a))) + def test_permute(self): """ Test that torch.permute is traced correctly. @@ -2341,11 +2374,25 @@ def test_resnext50_32x4d(self): self._make_model_unit_test(m) def test_cumsum(self): + # Tests call_function version self._make_acc_op_function_test(acc_ops.cumsum, torch.cumsum, dim=1) self._make_acc_op_function_test( acc_ops.cumsum, torch.cumsum, dim=1, dtype=torch.float ) + # Tests call_method version + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return a.cumsum(dim=0) + + m = TestModule() + a = torch.rand(2, 2) + gm = acc_tracer.trace(m, [a]) + self.assertTrue(torch.equal(m(a), gm(a))) + def test_chunk(self): self._make_acc_op_function_test(acc_ops.chunk, torch.chunk, chunks=2, dim=0) @@ -2659,6 +2706,7 @@ def test_all_acc_ops_registered(self): acc_ops.trunc_div, acc_ops.pow, acc_ops.relu, + acc_ops.prelu, acc_ops.leaky_relu, acc_ops.elu, acc_ops.selu, diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index bb515252ed..39a1a46a34 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -3,6 +3,8 @@ import torch import torch._dynamo as torchdynamo + +import torch._dynamo.config import torchvision from functorch.experimental import functionalize from torch._dynamo.optimizations import backends @@ -13,6 +15,9 @@ from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace +# TODO(ezyang): remove this after we properly support fake example inputs +torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True + torch.manual_seed(0) wrap_lib = Library("wrap", "DEF") @@ -109,9 +114,6 @@ def test_resnet18_aten(self): mod = torchvision.models.resnet18() mod = mod.cuda().half().eval() - def f(x): - return mod(x) - inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] @@ -148,9 +150,10 @@ def f(x): # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) # so we choose to use cosine similarity - cos = torch.nn.CosineSimilarity(dim=0, eps=1e-4) - cos_val = cos(aten_output.flatten(), fx_output.flatten()) - self.assertTrue(cos_val.cpu().numpy() > 0.999) + cos_val = torch.nn.functional.cosine_similarity( + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) def test_resnet18_dynamo(self): mod = torchvision.models.resnet18() diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 2aa82b1d72..bd22e8bb4e 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -1,16 +1,30 @@ import logging import time import unittest -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Set, Tuple import torch import torch.fx import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.experimental.normalize import NormalizeArgs from torch.fx.passes import shape_prop +from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) from torch_tensorrt.fx.passes.pass_utils import chain_passes from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace @@ -88,9 +102,15 @@ def run_test( f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" ) - if isinstance(outputs, torch.Tensor): - ref_outputs = [ref_outputs] + if type(outputs) not in (list, tuple): outputs = [outputs] + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] for out, ref in zip(outputs, ref_outputs): if not isinstance(ref, torch.Tensor): ref = torch.tensor([ref]) @@ -309,6 +329,47 @@ def run_test_with_dynamic_shape( class DispatchTestCase(TRTTestCase): + def generate_graph( + self, + mod: torch.nn.Module, + original_inputs: List[torch.Tensor], + expected_ops: Set[Callable], + unexpected_ops: Optional[Set[Callable]] = None, + customized_passes: List[Callable] = None, + ): + # Torchdynamo+aot proxytensor tracer + # Below are common passes + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + ] + # Combine with customized passes specific to any model + if customized_passes: + passes_list.extend(customized_passes) + + fx_module, _ = aten_tracer.trace(mod, original_inputs) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + fx_module(*original_inputs) + + fx_module = run_const_fold(fx_module) + print(fx_module.graph) + + if len(expected_ops): + self.assert_has_op(fx_module, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(fx_module, unexpected_ops) + + return fx_module + def run_test( self, mod, @@ -317,27 +378,18 @@ def run_test( unexpected_ops=None, apply_passes=None, test_explicit_batch_dim=True, - test_implicit_batch_dim=True, test_explicit_precision=False, rtol=1e-03, atol=1e-03, precision=LowerPrecision.FP32, ): - mod = proxytensor_trace(mod, inputs) + mod.eval() + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) if apply_passes is not None: pass_tracer = chain_passes(*apply_passes) mod = pass_tracer(mod, inputs) - if test_implicit_batch_dim: - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - if test_explicit_batch_dim: interp = TRTInterpreter( mod, diff --git a/py/torch_tensorrt/fx/tools/trt_minimizer.py b/py/torch_tensorrt/fx/tools/trt_minimizer.py index d396453e22..f44a5e1d25 100644 --- a/py/torch_tensorrt/fx/tools/trt_minimizer.py +++ b/py/torch_tensorrt/fx/tools/trt_minimizer.py @@ -22,8 +22,9 @@ def lower_mod_default( interpreter_result = interp.run(max_batch_size=batch_size) if use_experimental_rt: import io - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) diff --git a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py index ac0a02ac1d..48293773c4 100644 --- a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py +++ b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py @@ -3,9 +3,10 @@ import operator from typing import List, Mapping, Optional -import tensorrt as trt import torch +from tensorrt import tensorrt as trt + from .. import TRTModule _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index a9b692ead3..bea925453f 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -92,8 +92,9 @@ def _lower_model_to_backend( interpreter_result = interp.run(*inputs) if self.settings.use_experimental_rt: import io - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt._Device import Device + from torch_tensorrt._TRTModuleNext import TRTModuleNext with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 86b05b0d9e..bf693114fc 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -262,7 +262,7 @@ def custom_type_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: input_obj = node.kwargs["input"] dtype_obj = node.kwargs.get("dtype") with node.graph.inserting_before(node): - if dtype_obj == None: + if dtype_obj is None: dtype_node = node.graph.call_function(dtype, kwargs={"input": input_obj}) dtype_node.meta["type"] = torch.dtype return dtype_node @@ -787,6 +787,13 @@ def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ("mat2", "other"), ], ) +@register_acc_op_mapping( + op_and_target=("call_method", "matmul"), + arg_replacement_tuples=[ + ("input", "input"), + ("mat2", "other"), + ], +) @register_acc_op_mapping( op_and_target=("call_function", operator.matmul), arg_replacement_tuples=[ @@ -1163,6 +1170,16 @@ def leaky_relu(*, input, negative_slope=0.01, inplace=False): ) +@register_acc_op_properties(AccOpProperty.pointwise) +@register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.prelu)) +@register_acc_op +def prelu(*, input, weight): + return nn.functional.prelu( + input=input, + weight=weight, + ) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.elu)) @register_acc_op @@ -2838,7 +2855,14 @@ def gelu(*, input, approximate="none"): @register_acc_op_properties(AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.cumsum)) -@register_acc_op_mapping(op_and_target=("call_method", "cumsum")) +@register_acc_op_mapping( + op_and_target=("call_method", "cumsum"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("dtype", "dtype", this_arg_is_optional), + ], +) @register_acc_op def cumsum(*, input, dim, dtype=None): return torch.cumsum(input=input, dim=dim, dtype=dtype) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py index 4bc4fc0063..96411246f0 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py @@ -32,19 +32,15 @@ class AccShapeProp(shape_prop.ShapeProp): """ def _run_node(self, n: torch.fx.Node) -> Any: - # Run embedding bag ops with XL weights in a customized way, see - # docstring for self.run_embedding_bag for more details - if ( - n.target - in { - acc_ops.embedding_bag, - acc_ops.embedding_bag_4bit_rowwise_offsets, - acc_ops.embedding_bag_byte_rowwise_offsets, - } - and n.kwargs["weight"].target == acc_ops.xl_weight + # Run ops with XL weights by clamping their inputs, see + # docstring for self.run_node_with_xl_weights for more details + if any( + isinstance(kwarg, torch.fx.Node) and kwarg.target == acc_ops.xl_weight + for kwarg in n.kwargs.values() ): - return self.run_embedding_bag(n) - return super().run_node(n) + return self.run_node_with_xl_weights(n) + else: + return super().run_node(n) def run_node(self, n: torch.fx.Node) -> Any: # First try running shape_prop with the original inputs. @@ -79,7 +75,7 @@ def run_node(self, n: torch.fx.Node) -> Any: return result - def run_embedding_bag(self, n: torch.fx.Node) -> Any: + def run_node_with_xl_weights(self, n: torch.fx.Node) -> Any: """ EmbeddingBag with XL Weights of shape (num_embeddings, embedding_dim) are replaced with smaller proxies of shape @@ -87,21 +83,48 @@ def run_embedding_bag(self, n: torch.fx.Node) -> Any: cause index out of bounds issues when sample inputs lead to the embedding bag op indexing into the first dimension of the weight tensor which it expects to be bigger than it is during tracing. - """ - if n.target == acc_ops.embedding_bag: - indices = n.kwargs["input"] - else: - indices = n.kwargs["indices"] - # Replace indices with zeros of same shape and dtype - indices_tensor = self.env[indices] - indices_zeros = torch.zeros_like(indices_tensor, dtype=indices_tensor.dtype) - self.env[indices] = indices_zeros + For these ops, return a zeros tensor of the correct shape and dtype. - # Run node - result = super().run_node(n) + # TODO(T137066700): migrate shape inference to OSS and use it here to + determine shape/dtype of output tensor. This will enable all ops to use + xl_weights instead of just the ones treated here. + """ + + op = n.target.__module__ + "." + n.target.__name__ + + if op.endswith("acc_ops.int_nbit_split_embedding_codegen_lookup_function"): + output_dtype_int = n.kwargs["output_dtype"] + assert output_dtype_int < 2, "only support float16 and float32" + output_dtype = torch.float if output_dtype_int == 0 else torch.float16 + total_D = n.kwargs["total_D"] + + D_offsets_shape = self.env[n.kwargs["D_offsets"]].shape + offsets_shape = self.env[n.kwargs["offsets"]].shape + batches = (offsets_shape[0] - 1) // (D_offsets_shape[0] - 1) + result = torch.zeros((batches, total_D), dtype=output_dtype) + + elif op.find("acc_ops.embedding_bag"): + weight = self.env[n.kwargs["weight"]] + offsets_shape = self.env[n.kwargs["offsets"]].shape + batches = offsets_shape[0] - int(n.kwargs["include_last_offset"]) + output_dtype = weight.dtype + + embedding_size = weight.shape[1] + if op.endswith("acc_ops.embedding_bag_byte_rowwise_offsets"): + embedding_size -= 8 + # output dtype is hardcoded in https://fburl.com/code/unc4l6lj + output_dtype = torch.float32 + elif op.endswith("acc_ops.embedding_bag_4bit_rowwise_offsets"): + embedding_size = (embedding_size - 4) * 2 + # output dtype is hardcoded in https://fburl.com/code/434rkdtk + output_dtype = torch.float32 + + result = torch.zeros((batches, embedding_size), dtype=output_dtype) - # Restore indices - self.env[indices] = indices_tensor + else: + raise NotImplementedError( + f"The op {op} cannot be run with xl_weight(s) inputs" + ) return result diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index af7e27aa09..c3a5ad850e 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -3,12 +3,14 @@ import copy import inspect import logging +import operator import textwrap import warnings from types import FunctionType from typing import ( Any, Callable, + cast, Dict, Iterable, Optional, @@ -507,6 +509,78 @@ def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule): del node.meta["tensor_meta"] +def _replace_transpose_last_dims_impl( + transpose_node: torch.fx.Node, +) -> int: + transpose_input_node = transpose_node.args[0] + dim0 = cast(int, transpose_node.args[1]) + dim1 = cast(int, transpose_node.args[2]) + changed = False + + def _calculate_dim( + transpose_dim: Union[torch.fx.Node, int] + ) -> Union[torch.fx.Node, int]: + nonlocal transpose_input_node + nonlocal changed + if isinstance(transpose_dim, torch.fx.Node): + # Transpose dim is sub node + if not ( + transpose_dim.op == "call_function" + and transpose_dim.target == operator.sub + and len(transpose_dim.args) == 2 + ): + return transpose_dim + # Validity of length/subtracted int + len_node = transpose_dim.args[0] + sub_value = transpose_dim.args[1] + if not ( + isinstance(len_node, torch.fx.Node) + and len_node.target == len + and isinstance(sub_value, int) + ): + return transpose_dim + getattr_node = len_node.args[0] + # Check nodes for input.shape + if not ( + isinstance(getattr_node, torch.fx.Node) + and getattr_node.target == getattr + and len(getattr_node.args) == 2 + and getattr_node.args[0] == transpose_input_node + and getattr_node.args[1] == "shape" + ): + return transpose_dim + changed = True + rank = transpose_input_node.meta["tensor_rank"] + return rank - sub_value + return transpose_dim + + dim0 = _calculate_dim(dim0) + dim1 = _calculate_dim(dim1) + if changed: + with transpose_node.graph.inserting_before(transpose_node): + new_transpose_node = transpose_node.graph.call_method( + "transpose", (transpose_input_node, dim0, dim1) + ) + new_transpose_node.meta = transpose_node.meta.copy() + transpose_node.replace_all_uses_with(new_transpose_node) + return changed + + +# Allows mapping for transpose in the case where inputs are of the form x.transpose(a, b), +# where a and b are len(x.shape()) - n, where n is an int. In this case the inputs to transpose +# would be nodes rather than ints, so this replaces those nodes with their integral values +def _replace_transpose_last_dims(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if node.op == "call_method" and node.target == "transpose": + if len(node.args) != 3: + continue + changed = _replace_transpose_last_dims_impl(node) + if changed: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list): rewritten_graph, rewritten_mod = AccRewritingTracer().trace( mod, @@ -608,6 +682,9 @@ def trace( # Swap out tensor_meta for tensor_rank, because we don't actually want to rely on # tensor_meta yet for normalization/lowering, though rank shouldn't change. _replace_tensor_meta_with_rank(traced) + # Replace occurrences of x.transpose(len(x.shape) - a, len(x.shape) - b), where + # a and b are integers with their directly calculated dimensions + _replace_transpose_last_dims(traced) # Now normalize args/kwargs to make default values visible. Leave args/kwargs as # they were, since all-kwarg normalization is broken, and we don't need it anyway. traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py new file mode 100644 index 0000000000..5d81dec6b0 --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -0,0 +1,114 @@ +import copy +import sys +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union + +import torch +import torch._dynamo as torchdynamo +from torch._dynamo.guards import Guard +from typing_extensions import TypeAlias + +Value: TypeAlias = Union[ + Tuple["Value", ...], + List["Value"], + Dict[str, "Value"], +] + + +class DynamoConfig: + """ + Manage Exir-specific configurations of Dynamo. + """ + + def __init__( + self, + capture_scalar_outputs: bool = True, + guard_nn_modules: bool = True, + dynamic_shapes: bool = True, + specialize_int_float: bool = True, + verbose: bool = True, + ) -> None: + + self.capture_scalar_outputs = capture_scalar_outputs + self.guard_nn_modules = guard_nn_modules + self.dynamic_shapes = dynamic_shapes + self.specialize_int_float = specialize_int_float + self.verbose = verbose + + def activate(self) -> None: + torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs + torchdynamo.config.guard_nn_modules = self.guard_nn_modules + torchdynamo.config.dynamic_shapes = self.dynamic_shapes + torchdynamo.config.specialize_int_float = self.specialize_int_float + torchdynamo.config.verbose = self.verbose + + def deactivate(self) -> None: + torchdynamo.config.capture_scalar_outputs = True + torchdynamo.config.guard_nn_modules = True + torchdynamo.config.dynamic_shapes = True + torchdynamo.config.specialize_int_float = True + torchdynamo.config.verbose = True + + +@contextmanager +def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]: + config.activate() + try: + yield config + finally: + config.deactivate() + + +@contextmanager +def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: + """ + Temporarily increase the python interpreter stack recursion limit. + This is mostly used for pickling large scale modules. + """ + default = sys.getrecursionlimit() + if limit > default: + sys.setrecursionlimit(limit) + try: + yield + finally: + sys.setrecursionlimit(default) + + +def dynamo_trace( + f: Callable[..., Value], + # pyre-ignore + args: Tuple[Any, ...], + aten_graph: bool, + tracing_mode: str = "real", + dynamo_config: Optional[DynamoConfig] = None, +) -> Tuple[torch.fx.GraphModule, Set[Guard]]: + """ + TODO: Once we fully migrate to torchdynamo frontend, we will remove + this config option alltogether. For now, it helps with quick + experiments with playing around with TorchDynamo + """ + if dynamo_config is None: + dynamo_config = DynamoConfig() + with using_config(dynamo_config), setting_python_recursive_limit(2000): + torchdynamo.reset() + try: + return torchdynamo.export( + f, + *copy.deepcopy(args), + aten_graph=aten_graph, + tracing_mode=tracing_mode, + ) + except torchdynamo.exc.Unsupported as exc: + raise RuntimeError( + "The user code is using a feature we don't support. " + "Please try torchdynamo.explain() to get possible the reasons", + ) from exc + except Exception as exc: + raise RuntimeError( + "torchdynamo internal error occured. Please see above stacktrace" + ) from exc + + +def trace(f, args, *rest): + graph_module, guards = dynamo_trace(f, args, True, "symbolic") + return graph_module, guards