From 619b9a047479d5c5a2c5a5637a0e2b256c6a82cd Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:05:58 -0800 Subject: [PATCH 1/2] fix: Allow full model compilation with collection Inputs - Allow users to specify full model compilation when using `input_signature`, which allows for complex collection-based inputs - Enable "psuedo-partitioning" phase for input collections as well as output collections - Update `OutputIsCollection` to include dictionary outputs, and add function `InputIsCollection` to detect collection-based inputs during graph compilation - Remove automatic fallback for collection pack/unpack operations when using `input_signature` argument - Add collections tests to ensure full compilation is respected for input and output collections --- core/compiler.cpp | 3 +- core/conversion/conversion.cpp | 12 +++- core/conversion/conversion.h | 2 + cpp/src/compile_spec.cpp | 21 ------- py/torch_tensorrt/ts/_compile_spec.py | 37 +----------- tests/cpp/test_collections.cpp | 62 +++++++++++++++++++ tests/py/api/test_collections.py | 87 +++++++++++++++++++++++++++ 7 files changed, 165 insertions(+), 59 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 3dd735a59e..bda4583664 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) // Determine if the block is convertible/has collection output, and based on the result, // whether full compilation can be expected auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); + auto inputIsCollection = conversion::InputIsCollection(g->block()); auto outputIsCollection = conversion::OutputIsCollection(g->block()); - auto requires_collection_handling = (isBlockConvertible && outputIsCollection); + auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection)); // Determine whether user specifications necessitate partitioning auto isFallbackRequested = userRequestedFallback(cfg); diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 940e178850..b0e8174500 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -556,10 +556,20 @@ std::set ConvertableOpsInBlock(const torch::jit::Block* b) { return convertable_ops; } +bool InputIsCollection(const torch::jit::Block* b) { + for (auto in : b->inputs()) { + if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) { + return true; + } + } + return false; +} + bool OutputIsCollection(const torch::jit::Block* b) { for (auto out : b->outputs()) { if (out->type()->kind() == torch::jit::TypeKind::TupleType || - out->type()->kind() == torch::jit::TypeKind::ListType) { + out->type()->kind() == torch::jit::TypeKind::ListType || + out->type()->kind() == torch::jit::TypeKind::DictType) { return true; } } diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index a578c4288e..4ef092a1be 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -26,6 +26,8 @@ std::string ConvertBlockToEngine( bool OpSupported(const torch::jit::Node* n); +bool InputIsCollection(const torch::jit::Block* b); + bool OutputIsCollection(const torch::jit::Block* b); bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false); diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 0fe56265e7..1954827893 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) { LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change"); to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature); torchtrt::core::CompileSpec internal(converted_input_signature); - - TORCHTRT_CHECK( - !external.require_full_compilation, - "Grouped inputs currently requires partial compilation to be enabled, \ - this restriction will be relaxed in a future release"); - - LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature"); - LOG_DEBUG( - "Adding the following ops to torch_executed_ops:" << std::endl - << " - aten::__getitem__" << std::endl - << " - prim::ListConstruct" << std::endl - << " - prim::ListUnpack" << std::endl - << " - prim::TupleIndex" << std::endl - << " - prim::TupleConstruct" << std::endl - << " - prim::TupleUnpack"); - external.torch_executed_ops.push_back("aten::__getitem__"); - external.torch_executed_ops.push_back("prim::ListConstruct"); - external.torch_executed_ops.push_back("prim::ListUnpack"); - external.torch_executed_ops.push_back("prim::TupleIndex"); - external.torch_executed_ops.push_back("prim::TupleConstruct"); - external.torch_executed_ops.push_back("prim::TupleUnpack"); return internal; } } diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 0e11d3bcd3..8f06e2ef71 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: "Input signature parsing is an experimental feature, behavior and APIs may change", ) signature = _parse_input_signature(compile_spec["input_signature"]) - info.input_signature = _C.InputSignature(signature) # py_object - - if not compile_spec["torch_fallback"]["enabled"]: - raise ValueError( - "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" - ) - - log( - Level.Debug, - "Grouped inputs currently requires additional settings to enable the feature", - ) - log( - Level.Debug, - """Adding the following ops to torch_executed_ops: - - aten::__getitem__ - - prim::ListConstruct - - prim::ListUnpack - - prim::TupleIndex - - prim::TupleConstruct - - prim::TupleUnpack -""", - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "aten::__getitem__" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::ListConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") - compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleConstruct" - ) - compile_spec["torch_fallback"]["forced_fallback_ops"].append( - "prim::TupleUnpack" - ) + info.input_signature = _C.InputSignature(signature) else: raise KeyError( diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index cbca9c7b98..943119977b 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) { ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor())); } + +TEST(CppAPITests, TestCollectionFullCompilationComplexModel) { + std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + compile_settings.require_full_compilation = true; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor())); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual( + out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor())); +} diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 12c1ac9f50..64f46fa3e9 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -194,6 +194,34 @@ def test_compile(self): msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + (torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)), + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + class TestListInputOutput(unittest.TestCase): def test_compile(self): @@ -225,6 +253,36 @@ def test_compile(self): msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + [torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)], + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + class TestListInputTupleOutput(unittest.TestCase): def test_compile(self): @@ -255,6 +313,35 @@ def test_compile(self): msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_compile_full_compilation(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "input_signature": ( + [torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)], + ), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "require_full_compilation": True, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + for (t, p) in zip(trt_out, pyt_out): + cos_sim = cosine_similarity(t, p) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + if __name__ == "__main__": unittest.main() From 985f6a2e24de7df1bb9235bd849cc0c098764348 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 28 Mar 2023 20:41:07 -0700 Subject: [PATCH 2/2] chore: Add samples of `input_signature` usage to docs - Add documentation to `README` for usage of input signature - Add documentation to "Getting Started" page for usage of input signature --- README.md | 7 +++++ .../getting_started_with_python_api.rst | 30 +++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bc46646d70..bb3aa97b7e 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ import torch_tensorrt ... trt_ts_module = torch_tensorrt.compile(torch_script_module, + # If the inputs to the module are plain Tensors, specify them via the `inputs` argument: inputs = [example_tensor, # Provide example tensor for input shape or... torch_tensorrt.Input( # Specify input object with shape and dtype min_shape=[1, 3, 224, 224], @@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module, # For static size shape=[1, 3, 224, 224] dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool) ], + + # For inputs containing tuples or lists of tensors, use the `input_signature` argument: + # Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor]) + # input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half), + # torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ), + enabled_precisions = {torch.half}, # Run with FP16 ) diff --git a/docsrc/getting_started/getting_started_with_python_api.rst b/docsrc/getting_started/getting_started_with_python_api.rst index fece176156..9ca679d602 100644 --- a/docsrc/getting_started/getting_started_with_python_api.rst +++ b/docsrc/getting_started/getting_started_with_python_api.rst @@ -14,8 +14,7 @@ If given a ``torch.nn.Module`` and the ``ir`` flag is set to either ``default`` To compile your input ``torch.nn.Module`` with Torch-TensorRT, all you need to do is provide the module and inputs to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs -is a list of ``torch_tensorrt.Input`` classes which define input's shape, datatype and memory format. You can also specify settings such as -operating precision for the engine or target device. After compilation you can save the module just like any other module +is a list of ``torch_tensorrt.Input`` classes which define input Tensors' shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the ``input_signature`` argument to specify a collection-based input, such as ``(List[Tensor], Tuple[Tensor, Tensor])``. See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``torch_tensorrt``. .. code-block:: python @@ -44,6 +43,32 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod result = trt_ts_module(input_data) torch.jit.save(trt_ts_module, "trt_ts_module.ts") +.. code-block:: python + + # Sample using collection-based inputs via the input_signature argument + import torch_tensorrt + + ... + + model = MyModel().eval() + + # input_signature expects a tuple of individual input arguments to the module + # The module below, for example, would have a docstring of the form: + # def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor]) + input_signature = ( + [torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)], + (torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)), + ) + enabled_precisions = {torch.float, torch.half} + + trt_ts_module = torch_tensorrt.compile( + model, input_signature=input_signature, enabled_precisions=enabled_precisions + ) + + input_data = input_data.to("cuda").half() + result = trt_ts_module(input_data) + torch.jit.save(trt_ts_module, "trt_ts_module.ts") + .. code-block:: python # Deployment application @@ -55,4 +80,3 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod result = trt_ts_module(input_data) Torch-TensorRT Python API also provides ``torch_tensorrt.ts.compile`` which accepts a TorchScript module as input and ``torch_tensorrt.fx.compile`` which accepts a FX GraphModule as input. -