Skip to content

Commit

Permalink
fix(to_backend): Clean up to_backend implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 20, 2021
1 parent a5bc3b0 commit 4e15605
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,29 @@ c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValu

c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
auto mod = mod_val.toModule();
mod = core::lowering::LowerModule(mod);

auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
core::lowering::LowerInfo lower_info;
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
const auto& method_name = it->key();
auto method = mod.get_method(method_name);
auto graph = method.graph();
core::lowering::LowerGraph(graph, lower_info);
}

auto handles = c10::impl::GenericDict(
c10::StringType::get(), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());

for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
auto mod_ = mod.clone();
const auto& method_name = it->key();
auto method = mod.get_method(method_name);
auto g = method.graph();

auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
LOG_DEBUG(raw_spec->stringify());
auto cfg = raw_spec->toInternalCompileSpec();
auto convert_cfg = std::move(cfg.convert_info);
auto graph_and_ivalues = torch::jit::LowerGraph(*g, mod._ivalue());
auto graph_and_ivals = Lower(mod_, method_name, cfg.lower_info);

g = graph_and_ivalues.first;
auto params = graph_and_ivalues.second;
auto g = graph_and_ivals.first;
auto params = graph_and_ivals.second;
auto named_params = core::conversion::get_named_params(g->inputs(), params);

auto convert_cfg = std::move(cfg.convert_info);
auto device_spec = convert_cfg.engine_settings.device;
auto device = core::runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine, device);
handles.insert(method.name(), at::IValue(engine_handle));
handles.insert(method_name, at::IValue(engine_handle));
}

return c10::impl::toGenericDict(handles);
Expand Down

0 comments on commit 4e15605

Please sign in to comment.