Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use a global timing cache and add a save option #2898

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def compile(
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -137,6 +138,7 @@ def compile(
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -220,6 +222,7 @@ def compile(
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -477,6 +480,7 @@ def convert_module_to_trt_engine(
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -532,7 +536,7 @@ def convert_module_to_trt_engine(
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT

timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -585,6 +589,7 @@ def convert_module_to_trt_engine(
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
}

exported_program = pre_export_lowering(exported_program, torch_inputs)
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
Expand Down Expand Up @@ -28,6 +31,7 @@
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -71,6 +72,7 @@ class CompilationSettings:
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
ouptut to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -101,3 +103,4 @@ class CompilationSettings:
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
53 changes: 34 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
Expand Down Expand Up @@ -44,7 +45,6 @@ class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
output_names: Sequence[str]
serialized_cache: bytearray


class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
Expand Down Expand Up @@ -276,30 +276,43 @@ def _populate_trt_builder_config(
def _create_timing_cache(
self,
builder_config: trt.IBuilderConfig,
existing_cache: Optional[trt.ITimingCache] = None,
) -> trt.ITimingCache:
cache = None
if existing_cache:
cache_file = np.array(existing_cache)
cache = builder_config.create_timing_cache(cache_file.tobytes())
else:
cache = builder_config.create_timing_cache(b"")
timing_cache_path: str = "",
) -> None:
"""
Create a timing cache to enable faster build time for TRT engines.
By default the timing_cache_path="/tmp/timing_cache.bin"
"""
buffer = b""
if os.path.isfile(timing_cache_path):
# Load from existing cache
with open(timing_cache_path, mode="rb") as timing_cache_file:
buffer = timing_cache_file.read()
cache = builder_config.create_timing_cache(buffer)
builder_config.set_timing_cache(cache, False)
return cache

def _save_timing_cache(
self,
builder_config: trt.IBuilderConfig,
timing_cache_path: str,
) -> None:
"""
This is called after a TensorRT engine is built. Save the timing cache
"""
timing_cache = builder_config.get_timing_cache()
with open(timing_cache_path, "wb") as timing_cache_file:
timing_cache_file.write(memoryview(timing_cache.serialize()))

def run(
self,
strict_type_constraints: bool = False,
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
existing_cache: Optional[trt.ITimingCache] = None,
tactic_sources: Optional[int] = None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Args:
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
algorithm_selector: set up algorithm selection for certain layer
existing_cache: enable timing cache for TensorRT
Return:
TRTInterpreterResult
"""
Expand All @@ -316,25 +329,27 @@ def run(
builder_config = self._populate_trt_builder_config(
strict_type_constraints, algorithm_selector, tactic_sources
)
timing_cache = self._create_timing_cache(builder_config, existing_cache)

self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
assert serialized_engine

serialized_cache = (
bytearray(timing_cache.serialize())
if builder_config.get_timing_cache()
else bytearray()
)
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

return TRTInterpreterResult(
serialized_engine, self._input_names, self._output_names, serialized_cache
serialized_engine, self._input_names, self._output_names
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
Loading