Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow full model compilation with collection inputs (input_signature) #1656

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import torch_tensorrt
...

trt_ts_module = torch_tensorrt.compile(torch_script_module,
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
inputs = [example_tensor, # Provide example tensor for input shape or...
torch_tensorrt.Input( # Specify input object with shape and dtype
min_shape=[1, 3, 224, 224],
Expand All @@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
# For static size shape=[1, 3, 224, 224]
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
],

# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),

enabled_precisions = {torch.half}, # Run with FP16
)

Expand Down
3 changes: 2 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
// Determine if the block is convertible/has collection output, and based on the result,
// whether full compilation can be expected
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto inputIsCollection = conversion::InputIsCollection(g->block());
auto outputIsCollection = conversion::OutputIsCollection(g->block());
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));

// Determine whether user specifications necessitate partitioning
auto isFallbackRequested = userRequestedFallback(cfg);
Expand Down
12 changes: 11 additions & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool InputIsCollection(const torch::jit::Block* b) {
for (auto in : b->inputs()) {
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool OutputIsCollection(const torch::jit::Block* b) {
for (auto out : b->outputs()) {
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
out->type()->kind() == torch::jit::TypeKind::ListType) {
out->type()->kind() == torch::jit::TypeKind::ListType ||
out->type()->kind() == torch::jit::TypeKind::DictType) {
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool InputIsCollection(const torch::jit::Block* b);

bool OutputIsCollection(const torch::jit::Block* b);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);
Expand Down
21 changes: 0 additions & 21 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change");
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
torchtrt::core::CompileSpec internal(converted_input_signature);

TORCHTRT_CHECK(
!external.require_full_compilation,
"Grouped inputs currently requires partial compilation to be enabled, \
this restriction will be relaxed in a future release");

LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
LOG_DEBUG(
"Adding the following ops to torch_executed_ops:" << std::endl
<< " - aten::__getitem__" << std::endl
<< " - prim::ListConstruct" << std::endl
<< " - prim::ListUnpack" << std::endl
<< " - prim::TupleIndex" << std::endl
<< " - prim::TupleConstruct" << std::endl
<< " - prim::TupleUnpack");
external.torch_executed_ops.push_back("aten::__getitem__");
external.torch_executed_ops.push_back("prim::ListConstruct");
external.torch_executed_ops.push_back("prim::ListUnpack");
external.torch_executed_ops.push_back("prim::TupleIndex");
external.torch_executed_ops.push_back("prim::TupleConstruct");
external.torch_executed_ops.push_back("prim::TupleUnpack");
return internal;
}
}
Expand Down
30 changes: 27 additions & 3 deletions docsrc/getting_started/getting_started_with_python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ If given a ``torch.nn.Module`` and the ``ir`` flag is set to either ``default``

To compile your input ``torch.nn.Module`` with Torch-TensorRT, all you need to do is provide the module and inputs
to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs
is a list of ``torch_tensorrt.Input`` classes which define input's shape, datatype and memory format. You can also specify settings such as
operating precision for the engine or target device. After compilation you can save the module just like any other module
is a list of ``torch_tensorrt.Input`` classes which define input Tensors' shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the ``input_signature`` argument to specify a collection-based input, such as ``(List[Tensor], Tuple[Tensor, Tensor])``. See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module
to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``torch_tensorrt``.

.. code-block:: python
Expand Down Expand Up @@ -44,6 +43,32 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")

.. code-block:: python

# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt

...

model = MyModel().eval()

# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
[torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
(torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}

trt_ts_module = torch_tensorrt.compile(
model, input_signature=input_signature, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")

.. code-block:: python

# Deployment application
Expand All @@ -55,4 +80,3 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
result = trt_ts_module(input_data)

Torch-TensorRT Python API also provides ``torch_tensorrt.ts.compile`` which accepts a TorchScript module as input and ``torch_tensorrt.fx.compile`` which accepts a FX GraphModule as input.

37 changes: 1 addition & 36 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
"Input signature parsing is an experimental feature, behavior and APIs may change",
)
signature = _parse_input_signature(compile_spec["input_signature"])
info.input_signature = _C.InputSignature(signature) # py_object

if not compile_spec["torch_fallback"]["enabled"]:
raise ValueError(
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
)

log(
Level.Debug,
"Grouped inputs currently requires additional settings to enable the feature",
)
log(
Level.Debug,
"""Adding the following ops to torch_executed_ops:
- aten::__getitem__
- prim::ListConstruct
- prim::ListUnpack
- prim::TupleIndex
- prim::TupleConstruct
- prim::TupleUnpack
""",
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"aten::__getitem__"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::ListConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleConstruct"
)
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
"prim::TupleUnpack"
)
info.input_signature = _C.InputSignature(signature)

else:
raise KeyError(
Expand Down
62 changes: 62 additions & 0 deletions tests/cpp/test_collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) {
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
}

TEST(CppAPITests, TestCollectionFullCompilationComplexModel) {
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);

torch::jit::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
}
mod.eval();
mod.to(torch::kCUDA);

std::vector<torch::jit::IValue> inputs_;

for (auto in : inputs) {
inputs_.push_back(torch::jit::IValue(in.clone()));
}

std::vector<torch::jit::IValue> complex_inputs;
auto input_list = c10::impl::GenericList(c10::TensorType::get());
input_list.push_back(inputs_[0]);
input_list.push_back(inputs_[0]);

torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);

complex_inputs.push_back(input_list_ivalue);

auto out = mod.forward(complex_inputs);

auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);

auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));

c10::TypePtr elementType = input_shape_ivalue.type();
auto list = c10::impl::GenericList(elementType);
list.push_back(input_shape_ivalue);
list.push_back(input_shape_ivalue);

torch::jit::IValue complex_input_shape(list);
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
torch::jit::IValue complex_input_shape2(input_tuple2);

auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.min_block_size = 1;
compile_settings.require_full_compilation = true;

// // FP16 execution
compile_settings.enabled_precisions = {torch::kHalf};
// // Compile module
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
auto trt_out = trt_mod.forward(complex_inputs);

ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor()));
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
}
87 changes: 87 additions & 0 deletions tests/py/api/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@ def test_compile(self):
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
(torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))
for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


class TestListInputOutput(unittest.TestCase):
def test_compile(self):
Expand Down Expand Up @@ -225,6 +253,36 @@ def test_compile(self):
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))

for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


class TestListInputTupleOutput(unittest.TestCase):
def test_compile(self):
Expand Down Expand Up @@ -255,6 +313,35 @@ def test_compile(self):
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_compile_full_compilation(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = (
torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt")
.eval()
.to("cuda")
)

compile_spec = {
"input_signature": (
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
),
"device": torchtrt.Device("gpu:0"),
"enabled_precisions": {torch.float},
"min_block_size": 1,
"require_full_compilation": True,
}

trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
trt_out = trt_mod((self.input, self.input))
pyt_out = self.model((self.input, self.input))
for (t, p) in zip(trt_out, pyt_out):
cos_sim = cosine_similarity(t, p)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


if __name__ == "__main__":
unittest.main()