Skip to content

Commit

Permalink
refactor!(//cpp): Inlining partial compilation settings since the
Browse files Browse the repository at this point in the history
feature is now on by default

BREAKING CHANGE: This commit changes the API for automatic fallback
to inline settings regarding partial compilation in preparation
for it to be turned on by default

Now in the compile spec instead of a `torch_fallback` field with its
associated struct, there are four new fields in the compile spec

```c++
bool require_full_compilation = true;
uint64_t min_block_size = 3;
std::vector<std::string> torch_executed_ops = {};
std::vector<std::string> torch_executed_modules = {};
```

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Oct 19, 2021
1 parent 2a0d1c8 commit 19ecc64
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 61 deletions.
58 changes: 21 additions & 37 deletions cpp/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,38 +516,6 @@ struct TRTORCH_API CompileSpec {
bool explicit_set_dtype;
};

/**
* @brief A struct to hold fallback info
*/
struct TRTORCH_API TorchFallback {
/// enable the automatic fallback feature
bool enabled = false;

/// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
uint64_t min_block_size = 1;

/// A list of names of operations that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_ops;

/// A list of names of modules that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_modules;

/**
* @brief Construct a default Torch Fallback object, fallback will be off
*/
TorchFallback() = default;

/**
* @brief Construct from a bool
*/
TorchFallback(bool enabled) : enabled(enabled) {}

/**
* @brief Constructor for setting min_block_size
*/
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
};

/**
* @brief Construct a new Extra Info object
* Convienence constructor to set fixed input size from vectors describing
Expand Down Expand Up @@ -643,11 +611,6 @@ struct TRTORCH_API CompileSpec {
*/
Device device;

/**
* @brief Settings related to partial compilation
*/
TorchFallback torch_fallback;

/**
* Sets the restrictions for the engine (CUDA Safety)
*/
Expand Down Expand Up @@ -676,6 +639,27 @@ struct TRTORCH_API CompileSpec {
* Calibration dataloaders for each input for post training quantizatiom
*/
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;

/**
* Require the full module be compiled to TensorRT instead of potentially running unsupported operations in PyTorch
*/
bool require_full_compilation = false;

/**
* Minimum number of contiguous supported operators to compile a subgraph to TensorRT
*/
uint64_t min_block_size = 3;

/**
* List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
*/
std::vector<std::string> torch_executed_ops;


/**
* List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
*/
std::vector<std::string> torch_executed_modules;
};

/**
Expand Down
32 changes: 13 additions & 19 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,6 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}

// /* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype
// for inputs they will follow PyTorch convetions */ for (size_t i = 0; i < external.inputs.size(); i++) {
// if (!external.inputs[i].get_explicit_set_dtype()) {
// auto& precisions = internal.convert_info.engine_settings.enabled_precisions;
// auto& internal_ins = internal.convert_info.inputs;
// if (precisions.find(nvinfer1::DataType::kINT8) != precisions.end()) {
// internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
// } else if (precisions.find(nvinfer1::DataType::kHALF) != precisions.end()) {
// internal_ins[i].dtype = nvinfer1::DataType::kHALF;
// } else {
// internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
// }
// }
// }

internal.convert_info.engine_settings.sparse_weights = external.sparse_weights;
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
internal.convert_info.engine_settings.refit = external.refit;
Expand All @@ -346,10 +331,19 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.strict_types = external.strict_types;
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
internal.partition_info.enabled = external.torch_fallback.enabled;
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;

TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)),
"require_full_compilation is enabled however the list of ops to run in torch is not empty (Found "
<< external.torch_executed_ops.size() << " ops)");

TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_modules.size() > 0)),
"require_full_compilation is enabled however the list of modules to run in torch is not empty (Found "
<< external.torch_executed_modules.size() << " modules)");

internal.partition_info.enabled = external.require_full_compilation;
internal.partition_info.min_block_size = external.min_block_size;
internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops);
internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules);

switch (external.device.device_type) {
case CompileSpec::Device::DeviceType::kDLA:
Expand Down
31 changes: 31 additions & 0 deletions tests/core/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
config_setting(
name = "use_pre_cxx11_abi",
values = {
"define": "abi=pre_cxx11_abi",
}
)

filegroup(
name = "jit_models",
srcs = ["//tests/modules:mobilenet_v2_scripted.jit.pt"]
)

cc_test(
name = "test_detecting_input_type",
srcs = ["test_detecting_input_type.cpp"],
deps = [
"//tests/util",
"//core",
"//core/lowering",
"//core/util:prelude",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
]
)

test_suite(
name = "core_tests",
tests = [
":test_detecting_input_type",
"//tests/core/conversion:conversion_tests",
"//tests/core/lowering:lowering_tests",
"//tests/core/partitioning:partitioning_tests"
Expand Down
8 changes: 3 additions & 5 deletions tests/cpp/test_module_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
}

trtorch::CompileSpec cfg(input_shapes);
cfg.torch_fallback.enabled = true;
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock");
cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::CompileGraph(mod, cfg);
Expand All @@ -51,9 +50,8 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
}

trtorch::CompileSpec cfg(input_shapes);
cfg.torch_fallback.enabled = true;
cfg.torch_fallback.min_block_size = 5;
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");
cfg.min_block_size = 5;
cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation");

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = trtorch::CompileGraph(mod, cfg);
Expand Down

0 comments on commit 19ecc64

Please sign in to comment.