diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 52ca551142..c697979a43 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -3,9 +3,11 @@ import torch_tensorrt.ts from torch_tensorrt import logging import torch -from torch import fx +import torch.fx from enum import Enum -from torch_tensorrt import fx +import torch_tensorrt.fx +from torch_tensorrt.fx.lower import lower_to_trt +from torch_tensorrt.fx.utils import LowerPrecision class _IRType(Enum): """Enum to set the minimum required logging level to print a message to stdout @@ -108,78 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums ts_mod = torch.jit.script(module) return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) elif target_ir == _IRType.fx: - from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - from torch_tensorrt.fx import InputTensorSpec - from torch_tensorrt.fx import TRTInterpreter - from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem - from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter - from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting - from torch_tensorrt.fx.trt_module import TRTModule - from torch_tensorrt.fx.utils import LowerPrecision - acc_model = acc_tracer.trace(module, inputs) - - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) - splitter.node_support_preview() - split_mod = splitter() - num_piece = 0 - for name, _ in split_mod.named_children(): - print(f"graph is split into {name}") - num_piece += 1 - - # if the graph module is split into pieces larger than 8, we consider its perf - # is not good and fall back to non-TRT - if num_piece > 8: - print( - f"The graph module is split into {num_piece} which is large than the \ - threshold=8. Fall back to non-TRT module." - ) - return None - - if torch.float16 in enabled_precisions or torch.half in enabled_precisions: - precision = LowerPrecision.FP16 + if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions: + lower_precision = LowerPrecision.FP16 + elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions: + lower_precision = LowerPrecision.FP32 else: - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - acc_inputs = None - - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in split_mod.named_children(): - if "_run_on_acc" in name: - submod = getattr(split_mod, name) - # Get submodule inputs for fx2trt - acc_inputs = get_submod_inputs(split_mod, submod, inputs) - - # fx2trt replacement - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - ) - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile - ) - # For profile - # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module - # profile_trt_module("", trt_mod, acc_inputs) - trt_mod = TRTModule(*r) - - setattr(split_mod, name, trt_mod) - else: - submod = getattr(split_mod, name) - return split_mod + raise ValueError(f"Precision {enabled_precisions} not supported on FX") + + return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/example/fx2trt_example.py b/py/torch_tensorrt/fx/example/fx2trt_example.py index 8c648ec065..b9fbc05f17 100644 --- a/py/torch_tensorrt/fx/example/fx2trt_example.py +++ b/py/torch_tensorrt/fx/example/fx2trt_example.py @@ -3,11 +3,11 @@ import torch import torch.fx import torch.nn as nn +from torch_tensorrt.fx.utils import LowerPrecision import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter - # The purpose of this example is to demonstrate the overall flow of lowering a PyTorch # model to TensorRT via FX with existing FX based tooling. The general lowering flow # would be like: @@ -30,11 +30,12 @@ def forward(self, x): x = self.linear(x) x = self.relu(x) x = torch.linalg.norm(x, ord=2, dim=1) + x = self.relu(x) return x -inputs = [torch.randn(1, 10)] -model = Model().eval() +inputs = [torch.randn((1, 10), device=torch.device('cuda'))] +model = Model().cuda().eval() # acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators # to acc ops. @@ -64,20 +65,23 @@ def forward(self, x): # Split. split_mod = splitter() -# After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1. +# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1. print(split_mod.graph) """ graph(): %x : [#users=1] = placeholder[target=x] %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {}) %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {}) - return _run_on_gpu_1 + %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {}) + return _run_on_acc_2 """ # Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while -# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. +# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3 +# is the another submodule supported. print(split_mod._run_on_acc_0.graph) print(split_mod._run_on_gpu_1.graph) +print(split_mod._run_on_acc_2.graph) """ graph(): %x : [#users=1] = placeholder[target=x] @@ -90,32 +94,51 @@ def forward(self, x): %relu_1 : [#users=1] = placeholder[target=relu_1] %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ... return linalg_norm_1 +graph(): + %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1] + %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False}) + return relu_3 """ -# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered, -# we can skip the splitter part. -interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs)) -r = interp.run() -trt_mod = TRTModule(r.engine, r.input_names, r.output_names) -split_mod._run_on_acc_0 = trt_mod - -cuda_inputs = [input.cuda() for input in inputs] -split_mod.cuda() -lowered_model_output = split_mod(*cuda_inputs) +def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + +# Since the model is splitted into three segments. We need to lower each TRT eligible segment. +# If we know the model can be fully lowered, we can skip the splitter part. +for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + # Get submodule inputs for fx2trt + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + # fx2trt replacement + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + ) + r = interp.run(lower_precision=LowerPrecision.FP32) + trt_mod = TRTModule(*r) + setattr(split_mod, name, trt_mod) + +lowered_model_output = split_mod(*inputs) + +# Save and load model +torch.save(split_mod, "trt.pt") +reload_trt_mod = torch.load("trt.pt") +reload_model_output = reload_trt_mod(*inputs) # Make sure the results match -model.cuda() -regular_model_output = model(*cuda_inputs) +regular_model_output = model(*inputs) torch.testing.assert_close( - lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2 + reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2 ) - -# We can utilize the trt profiler to print out the time spend on each layer. -trt_mod.enable_profiling() -trt_mod(*cuda_inputs) -""" -Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms -LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms -PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms -""" -trt_mod.disable_profiling() diff --git a/py/torch_tensorrt/fx/example/lower_example.py b/py/torch_tensorrt/fx/example/lower_example.py index b93e93598e..71f15a2f88 100644 --- a/py/torch_tensorrt/fx/example/lower_example.py +++ b/py/torch_tensorrt/fx/example/lower_example.py @@ -198,6 +198,6 @@ def run_configuration_benchmark( if __name__ == "__main__": - test_model = torchvision.models.resnet101() - input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined] - benchmark(test_model, input, 100, 1024) + test_model = torchvision.models.resnet18(pretrained=True) + input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined] + benchmark(test_model, input, 50, 128) diff --git a/py/torch_tensorrt/fx/example/test_fx2trt.py b/py/torch_tensorrt/fx/example/test_fx2trt.py deleted file mode 100644 index effc188e7a..0000000000 --- a/py/torch_tensorrt/fx/example/test_fx2trt.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch_tensorrt - - -class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 3) - self.relu = torch.nn.functional.relu - - def forward(self, x): - x = self.linear(x) - x = self.relu(x) - return x - - -model = MyModel().eval() # torch module needs to be in eval (not training) mode - -# torch tensorrt -inputs = [ - torch_tensorrt.Input( - (2, 5), - dtype=torch.half, - ) -] -enabled_precisions = {torch.float, torch.half} # Run with fp16 - -trt_ts_module = torch_tensorrt.compile( - model, inputs=inputs, enabled_precisions=enabled_precisions -) - -inputs_ts = [torch.ones(2, 5)] -inputs_ts = [i.cuda().half() for i in inputs_ts] -result = trt_ts_module(*inputs_ts) -print(result) - -model.cuda().half() -ref = model(*inputs_ts) -print(ref) - -# fx2trt -inputs_fx = [torch.ones((2, 5))] - -model.cuda().half() -inputs_fx = [i.cuda().half() for i in inputs_fx] - -trt_fx_module = torch_tensorrt.compile( - model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half} -) -result = trt_fx_module(*inputs_fx) -print(result) - -ref = model(*inputs_fx) -print(ref) diff --git a/py/torch_tensorrt/fx/example/torch_trt_simple_example.py b/py/torch_tensorrt/fx/example/torch_trt_simple_example.py new file mode 100644 index 0000000000..a6dd732c84 --- /dev/null +++ b/py/torch_tensorrt/fx/example/torch_trt_simple_example.py @@ -0,0 +1,57 @@ +import torch +import copy +import torchvision +import torch_tensorrt +from torch_tensorrt.fx import InputTensorSpec + + +def test_torch_tensorrt(model, inputs): + # torchscript path + model_ts = copy.deepcopy(model) + inputs_ts = copy.deepcopy(inputs) + # fp32 test + with torch.inference_mode(): + ref_fp32 = model_ts(*inputs_ts) + trt_ts_module = torch_tensorrt.compile( + model_ts, inputs=inputs_ts, enabled_precisions={torch.float32} + ) + result_fp32 = trt_ts_module(*inputs_ts) + assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) + # fp16 test + model_ts = model_ts.half() + inputs_ts = [i.cuda().half() for i in inputs_ts] + with torch.inference_mode(): + ref_fp16 = model_ts(*inputs_ts) + trt_ts_module = torch_tensorrt.compile( + model_ts, inputs=inputs_ts, enabled_precisions={torch.float16} + ) + result_fp16 = trt_ts_module(*inputs_ts) + assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99) + + # FX path + model_fx = copy.deepcopy(model) + inputs_fx = copy.deepcopy(inputs) + # fp32 test + with torch.inference_mode(): + ref_fp32 = model_fx(*inputs_fx) + trt_fx_module = torch_tensorrt.compile( + model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32} + ) + result_fp32 = trt_fx_module(*inputs_fx) + assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) + # fp16 test + model_fx = model_fx.cuda().half() + inputs_fx = [i.cuda().half() for i in inputs_fx] + with torch.inference_mode(): + ref_fp16 = model_fx(*inputs_fx) + trt_fx_module = torch_tensorrt.compile( + model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16} + ) + result_fp16 = trt_fx_module(*inputs_fx) + assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 ) + + +if __name__ == "__main__": + model = torchvision.models.resnet18(pretrained=True).cuda().eval() + inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined] + test_torch_tensorrt(model, inputs) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 16da30575f..29b1490586 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -164,6 +164,21 @@ def run( timing_cache=None, profiling_verbosity=None, ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + max_batch_size: set accordingly for maximum batch size you will use. + max_workspace_size: set to the maximum size we can afford for temporary buffer + lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). + sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity + force_fp32_output: force output to be fp32 + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + timing_cache: enable timing cache for TensorRT + profiling_verbosity: TensorRT logging level + Return: + TRTInterpreterResult + """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 10b56f31b4..763ffdc653 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -42,6 +42,7 @@ def lower_to_trt( timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, + dynamic_batch=False, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -71,6 +72,7 @@ def lower_to_trt( timing_cache_prefix=timing_cache_prefix, save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, + dynamic_batch=dynamic_batch, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) @@ -102,11 +104,10 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: ), self.lower_setting.opt_profile_replica, ) - if self.lower_setting.explicit_batch_dimension + if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch else InputTensorSpec.from_tensors(input) ) ) - # Prepare algorithm selector and timing_cache for TRTInterpreter algo_selector = None if self.lower_setting.algo_selector: diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 78d4e3a2e9..6695c8ff85 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -71,6 +71,7 @@ class LowerSetting(LowerSettingBasic): to add a preset. opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is only used by explicit batch dim with dynamic shape mode. + dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -89,3 +90,4 @@ class LowerSetting(LowerSettingBasic): cuda_graph_batch_size: int = -1 preset_lowerer: str = "" opt_profile_replica: int = 1 + dynamic_batch: bool = False