Skip to content

Commit

Permalink
feat: support Python APIs for Automatic Fallback
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <[email protected]>
  • Loading branch information
bowang007 committed Mar 11, 2021
1 parent 6d3064a commit 100b090
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 0 deletions.
27 changes: 27 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:

return info

def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback:
info = trtorch._C.TorchFallback()
if "enabled" not in fallback_info:
raise KeyError("Enabled is required parameter")
else:
assert isinstance(fallback_info["enabled"], bool)
info.enabled = fallback_info["enabled"]
if "min_block_size" in fallback_info:
assert isinstance(fallback_info["min_block_size"], int)
info.min_block_size = fallback_info["min_block_size"]

if "forced_fallback_operators" in fallback_info:
assert isinstance(fallback_info["forced_fallback_operators"], list)
info.forced_fallback_operators = fallback_info["forced_fallback_operators"]

return info


def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
Expand Down Expand Up @@ -174,6 +191,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
assert type(compile_spec["max_batch_size"]) is int
info.max_batch_size = compile_spec["max_batch_size"]

if "torch_fallback" in compile_spec:
info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"])


return info


Expand Down Expand Up @@ -242,7 +263,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
d.set_dla_core(parsed_spec.device.dla_core)
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)

torch_fallback = torch.classes.tensorrt.TorchFallback()
torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled)
torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size)
torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)

backend_spec.set_device(d)
backend_spec.set_torch_fallback(fallback)
backend_spec.set_op_precision(int(parsed_spec.op_precision))
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
backend_spec.set_refit(parsed_spec.refit)
Expand Down
8 changes: 8 additions & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);

static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "Fallback").def(torch::init<>());
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);


static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
3 changes: 3 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
info.convert_info.engine_settings.device.dla_core = device.dla_core;
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled;
info.convert_info.engine_settings.torch_fallback.min_block_size = torch_fallback.min_block_size;
info.convert_info.engine_settings.torch_fallback.forced_fallback_operators = torch_fallback.forced_fallback_operators;

info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
Expand Down
17 changes: 17 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ struct Device : torch::CustomClassHolder {
std::string to_str(DeviceType value);
nvinfer1::DeviceType toTRTDeviceType(DeviceType value);

struct TorchFallback : torch::CustomClassHolder {
bool enabled;
int64_t min_block_size;
std::vector<std::string> forced_fallback_operators;
TorchFallback() : enabled(false), min_block_size(1) {}

ADD_FIELD_GET_SET(enabled, bool);
ADD_FIELD_GET_SET(min_block_size, int64_t);
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
};

enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
Expand All @@ -98,6 +109,10 @@ struct CompileSpec : torch::CustomClassHolder {
device = *d;
}

void setTorchFallbackIntrusive(const c10::intrusive_ptr<TorchFallback> &fb) {
torch_fallback = *fb;
}

ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
ADD_FIELD_GET_SET(disable_tf32, bool);
ADD_FIELD_GET_SET(refit, bool);
Expand All @@ -109,6 +124,7 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(workspace_size, int64_t);
ADD_FIELD_GET_SET(max_batch_size, int64_t);
ADD_FIELD_GET_SET(device, Device);
ADD_FIELD_GET_SET(torch_fallback, TorchFallback);

std::vector<InputRange> input_ranges;
DataType op_precision = DataType::kFloat;
Expand All @@ -117,6 +133,7 @@ struct CompileSpec : torch::CustomClassHolder {
bool debug = false;
bool strict_types = false;
Device device;
TorchFallback torch_fallback;
EngineCapability capability = EngineCapability::kDEFAULT;
int64_t num_min_timing_iters = 2;
int64_t num_avg_timing_iters = 1;
Expand Down
6 changes: 6 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("dla_core", &Device::dla_core)
.def_readwrite("allow_gpu_fallback", &Device::allow_gpu_fallback);

py::class_<TorchFallback>(m, "TorchFallback")
.def(py::init<>())
.def_readwrite("enabled", &TorchFallback::enabled)
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);

m.doc() =
"TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
m.def(
Expand Down

0 comments on commit 100b090

Please sign in to comment.