diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 814be63c13..2950e08803 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -122,6 +122,23 @@ def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device: return info +def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback: + info = trtorch._C.TorchFallback() + if "enabled" not in fallback_info: + raise KeyError("Enabled is required parameter") + else: + assert isinstance(fallback_info["enabled"], bool) + info.enabled = fallback_info["enabled"] + if "min_block_size" in fallback_info: + assert isinstance(fallback_info["min_block_size"], int) + info.min_block_size = fallback_info["min_block_size"] + + if "forced_fallback_operators" in fallback_info: + assert isinstance(fallback_info["forced_fallback_operators"], list) + info.forced_fallback_operators = fallback_info["forced_fallback_operators"] + + return info + def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: info = trtorch._C.CompileSpec() @@ -174,6 +191,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: assert type(compile_spec["max_batch_size"]) is int info.max_batch_size = compile_spec["max_batch_size"] + if "torch_fallback" in compile_spec: + info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"]) + + return info @@ -242,7 +263,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. d.set_dla_core(parsed_spec.device.dla_core) d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback) + torch_fallback = torch.classes.tensorrt.TorchFallback() + torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled) + torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size) + torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators) + backend_spec.set_device(d) + backend_spec.set_torch_fallback(fallback) backend_spec.set_op_precision(int(parsed_spec.op_precision)) backend_spec.set_disable_tf32(parsed_spec.disable_tf32) backend_spec.set_refit(parsed_spec.refit) diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 048e69dbe1..a591c2eadd 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -24,11 +24,19 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core); ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback); + static auto TRTORCH_UNUSED TRTFallbackTSRegistration = + torch::class_("tensorrt", "Fallback").def(torch::init<>()); + ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled); + ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size); + ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators); + + static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = torch::class_("tensorrt", "CompileSpec") .def(torch::init<>()) .def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) .def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) + .def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive) .def("__str__", &trtorch::pyapi::CompileSpec::stringify); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 54b47d9111..98f21601e9 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -107,6 +107,9 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; + info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled; + info.convert_info.engine_settings.torch_fallback.min_block_size = torch_fallback.min_block_size; + info.convert_info.engine_settings.torch_fallback.forced_fallback_operators = torch_fallback.forced_fallback_operators; info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater"); diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index a60390d1e2..98dfa345b8 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -78,6 +78,17 @@ struct Device : torch::CustomClassHolder { std::string to_str(DeviceType value); nvinfer1::DeviceType toTRTDeviceType(DeviceType value); +struct TorchFallback : torch::CustomClassHolder { + bool enabled; + int64_t min_block_size; + std::vector forced_fallback_operators; + TorchFallback() : enabled(false), min_block_size(1) {} + + ADD_FIELD_GET_SET(enabled, bool); + ADD_FIELD_GET_SET(min_block_size, int64_t); + ADD_FIELD_GET_SET(forced_fallback_operators, std::vector); +}; + enum class EngineCapability : int8_t { kDEFAULT, kSAFE_GPU, @@ -98,6 +109,10 @@ struct CompileSpec : torch::CustomClassHolder { device = *d; } + void setTorchFallbackIntrusive(const c10::intrusive_ptr &fb) { + torch_fallback = *fb; + } + ADD_ENUM_GET_SET(op_precision, DataType, static_cast(DataType::kChar)); ADD_FIELD_GET_SET(disable_tf32, bool); ADD_FIELD_GET_SET(refit, bool); @@ -109,6 +124,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(workspace_size, int64_t); ADD_FIELD_GET_SET(max_batch_size, int64_t); ADD_FIELD_GET_SET(device, Device); + ADD_FIELD_GET_SET(torch_fallback, TorchFallback); std::vector input_ranges; DataType op_precision = DataType::kFloat; @@ -117,6 +133,7 @@ struct CompileSpec : torch::CustomClassHolder { bool debug = false; bool strict_types = false; Device device; + TorchFallback torch_fallback; EngineCapability capability = EngineCapability::kDEFAULT; int64_t num_min_timing_iters = 2; int64_t num_avg_timing_iters = 1; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 418423db41..747d9dcd73 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -124,6 +124,12 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("dla_core", &Device::dla_core) .def_readwrite("allow_gpu_fallback", &Device::allow_gpu_fallback); + py::class_(m, "TorchFallback") + .def(py::init<>()) + .def_readwrite("enabled", &TorchFallback::enabled) + .def_readwrite("min_block_size", &TorchFallback::min_block_size) + .def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators); + m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT"; m.def(