diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index a896448d39..72a8f6bf55 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -12,6 +12,7 @@ namespace conversion { std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { os << "Settings requested for TensorRT engine:" \ << "\n Operating Precision: " << s.op_precision \ + << "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \ << "\n Make Refittable Engine: " << s.refit \ << "\n Debuggable Engine: " << s.debug \ << "\n Strict Types: " << s.strict_types \ @@ -77,6 +78,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) } op_precision = settings.op_precision; + if (settings.disable_tf32) { + cfg->clearFlag(nvinfer1::BuilderFlag::kTF32); + } + if (settings.refit) { cfg->setFlag(nvinfer1::BuilderFlag::kREFIT); } diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index d115b3e519..444936cb8b 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -24,6 +24,7 @@ struct Device { struct BuilderSettings { nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT; + bool disable_tf32 = false; bool refit = false; bool debug = false; bool strict_types = false; diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 103c80d421..4739d9199a 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -239,6 +239,15 @@ struct TRTORCH_API CompileSpec { */ DataType op_precision = DataType::kFloat; + /** + * Prevent Float32 layers from using TF32 data format + * + * TF32 computes inner products by rounding the inputs to 10-bit mantissas + * before multiplying, but accumulates the sum using 23-bit mantissas. + * This is the behavior of FP32 layers by default. + */ + bool disable_tf32 = false; + /** * Build a refitable engine */ diff --git a/cpp/api/src/compile_spec.cpp b/cpp/api/src/compile_spec.cpp index 253a3b1e63..25dbda9c96 100644 --- a/cpp/api/src/compile_spec.cpp +++ b/cpp/api/src/compile_spec.cpp @@ -89,6 +89,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT; } + internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32; internal.convert_info.engine_settings.refit = external.refit; internal.convert_info.engine_settings.debug = external.debug; internal.convert_info.engine_settings.strict_types = external.strict_types; diff --git a/cpp/trtorchc/main.cpp b/cpp/trtorchc/main.cpp index 3023c189d1..61cc4f903c 100644 --- a/cpp/trtorchc/main.cpp +++ b/cpp/trtorchc/main.cpp @@ -163,6 +163,12 @@ int main(int argc, char** argv) { "(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA", {"allow-gpu-fallback"}); + args::Flag disable_tf32( + parser, + "disable-tf32", + "Prevent Float32 layers from using the TF32 data format", + {"disable-tf32"}); + args::ValueFlag op_precision( parser, "precision", @@ -263,6 +269,10 @@ int main(int argc, char** argv) { compile_settings.device.allow_gpu_fallback = true; } + if (disable_tf32) { + compile_settings.disable_tf32 = true; + } + std::string calibration_cache_file_path = ""; if (calibration_cache_file) { calibration_cache_file_path = resolve_path(args::get(calibration_cache_file)); diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 311bee34c1..814be63c13 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -135,6 +135,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: if "op_precision" in compile_spec: info.op_precision = _parse_op_precision(compile_spec["op_precision"]) + if "disable_tf32" in compile_spec: + assert isinstance(compile_spec["disable_tf32"], bool) + info.disable_tf32 = compile_spec["disable_tf32"] + if "refit" in compile_spec: assert isinstance(compile_spec["refit"], bool) info.refit = compile_spec["refit"] @@ -201,6 +205,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU }, "op_precision": torch.half, # Operating precision set to FP16 + "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "refit": False, # enable refit "debug": False, # enable debuggable engine "strict_types": False, # kernels should strictly run in operating precision @@ -239,6 +244,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec.set_device(d) backend_spec.set_op_precision(int(parsed_spec.op_precision)) + backend_spec.set_disable_tf32(parsed_spec.disable_tf32) backend_spec.set_refit(parsed_spec.refit) backend_spec.set_debug(parsed_spec.debug) backend_spec.set_refit(parsed_spec.refit) diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index cd9e8e203b..76a7923ca2 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -99,6 +99,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU }, "op_precision": torch.half, # Operating precision set to FP16 + "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "refit": false, # enable refit "debug": false, # enable debuggable engine "strict_types": false, # kernels should strictly run in operating precision diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 4446ece752..048e69dbe1 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -32,6 +32,7 @@ void RegisterTRTCompileSpec() { .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); + ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 88a9ab59c5..54b47d9111 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -99,6 +99,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { } auto info = core::CompileSpec(internal_input_ranges); info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision); + info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; info.convert_info.engine_settings.debug = debug; info.convert_info.engine_settings.strict_types = strict_types; @@ -128,6 +129,7 @@ std::string CompileSpec::stringify() { } ss << " ]" << std::endl; ss << " \"Op Precision\": " << to_str(op_precision) << std::endl; + ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl; ss << " \"Refit\": " << refit << std::endl; ss << " \"Debug\": " << debug << std::endl; ss << " \"Strict Types\": " << strict_types << std::endl; diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 1ad32b3167..a60390d1e2 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -99,6 +99,7 @@ struct CompileSpec : torch::CustomClassHolder { } ADD_ENUM_GET_SET(op_precision, DataType, static_cast(DataType::kChar)); + ADD_FIELD_GET_SET(disable_tf32, bool); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); ADD_FIELD_GET_SET(strict_types, bool); @@ -111,6 +112,7 @@ struct CompileSpec : torch::CustomClassHolder { std::vector input_ranges; DataType op_precision = DataType::kFloat; + bool disable_tf32 = false; bool refit = false; bool debug = false; bool strict_types = false; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index c071420a11..420e27cccc 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -103,6 +103,7 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("input_ranges", &CompileSpec::input_ranges) .def_readwrite("op_precision", &CompileSpec::op_precision) .def_readwrite("refit", &CompileSpec::refit) + .def_readwrite("disable_tf32", &CompileSpec::disable_tf32) .def_readwrite("debug", &CompileSpec::debug) .def_readwrite("strict_types", &CompileSpec::strict_types) .def_readwrite("device", &CompileSpec::device) diff --git a/tests/py/test_api.py b/tests/py/test_api.py index b99199036f..fe0613413b 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -20,7 +20,8 @@ def test_compile_traced(self): "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, "dla_core": 0, - "allow_gpu_fallback": False + "allow_gpu_fallback": False, + "disable_tf32": False } } @@ -35,7 +36,8 @@ def test_compile_script(self): "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, "dla_core": 0, - "allow_gpu_fallback": False + "allow_gpu_fallback": False, + "disable_tf32": False } } diff --git a/tests/py/test_to_backend_api.py b/tests/py/test_to_backend_api.py index f111694bfc..77ada08931 100644 --- a/tests/py/test_to_backend_api.py +++ b/tests/py/test_to_backend_api.py @@ -29,6 +29,7 @@ def setUp(self): "num_min_timing_iters": 2, "num_avg_timing_iters": 1, "max_batch_size": 0, + "disable_tf32": False, }) }