From b8fa228777f196143220e180266b3a302763fc0b Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 30 Sep 2020 16:42:05 -0700 Subject: [PATCH] refactor!: Renaming extra info to compile spec to be more consistent with other backends and between APIs in TRTorch BREAKING CHANGE: This changes the top level api for setting the specification for compilation, a simple find and replace should allow users to port forward Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .bazelrc | 1 + README.md | 4 +- core/compiler.cpp | 12 ++-- core/compiler.h | 8 +-- .../conversionctx/ConversionCtx.cpp | 2 +- cpp/api/BUILD | 2 +- cpp/api/README.md | 16 ++--- cpp/api/include/trtorch/ptq.h | 4 +- cpp/api/include/trtorch/trtorch.h | 16 ++--- .../src/{extra_info.cpp => compile_spec.cpp} | 40 +++++------ cpp/api/src/trtorch.cpp | 12 ++-- cpp/benchmark/main.cpp | 10 +-- cpp/ptq/README.md | 12 ++-- cpp/ptq/main.cpp | 14 ++-- cpp/trtorchc/main.cpp | 22 +++--- cpp/trtorchexec/main.cpp | 8 +-- docsrc/tutorials/getting_started.rst | 10 +-- docsrc/tutorials/ptq.rst | 12 ++-- py/BUILD | 2 +- .../{_extra_info.py => _compile_spec.py} | 70 +++++++++---------- py/trtorch/_compiler.py | 18 ++--- py/trtorch/csrc/trtorch_py.cpp | 48 +++++++------ tests/accuracy/test_fp16_accuracy.cpp | 6 +- tests/accuracy/test_fp32_accuracy.cpp | 6 +- tests/accuracy/test_int8_accuracy.cpp | 12 ++-- tests/modules/test_serialization.cpp | 6 +- tests/py/test_api.py | 8 +-- 27 files changed, 194 insertions(+), 187 deletions(-) rename cpp/api/src/{extra_info.cpp => compile_spec.cpp} (73%) rename py/trtorch/{_extra_info.py => _compile_spec.py} (64%) diff --git a/.bazelrc b/.bazelrc index 4a2a7423df..0a89848888 100644 --- a/.bazelrc +++ b/.bazelrc @@ -29,6 +29,7 @@ build --cxxopt='-std=c++14' build:python --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" build:python --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0" build:python --define=abi=pre_cxx11_abi +build:python --define=target_lang=python build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0" diff --git a/README.md b/README.md index c9c6c90ccf..61f71fa493 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ More Information / System Architecture: #include "trtorch/trtorch.h" ... -auto compile_settings = trtorch::ExtraInfo(dims); +auto compile_settings = trtorch::CompileSpec(dims); // FP16 execution compile_settings.op_precision = torch::kFloat; // Compile module @@ -54,7 +54,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") ``` > Notes on running in lower precisions: -> - Set precision with extra_info.op_precision +> - Set precision with compile_spec.op_precision > - The module should be left in FP32 before compilation (FP16 can support half tensor models) > - In FP16 only input tensors should be converted to FP16, other precisions use FP32 diff --git a/core/compiler.cpp b/core/compiler.cpp index d45dcf4a38..099651a1fa 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -20,7 +20,7 @@ #include "core/lowering/lowering.h" #include "core/conversion/conversion.h" -#include "core/execution/execution.h" +#include "core/runtime/runtime.h" namespace trtorch { namespace core { @@ -42,7 +42,7 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr& g, std::string& serialized_engine) { - auto engine_ptr = c10::make_intrusive(mod._ivalue()->name(), serialized_engine); + auto engine_ptr = c10::make_intrusive(mod._ivalue()->name(), serialized_engine); // Get required metadata about the engine out auto num_io = engine_ptr->num_io; auto name = engine_ptr->name; @@ -50,7 +50,7 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr>(), + c10::getCustomClassType>(), c10::IValue(std::move(engine_ptr)), false ); @@ -125,7 +125,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, - ExtraInfo cfg) { + CompileSpec cfg) { // Go through Lowering to simplify graph and extract weight parameters auto graph_and_parameters = lowering::Lower(mod, method_name); @@ -137,12 +137,12 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, LOG_INFO(*g << "(CompileGraph)\n"); - auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params); + auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params); return std::move(engine); } torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, - ExtraInfo cfg) { + CompileSpec cfg) { // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling //torch::jit::script::Module new_mod = mod.clone(); diff --git a/core/compiler.h b/core/compiler.h index f9ff400159..281973d4d6 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -7,8 +7,8 @@ namespace trtorch { namespace core { -struct ExtraInfo { - ExtraInfo(std::vector input_ranges) +struct CompileSpec { + CompileSpec(std::vector input_ranges) : convert_info(std::move(input_ranges)) {} conversion::ConversionInfo convert_info; }; @@ -16,9 +16,9 @@ struct ExtraInfo { bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name); std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, - std::string method_name, ExtraInfo cfg); + std::string method_name, CompileSpec cfg); -torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg); +torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg); } // namespace core } // namespace trtorch diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 2993ee593e..3280464635 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -55,7 +55,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) cfg->setFlag(nvinfer1::BuilderFlag::kFP16); } input_type = nvinfer1::DataType::kFLOAT; - TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator"); + TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator"); cfg->setInt8Calibrator(settings.calibrator); break; case nvinfer1::DataType::kFLOAT: diff --git a/cpp/api/BUILD b/cpp/api/BUILD index d396d1690a..18ce5b8118 100644 --- a/cpp/api/BUILD +++ b/cpp/api/BUILD @@ -9,7 +9,7 @@ cc_library( "include/trtorch/ptq.h" ], srcs = [ - "src/extra_info.cpp", + "src/compile_spec.cpp", "src/logging.cpp", "src/trtorch.cpp", "src/ptq.cpp" diff --git a/cpp/api/README.md b/cpp/api/README.md index ab1bf03cfe..4bdbae379b 100644 --- a/cpp/api/README.md +++ b/cpp/api/README.md @@ -31,7 +31,7 @@ namespace trtorch { * Settings data structure for TRTorch compilation * */ -struct TRTORCH_API ExtraInfo { +struct TRTORCH_API CompileSpec { /** * @brief A struct to hold an input range (used by TensorRT Optimization profile) * @@ -132,10 +132,10 @@ struct TRTORCH_API ExtraInfo { kSAFE_DLA, }; - ExtraInfo(std::vector input_ranges) + CompileSpec(std::vector input_ranges) : input_ranges(std::move(input_ranges)) {} - ExtraInfo(std::vector> fixed_sizes); - ExtraInfo(std::vector> fixed_sizes); + CompileSpec(std::vector> fixed_sizes); + CompileSpec(std::vector> fixed_sizes); // Defaults should reflect TensorRT defaults for BuilderConfig @@ -236,27 +236,27 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT * * @param module: torch::jit::script::Module - Existing TorchScript module - * @param info: trtorch::ExtraInfo - Compilation settings + * @param info: trtorch::CompileSpec - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert methods to JIT Graphs which call equivalent TensorRT engines * * Converts specifically the forward method of a TorchScript Module */ -TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info); +TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info); /** * @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT * * @param module: torch::jit::script::Module - Existing TorchScript module * @param method_name: std::string - Name of method to compile - * @param info: trtorch::ExtraInfo - Compilation settings + * @param info: trtorch::CompileSpec - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert selected method to a serialized TensorRT engine which can be run with * TensorRT */ -TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info); +TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, CompileSpec info); namespace ptq { /** diff --git a/cpp/api/include/trtorch/ptq.h b/cpp/api/include/trtorch/ptq.h index 05f3583947..4932218405 100644 --- a/cpp/api/include/trtorch/ptq.h +++ b/cpp/api/include/trtorch/ptq.h @@ -145,7 +145,7 @@ class Int8Calibrator : Algorithm { /** * @brief operator to cast to nvinfer1::IInt8Calibrator* * - * Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo + * Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec * * @return nvinfer1::IInt8Calibrator* */ @@ -259,7 +259,7 @@ class Int8CacheCalibrator : Algorithm { /** * @brief operator to cast to nvinfer1::IInt8Calibrator* * - * Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo + * Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec * * @return nvinfer1::IInt8Calibrator* */ diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 8e2757ad3b..cf8bd9e329 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -39,7 +39,7 @@ namespace trtorch { * Settings data structure for TRTorch compilation * */ -struct TRTORCH_API ExtraInfo { +struct TRTORCH_API CompileSpec { /** * @brief A struct to hold an input range (used by TensorRT Optimization profile) * @@ -256,7 +256,7 @@ struct TRTORCH_API ExtraInfo { * * @param input_ranges */ - ExtraInfo(std::vector input_ranges) + CompileSpec(std::vector input_ranges) : input_ranges(std::move(input_ranges)) {} /** * @brief Construct a new Extra Info object @@ -265,14 +265,14 @@ struct TRTORCH_API ExtraInfo { * * @param fixed_sizes */ - ExtraInfo(std::vector> fixed_sizes); + CompileSpec(std::vector> fixed_sizes); /** * @brief Construct a new Extra Info object * Convienence constructor to set fixed input size from c10::ArrayRef's (the output of tensor.sizes()) describing size of input tensors. * Each entry in the vector represents a input and should be provided in call order. * @param fixed_sizes */ - ExtraInfo(std::vector> fixed_sizes); + CompileSpec(std::vector> fixed_sizes); // Defaults should reflect TensorRT defaults for BuilderConfig @@ -379,7 +379,7 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT * * @param module: torch::jit::Module - Existing TorchScript module - * @param info: trtorch::ExtraInfo - Compilation settings + * @param info: trtorch::CompileSpec - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert methods to JIT Graphs which call equivalent TensorRT engines @@ -388,14 +388,14 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st * * @return: A new module trageting a TensorRT engine */ -TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info); +TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info); /** * @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT * * @param module: torch::jit::Module - Existing TorchScript module * @param method_name: std::string - Name of method to compile - * @param info: trtorch::ExtraInfo - Compilation settings + * @param info: trtorch::CompileSpec - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert selected method to a serialized TensorRT engine which can be run with @@ -403,5 +403,5 @@ TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, Ex * * @return: std::string: Serialized TensorRT engine equivilant to the method graph */ -TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info); +TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, CompileSpec info); } // namespace trtorch diff --git a/cpp/api/src/extra_info.cpp b/cpp/api/src/compile_spec.cpp similarity index 73% rename from cpp/api/src/extra_info.cpp rename to cpp/api/src/compile_spec.cpp index 5bc12fa204..bfec3e7ba7 100644 --- a/cpp/api/src/extra_info.cpp +++ b/cpp/api/src/compile_spec.cpp @@ -6,7 +6,7 @@ #include "trtorch/trtorch.h" namespace trtorch { -ExtraInfo::DataType::DataType(c10::ScalarType t) { +CompileSpec::DataType::DataType(c10::ScalarType t) { TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported"); switch (t) { case at::kHalf: @@ -21,52 +21,52 @@ ExtraInfo::DataType::DataType(c10::ScalarType t) { } } -ExtraInfo::DeviceType::DeviceType(c10::DeviceType t) { +CompileSpec::DeviceType::DeviceType(c10::DeviceType t) { TRTORCH_CHECK(t == at::kCUDA, "Device type when specified using torch device enum must be torch::kCUDA"); value = DeviceType::kGPU; } -ExtraInfo::InputRange::InputRange(std::vector opt) { +CompileSpec::InputRange::InputRange(std::vector opt) { this->opt = opt; this->min = opt; this->max = opt; } -ExtraInfo::InputRange::InputRange(c10::IntArrayRef opt) { +CompileSpec::InputRange::InputRange(c10::IntArrayRef opt) { this->opt = core::util::toVec(opt); this->min = core::util::toVec(opt); this->max = core::util::toVec(opt); } -ExtraInfo::InputRange::InputRange(std::vector min, std::vector opt, std::vector max) { +CompileSpec::InputRange::InputRange(std::vector min, std::vector opt, std::vector max) { this->opt = opt; this->min = min; this->max = max; } -ExtraInfo::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) { +CompileSpec::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) { this->opt = core::util::toVec(opt); this->min = core::util::toVec(min); this->max = core::util::toVec(max); } -ExtraInfo::ExtraInfo(std::vector> fixed_sizes) { +CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { input_ranges.push_back(InputRange(in)); } } -ExtraInfo::ExtraInfo(std::vector> fixed_sizes) { +CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { input_ranges.push_back(InputRange(in)); } } -core::conversion::InputRange to_internal_input_range(ExtraInfo::InputRange i) { +core::conversion::InputRange to_internal_input_range(CompileSpec::InputRange i) { return core::conversion::InputRange(i.min, i.opt, i.max); } -std::vector to_vec_internal_input_ranges(std::vector external) { +std::vector to_vec_internal_input_ranges(std::vector external) { std::vector internal; for (auto range : external) { internal.push_back(to_internal_input_range(range)); @@ -74,17 +74,17 @@ std::vector to_vec_internal_input_ranges(std::vect return internal; } -core::ExtraInfo to_internal_extra_info(ExtraInfo external) { - core::ExtraInfo internal(to_vec_internal_input_ranges(external.input_ranges)); +core::CompileSpec to_internal_compile_spec(CompileSpec external) { + core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges)); switch(external.op_precision) { - case ExtraInfo::DataType::kChar: + case CompileSpec::DataType::kChar: internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8; break; - case ExtraInfo::DataType::kHalf: + case CompileSpec::DataType::kHalf: internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF; break; - case ExtraInfo::DataType::kFloat: + case CompileSpec::DataType::kFloat: default: internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT; } @@ -96,22 +96,22 @@ core::ExtraInfo to_internal_extra_info(ExtraInfo external) { internal.convert_info.engine_settings.max_batch_size = external.max_batch_size; switch(external.device) { - case ExtraInfo::DeviceType::kDLA: + case CompileSpec::DeviceType::kDLA: internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kDLA; break; - case ExtraInfo::DeviceType::kGPU: + case CompileSpec::DeviceType::kGPU: default: internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kGPU; } switch(external.capability) { - case ExtraInfo::EngineCapability::kSAFE_GPU: + case CompileSpec::EngineCapability::kSAFE_GPU: internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU; break; - case ExtraInfo::EngineCapability::kSAFE_DLA: + case CompileSpec::EngineCapability::kSAFE_DLA: internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA; break; - case ExtraInfo::EngineCapability::kDEFAULT: + case CompileSpec::EngineCapability::kDEFAULT: default: internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT; diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index e6a1940db1..742b4111a9 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -7,8 +7,8 @@ namespace trtorch { -// Defined in extra_info.cpp -core::ExtraInfo to_internal_extra_info(ExtraInfo external); +// Defined in compile_spec.cpp +core::CompileSpec to_internal_compile_spec(CompileSpec external); bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) { @@ -16,18 +16,18 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, - std::string method_name, ExtraInfo info) { + std::string method_name, CompileSpec info) { LOG_DEBUG(get_build_info()); // Want to export a much simpler (non TRT header dependent) API so doing the // type conversion here - return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_extra_info(info))); + return std::move(core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info))); } -torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info) { +torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info) { LOG_DEBUG(get_build_info()); // Want to export a much simpler (non TRT header dependent) API so doing the // type conversion here - return core::CompileGraph(module, to_internal_extra_info(info)); + return core::CompileGraph(module, to_internal_compile_spec(info)); } std::string get_build_info() { diff --git a/cpp/benchmark/main.cpp b/cpp/benchmark/main.cpp index e73f1da4e8..48566b60c6 100644 --- a/cpp/benchmark/main.cpp +++ b/cpp/benchmark/main.cpp @@ -121,18 +121,18 @@ int main(int argc, const char* argv[]) { at::globalContext().setBenchmarkCuDNN(true); #ifdef TRT - auto extra_info = trtorch::ExtraInfo(dims); - extra_info.workspace_size = 1 << 20; + auto compile_spec = trtorch::CompileSpec(dims); + compile_spec.workspace_size = 1 << 20; #ifdef HALF - extra_info.op_precision = torch::kF16; + compile_spec.op_precision = torch::kF16; #endif - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); #ifdef SAVE_ENGINE std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; - auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec); std::ofstream out("/tmp/engine_converted_from_jit.trt"); out << engine; out.close(); diff --git a/cpp/ptq/README.md b/cpp/ptq/README.md index 70eb990fb3..ceffb6dcec 100644 --- a/cpp/ptq/README.md +++ b/cpp/ptq/README.md @@ -92,20 +92,20 @@ The calibrator factories create a calibrator that inherits from a `nvinfer1::IIn auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); ``` -Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::ExtraInfo` struct and compiling the module: +Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::CompileSpec` struct and compiling the module: ```C++ std::vector> input_shape = {{32, 3, 32, 32}}; /// Configure settings for compilation - auto extra_info = trtorch::ExtraInfo({input_shape}); + auto compile_spec = trtorch::CompileSpec({input_shape}); /// Set operating precision to INT8 - extra_info.op_precision = torch::kI8; + compile_spec.op_precision = torch::kI8; /// Use the TensorRT Entropy Calibrator - extra_info.ptq_calibrator = calibrator; + compile_spec.ptq_calibrator = calibrator; /// Set a larger workspace (you may get better performace from doing so) - extra_info.workspace_size = 1 << 28; + compile_spec.workspace_size = 1 << 28; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); ``` If you have an existing Calibrator implementation for TensorRT you may directly set the `ptq_calibrator` field with a pointer to your calibrator and it will work as well. diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp index 241261dfba..340ec9cd66 100644 --- a/cpp/ptq/main.cpp +++ b/cpp/ptq/main.cpp @@ -50,28 +50,28 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M std::vector> input_shape = {{32, 3, 32, 32}}; /// Configure settings for compilation - auto extra_info = trtorch::ExtraInfo({input_shape}); + auto compile_spec = trtorch::CompileSpec({input_shape}); /// Set operating precision to INT8 - extra_info.op_precision = torch::kI8; + compile_spec.op_precision = torch::kI8; /// Use the TensorRT Entropy Calibrator - extra_info.ptq_calibrator = calibrator; + compile_spec.ptq_calibrator = calibrator; /// Set max batch size for the engine - extra_info.max_batch_size = 32; + compile_spec.max_batch_size = 32; /// Set a larger workspace - extra_info.workspace_size = 1 << 28; + compile_spec.workspace_size = 1 << 28; mod.eval(); #ifdef SAVE_ENGINE std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; - auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec); std::ofstream out("/tmp/engine_converted_from_jit.trt"); out << engine; out.close(); #endif std::cout << "Compiling and quantizing module" << std::endl; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); return std::move(trt_mod); } diff --git a/cpp/trtorchc/main.cpp b/cpp/trtorchc/main.cpp index b37e2a53e9..0e3aaf61d8 100644 --- a/cpp/trtorchc/main.cpp +++ b/cpp/trtorchc/main.cpp @@ -66,7 +66,7 @@ std::vector parseSingleDim(std::string shape_str) { return {}; } -trtorch::ExtraInfo::InputRange parseDynamicDim(std::string shape_str) { +trtorch::CompileSpec::InputRange parseDynamicDim(std::string shape_str) { shape_str = shape_str.substr(1, shape_str.size() - 2); std::vector> shape; std::stringstream ss; @@ -89,7 +89,7 @@ trtorch::ExtraInfo::InputRange parseDynamicDim(std::string shape_str) { exit(1); } - return trtorch::ExtraInfo::InputRange(shape[0], shape[1], shape[2]); + return trtorch::CompileSpec::InputRange(shape[0], shape[1], shape[2]); } std::string get_cwd() { @@ -190,10 +190,10 @@ int main(int argc, char** argv) { } - std::vector ranges; + std::vector ranges; for (const auto shapes : args::get(input_shapes)) { if (shapes.rfind("(", 0) == 0) { - ranges.push_back(trtorch::ExtraInfo::InputRange(parseSingleDim(shapes))); + ranges.push_back(trtorch::CompileSpec::InputRange(parseSingleDim(shapes))); } else if (shapes.rfind("[", 0) == 0) { ranges.push_back(parseDynamicDim(shapes)); } else { @@ -203,7 +203,7 @@ int main(int argc, char** argv) { } } - auto compile_settings = trtorch::ExtraInfo(ranges); + auto compile_settings = trtorch::CompileSpec(ranges); if (build_debuggable_engine) { compile_settings.debug = true; @@ -251,9 +251,9 @@ int main(int argc, char** argv) { auto device = args::get(device_type); std::transform(device.begin(), device.end(), device.begin(), [](unsigned char c){ return std::tolower(c); }); if (device == "gpu") { - compile_settings.device = trtorch::ExtraInfo::DeviceType::kGPU; + compile_settings.device = trtorch::CompileSpec::DeviceType::kGPU; } else if (device == "dla") { - compile_settings.device = trtorch::ExtraInfo::DeviceType::kDLA; + compile_settings.device = trtorch::CompileSpec::DeviceType::kDLA; } else { trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ]"); std::cerr << parser; @@ -265,11 +265,11 @@ int main(int argc, char** argv) { auto capability = args::get(engine_capability); std::transform(capability.begin(), capability.end(), capability.begin(), [](unsigned char c){ return std::tolower(c); }); if (capability == "default") { - compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kDEFAULT; + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDEFAULT; } else if (capability == "safe_gpu") { - compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kSAFE_GPU; + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_GPU; } else if (capability == "safe_dla") { - compile_settings.capability = trtorch::ExtraInfo::EngineCapability::kSAFE_DLA; + compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_DLA; } else { trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]"); std::cerr << parser; @@ -320,7 +320,7 @@ int main(int argc, char** argv) { } else { auto trt_mod = trtorch::CompileGraph(mod, compile_settings); - if (compile_settings.op_precision == trtorch::ExtraInfo::DataType::kFloat) { + if (compile_settings.op_precision == trtorch::CompileSpec::DataType::kFloat) { double threshold_val = 2e-5; if (threshold) { threshold_val = args::get(threshold); diff --git a/cpp/trtorchexec/main.cpp b/cpp/trtorchexec/main.cpp index 8b3e114e62..1dcc74e91b 100644 --- a/cpp/trtorchexec/main.cpp +++ b/cpp/trtorchexec/main.cpp @@ -56,8 +56,8 @@ int main(int argc, const char* argv[]) { dims.push_back(v); } - auto extra_info = trtorch::ExtraInfo(dims); - extra_info.workspace_size = 1 << 24; + auto compile_spec = trtorch::CompileSpec(dims); + compile_spec.workspace_size = 1 << 24; std::cout << "Checking operator support" << std::endl; if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) { @@ -66,7 +66,7 @@ int main(int argc, const char* argv[]) { } std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; - auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec); std::ofstream out("/tmp/engine_converted_from_jit.trt"); out << engine; out.close(); @@ -89,7 +89,7 @@ int main(int argc, const char* argv[]) { } std::cout << "Compiling graph as module" << std::endl; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); std::cout << "Running TRT module" << std::endl; torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues); std::vector trt_results; diff --git a/docsrc/tutorials/getting_started.rst b/docsrc/tutorials/getting_started.rst index 05c4e9efba..a1978927b1 100644 --- a/docsrc/tutorials/getting_started.rst +++ b/docsrc/tutorials/getting_started.rst @@ -305,7 +305,7 @@ With out module loaded, we can feed it into the TRTorch compiler. When we do so mod.eval(); auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}); - auto trt_mod = trtorch::CompileGraph(mod, std::vector{{in.sizes()}}); + auto trt_mod = trtorch::CompileGraph(mod, std::vector{{in.sizes()}}); auto out = trt_mod.forward({in}); Thats it! Now the graph runs primarily not with the JIT compiler but using TensorRT (though we execute the graph using the JIT runtime). @@ -322,8 +322,8 @@ We can also set settings like operating precision to run in FP16. mod.eval(); auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF); - auto input_sizes = std::vector({in.sizes()}); - trtorch::ExtraInfo info(input_sizes); + auto input_sizes = std::vector({in.sizes()}); + trtorch::CompileSpec info(input_sizes); info.op_precision = torch::kHALF; auto trt_mod = trtorch::CompileGraph(mod, info); auto out = trt_mod.forward({in}); @@ -370,8 +370,8 @@ If you want to save the engine produced by TRTorch to use in a TensorRT applicat mod.eval(); auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF); - auto input_sizes = std::vector({in.sizes()}); - trtorch::ExtraInfo info(input_sizes); + auto input_sizes = std::vector({in.sizes()}); + trtorch::CompileSpec info(input_sizes); info.op_precision = torch::kHALF; auto trt_mod = trtorch::ConvertGraphToTRTEngine(mod, "forward", info); std::ofstream out("/tmp/engine_converted_from_jit.trt"); diff --git a/docsrc/tutorials/ptq.rst b/docsrc/tutorials/ptq.rst index fb12e46ef4..28d60acec3 100644 --- a/docsrc/tutorials/ptq.rst +++ b/docsrc/tutorials/ptq.rst @@ -115,21 +115,21 @@ defines the calibration algorithm used when calibrating. You can explicitly make // MinMax Calibrator is geared more towards NLP tasks auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); -Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::ExtraInfo` struct and compiling the module: +Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::CompileSpec` struct and compiling the module: .. code-block:: c++ std::vector> input_shape = {{32, 3, 32, 32}}; /// Configure settings for compilation - auto extra_info = trtorch::ExtraInfo({input_shape}); + auto compile_spec = trtorch::CompileSpec({input_shape}); /// Set operating precision to INT8 - extra_info.op_precision = torch::kI8; + compile_spec.op_precision = torch::kI8; /// Use the TensorRT Entropy Calibrator - extra_info.ptq_calibrator = calibrator; + compile_spec.ptq_calibrator = calibrator; /// Set a larger workspace (you may get better performace from doing so) - extra_info.workspace_size = 1 << 28; + compile_spec.workspace_size = 1 << 28; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well. diff --git a/py/BUILD b/py/BUILD index a2eb3c004b..be5b2d7047 100644 --- a/py/BUILD +++ b/py/BUILD @@ -9,7 +9,7 @@ py_library( "trtorch/__init__.py", "trtorch/_version.py", "trtorch/_compiler.py", - "trtorch/_extra_info.py", + "trtorch/_compile_spec.py", "trtorch/_types.py", "trtorch/logging.py" ], diff --git a/py/trtorch/_extra_info.py b/py/trtorch/_compile_spec.py similarity index 64% rename from py/trtorch/_extra_info.py rename to py/trtorch/_compile_spec.py index 5247b91a0a..aa060bd085 100644 --- a/py/trtorch/_extra_info.py +++ b/py/trtorch/_compile_spec.py @@ -84,53 +84,53 @@ def _parse_device_type(device: Any) -> _types.DeviceType: else: raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device))) -def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C.ExtraInfo: - info = trtorch._C.ExtraInfo() - if "input_shapes" not in extra_info: +def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: + info = trtorch._C.CompileSpec() + if "input_shapes" not in compile_spec: raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict") - info.input_ranges = _parse_input_ranges(extra_info["input_shapes"]) + info.input_ranges = _parse_input_ranges(compile_spec["input_shapes"]) - if "op_precision" in extra_info: - info.op_precision = _parse_op_precision(extra_info["op_precision"]) + if "op_precision" in compile_spec: + info.op_precision = _parse_op_precision(compile_spec["op_precision"]) - if "refit" in extra_info: - assert isinstance(extra_info["refit"], bool) - info.refit = extra_info["refit"] + if "refit" in compile_spec: + assert isinstance(compile_spec["refit"], bool) + info.refit = compile_spec["refit"] - if "debug" in extra_info: - assert isinstance(extra_info["debug"], bool) - info.debug = extra_info["debug"] + if "debug" in compile_spec: + assert isinstance(compile_spec["debug"], bool) + info.debug = compile_spec["debug"] - if "strict_types" in extra_info: - assert isinstance(extra_info["strict_types"], bool) - info.strict_types = extra_info["strict_types"] + if "strict_types" in compile_spec: + assert isinstance(compile_spec["strict_types"], bool) + info.strict_types = compile_spec["strict_types"] - if "allow_gpu_fallback" in extra_info: - assert isinstance(extra_info["allow_gpu_fallback"], bool) - info.allow_gpu_fallback = extra_info["allow_gpu_fallback"] + if "allow_gpu_fallback" in compile_spec: + assert isinstance(compile_spec["allow_gpu_fallback"], bool) + info.allow_gpu_fallback = compile_spec["allow_gpu_fallback"] - if "device" in extra_info: - info.device = _parse_device_type(extra_info["device"]) + if "device" in compile_spec: + info.device = _parse_device_type(compile_spec["device"]) - if "capability" in extra_info: - assert isinstance(extra_info["capability"], type.EngineCapability) - info.capability = extra_info["capability"] + if "capability" in compile_spec: + assert isinstance(compile_spec["capability"], type.EngineCapability) + info.capability = compile_spec["capability"] - if "num_min_timing_iters" in extra_info: - assert type(extra_info["num_min_timing_iters"]) is int - info.num_min_timing_iters = extra_info["num_min_timing_iters"] + if "num_min_timing_iters" in compile_spec: + assert type(compile_spec["num_min_timing_iters"]) is int + info.num_min_timing_iters = compile_spec["num_min_timing_iters"] - if "num_avg_timing_iters" in extra_info: - assert type(extra_info["num_avg_timing_iters"]) is int - info.num_avg_timing_iters = extra_info["num_avg_timing_iters"] + if "num_avg_timing_iters" in compile_spec: + assert type(compile_spec["num_avg_timing_iters"]) is int + info.num_avg_timing_iters = compile_spec["num_avg_timing_iters"] - if "workspace_size" in extra_info: - assert type(extra_info["workspace_size"]) is int - info.workspace_size = extra_info["workspace_size"] + if "workspace_size" in compile_spec: + assert type(compile_spec["workspace_size"]) is int + info.workspace_size = compile_spec["workspace_size"] - if "max_batch_size" in extra_info: - assert type(extra_info["max_batch_size"]) is int - info.max_batch_size = extra_info["max_batch_size"] + if "max_batch_size" in compile_spec: + assert type(compile_spec["max_batch_size"]) is int + info.max_batch_size = compile_spec["max_batch_size"] return info \ No newline at end of file diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 1627e5a05f..1c35dbe4a1 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -3,12 +3,12 @@ from torch import nn import trtorch._C -from trtorch._extra_info import _parse_extra_info +from trtorch._compile_spec import _parse_compile_spec from trtorch._version import __version__ from types import FunctionType -def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule: +def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler @@ -19,13 +19,13 @@ def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.Script Args: module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch ``torch.nn.Module`` - extra_info (dict): Compilation settings including operating precision, target device, etc. + compile_spec (dict): Compilation settings including operating precision, target device, etc. One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs to the graph. All other keys are optional .. code-block:: py - ExtraInfo = { + compile_spec = { "input_shapes": [ (1, 3, 224, 224), # Static input shape for input #1 { @@ -58,11 +58,11 @@ def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.Script if isinstance(module, torch.jit.ScriptFunction): raise TypeError("torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile") - compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_extra_info(extra_info)) + compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(compile_spec)) compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) return compiled_module -def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str: +def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, compile_spec: Any) -> str: """Convert a TorchScript module method to a serialized TensorRT engine Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings @@ -71,13 +71,13 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch ``torch.nn.Module`` method_name (str): Name of method to convert - extra_info (dict): Compilation settings including operating precision, target device, etc. + compile_spec (dict): Compilation settings including operating precision, target device, etc. One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs to the graph. All other keys are optional .. code-block:: py - ExtraInfo = { + CompileSpec = { "input_shapes": [ (1, 3, 224, 224), # Static input shape for input #1 { @@ -109,7 +109,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st if isinstance(module, torch.jit.ScriptFunction): raise TypeError("torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile") - return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info)) + return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec)) def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: """Checks to see if a method is fully supported by TRTorch diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 765f75d56a..da6d2b2688 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -1,5 +1,7 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" +//TODO: Remove when we have access to PyTorch to_backend autoregistration +#include "core/backend.h" #include "core/compiler.h" #include "core/conversion/conversion.h" #include "torch/torch.h" @@ -73,13 +75,13 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { } } -struct ExtraInfo { +struct CompileSpec { - core::ExtraInfo toInternalExtraInfo() { + core::CompileSpec toInternalCompileSpec() { for (auto i : input_ranges) { internal_input_ranges.push_back(i.toInternalInputRange()); } - auto info = core::ExtraInfo(internal_input_ranges); + auto info = core::CompileSpec(internal_input_ranges); info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision); info.convert_info.engine_settings.refit = refit; info.convert_info.engine_settings.debug = debug; @@ -109,15 +111,15 @@ struct ExtraInfo { uint64_t max_batch_size = 0; }; -torch::jit::Module CompileGraph(const torch::jit::Module& mod, ExtraInfo& info) { +torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info) { py::gil_scoped_acquire gil; - auto trt_mod = core::CompileGraph(mod, info.toInternalExtraInfo()); + auto trt_mod = core::CompileGraph(mod, info.toInternalCompileSpec()); return trt_mod; } -py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, ExtraInfo& info) { +py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) { py::gil_scoped_acquire gil; - auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalExtraInfo()); + auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec()); return py::bytes(trt_engine); } @@ -189,20 +191,20 @@ PYBIND11_MODULE(_C, m) { .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") .value("default", EngineCapability::kDEFAULT, "Use default behavior"); - py::class_(m, "ExtraInfo") + py::class_(m, "CompileSpec") .def(py::init<>()) - .def_readwrite("input_ranges", &ExtraInfo::input_ranges) - .def_readwrite("op_precision", &ExtraInfo::op_precision) - .def_readwrite("refit", &ExtraInfo::refit) - .def_readwrite("debug", &ExtraInfo::debug) - .def_readwrite("strict_types", &ExtraInfo::strict_types) - .def_readwrite("allow_gpu_fallback", &ExtraInfo::allow_gpu_fallback) - .def_readwrite("device", &ExtraInfo::device) - .def_readwrite("capability", &ExtraInfo::capability) - .def_readwrite("num_min_timing_iters", &ExtraInfo::num_min_timing_iters) - .def_readwrite("num_avg_timing_iters", &ExtraInfo::num_avg_timing_iters) - .def_readwrite("workspace_size", &ExtraInfo::workspace_size) - .def_readwrite("max_batch_size", &ExtraInfo::max_batch_size); + .def_readwrite("input_ranges", &CompileSpec::input_ranges) + .def_readwrite("op_precision", &CompileSpec::op_precision) + .def_readwrite("refit", &CompileSpec::refit) + .def_readwrite("debug", &CompileSpec::debug) + .def_readwrite("strict_types", &CompileSpec::strict_types) + .def_readwrite("allow_gpu_fallback", &CompileSpec::allow_gpu_fallback) + .def_readwrite("device", &CompileSpec::device) + .def_readwrite("capability", &CompileSpec::capability) + .def_readwrite("num_min_timing_iters", &CompileSpec::num_min_timing_iters) + .def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters) + .def_readwrite("workspace_size", &CompileSpec::workspace_size) + .def_readwrite("max_batch_size", &CompileSpec::max_batch_size); m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT"; m.def("compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded"); @@ -225,7 +227,11 @@ PYBIND11_MODULE(_C, m) { .value("INFO", core::util::logging::LogLevel::kINFO) .value("DEBUG", core::util::logging::LogLevel::kDEBUG) .export_values(); + + //TODO: Remove when we have access to PyTorch autoregistration + //m.def("to_tensorrt", backend::GetTensorRTBackend().generateToBackendFn()); } -} // namespace py + +} // namespace pyapi } // namespace trtorch diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp index 6de40a6c31..b19c01cb38 100644 --- a/tests/accuracy/test_fp16_accuracy.cpp +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -27,10 +27,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; - auto extra_info = trtorch::ExtraInfo({input_shape}); - extra_info.op_precision = torch::kF16; + auto compile_spec = trtorch::CompileSpec({input_shape}); + compile_spec.op_precision = torch::kF16; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); for (auto batch : *eval_dataloader) { diff --git a/tests/accuracy/test_fp32_accuracy.cpp b/tests/accuracy/test_fp32_accuracy.cpp index d3d8bddb96..11ed944077 100644 --- a/tests/accuracy/test_fp32_accuracy.cpp +++ b/tests/accuracy/test_fp32_accuracy.cpp @@ -27,10 +27,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; std::vector> input_shape = {{32, 3, 32, 32}}; - auto extra_info = trtorch::ExtraInfo({input_shape}); - extra_info.op_precision = torch::kF32; + auto compile_spec = trtorch::CompileSpec({input_shape}); + compile_spec.op_precision = torch::kF32; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); for (auto batch : *eval_dataloader) { diff --git a/tests/accuracy/test_int8_accuracy.cpp b/tests/accuracy/test_int8_accuracy.cpp index aa4824948a..db5b259657 100644 --- a/tests/accuracy/test_int8_accuracy.cpp +++ b/tests/accuracy/test_int8_accuracy.cpp @@ -20,15 +20,15 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { std::vector> input_shape = {{32, 3, 32, 32}}; // Configure settings for compilation - auto extra_info = trtorch::ExtraInfo({input_shape}); + auto compile_spec = trtorch::CompileSpec({input_shape}); // Set operating precision to INT8 - extra_info.op_precision = torch::kI8; + compile_spec.op_precision = torch::kI8; // Use the TensorRT Entropy Calibrator - extra_info.ptq_calibrator = calibrator; + compile_spec.ptq_calibrator = calibrator; // Set max batch size for the engine - extra_info.max_batch_size = 32; + compile_spec.max_batch_size = 32; // Set a larger workspace - extra_info.workspace_size = 1 << 28; + compile_spec.workspace_size = 1 << 28; mod.eval(); @@ -57,7 +57,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; // Compile Graph - auto trt_mod = trtorch::CompileGraph(mod, extra_info); + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); // Check the INT8 accuracy in TRT torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA}); diff --git a/tests/modules/test_serialization.cpp b/tests/modules/test_serialization.cpp index a7fcea5558..bb6b984f4f 100644 --- a/tests/modules/test_serialization.cpp +++ b/tests/modules/test_serialization.cpp @@ -1,7 +1,7 @@ #include "module_test.h" -std::vector toInputRangesDynamic(std::vector> opts) { - std::vector a; +std::vector toInputRangesDynamic(std::vector> opts) { + std::vector a; for (auto opt : opts) { std::vector min_range(opt); @@ -12,7 +12,7 @@ std::vector toInputRangesDynamic(std::vector