Skip to content

Commit

Permalink
feat: Support fallback options in trtorchc
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 18, 2021
1 parent 75e86e8 commit ad966b7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cpp/trtorchc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ trtorchc [input_file_path] [output_file_path]
--allow-gpu-fallback (Only used when targeting DLA
(device-type)) Lets engine run layers on
GPU if they are not supported on DLA
--allow-torch-fallback Enable layers to run in torch
if they are not supported in TensorRT
-p[precision],
--default-op-precision=[precision]
Default operating precision for the
engine (Int8 requires a
calibration-cache argument) [ float |
float32 | f32 | half | float16 | f16 |
int8 | i8 ] (default: float)
--forced-fallback-ops List of operators in the graph that
should be forced to fallback to Pytorch for execution
-d[type], --device-type=[type] The type of device the engine should be
built for [ gpu | dla ] (default: gpu)
--engine-capability=[capability] The type of device the engine should be
Expand Down Expand Up @@ -84,4 +88,4 @@ trtorchc [input_file_path] [output_file_path]
e.g.
```
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
```
```
26 changes: 26 additions & 0 deletions cpp/trtorchc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ int main(int argc, char** argv) {
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
{"allow-gpu-fallback"});

args::Flag allow_torch_fallback(
parser, "allow-torch-fallback", "Enable layers to run in torch if they are not supported in TensorRT", {"allow-torch-fallback"});

args::Flag disable_tf32(
parser, "disable-tf32", "Prevent Float32 layers from using the TF32 data format", {"disable-tf32"});

Expand Down Expand Up @@ -191,6 +194,11 @@ int main(int argc, char** argv) {
"file_path",
"Path to calibration cache file to use for post training quantization",
{"calibration-cache-file"});
args::ValueFlag<std::string> forced_fallback_ops(
parser,
"forced_fallback_ops",
"List of operators in the graph that should be forced to fallback to Pytorch for execution.",
{"ffo", "forced-fallback-ops"});
args::ValueFlag<int> num_min_timing_iters(
parser, "num_iters", "Number of minimization timing iterations used to select kernels", {"num-min-timing-iter"});
args::ValueFlag<int> num_avg_timing_iters(
Expand Down Expand Up @@ -266,6 +274,10 @@ int main(int argc, char** argv) {
compile_settings.device.allow_gpu_fallback = true;
}

if (allow_torch_fallback) {
compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
}

if (disable_tf32) {
compile_settings.disable_tf32 = true;
}
Expand All @@ -277,6 +289,20 @@ int main(int argc, char** argv) {

auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path);

if (forced_fallback_ops) {
std::string fallback_ops = args::get(forced_fallback_ops);
if (!allow_torch_fallback){
trtorch::logging::log(
trtorch::logging::Level::kERROR,
"Forced fallback ops provided but allow_torch_fallback is False. Please use --allow_torch_fallback to enable automatic fallback of operators.");
}
std::string op;
std::stringstream ss(fallback_ops);
while (getline(ss, op, ',')) {
compile_settings.torch_fallback.forced_fallback_ops.push_back(op);
}
}

if (op_precision) {
auto precision = args::get(op_precision);
std::transform(
Expand Down

0 comments on commit ad966b7

Please sign in to comment.