Skip to content

Commit

Permalink
Add partition logic and torch backend integration
Browse files Browse the repository at this point in the history
Signed-off-by: Arvind Sridhar <[email protected]>
  • Loading branch information
ArvindSridhar authored and narendasan committed Aug 20, 2021
1 parent b4feb49 commit b96087b
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void AddEngineToGraph(
}

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
// Go through Lowering to simplify graph
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name, lowering::LowerInfo());

auto g = graph_and_parameters.first;
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CompileSpec {
partitioning::PartitionInfo partition_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);

Expand Down
3 changes: 2 additions & 1 deletion core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
}

std::string node_string(n->kind().toQualString());
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) && (!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
tensorrt_nodes.push_back(n);
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
Expand Down
2 changes: 1 addition & 1 deletion cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ TRTORCH_API void dump_build_info();
*
* @returns bool: Method is supported by TRTorch
*/
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name, CompileSpec info);
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, std::string method_name);

/**
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down
4 changes: 2 additions & 2 deletions cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace trtorch {
core::CompileSpec to_internal_compile_spec(CompileSpec external);
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device);

bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name, CompileSpec info) {
return core::CheckMethodOperatorSupport(module, method_name, to_internal_compile_spec(info));
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
return core::CheckMethodOperatorSupport(module, method_name);
}

std::string ConvertGraphToTRTEngine(
Expand Down
5 changes: 5 additions & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
assert isinstance(fallback_info["forced_fallback_ops"], list)
info.forced_fallback_operators = fallback_info["forced_fallback_ops"]

if "forced_fallback_modules" in fallback_info:
assert isinstance(fallback_info["forced_fallback_modules"], list)
info.forced_fallback_modules = fallback_info["forced_fallback_modules"]

return info


Expand Down Expand Up @@ -338,6 +342,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
torch_fallback._set_min_block_size(parsed_spec.torch_fallback.min_block_size)
torch_fallback._set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)
torch_fallback._set_forced_fallback_modules(parsed_spec.torch_fallback.forced_fallback_modules)

backend_spec._set_device(d)
backend_spec._set_torch_fallback(torch_fallback)
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_modules);

static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
Expand Down
6 changes: 6 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ std::string TorchFallback::to_str() {
ss << " " << i << ',' << std::endl;
}
ss << " ]" << std::endl;
ss << " \"forced_fallback_modules\": [" << std::endl;
for (auto i : forced_fallback_modules) {
ss << " " << i << ',' << std::endl;
}
ss << " ]" << std::endl;
ss << " }" << std::endl;
return ss.str();
}
Expand Down Expand Up @@ -203,6 +208,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.partition_info.enabled = torch_fallback.enabled;
info.partition_info.min_block_size = torch_fallback.min_block_size;
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;

info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
Expand Down
2 changes: 2 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ struct TorchFallback : torch::CustomClassHolder {
bool enabled;
int64_t min_block_size;
std::vector<std::string> forced_fallback_operators;
std::vector<std::string> forced_fallback_modules;
TorchFallback() : enabled(false), min_block_size(1) {}

ADD_FIELD_GET_SET(enabled, bool);
ADD_FIELD_GET_SET(min_block_size, int64_t);
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
ADD_FIELD_GET_SET(forced_fallback_modules, std::vector<std::string>);

std::string to_str();
};
Expand Down
3 changes: 2 additions & 1 deletion py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ PYBIND11_MODULE(_C, m) {
.def("__str__", &trtorch::pyapi::TorchFallback::to_str)
.def_readwrite("enabled", &TorchFallback::enabled)
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators)
.def_readwrite("forced_fallback_modules", &TorchFallback::forced_fallback_modules);

m.doc() =
"TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
Expand Down

0 comments on commit b96087b

Please sign in to comment.