Skip to content

Commit

Permalink
refactor!: Renaming extra info to compile spec to be more consistent
Browse files Browse the repository at this point in the history
with other backends and between APIs in TRTorch

BREAKING CHANGE: This changes the top level api for setting the
specification for compilation, a simple find and replace should allow
users to port forward

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Sep 30, 2020
1 parent e4a4574 commit b8fa228
Show file tree
Hide file tree
Showing 27 changed files with 194 additions and 187 deletions.
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ build --cxxopt='-std=c++14'
build:python --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:python --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:python --define=abi=pre_cxx11_abi
build:python --define=target_lang=python

build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ More Information / System Architecture:
#include "trtorch/trtorch.h"

...
auto compile_settings = trtorch::ExtraInfo(dims);
auto compile_settings = trtorch::CompileSpec(dims);
// FP16 execution
compile_settings.op_precision = torch::kFloat;
// Compile module
Expand Down Expand Up @@ -54,7 +54,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
```

> Notes on running in lower precisions:
> - Set precision with extra_info.op_precision
> - Set precision with compile_spec.op_precision
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
Expand Down
12 changes: 6 additions & 6 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "core/lowering/lowering.h"
#include "core/conversion/conversion.h"
#include "core/execution/execution.h"
#include "core/runtime/runtime.h"

namespace trtorch {
namespace core {
Expand All @@ -42,15 +42,15 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str


void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;

// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
mod.register_attribute(
name,
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
c10::getCustomClassType<c10::intrusive_ptr<runtime::TRTEngine>>(),
c10::IValue(std::move(engine_ptr)),
false
);
Expand Down Expand Up @@ -125,7 +125,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
ExtraInfo cfg) {
CompileSpec cfg) {

// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);
Expand All @@ -137,12 +137,12 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,

LOG_INFO(*g << "(CompileGraph)\n");

auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
return std::move(engine);
}

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
ExtraInfo cfg) {
CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
//torch::jit::script::Module new_mod = mod.clone();
Expand Down
8 changes: 4 additions & 4 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
namespace trtorch {
namespace core {

struct ExtraInfo {
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
struct CompileSpec {
CompileSpec(std::vector<conversion::InputRange> input_ranges)
: convert_info(std::move(input_ranges)) {}
conversion::ConversionInfo convert_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name, ExtraInfo cfg);
std::string method_name, CompileSpec cfg);

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
Expand Down
2 changes: 1 addition & 1 deletion cpp/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cc_library(
"include/trtorch/ptq.h"
],
srcs = [
"src/extra_info.cpp",
"src/compile_spec.cpp",
"src/logging.cpp",
"src/trtorch.cpp",
"src/ptq.cpp"
Expand Down
16 changes: 8 additions & 8 deletions cpp/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace trtorch {
* Settings data structure for TRTorch compilation
*
*/
struct TRTORCH_API ExtraInfo {
struct TRTORCH_API CompileSpec {
/**
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
*
Expand Down Expand Up @@ -132,10 +132,10 @@ struct TRTORCH_API ExtraInfo {
kSAFE_DLA,
};

ExtraInfo(std::vector<InputRange> input_ranges)
CompileSpec(std::vector<InputRange> input_ranges)
: input_ranges(std::move(input_ranges)) {}
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

// Defaults should reflect TensorRT defaults for BuilderConfig

Expand Down Expand Up @@ -236,27 +236,27 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
*
* Converts specifically the forward method of a TorchScript Module
*/
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info);

/**
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param method_name: std::string - Name of method to compile
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert selected method to a serialized TensorRT engine which can be run with
* TensorRT
*/
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, CompileSpec info);

namespace ptq {
/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/api/include/trtorch/ptq.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Int8Calibrator : Algorithm {
/**
* @brief operator to cast to nvinfer1::IInt8Calibrator*
*
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
*
* @return nvinfer1::IInt8Calibrator*
*/
Expand Down Expand Up @@ -259,7 +259,7 @@ class Int8CacheCalibrator : Algorithm {
/**
* @brief operator to cast to nvinfer1::IInt8Calibrator*
*
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
*
* @return nvinfer1::IInt8Calibrator*
*/
Expand Down
16 changes: 8 additions & 8 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace trtorch {
* Settings data structure for TRTorch compilation
*
*/
struct TRTORCH_API ExtraInfo {
struct TRTORCH_API CompileSpec {
/**
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
*
Expand Down Expand Up @@ -256,7 +256,7 @@ struct TRTORCH_API ExtraInfo {
*
* @param input_ranges
*/
ExtraInfo(std::vector<InputRange> input_ranges)
CompileSpec(std::vector<InputRange> input_ranges)
: input_ranges(std::move(input_ranges)) {}
/**
* @brief Construct a new Extra Info object
Expand All @@ -265,14 +265,14 @@ struct TRTORCH_API ExtraInfo {
*
* @param fixed_sizes
*/
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
/**
* @brief Construct a new Extra Info object
* Convienence constructor to set fixed input size from c10::ArrayRef's (the output of tensor.sizes()) describing size of input tensors.
* Each entry in the vector represents a input and should be provided in call order.
* @param fixed_sizes
*/
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

// Defaults should reflect TensorRT defaults for BuilderConfig

Expand Down Expand Up @@ -379,7 +379,7 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::Module - Existing TorchScript module
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
Expand All @@ -388,20 +388,20 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st
*
* @return: A new module trageting a TensorRT engine
*/
TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info);
TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);

/**
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::Module - Existing TorchScript module
* @param method_name: std::string - Name of method to compile
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert selected method to a serialized TensorRT engine which can be run with
* TensorRT
*
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
*/
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, CompileSpec info);
} // namespace trtorch
40 changes: 20 additions & 20 deletions cpp/api/src/extra_info.cpp → cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "trtorch/trtorch.h"

namespace trtorch {
ExtraInfo::DataType::DataType(c10::ScalarType t) {
CompileSpec::DataType::DataType(c10::ScalarType t) {
TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported");
switch (t) {
case at::kHalf:
Expand All @@ -21,70 +21,70 @@ ExtraInfo::DataType::DataType(c10::ScalarType t) {
}
}

ExtraInfo::DeviceType::DeviceType(c10::DeviceType t) {
CompileSpec::DeviceType::DeviceType(c10::DeviceType t) {
TRTORCH_CHECK(t == at::kCUDA, "Device type when specified using torch device enum must be torch::kCUDA");
value = DeviceType::kGPU;
}

ExtraInfo::InputRange::InputRange(std::vector<int64_t> opt) {
CompileSpec::InputRange::InputRange(std::vector<int64_t> opt) {
this->opt = opt;
this->min = opt;
this->max = opt;
}

ExtraInfo::InputRange::InputRange(c10::IntArrayRef opt) {
CompileSpec::InputRange::InputRange(c10::IntArrayRef opt) {
this->opt = core::util::toVec(opt);
this->min = core::util::toVec(opt);
this->max = core::util::toVec(opt);
}

ExtraInfo::InputRange::InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
CompileSpec::InputRange::InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
this->opt = opt;
this->min = min;
this->max = max;
}

ExtraInfo::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
CompileSpec::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
this->opt = core::util::toVec(opt);
this->min = core::util::toVec(min);
this->max = core::util::toVec(max);
}

ExtraInfo::ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
CompileSpec::CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
for (auto in : fixed_sizes) {
input_ranges.push_back(InputRange(in));
}
}

ExtraInfo::ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes) {
CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
for (auto in : fixed_sizes) {
input_ranges.push_back(InputRange(in));
}
}

core::conversion::InputRange to_internal_input_range(ExtraInfo::InputRange i) {
core::conversion::InputRange to_internal_input_range(CompileSpec::InputRange i) {
return core::conversion::InputRange(i.min, i.opt, i.max);
}

std::vector<core::conversion::InputRange> to_vec_internal_input_ranges(std::vector<ExtraInfo::InputRange> external) {
std::vector<core::conversion::InputRange> to_vec_internal_input_ranges(std::vector<CompileSpec::InputRange> external) {
std::vector<core::conversion::InputRange> internal;
for (auto range : external) {
internal.push_back(to_internal_input_range(range));
}
return internal;
}

core::ExtraInfo to_internal_extra_info(ExtraInfo external) {
core::ExtraInfo internal(to_vec_internal_input_ranges(external.input_ranges));
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));

switch(external.op_precision) {
case ExtraInfo::DataType::kChar:
case CompileSpec::DataType::kChar:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8;
break;
case ExtraInfo::DataType::kHalf:
case CompileSpec::DataType::kHalf:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF;
break;
case ExtraInfo::DataType::kFloat:
case CompileSpec::DataType::kFloat:
default:
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
}
Expand All @@ -96,22 +96,22 @@ core::ExtraInfo to_internal_extra_info(ExtraInfo external) {
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;

switch(external.device) {
case ExtraInfo::DeviceType::kDLA:
case CompileSpec::DeviceType::kDLA:
internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kDLA;
break;
case ExtraInfo::DeviceType::kGPU:
case CompileSpec::DeviceType::kGPU:
default:
internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kGPU;
}

switch(external.capability) {
case ExtraInfo::EngineCapability::kSAFE_GPU:
case CompileSpec::EngineCapability::kSAFE_GPU:
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU;
break;
case ExtraInfo::EngineCapability::kSAFE_DLA:
case CompileSpec::EngineCapability::kSAFE_DLA:
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA;
break;
case ExtraInfo::EngineCapability::kDEFAULT:
case CompileSpec::EngineCapability::kDEFAULT:
default:
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT;

Expand Down
Loading

0 comments on commit b8fa228

Please sign in to comment.