Skip to content

Commit

Permalink
fix: Fix .numpy() issue on fake tensors
Browse files Browse the repository at this point in the history
- Add `fake_tensor_unsupported` decorator to helper backend
- Refactor `conversion` implementation to use compilation settings
object as well, to reduce code duplication and encourage reuse
- Improve debugger messages by pre-formatting support string
  • Loading branch information
gs-olive committed May 24, 2023
1 parent c60070b commit 6d65cbf
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def compile(
torch_executed_modules=[],
**kwargs,
):
if debug:
logger.setLevel(logging.DEBUG)

logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def aot_torch_tensorrt_aten_backend(
)


@fake_tensor_unsupported
def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand Down
18 changes: 7 additions & 11 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,41 @@
import torch
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.fx.fx2trt import (
InputTensorSpec,
TRTInterpreter,
)
from torch_tensorrt.fx.utils import LowerPrecision

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
debug: bool = False,
workspace_size: int = 20 << 30,
precision: LowerPrecision = LowerPrecision.FP32,
settings: CompilationSettings = CompilationSettings(),
) -> Union[TRTModuleNext, TRTModule]:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
debug: Whether to print out verbose debugging information
workspace_size: Maximum workspace TRT is allowed to use for the module
precision: Model Layer precision
settings: Compilation settings
Returns:
TRTModule or TRTModuleNext
"""
interp = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
)

r = interp.run(
max_workspace_size=workspace_size,
lower_precision=precision,
max_workspace_size=settings.workspace_size,
lower_precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
if debug
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
)
Expand Down
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
)

logger.debug("\nSupported Nodes:")
# Reformat support messages for debugger to print node overview as a single string
supported_nodes_str = "\nSupported Nodes:\n"
for node_name in self.supported_operators:
logger.debug("-", node_name)
supported_nodes_str += f"- {node_name}\n"

logger.debug(supported_nodes_str)

if len(self.unsupported_operators) != 0:
logger.debug("\nUnsupported or Excluded Nodes:")
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
for node_name in self.unsupported_operators:
logger.debug("-", node_name)
logger.debug("\n")
unsupported_nodes_str += f"- {node_name}\n"
logger.debug(unsupported_nodes_str)
else:
logger.debug("\nAll Nodes Supported\n")

Expand Down

0 comments on commit 6d65cbf

Please sign in to comment.