Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] refactor the fx path in compile function #1141

Merged
merged 14 commits into from
Jun 28, 2022
84 changes: 11 additions & 73 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import torch_tensorrt.ts
from torch_tensorrt import logging
import torch
from torch import fx
import torch.fx
from enum import Enum
from torch_tensorrt import fx
import torch_tensorrt.fx
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision

class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout
Expand Down Expand Up @@ -108,78 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
elif target_ir == _IRType.fx:
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
acc_model = acc_tracer.trace(module, inputs)

splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_mod = splitter()
num_piece = 0
for name, _ in split_mod.named_children():
print(f"graph is split into {name}")
num_piece += 1

# if the graph module is split into pieces larger than 8, we consider its perf
# is not good and fall back to non-TRT
if num_piece > 8:
print(
f"The graph module is split into {num_piece} which is large than the \
threshold=8. Fall back to non-TRT module."
)
return None

if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
precision = LowerPrecision.FP16
if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions:
lower_precision = LowerPrecision.FP16
elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions:
lower_precision = LowerPrecision.FP32
else:
precision = LowerPrecision.FP32

def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(
max_workspace_size=20 << 30,
lower_precision=precision,
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

setattr(split_mod, name, trt_mod)
else:
submod = getattr(split_mod, name)
return split_mod
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,6 @@ def run_configuration_benchmark(


if __name__ == "__main__":
test_model = torchvision.models.resnet101()
input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 100, 1024)
test_model = torchvision.models.resnet18(pretrained=True)
frank-wei marked this conversation as resolved.
Show resolved Hide resolved
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 50, 128)
54 changes: 0 additions & 54 deletions py/torch_tensorrt/fx/example/test_fx2trt.py

This file was deleted.

57 changes: 57 additions & 0 deletions py/torch_tensorrt/fx/example/torch_trt_simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import copy
import torchvision
import torch_tensorrt
from torch_tensorrt.fx import InputTensorSpec


def test_torch_tensorrt(model, inputs):
# torchscript path
model_ts = copy.deepcopy(model)
inputs_ts = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32}
)
result_fp32 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_ts = model_ts.half()
inputs_ts = [i.cuda().half() for i in inputs_ts]
with torch.inference_mode():
ref_fp16 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16}
)
result_fp16 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99)

# FX path
model_fx = copy.deepcopy(model)
inputs_fx = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32}
)
result_fp32 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_fx = model_fx.cuda().half()
inputs_fx = [i.cuda().half() for i in inputs_fx]
with torch.inference_mode():
ref_fp16 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16}
)
result_fp16 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 )


if __name__ == "__main__":
model = torchvision.models.resnet18(pretrained=True).cuda().eval()
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined]
test_torch_tensorrt(model, inputs)
6 changes: 4 additions & 2 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def lower_to_trt(
timing_cache_prefix="",
save_timing_cache=False,
cuda_graph_batch_size=-1,
dynamic_batch=True,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
Expand Down Expand Up @@ -71,6 +72,7 @@ def lower_to_trt(
timing_cache_prefix=timing_cache_prefix,
save_timing_cache=save_timing_cache,
cuda_graph_batch_size=cuda_graph_batch_size,
dynamic_batch=dynamic_batch,
)
lowerer = Lowerer.create(lower_setting=lower_setting)
return lowerer(module, input)
Expand Down Expand Up @@ -100,12 +102,12 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
self.lower_setting.max_batch_size,
self.lower_setting.max_batch_size,
),
self.lower_setting.opt_profile_replica,
)
if self.lower_setting.explicit_batch_dimension
if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamic_batch is added to differentiate two cases: with or w/o dynamic shape on batch dim (dim=0). cc @wushirong. I keep the dynamic_batch=True as default value so it will not change the previous behavior in production. Please have a review.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use case of w/o dyanmic shape on batch dim with explicit_batch_dimension=True? What's the different in terms of behavior in TensorRT? Basically, if I have explicit_batch_dimension=True while all my input dims are positive, how does TRT interprets it?

Maybe a question to @narendasan too.

Copy link
Contributor Author

@frank-wei frank-wei Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my understanding, firstly, explicit_batch_dimension=True will become the default in next year and there is no explicit_batch_dimension=False(implicit) mode.
Secondly,

if I have explicit_batch_dimension=True while all my input dims are positive, how does TRT interprets it?

TRT will treat it as fixed shape for any future input. And that is what I tested for all the torchdynamo benchmarks

else InputTensorSpec.from_tensors(input)
)
)

# Prepare algorithm selector and timing_cache for TRTInterpreter
algo_selector = None
if self.lower_setting.algo_selector:
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class LowerSetting(LowerSettingBasic):
how presets are applied. Refer to
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
to add a preset.
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
only used by explicit batch dim with dynamic shape mode.
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -86,3 +88,5 @@ class LowerSetting(LowerSettingBasic):
save_timing_cache: bool = False
cuda_graph_batch_size: int = -1
preset_lowerer: str = ""
opt_profile_replica: int = 1
dynamic_batch: bool = True