Skip to content

Commit

Permalink
feat: [collection] update python api, refactor code
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <[email protected]>
  • Loading branch information
inocsin committed Apr 6, 2022
1 parent b91c1c9 commit a206336
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 74 deletions.
8 changes: 2 additions & 6 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ namespace torch_tensorrt {
namespace core {

struct CompileSpec {
CompileSpec(std::vector<ir::Input> inputs) {
graph_inputs.inputs = inputs;
}
CompileSpec(torch::jit::IValue& input_signature) {
graph_inputs.input_signature = input_signature;
}
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
Expand Down
3 changes: 2 additions & 1 deletion core/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
srcs = [
"ir.cpp",
"Input.cpp",
"StaticParams.cpp"
"StaticParams.cpp",
"GraphInputs.cpp"
],
deps = [
"@tensorrt//:nvinfer",
Expand Down
75 changes: 75 additions & 0 deletions core/ir/GraphInputs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "core/ir/ir.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace ir {

void flatten_dfs(std::vector<torch_tensorrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torch_tensorrt::core::ir::Input>>& collection_inputs,
torch::jit::IValue input_ivalue, int level, int index) {
if (input_ivalue.isTuple()) {
auto input_tuple = input_ivalue.toTuple();
int idx = 0;
if (level == 0) {
collection_inputs.resize(input_tuple->elements().size());
}
for (auto item: input_tuple->elements()) {
torch::jit::IValue converted_item;
int cur_idx = level < 1 ? idx: index;
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
idx++;
}
} else if(input_ivalue.isList()) {
auto input_list = input_ivalue.toList().vec();
if (level == 0) {
collection_inputs.resize(input_list.size());
}
c10::TypePtr type = input_list[0].type();
auto converted_elements = c10::impl::GenericList(type);
int idx = 0;
for (auto item: input_list) {
int cur_idx = level < 1 ? idx: index;
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
idx++;
}
} else if(input_ivalue.isCustomClass()) {
torch_tensorrt::core::ir::Input cur_input = *(input_ivalue.toCustomClass<torch_tensorrt::core::ir::Input>());
flattened_inputs.push_back(cur_input);
if (level == 0) { // a single value like A
collection_inputs.resize(1);
collection_inputs[0].push_back(cur_input);
} else if (level == 1) { // like A in [A, A] or [(B, B), A]
collection_inputs[index].push_back(cur_input);
} else if (level == 2) { // like A in [(A, A), C]
collection_inputs[index].push_back(cur_input);
} else {// only support 2 level
LOG_ERROR("Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]");
}
}
}


GraphInputs::GraphInputs(std::vector<ir::Input> inputs_) {
LOG_DEBUG("Construct GraphInput with ir::Input");
inputs = inputs_;
collection_inputs.resize(inputs_.size());
for (int i = 0; i < inputs_.size(); i++) {
collection_inputs[i].push_back(inputs_[i]);
}
}

GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
LOG_DEBUG("Construct GraphInput with IValue");

std::vector<torch_tensorrt::core::ir::Input> flattened_inputs;
std::vector<std::vector<torch_tensorrt::core::ir::Input>> collection_inputs_;

flatten_dfs(flattened_inputs, collection_inputs_, input_signature_, 0, 0);
inputs = flattened_inputs;
input_signature = input_signature_;
collection_inputs = collection_inputs_;
}

} // namespace ir
} // namespace core
} // namespace torch_tensorrt
2 changes: 2 additions & 0 deletions core/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct Input : torch::CustomClassHolder {

// Add to spec
struct GraphInputs {
GraphInputs(std::vector<ir::Input> inputs);
GraphInputs(torch::jit::IValue& input_signature);
torch::jit::IValue input_signature; // nested Input, full input spec
std::vector<Input> inputs; // flattend Input
std::vector<std::vector<Input>> collection_inputs; // only support two layer nesting, e.g. ((a, b), [c, d], e)
Expand Down
67 changes: 13 additions & 54 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,88 +38,47 @@ CompileSpec::CompileSpec(torch::jit::IValue input_signature) {



void flatten_dfs(std::vector<torchtrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torchtrt::core::ir::Input>>& collection_inputs,
torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int level, int index) {
void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) {
if (input_ivalue.isTuple()) {
auto input_tuple = input_ivalue.toTuple();
std::vector<torch::jit::IValue> converted_elements;
int idx = 0;
if (level == 0) {
collection_inputs.resize(input_tuple->elements().size());
}
for (auto item: input_tuple->elements()) {
torch::jit::IValue converted_item;
int cur_idx = level < 1 ? idx: index;
flatten_dfs(flattened_inputs, collection_inputs, item, converted_item, level+1, cur_idx);
to_internal_input_signature(item, converted_item);
converted_elements.push_back(converted_item);
auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements);
converted_ivalue = torch::jit::IValue(tuple_ptr);
idx++;
}
} else if(input_ivalue.isList()) {
auto input_list = input_ivalue.toList().vec();
if (level == 0) {
collection_inputs.resize(input_list.size());
}
c10::TypePtr type = input_list[0].type();
auto converted_elements = c10::impl::GenericList(type);
int idx = 0;
for (auto item: input_list) {
int cur_idx = level < 1 ? idx: index;
torch::jit::IValue converted_item;
flatten_dfs(flattened_inputs, collection_inputs, item, converted_item, level+1, cur_idx);
to_internal_input_signature(item, converted_item);
converted_elements.push_back(converted_item);
idx++;
}
converted_ivalue = torch::jit::IValue(converted_elements);
} else if(input_ivalue.isCustomClass()) {
torchtrt::core::ir::Input cur_input = to_internal_input(*(input_ivalue.toCustomClass<torchtrt::Input>()));
flattened_inputs.push_back(cur_input);
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::core::ir::Input>(cur_input)));
if (level == 0) { // a single value like A
collection_inputs.resize(1);
collection_inputs[0].push_back(cur_input);
} else if (level == 1) { // like A in [A, A] or [(B, B), A]
collection_inputs[index].push_back(cur_input);
} else if (level == 2) { // like A in [(A, A), C]
collection_inputs[index].push_back(cur_input);
} else {// only support 2 level
LOG_ERROR("Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]");
}
}
}


torch_tensorrt::core::ir::GraphInputs to_internal_graph_inputs(GraphInputs external_graph_input) {
torch_tensorrt::core::ir::GraphInputs internal_graph_input;

std::vector<torchtrt::core::ir::Input> flattened_inputs;
std::vector<std::vector<torchtrt::core::ir::Input>> collection_inputs;

torchtrt::core::CompileSpec init_compile_spec(CompileSpec external) {
if (external.graph_inputs.inputs.size() > 0) {
torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs));
return internal;
} else {
torch::jit::IValue converted_input_signature;
flatten_dfs(flattened_inputs, collection_inputs, external_graph_input.input_signature, converted_input_signature, 0, 0);
internal_graph_input.inputs = flattened_inputs;
internal_graph_input.input_signature = converted_input_signature;
internal_graph_input.collection_inputs = collection_inputs;

LOG_DEBUG("Convert external_graph_input to internal_graph_inputs, total input input spec number: " << flattened_inputs.size() << ", top level input spec number "<< collection_inputs.size());

return internal_graph_input;
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
torchtrt::core::CompileSpec internal(converted_input_signature);
return internal;
}
}

torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs));
if (internal.graph_inputs.inputs.size() == 0) {
LOG_DEBUG("GraphInput.inputs size == 0, using GraphInput.input_signature to get Input spec");
internal.graph_inputs = to_internal_graph_inputs(external.graph_inputs);
} else {
LOG_DEBUG("GraphInput.inputs size != 0, using GraphInput.inputs to get Input spec");
internal.graph_inputs.collection_inputs.resize(internal.graph_inputs.inputs.size());
for (int i = 0; i < internal.graph_inputs.inputs.size(); i++) {
internal.graph_inputs.collection_inputs[i].push_back(internal.graph_inputs.inputs[i]);
}
}

torchtrt::core::CompileSpec internal = init_compile_spec(external);

for (auto p : external.enabled_precisions) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, input_is_dynamic);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, explicit_set_dtype);

static auto TORCHTRT_UNUSED TRTGraphInpuTSRegistration =
torch::class_<torch_tensorrt::pyapi::GraphInputs>("tensorrt", "_GraphInputs")
.def(torch::init<>())
.def("__str__", &torch_tensorrt::pyapi::GraphInputs::to_str);

ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::GraphInputs, input_signature);

static auto TORCHTRT_UNUSED TRTDeviceTSRegistration =
torch::class_<torch_tensorrt::pyapi::Device>("tensorrt", "_Device")
.def(torch::init<>())
Expand Down
53 changes: 48 additions & 5 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ std::string Input::to_str() {
return ss.str();
}

std::string GraphInputs::to_str() {
std::stringstream ss;
return ss.str();
}

std::string to_str(DeviceType value) {
switch (value) {
case DeviceType::kDLA:
Expand Down Expand Up @@ -184,13 +189,51 @@ std::string TorchFallback::to_str() {
return ss.str();
}

core::CompileSpec CompileSpec::toInternalCompileSpec() {
std::vector<core::ir::Input> internal_inputs;
for (auto i : inputs) {
internal_inputs.push_back(i.toInternalInput());
void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) {
if (input_ivalue.isTuple()) {
auto input_tuple = input_ivalue.toTuple();
std::vector<torch::jit::IValue> converted_elements;
for (auto item: input_tuple->elements()) {
torch::jit::IValue converted_item;
to_internal_input_signature(item, converted_item);
converted_elements.push_back(converted_item);
auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements);
converted_ivalue = torch::jit::IValue(tuple_ptr);
}
} else if(input_ivalue.isList()) {
auto input_list = input_ivalue.toList().vec();
c10::TypePtr type = input_list[0].type();
auto converted_elements = c10::impl::GenericList(type);
for (auto item: input_list) {
torch::jit::IValue converted_item;
to_internal_input_signature(item, converted_item);
converted_elements.push_back(converted_item);
}
converted_ivalue = torch::jit::IValue(converted_elements);
} else if(input_ivalue.isCustomClass()) {
core::ir::Input cur_input = (*(input_ivalue.toCustomClass<Input>())).toInternalInput();
converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<core::ir::Input>(cur_input)));
}
}

core::CompileSpec init_compile_spec(CompileSpec external) {
if (external.graph_inputs.inputs.size() > 0) {
std::vector<core::ir::Input> internal_inputs;
for (auto i : external.graph_inputs.inputs) {
internal_inputs.push_back(i.toInternalInput());
}
core::CompileSpec internal(internal_inputs);
return internal;
} else {
torch::jit::IValue converted_input_signature;
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
core::CompileSpec internal(converted_input_signature);
return internal;
}
}

auto info = core::CompileSpec(internal_inputs);
core::CompileSpec CompileSpec::toInternalCompileSpec() {
core::CompileSpec info = init_compile_spec(*this);

for (auto p : enabled_precisions) {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ struct Input : torch::CustomClassHolder {
std::string to_str();
};

struct GraphInputs : torch::CustomClassHolder {
torch::jit::IValue input_signature; // nested Input, full input spec
std::vector<Input> inputs; // flatten input spec
ADD_FIELD_GET_SET(input_signature, torch::jit::IValue);
std::string to_str();
};

enum DeviceType : int8_t {
kGPU,
kDLA,
Expand Down Expand Up @@ -156,6 +163,7 @@ struct CompileSpec : torch::CustomClassHolder {
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);

std::vector<Input> inputs;
GraphInputs graph_inputs;
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
std::set<DataType> enabled_precisions = {};
bool sparse_weights = false;
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ PYBIND11_MODULE(_C, m) {
.def_readwrite("dtype", &Input::dtype)
.def_readwrite("format", &Input::format);

py::class_<GraphInputs>(m, "GraphInputs")
.def(py::init<>())
.def("__str__", &torch_tensorrt::pyapi::GraphInputs::to_str)
.def_readwrite("input_signature", &GraphInputs::input_signature)
.def_readwrite("inputs", &GraphInputs::inputs);

py::enum_<DataType>(m, "dtype", "Enum to specifiy operating precision for engine execution")
.value("float", DataType::kFloat, "32 bit floating point number")
.value("float32", DataType::kFloat, "32 bit floating point number")
Expand Down Expand Up @@ -292,6 +298,7 @@ PYBIND11_MODULE(_C, m) {
.def("__str__", &torch_tensorrt::pyapi::CompileSpec::stringify)
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
.def_readwrite("inputs", &CompileSpec::inputs)
.def_readwrite("graph_inputs", &CompileSpec::graph_inputs)
.def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
.def_readwrite("refit", &CompileSpec::refit)
Expand Down
39 changes: 31 additions & 8 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_tensorrt import _enums
from torch_tensorrt._Input import Input
from torch_tensorrt._Device import Device

from typing import Tuple, List, Dict
import warnings


Expand Down Expand Up @@ -156,6 +156,24 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:

return info

def _parse_collection_input(input_signature: Any) -> _C.GraphInputs.input_signature:
if isinstance(input_signature, tuple):
input_list = []
for item in input_signature:
input = _parse_collection_input(item)
input_list.append(input)
return tuple(input_list)
elif isinstance(input_signature, list):
input_list = []
for item in input_signature:
input = _parse_collection_input(item)
input_list.append(input)
return input_list
elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor):
input = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature
return input._to_internal()
else:
raise KeyError("Invalid Input spec")

def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
info = _ts_C.CompileSpec()
Expand All @@ -165,14 +183,19 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
)

if "inputs" in compile_spec:
if not all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]):
raise KeyError("Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format(
[type(i) for i in compile_spec["inputs"]]))

inputs = [Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
info.inputs = [i._to_internal() for i in inputs]
# if not all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]):
# raise KeyError("Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format(
# [type(i) for i in compile_spec["inputs"]]))

if isinstance(compile_spec["inputs"], list) and all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]):
inputs = [Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
# from python Input to torch_tensorrt::pyapi::Input
# info.inputs = [i._to_internal() for i in inputs]
info.graph_inputs.inputs = [i._to_internal() for i in inputs]
else:
info.graph_inputs.input_signature = _parse_collection_input(compile_spec["inputs"])

assert (len(info.inputs) > 0), "Require at least one input definition to compile model"
assert (len(info.graph_inputs.inputs) > 0), "Require at least one input definition to compile model"

if "enabled_precisions" in compile_spec:
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])
Expand Down
Loading

0 comments on commit a206336

Please sign in to comment.