diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index f7ed5c89d2..8587885eca 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -23,6 +23,7 @@ struct BuilderSettings { bool refit = false; bool debug = false; bool truncate_long_and_double = false; + bool allow_shape_tensors = false; ir::Device device; nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD; nvinfer1::IInt8Calibrator* calibrator = nullptr; diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index b72320b8da..5ce3d02978 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -270,7 +270,12 @@ auto aten_registrations TORCHTRT_UNUSED = if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); if (ctx->input_is_dynamic) { - return dynamic_size_layer(ctx, n, args); + if (ctx->settings.allow_shape_tensors) { + return dynamic_size_layer(ctx, n, args); + } else { + LOG_WARNING( + "There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors"); + } } return util::toVec(tensor->getDimensions()); } else if (tensor_var.IValue()->isTensor()) { @@ -286,7 +291,12 @@ auto aten_registrations TORCHTRT_UNUSED = auto dim = args.at(n->input(1)).unwrapToInt(); if (tensor_var.isITensor()) { if (ctx->input_is_dynamic) { - return dynamic_size_layer(ctx, n, args); + if (ctx->settings.allow_shape_tensors) { + return dynamic_size_layer(ctx, n, args); + } else { + LOG_WARNING( + "There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors"); + } } auto tensor = tensor_var.ITensor(); auto dims = util::toVec(tensor->getDimensions()); @@ -605,7 +615,8 @@ auto aten_registrations TORCHTRT_UNUSED = .evaluator( {c10::Symbol::fromQualString("aten::numel"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { - LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel"); + LOG_WARNING( + "There may be undefined behavior using dynamic shape and aten::numel without setting allow_shape_tensors"); auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 114689036f..0a0b97cfe1 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -32,7 +32,9 @@ nvinfer1::ITensor* index_layer( c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { LOG_DEBUG("Using dynamic version of aten::size evaluator"); auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); - LOG_DEBUG("Input dimensions: " << in->getDimensions()); + auto input_dims = in->getDimensions(); + LOG_DEBUG("Input dimensions: " << input_dims); + auto shape_layer = ctx->net->addShape(*in); TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); auto shape_1d_tensor = shape_layer->getOutput(0); @@ -44,15 +46,31 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw dim = dim < 0 ? dim + maxDim : dim; LOG_DEBUG("Dimension to select: " << dim); shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); - } + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); - LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(shape_1d_tensor); + auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); - auto tensor_holder = TensorContainer(); - tensor_holder.hold_tensor(shape_1d_tensor); - auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + return shape_1d_ivalue; - return shape_1d_ivalue; + } else { + auto input_size = c10::impl::GenericList(c10::AnyType::get()); + // Only express the dynamic dimension with a shape layer output. + // The static dimensions are preserved in the input size. + for (int32_t i = 0; i < input_dims.nbDims; i++) { + if (input_dims.d[i] == -1) { + auto dynamic_dim_tensor = index_layer(ctx, n, shape_1d_tensor, i); + auto dynamic_dim_holder = TensorContainer(); + dynamic_dim_holder.hold_tensor(dynamic_dim_tensor); + auto dynamic_dim_ivalue = c10::IValue(std::move(c10::make_intrusive(dynamic_dim_holder))); + input_size.emplace_back(std::move(dynamic_dim_ivalue)); + } else { + input_size.emplace_back(input_dims.d[i]); + } + } + return c10::IValue(input_size); + } } int64_t normalizeIndex(int64_t idx, int64_t list_size) { diff --git a/cpp/bin/torchtrtc/main.cpp b/cpp/bin/torchtrtc/main.cpp index b5f30080b9..c36cfdd0fc 100644 --- a/cpp/bin/torchtrtc/main.cpp +++ b/cpp/bin/torchtrtc/main.cpp @@ -168,6 +168,12 @@ int main(int argc, char** argv) { "Truncate weights that are provided in 64bit to 32bit (Long, Double to Int, Float)", {"truncate", "truncate-long-double", "truncate-64bit"}); + args::Flag allow_shape_tensors( + parser, + "allow-shape-tensors", + "(Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT", + {"allow-shape-tensors"}); + args::Flag save_engine( parser, "save_engine", @@ -443,6 +449,10 @@ int main(int argc, char** argv) { compile_settings.truncate_long_and_double = true; } + if (allow_shape_tensors) { + compile_settings.allow_shape_tensors = true; + } + torch::jit::Module mod; try { // Deserialize the ScriptModule from a file using torch::jit::load(). diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index dead16f6d9..29f860c8b3 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -791,6 +791,11 @@ struct CompileSpec { */ bool truncate_long_and_double = false; + /** + * Allow shape tensors (from IShape layer) in the graph + */ + bool allow_shape_tensors = false; + /** * Target Device */ diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 1954827893..41dae65114 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -90,6 +90,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.refit = external.refit; internal.convert_info.engine_settings.debug = external.debug; internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double; + internal.convert_info.engine_settings.allow_shape_tensors = external.allow_shape_tensors; internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback; internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; diff --git a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp index 528ffde23f..bae61881da 100644 --- a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp @@ -84,6 +84,7 @@ void RegisterTRTCompileSpec() { TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_global_dram_size); ADD_FIELD_GET_SET_REGISTRATION( TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double); + ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, allow_shape_tensors); } struct TRTTSRegistrations { diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index a312832628..9488b963cf 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -373,6 +373,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.partitioning_info.truncate_long_and_double = truncate_long_and_double; info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules; info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double; + info.convert_info.engine_settings.allow_shape_tensors = allow_shape_tensors; info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); TORCHTRT_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater"); @@ -423,6 +424,7 @@ std::string CompileSpec::stringify() { ss << " \"DLA Local DRAM Size\": " << dla_local_dram_size << std::endl; ss << " \"DLA Global DRAM Size\": " << dla_global_dram_size << std::endl; ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl; + ss << " \"Allow Shape tensors\": " << allow_shape_tensors << std::endl; ss << " \"Torch Fallback\": " << torch_fallback.to_str(); ss << "}"; return ss.str(); diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 0b42b68729..b570e456e9 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -167,6 +167,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(dla_local_dram_size, int64_t); ADD_FIELD_GET_SET(dla_global_dram_size, int64_t); ADD_FIELD_GET_SET(truncate_long_and_double, bool); + ADD_FIELD_GET_SET(allow_shape_tensors, bool); ADD_FIELD_GET_SET(device, Device); ADD_FIELD_GET_SET(torch_fallback, TorchFallback); ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); @@ -180,6 +181,7 @@ struct CompileSpec : torch::CustomClassHolder { bool refit = false; bool debug = false; bool truncate_long_and_double = false; + bool allow_shape_tensors = false; Device device; TorchFallback torch_fallback; EngineCapability capability = EngineCapability::kDEFAULT; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 142a316c05..f39888eb0f 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -371,7 +371,8 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("dla_local_dram_size", &CompileSpec::dla_local_dram_size) .def_readwrite("dla_global_dram_size", &CompileSpec::dla_global_dram_size) .def_readwrite("torch_fallback", &CompileSpec::torch_fallback) - .def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double); + .def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double) + .def_readwrite("allow_shape_tensors", &CompileSpec::allow_shape_tensors); py::class_(ts_sub_mod, "TorchFallback") .def(py::init<>()) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index f17a9fa5bf..08f18a22dd 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -298,6 +298,10 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: assert isinstance(compile_spec["debug"], bool) info.debug = compile_spec["debug"] + if "allow_shape_tensors" in compile_spec: + assert isinstance(compile_spec["allow_shape_tensors"], bool) + info.allow_shape_tensors = compile_spec["allow_shape_tensors"] + if "device" in compile_spec: info.device = _parse_device(compile_spec["device"]) @@ -354,6 +358,7 @@ def TensorRTCompileSpec( dla_global_dram_size=536870912, truncate_long_and_double=False, calibrator=None, + allow_shape_tensors=False, ) -> torch.classes.tensorrt.CompileSpec: """Utility to create a formated spec dictionary for using the PyTorch TensorRT backend @@ -388,6 +393,7 @@ def TensorRTCompileSpec( workspace_size (int): Maximum size of workspace given to TensorRT truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration + allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT Returns: torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt`` @@ -410,6 +416,7 @@ def TensorRTCompileSpec( "dla_global_dram_size": dla_global_dram_size, # Host RAM used by DLA to store weights and metadata for execution "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double, + "allow_shape_tensors": allow_shape_tensors, } parsed_spec = _parse_compile_spec(compile_spec) @@ -461,6 +468,7 @@ def TensorRTCompileSpec( backend_spec._set_dla_local_dram_size(parsed_spec.dla_local_dram_size) backend_spec._set_dla_global_dram_size(parsed_spec.dla_global_dram_size) backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double) + backend_spec._set_allow_shape_tensors(parsed_spec.allow_shape_tensors) backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle()) return backend_spec diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 19ea4ee802..9dc0731014 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -31,6 +31,7 @@ def compile( min_block_size=3, torch_executed_ops=[], torch_executed_modules=[], + allow_shape_tensors=False, ) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -94,6 +95,7 @@ def compile( min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT Returns: torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT @@ -131,6 +133,7 @@ def compile( "forced_fallback_modules": torch_executed_modules, "min_block_size": min_block_size, }, + "allow_shape_tensors": allow_shape_tensors, } compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) @@ -156,6 +159,7 @@ def convert_method_to_trt_engine( dla_global_dram_size=536870912, truncate_long_and_double=False, calibrator=None, + allow_shape_tensors=False, ) -> bytearray: """Convert a TorchScript module method to a serialized TensorRT engine @@ -214,6 +218,7 @@ def convert_method_to_trt_engine( dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration + allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT Returns: bytearray: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs @@ -236,6 +241,7 @@ def convert_method_to_trt_engine( "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double, + "allow_shape_tensors": allow_shape_tensors, } engine_str = _C.convert_graph_to_trt_engine( diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index afec847b3e..9e46842d9c 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -27,7 +27,8 @@ TEST(Converters, ATenResizeDynamicShapeCorrectly) { auto trt_in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + auto trt_results = + torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true); auto trt = trt_results[0].reshape(jit_results[0].sizes()); @@ -53,7 +54,8 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) { auto trt_in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + auto trt_results = + torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true); auto trt = trt_results[0].reshape(jit_results[0].sizes()); @@ -83,7 +85,8 @@ TEST(Converters, ATenResizeGetItemDynShapeCorrectly) { auto trt_in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + auto trt_results = + torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true); auto trt = trt_results[0].reshape(jit_results[0].sizes()); @@ -115,7 +118,8 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) { auto trt_in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + auto trt_results = + torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true); auto trt = trt_results[0].reshape(jit_results[0].sizes()); diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index 8358a3c570..62e1634623 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -94,12 +94,14 @@ std::vector RunGraphEngineDynamic( std::shared_ptr& g, core::ir::StaticParams& named_params, std::vector inputs, - bool dynamic_batch) { + bool dynamic_batch = false, + bool allow_shape_tensors = false) { LOG_DEBUG("Running TRT version"); auto var_ins = get_var_inputs(g->inputs(), named_params); auto in = core::ir::pair_input_vals_with_specs(var_ins, toInputsDynamic(inputs, dynamic_batch)); auto info = core::conversion::ConversionInfo(); info.inputs = std::move(in); + info.engine_settings.allow_shape_tensors = allow_shape_tensors; std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params); return RunEngine(eng, inputs); } diff --git a/tests/util/util.h b/tests/util/util.h index 7b1e46e083..9d3fc238d2 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -57,7 +57,8 @@ std::vector RunGraphEngineDynamic( std::shared_ptr& g, core::ir::StaticParams& named_params, std::vector inputs, - bool dynamic_batch = false); + bool dynamic_batch = false, + bool allow_shape_tensors = false); // Run the forward method of a module and return results torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector inputs);