diff --git a/core/compiler.cpp b/core/compiler.cpp index ccccc512ae..1095a88587 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -287,22 +287,45 @@ GraphAndMapping ConstructFallbackGraph( return {new_g, old_to_new_g}; } + +void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) { + // Associate input specs with inputs + cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); + + for (auto& in : g->inputs()) { + auto est_type_opt = first_use_type_map.find(in)->second; + ir::Input& spec = cfg.convert_info.inputs.find(in)->second; + if (est_type_opt && !spec.dtype_is_user_defined) { + // If we can calculate the type from the graph and the type was not defined by the user then use the calculated type + LOG_INFO("Since input type is not explicitly defined, infering using first tensor calculation\n Found input " + << in->debugName() << " has type " << est_type_opt.value() << ". If this is incorrect explicitly set dtype for input and file a bug"); + spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); + } else if (!est_type_opt && !spec.dtype_is_user_defined) { + // If we cannot calculate the type and the user did not define the type, then default to FP32 + LOG_WARNING( + "Cannot deterime input type from calcuations in graph for input " + << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); + spec.dtype = nvinfer1::DataType::kFLOAT; + } else { + // The user defined the type so no changes are necessary + } + } +} + std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { // Go through Lowering to simplify graph and extract weight parameters auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info); - auto convert_cfg = std::move(cfg.convert_info); auto g = graph_and_parameters.first; - auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); + // Infer the type of an input from the weights of the calculation + auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); - LOG_INFO(*g << "(CompileGraph)\n"); + MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); - // Move the user defined inputs to the convert_cfg since some might be static; - convert_cfg.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); + auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); - auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, static_params); return std::move(engine); } @@ -331,27 +354,12 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info); auto g = graph_and_parameters.first; - LOG_INFO("Lowered Graph: " << *g); auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); - - cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); - - // If the user did not explicitly set the input type, then use the first - // tensor calculation to infer type. + // Infer the type of an input from the weights of the calculation auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); - for (auto& in : g->inputs()) { - auto est_type_opt = first_use_types[in]; - ir::Input& spec = cfg.convert_info.inputs.find(in)->second; - if (est_type_opt && !spec.dtype_is_user_defined) { - spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); - } else if (!est_type_opt && !spec.dtype_is_user_defined) { - LOG_WARNING( - "Cannot deterime input type from calcuations in graph for input " - << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec.dtype = nvinfer1::DataType::kFLOAT; - } - } + + MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); if (cfg.partition_info.enabled) { auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 4be3e403aa..232dc08e19 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -97,6 +97,7 @@ std::pair, std::vector> L // Is this necessary? // lowering::LowerBlock(g->block()); + LOG_INFO("Lowered Graph: " << *(graph_and_ivalues.first)); return graph_and_ivalues; } diff --git a/core/util/jit_util.cpp b/core/util/jit_util.cpp index 128546e5f8..91bb1bede0 100644 --- a/core/util/jit_util.cpp +++ b/core/util/jit_util.cpp @@ -96,9 +96,8 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* return dtype; } -std::unordered_map> get_block_first_calc_dtypes_opt( - torch::jit::Block* b) { - std::unordered_map> types; +InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) { + InputTypeMap types; for (auto i : b->inputs()) { if (i->type() == c10::TensorType::get()) { diff --git a/core/util/jit_util.h b/core/util/jit_util.h index 7fa0739873..082441eeb1 100644 --- a/core/util/jit_util.h +++ b/core/util/jit_util.h @@ -9,6 +9,8 @@ namespace trtorch { namespace core { namespace util { +using InputTypeMap = std::unordered_map>; + inline std::string node_info(const torch::jit::Node* n) { std::stringstream ss; ss << *n; @@ -61,8 +63,7 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) { } c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); -std::unordered_map> get_block_first_calc_dtypes_opt( - torch::jit::Block* b); +InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); } // namespace util } // namespace core diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index 1be2cfa3ce..0f7030193a 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -125,7 +125,7 @@ namespace { TRTorchLogger& get_global_logger() { #ifndef NDEBUG - static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true); + static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true); #else static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false); #endif diff --git a/cpp/include/trtorch/trtorch.h b/cpp/include/trtorch/trtorch.h index 0c298a3321..874221e523 100644 --- a/cpp/include/trtorch/trtorch.h +++ b/cpp/include/trtorch/trtorch.h @@ -387,7 +387,7 @@ struct TRTORCH_API CompileSpec { * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(std::vector shape, TensorFormat format = TensorFormat::kContiguous); @@ -398,7 +398,7 @@ struct TRTORCH_API CompileSpec { * tensor format * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(std::vector shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); @@ -421,7 +421,7 @@ struct TRTORCH_API CompileSpec { * allow the user to configure expected input shape tensor format * * @param shape Input tensor shape - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input(c10::ArrayRef shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); @@ -451,7 +451,7 @@ struct TRTORCH_API CompileSpec { * @param min_shape Minimum shape for input tensor * @param opt_shape Target optimization shape for input tensor * @param max_shape Maximum acceptible shape for input tensor - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input( @@ -486,7 +486,7 @@ struct TRTORCH_API CompileSpec { * @param min_shape Minimum shape for input tensor * @param opt_shape Target optimization shape for input tensor * @param max_shape Maximum acceptible shape for input tensor - * @param dtype Expected data type for the input (Defaults to Float32) + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor calculation if detectable else Float32) * @param format Expected tensor format for the input (Defaults to contiguous) */ Input( @@ -506,14 +506,9 @@ struct TRTORCH_API CompileSpec { */ Input(at::Tensor tensor); - bool get_explicit_set_dtype() { - return explicit_set_dtype; - } - private: friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic; - bool explicit_set_dtype; }; /** diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index ff28bd2fe9..04bec052f6 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -73,7 +73,6 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) { } nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) { - TRTORCH_CHECK(!(value == CompileSpec::DataType::kUnknown), "Data type is unknown"); switch (value) { case CompileSpec::DataType::kChar: return nvinfer1::DataType::kINT8; @@ -162,8 +161,7 @@ CompileSpec::Input::Input(std::vector shape, TensorFormat format) { this->min_shape = shape; this->max_shape = shape; this->shape = shape; - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = false; } @@ -174,7 +172,6 @@ CompileSpec::Input::Input(std::vector shape, DataType dtype, TensorForm this->max_shape = shape; this->shape = shape; this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = false; } @@ -184,8 +181,7 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, TensorFormat format) { this->min_shape = core::util::toVec(shape); this->max_shape = core::util::toVec(shape); this->shape = core::util::toVec(shape); - this->dtype = DataType::kFloat; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = false; } @@ -196,7 +192,6 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f this->max_shape = core::util::toVec(shape); this->shape = core::util::toVec(shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = false; } @@ -210,8 +205,7 @@ CompileSpec::Input::Input( this->min_shape = min_shape; this->max_shape = max_shape; this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = true; } @@ -227,7 +221,6 @@ CompileSpec::Input::Input( this->max_shape = max_shape; this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = true; } @@ -241,8 +234,7 @@ CompileSpec::Input::Input( this->min_shape = core::util::toVec(min_shape); this->max_shape = core::util::toVec(max_shape); this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); - this->dtype = dtype; - this->explicit_set_dtype = false; + this->dtype = CompileSpec::DataType::kUnknown; this->format = format; this->input_is_dynamic = true; } @@ -258,7 +250,6 @@ CompileSpec::Input::Input( this->max_shape = core::util::toVec(max_shape); this->shape = core::util::toVec(core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); this->dtype = dtype; - this->explicit_set_dtype = true; this->format = format; this->input_is_dynamic = true; } @@ -269,7 +260,6 @@ CompileSpec::Input::Input(at::Tensor tensor) { this->max_shape = tensor.sizes().vec(); this->shape = tensor.sizes().vec(); this->dtype = tensor.scalar_type(); - this->explicit_set_dtype = true; TRTORCH_ASSERT( tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"); @@ -292,7 +282,7 @@ core::ir::Input to_internal_input(CompileSpec::Input& i) { i.max_shape, toTRTDataType(i.dtype), toTRTTensorFormat(i.format), - i.get_explicit_set_dtype()); + !(i.dtype == CompileSpec::DataType::kUnknown)); } std::vector to_vec_internal_inputs(std::vector& external) { diff --git a/py/trtorch/Input.py b/py/trtorch/Input.py index 51cf4f6860..27daab0a27 100644 --- a/py/trtorch/Input.py +++ b/py/trtorch/Input.py @@ -30,7 +30,7 @@ class _ShapeMode(Enum): shape_mode = None #: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` - dtype = _types.dtype.float32 #: The expected data type of the input tensor (default: trtorch.dtype.float32) + dtype = _types.dtype.unknown #: The expected data type of the input tensor (default: trtorch.dtype.float32) _explicit_set_dtype = False format = _types.TensorFormat.contiguous #: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW) @@ -133,16 +133,44 @@ def __str__(self) -> str: def _to_internal(self) -> trtorch._C.Input: internal_in = trtorch._C.Input() if self.shape_mode == Input._ShapeMode.DYNAMIC: - internal_in.min = self.shape["min_shape"] - internal_in.opt = self.shape["opt_shape"] - internal_in.max = self.shape["max_shape"] + if not Input._supported_input_size_type(self.shape["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["min_shape"])) + " for min_shape") + else: + internal_in.min = self.shape["min_shape"] + + if not Input._supported_input_size_type(self.shape["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["opt_shape"])) + " for opt_shape") + else: + internal_in.min = self.shape["op_shape"] + + if not Input._supported_input_size_type(self.shape["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape["max_shape"])) + " for max_shape") + else: + internal_in.min = self.shape["opt_shape"] internal_in.input_is_dynamic = True else: - internal_in.opt = self.shape + if not Input._supported_input_size_type(self.shape): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(self.shape)) + " for shape") + else: + internal_in.opt = self.shape internal_in.input_is_dynamic = False - internal_in.dtype = self.dtype + + if self.dtype != _types.dtype.unknown: + self._explicit_set_dtype = True + else: + self._explicit_set_dtype = False + + internal_in.dtype = Input._parse_dtype(self.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype - internal_in.format = self.format + internal_in.format = Input._parse_format(self.format) return internal_in @staticmethod @@ -172,7 +200,7 @@ def _parse_dtype(dtype: Any) -> _types.dtype: "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " + str(dtype)) - elif isinstance(dtype, _types.DataTypes): + elif isinstance(dtype, _types.dtype): return dtype else: diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index cf037575a1..913493a414 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -33,6 +33,8 @@ nvinfer1::DataType toTRTDataType(DataType value) { return nvinfer1::DataType::kBOOL; case DataType::kFloat: return nvinfer1::DataType::kFLOAT; + case DataType::kUnknown: + return nvinfer1::DataType::kFLOAT; default: TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value)); } diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index b7b5d08873..815c2a0ce4 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -27,7 +27,7 @@ namespace pyapi { return static_cast(field_name); \ } -enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool }; +enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index e8d3c9e696..beffc10dd3 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -186,6 +186,7 @@ PYBIND11_MODULE(_C, m) { .value("int8", DataType::kChar, "8 bit integer number") .value("int32", DataType::kInt32, "32 bit integer number") .value("bool", DataType::kChar, "Boolean value") + .value("unknown", DataType::kUnknown, "Unknown data type") .export_values(); py::enum_(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for") diff --git a/tests/core/test_detecting_input_type.cpp b/tests/core/test_detecting_input_type.cpp new file mode 100644 index 0000000000..c7a279d38a --- /dev/null +++ b/tests/core/test_detecting_input_type.cpp @@ -0,0 +1,51 @@ +#include +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" +#include "core/util/prelude.h" +#include "core/lowering/lowering.h" +#include "trtorch/trtorch.h" + +TEST(CoreTest, DetectingInputTypeWorksCorrectFP32) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); + auto g = graph_and_parameters.first; + + auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + + for (auto in : input_types) { + c10::optional& detected_type_opt = in.second; + ASSERT_TRUE(detected_type_opt); + ASSERT_TRUE(detected_type_opt.value() == at::kFloat); + } +} + +TEST(CoreTest, DetectingInputTypeWorksCorrectFP16) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/mobilenet_v2_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + mod.to(at::kHalf); + + auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); + auto g = graph_and_parameters.first; + + auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + + for (auto in : input_types) { + c10::optional& detected_type_opt = in.second; + ASSERT_TRUE(detected_type_opt); + ASSERT_TRUE(detected_type_opt.value() == at::kHalf); + } +} diff --git a/tests/cpp/test_default_input_types.cpp b/tests/cpp/test_default_input_types.cpp index 1522126791..9fd37d936a 100644 --- a/tests/cpp/test_default_input_types.cpp +++ b/tests/cpp/test_default_input_types.cpp @@ -1,11 +1,34 @@ #include "cpp_api_test.h" +#include "trtorch/logging.h" + +TEST_P(CppAPITests, InputsUseDefaultFP32) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone()); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP32 +} + +TEST_P(CppAPITests, InputsUseDefaultFP16) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); -TEST_P(CppAPITests, InputsUseDefault) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; for (auto in_shape : input_shapes) { auto in = at::randn(in_shape, {at::kCUDA}); - jit_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); } @@ -22,6 +45,74 @@ TEST_P(CppAPITests, InputsUseDefault) { // If exits without error successfully defaults to FP16 } +TEST_P(CppAPITests, InputsUseDefaultFP16WithoutFP16Enabled) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + auto spec = trtorch::CompileSpec({in}); + + mod.to(torch::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + +TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone()); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + in.dtype = torch::kF32; + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + mod.to(torch::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + +TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) { + trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kINFO); + + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randn(in_shape, {at::kCUDA}); + trt_inputs_ivalues.push_back(in.clone().to(torch::kHalf)); + } + + auto in = trtorch::CompileSpec::Input(input_shapes[0]); + in.dtype = torch::kF16; + auto spec = trtorch::CompileSpec({in}); + spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf); + + auto trt_mod = trtorch::CompileGraph(mod, spec); + torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); + std::vector trt_results; + trt_results.push_back(trt_results_ivalues.toTensor()); + // If exits without error successfully defaults to FP16 +} + INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, diff --git a/tests/py/test_api.py b/tests/py/test_api.py index c28cdaa27b..241f9a7609 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -2,6 +2,7 @@ import trtorch import torch import torchvision.models as models +import copy from model_test_case import ModelTestCase @@ -75,8 +76,6 @@ def test_compile_script_from_dict(self): self.assertTrue(same < 2e-2) - - class TestCompileHalf(ModelTestCase): def setUp(self): @@ -135,7 +134,6 @@ def test_compile_script(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False }, @@ -161,7 +159,6 @@ def test_compile_script(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False }, @@ -187,7 +184,6 @@ def test_pt_to_trt_to_pt(self): "device": { "device_type": trtorch.DeviceType.GPU, "gpu_id": 0, - "dla_core": 0, "allow_gpu_fallback": False, "disable_tf32": False } @@ -199,6 +195,80 @@ def test_pt_to_trt_to_pt(self): self.assertTrue(same < 2e-3) +class TestInputTypeDefaultsFP32Model(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + def test_input_use_default_fp32(self): + ts_model = torch.jit.script(self.model) + trt_mod = trtorch.compile(ts_model, + inputs=[trtorch.Input(self.input.shape)], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + def test_input_respect_user_setting_fp32_weights_fp16_in(self): + ts_model = torch.jit.script(self.model) + trt_mod = trtorch.compile(ts_model, + inputs=[self.input.half()], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self): + ts_model = torch.jit.script(self.model) + input_spec = trtorch.Input(self.input.shape) + input_spec.dtype = torch.half + + trt_mod = trtorch.compile(ts_model, + inputs=[input_spec], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + +class TestInputTypeDefaultsFP16Model(ModelTestCase): + + def setUp(self): + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + + def test_input_use_default_fp16(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[trtorch.Input(self.input.shape)], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input.half()) + + def test_input_use_default_fp16_without_fp16_enabled(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[trtorch.Input(self.input.shape)]) + trt_mod(self.input.half()) + + def test_input_respect_user_setting_fp16_weights_fp32_in(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + trt_mod = trtorch.compile(half_mod, + inputs=[self.input], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): + half_mod = torch.jit.script(self.model) + half_mod.half() + + input_spec = trtorch.Input(self.input.shape) + input_spec.dtype = torch.float + + trt_mod = trtorch.compile(half_mod, + inputs=[input_spec], + enabled_precisions={torch.float, torch.half}) + trt_mod(self.input) + + class TestCheckMethodOpSupport(unittest.TestCase): def setUp(self): @@ -284,6 +354,8 @@ def test_suite(): suite.addTest(TestCompileHalf.parametrize(TestCompileHalf, model=models.resnet18(pretrained=True))) suite.addTest(TestCompileHalfDefault.parametrize(TestCompileHalfDefault, model=models.resnet18(pretrained=True))) suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True))) + suite.addTest(TestInputTypeDefaultsFP32Model.parametrize(TestInputTypeDefaultsFP32Model, model=models.resnet18(pretrained=True))) + suite.addTest(TestInputTypeDefaultsFP16Model.parametrize(TestInputTypeDefaultsFP16Model, model=models.resnet18(pretrained=True))) suite.addTest(TestFallbackToTorch.parametrize(TestFallbackToTorch, model=models.resnet18(pretrained=True))) suite.addTest( TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))