Skip to content

Commit

Permalink
feat(disable_tf32): Add a new API to disable TF32
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Feb 9, 2021
1 parent 1660633 commit 536983b
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 2 deletions.
5 changes: 5 additions & 0 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace conversion {
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
os << "Settings requested for TensorRT engine:" \
<< "\n Operating Precision: " << s.op_precision \
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
<< "\n Strict Types: " << s.strict_types \
Expand Down Expand Up @@ -77,6 +78,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
}
op_precision = settings.op_precision;

if (settings.disable_tf32) {
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
}

if (settings.refit) {
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
}
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct Device {

struct BuilderSettings {
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
bool strict_types = false;
Expand Down
9 changes: 9 additions & 0 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ struct TRTORCH_API CompileSpec {
*/
DataType op_precision = DataType::kFloat;

/**
* Prevent Float32 layers from using TF32 data format
*
* TF32 computes inner products by rounding the inputs to 10-bit mantissas
* before multiplying, but accumulates the sum using 23-bit mantissas.
* This is the behavior of FP32 layers by default.
*/
bool disable_tf32 = false;

/**
* Build a refitable engine
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
}

internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
internal.convert_info.engine_settings.refit = external.refit;
internal.convert_info.engine_settings.debug = external.debug;
internal.convert_info.engine_settings.strict_types = external.strict_types;
Expand Down
10 changes: 10 additions & 0 deletions cpp/trtorchc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ int main(int argc, char** argv) {
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
{"allow-gpu-fallback"});

args::Flag disable_tf32(
parser,
"disable-tf32",
"Prevent Float32 layers from using the TF32 data format",
{"disable-tf32"});

args::ValueFlag<std::string> op_precision(
parser,
"precision",
Expand Down Expand Up @@ -263,6 +269,10 @@ int main(int argc, char** argv) {
compile_settings.device.allow_gpu_fallback = true;
}

if (disable_tf32) {
compile_settings.disable_tf32 = true;
}

std::string calibration_cache_file_path = "";
if (calibration_cache_file) {
calibration_cache_file_path = resolve_path(args::get(calibration_cache_file));
Expand Down
6 changes: 6 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "op_precision" in compile_spec:
info.op_precision = _parse_op_precision(compile_spec["op_precision"])

if "disable_tf32" in compile_spec:
assert isinstance(compile_spec["disable_tf32"], bool)
info.disable_tf32 = compile_spec["disable_tf32"]

if "refit" in compile_spec:
assert isinstance(compile_spec["refit"], bool)
info.refit = compile_spec["refit"]
Expand Down Expand Up @@ -201,6 +205,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": False, # enable refit
"debug": False, # enable debuggable engine
"strict_types": False, # kernels should strictly run in operating precision
Expand Down Expand Up @@ -239,6 +244,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.

backend_spec.set_device(d)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
backend_spec.set_refit(parsed_spec.refit)
backend_spec.set_debug(parsed_spec.debug)
backend_spec.set_refit(parsed_spec.refit)
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"refit": false, # enable refit
"debug": false, # enable debuggable engine
"strict_types": false, # kernels should strictly run in operating precision
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void RegisterTRTCompileSpec() {
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
}
auto info = core::CompileSpec(internal_input_ranges);
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;
info.convert_info.engine_settings.strict_types = strict_types;
Expand Down Expand Up @@ -128,6 +129,7 @@ std::string CompileSpec::stringify() {
}
ss << " ]" << std::endl;
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
ss << " \"Refit\": " << refit << std::endl;
ss << " \"Debug\": " << debug << std::endl;
ss << " \"Strict Types\": " << strict_types << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct CompileSpec : torch::CustomClassHolder {
}

ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
ADD_FIELD_GET_SET(disable_tf32, bool);
ADD_FIELD_GET_SET(refit, bool);
ADD_FIELD_GET_SET(debug, bool);
ADD_FIELD_GET_SET(strict_types, bool);
Expand All @@ -111,6 +112,7 @@ struct CompileSpec : torch::CustomClassHolder {

std::vector<InputRange> input_ranges;
DataType op_precision = DataType::kFloat;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
bool strict_types = false;
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
.def_readwrite("refit", &CompileSpec::refit)
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)
.def_readwrite("debug", &CompileSpec::debug)
.def_readwrite("strict_types", &CompileSpec::strict_types)
.def_readwrite("device", &CompileSpec::device)
Expand Down
6 changes: 4 additions & 2 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_compile_traced(self):
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

Expand All @@ -35,7 +36,8 @@ def test_compile_script(self):
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/py/test_to_backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setUp(self):
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
"max_batch_size": 0,
"disable_tf32": False,
})
}

Expand Down

0 comments on commit 536983b

Please sign in to comment.