Skip to content

Commit

Permalink
Merge pull request #1152 from pytorch/trt_8.4ga
Browse files Browse the repository at this point in the history
feat: Upgrade TRT to 8.4
  • Loading branch information
peri044 authored Jul 23, 2022
2 parents 1625cd3 + 66c1cab commit 92e32aa
Show file tree
Hide file tree
Showing 31 changed files with 145 additions and 131 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
- Bazel 5.1.1
- Libtorch 1.11.0 (built with CUDA 11.3)
- CUDA 11.3
- cuDNN 8.2.1
- TensorRT 8.2.4.2
- cuDNN 8.4.1
- TensorRT 8.4.1.5

## Prebuilt Binaries and Wheel files

Expand Down
12 changes: 6 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,20 @@ http_archive(
http_archive(
name = "cudnn",
build_file = "@//third_party/cudnn/archive:BUILD",
sha256 = "0e5d2df890b9967efa6619da421310d97323565a79f05a1a8cb9b7165baad0d7",
strip_prefix = "cuda",
sha256 = "ec96d2376d81fca42bdd3d4c3d705a99b29a065bab57f920561c763e29c67d01",
strip_prefix = "cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive",
urls = [
"https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.2.4/11.4_20210831/cudnn-11.4-linux-x64-v8.2.4.15.tgz",
"https://developer.nvidia.com/compute/cudnn/secure/8.4.1/local_installers/11.6/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz",
],
)

http_archive(
name = "tensorrt",
build_file = "@//third_party/tensorrt/archive:BUILD",
sha256 = "826180eaaecdf9a7e76116855b9f1f3400ea9b06e66b06a3f6a0747ba6f863ad",
strip_prefix = "TensorRT-8.2.4.2",
sha256 = "8107861af218694130f170e071f49814fa3e27f1386ce7cb6d807ac05a7fcf0e",
strip_prefix = "TensorRT-8.4.1.5",
urls = [
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.4/tars/tensorrt-8.2.4.2.linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz",
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz",
],
)

Expand Down
36 changes: 14 additions & 22 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,6 @@ void MapInputsAndDetermineDTypes(
}
}

uint64_t GetRecommendedWorkspaceSize(const runtime::CudaDevice& device) {
if (device.major < 6) {
return 256 * (1 << 20);
} else {
return 1 << 30;
}
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
Expand All @@ -380,14 +372,14 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());

// GPU default WS size : 1 GB
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
if (workspace_size == 0) {
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
}
// // GPU default WS size : 1 GB
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
// auto device_spec = cfg.convert_info.engine_settings.device;
// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
// if (workspace_size == 0) {
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
// }

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);

Expand All @@ -399,14 +391,14 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");

// GPU default WS size : 1 GB
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
// // GPU default WS size : 1 GB
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
auto device_spec = cfg.convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
if (workspace_size == 0) {
cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
}
// if (workspace_size == 0) {
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
// }

for (const torch::jit::Method& method : mod.get_methods()) {
if (method.name().compare("forward") == 0) {
Expand Down
21 changes: 17 additions & 4 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
<< "\n Debuggable Engine: " << s.debug \
<< "\n GPU ID: " << s.device.gpu_id \
<< "\n Allow GPU Fallback (if running on DLA): " << s.device.allow_gpu_fallback \
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
<< "\n Max Workspace Size: " << s.workspace_size;
<< "\n Max Workspace Size: " << s.workspace_size \
<< "\n DLA SRAM Size: " << s.dla_sram_size \
<< "\n DLA Local DRAM Size: " << s.dla_local_dram_size \
<< "\n DLA Global DRAM Size: " << s.dla_global_dram_size;

os << "\n Device Type: " << s.device.device_type \
<< "\n GPU ID: " << s.device.gpu_id;
Expand Down Expand Up @@ -104,9 +106,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
}

cfg->setMinTimingIterations(settings.num_min_timing_iters);
cfg->setAvgTimingIterations(settings.num_avg_timing_iters);
cfg->setMaxWorkspaceSize(settings.workspace_size);
if (settings.workspace_size != 0){
cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, settings.workspace_size);
}

cfg->setDefaultDeviceType(settings.device.device_type);
cfg->setEngineCapability(settings.capability);

Expand All @@ -120,6 +124,15 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
"DLA supports only fp16 or int8 precision");
cfg->setDLACore(settings.device.dla_core);
if (settings.dla_sram_size != 1048576){
cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_MANAGED_SRAM, settings.dla_sram_size);
}
if (settings.dla_local_dram_size != 1073741824){
cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_LOCAL_DRAM, settings.dla_local_dram_size);
}
if (settings.dla_global_dram_size != 536870912){
cfg->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kDLA_GLOBAL_DRAM, settings.dla_global_dram_size);
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ struct BuilderSettings {
Device device;
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
uint64_t num_min_timing_iters = 2;
uint64_t num_avg_timing_iters = 1;
uint64_t workspace_size = 0;
uint64_t dla_sram_size = 1048576;
uint64_t dla_local_dram_size = 1073741824;
uint64_t dla_global_dram_size = 536870912;

BuilderSettings() = default;
BuilderSettings(const BuilderSettings& other) = default;
Expand Down
5 changes: 3 additions & 2 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv

auto id_layer = ctx->net->addIdentity(*tensor);
TORCHTRT_CHECK(id_layer, "Unable to create identity layer for ITensor: " << tensor_id.str());
auto casted_tensor = id_layer->getOutput(0);
casted_tensor->setType(dtype);
// layer->setOutputType should be used for casting and not manually setting output_tensor->setType()
id_layer->setOutputType(0, dtype);

auto casted_tensor = id_layer->getOutput(0);
LOG_DEBUG(ctx->logger, "Casting ITensor " << tensor_id.str() << " from " << tensor->getType() << " to " << dtype);

std::stringstream ss;
Expand Down
8 changes: 6 additions & 2 deletions cpp/bin/torchtrtc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ torchtrtc [input_file_path] [output_file_path]
serialized TensorRT engine and embed it
into a TorchScript module (device spec
must be provided)
--num-min-timing-iter=[num_iters] Number of minimization timing iterations
used to select kernels
--num-avg-timing-iters=[num_iters]
Number of averaging timing iterations
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
--dla-sram-size=[dla_sram_size] Fast software managed RAM used by DLA
to communicate within a layer.
--dla-local-dram-size=[dla_local_dram_size] Host RAM used by DLA to share
intermediate tensor data across operations.
--dla-global-dram-size=[dla_global_dram_size] Host RAM used by DLA to store
weights and metadata for execution
--atol=[atol] Absolute tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-8)
Expand Down
21 changes: 15 additions & 6 deletions cpp/bin/torchtrtc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,16 @@ int main(int argc, char** argv) {
"Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec must be provided)",
{"embed-engine"});

args::ValueFlag<uint64_t> num_min_timing_iters(
parser, "num_iters", "Number of minimization timing iterations used to select kernels", {"num-min-timing-iter"});
args::ValueFlag<uint64_t> num_avg_timing_iters(
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
args::ValueFlag<uint64_t> workspace_size(
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
args::ValueFlag<uint64_t> dla_sram_size(
parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"});
args::ValueFlag<uint64_t> dla_local_dram_size(
parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"});
args::ValueFlag<uint64_t> dla_global_dram_size(
parser, "dla_global_dram_size", "DLA Global DRAM size", {"dla-global-dram-size"});
args::ValueFlag<double> atol(
parser,
"atol",
Expand Down Expand Up @@ -325,6 +329,15 @@ int main(int argc, char** argv) {
if (dla_core) {
compile_settings.device.dla_core = args::get(dla_core);
}
if (dla_sram_size) {
compile_settings.dla_sram_size = args::get(dla_sram_size);
}
if (dla_local_dram_size) {
compile_settings.dla_local_dram_size = args::get(dla_local_dram_size);
}
if (dla_global_dram_size) {
compile_settings.dla_global_dram_size = args::get(dla_global_dram_size);
}
} else {
torchtrt::logging::log(
torchtrt::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ] found: " + device);
Expand Down Expand Up @@ -352,10 +365,6 @@ int main(int argc, char** argv) {
}
}

if (num_min_timing_iters) {
compile_settings.num_min_timing_iters = args::get(num_min_timing_iters);
}

if (num_avg_timing_iters) {
compile_settings.num_avg_timing_iters = args::get(num_avg_timing_iters);
}
Expand Down
19 changes: 15 additions & 4 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,6 @@ struct TORCHTRT_API CompileSpec {
*/
EngineCapability capability = EngineCapability::kSTANDARD;

/**
* Number of minimization timing iterations used to select kernels
*/
uint64_t num_min_timing_iters = 2;
/**
* Number of averaging timing iterations used to select kernels
*/
Expand All @@ -650,6 +646,21 @@ struct TORCHTRT_API CompileSpec {
*/
uint64_t workspace_size = 0;

/**
* Fast software managed RAM used by DLA to communicate within a layer.
*/
uint64_t dla_sram_size = 1048576;

/**
* Host RAM used by DLA to share intermediate tensor data across operations
*/
uint64_t dla_local_dram_size = 1073741824;

/**
* host RAM used by DLA to store weights and metadata for execution
*/
uint64_t dla_global_dram_size = 536870912;

/**
* Calibration dataloaders for each input for post training quantizatiom
*/
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {

internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;
internal.convert_info.engine_settings.device.dla_core = external.device.dla_core;
internal.convert_info.engine_settings.num_min_timing_iters = external.num_min_timing_iters;
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;
internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size;
internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size;

if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
internal.convert_info.engine_settings.enabled_precisions.end()) {
Expand Down
2 changes: 0 additions & 2 deletions docsrc/tutorials/ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ Then all thats required to setup the module for INT8 calibration is to set the f
compile_spec.enabled_precisions.insert(torch::kI8);
/// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;
/// Set a larger workspace (you may get better performace from doing so)
compile_spec.workspace_size = 1 << 28;

auto trt_mod = torch_tensorrt::CompileGraph(mod, compile_spec);

Expand Down
8 changes: 6 additions & 2 deletions docsrc/tutorials/torchtrtc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
serialized TensorRT engine and embed it
into a TorchScript module (device spec
must be provided)
--num-min-timing-iter=[num_iters] Number of minimization timing iterations
used to select kernels
--num-avg-timing-iters=[num_iters]
Number of averaging timing iterations
used to select kernels
--workspace-size=[workspace_size] Maximum size of workspace given to
TensorRT
--dla-sram-size=[dla_sram_size] Fast software managed RAM used by DLA
to communicate within a layer.
--dla-local-dram-size=[dla_local_dram_size] Host RAM used by DLA to share
intermediate tensor data across operations.
--dla-global-dram-size=[dla_global_dram_size] Host RAM used by DLA to store
weights and metadata for execution
--atol=[atol] Absolute tolerance threshold for acceptable
numerical deviation from standard torchscript
output (default 1e-8)
Expand Down
1 change: 0 additions & 1 deletion docsrc/tutorials/use_from_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ at the documentation for the Torch-TensorRT ``TensorRTCompileSpec`` API.
"allow_gpu_fallback": True
},
"capability": torch_tensorrt.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
})
}
Expand Down
3 changes: 0 additions & 3 deletions docsrc/tutorials/using_dla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ Using DLA in a C++ application
# If a layer fails to run on DLA it will fallback to GPU
compile_spec.device.allow_gpu_fallback = true;

# Set the workspace size
compile_spec.workspace_size = 1 << 28;


Using DLA in a python application

Expand Down
2 changes: 0 additions & 2 deletions examples/int8/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
compile_spec.enabled_precisions.insert(torch::kI8);
/// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;
/// Set a larger workspace
compile_spec.workspace_size = 1 << 28;

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
Expand Down
2 changes: 0 additions & 2 deletions examples/int8/qat/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ torch::jit::Module compile_int8_qat_model(const std::string& data_dir, torch::ji
auto compile_spec = torch_tensorrt::ts::CompileSpec(inputs);
/// Set operating precision to INT8
compile_spec.enabled_precisions.insert(torch::kI8);
/// Set a larger workspace
compile_spec.workspace_size = 1 << 28;

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
Expand Down
5 changes: 3 additions & 2 deletions py/torch_tensorrt/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, refit);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, debug);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, capability);
ADD_FIELD_GET_SET_REGISTRATION(
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_min_timing_iters);
ADD_FIELD_GET_SET_REGISTRATION(
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_avg_timing_iters);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, workspace_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_sram_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_local_dram_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_global_dram_size);
ADD_FIELD_GET_SET_REGISTRATION(
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
}
Expand Down
Loading

0 comments on commit 92e32aa

Please sign in to comment.