diff --git a/core/compiler.cpp b/core/compiler.cpp index 1095a88587..a13bc1059c 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -253,6 +253,7 @@ GraphAndMapping ConstructFallbackGraph( } // update the input ranges for each segments convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); + auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params); auto temp_g = std::make_shared(); auto device_spec = convert_cfg.engine_settings.device; @@ -288,7 +289,7 @@ GraphAndMapping ConstructFallbackGraph( } -void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, const util::InputTypeMap& first_use_type_map) { +void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, ir::TypeMap& first_use_type_map) { // Associate input specs with inputs cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); @@ -303,9 +304,31 @@ void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptrdebugName() << ". Assuming it is Float32. If not, specify input type explicity"); spec.dtype = nvinfer1::DataType::kFLOAT; + } else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) { + if (!est_type_opt) { + LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings"); + } else { + if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) { + std::stringstream ss; + ss <<"For input " << in->debugName() << ", found user specified input dtype as "; + ss << cfg.convert_info.inputs.find(in)->second.dtype; + ss << ", however when inspecting the graph, the input type expected was inferred to be "; + ss << est_type_opt.value() << std::endl; + ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype; + ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; + ss << "compatibility with PyTorch's data type convention is required.\n"; + ss << "If you do indeed see errors at runtime either:\n"; + ss << "- Remove the dtype spec for " << in->debugName() << std::endl; + ss << "- Disable partial compilation by setting require_full_compilation to True"; + auto warn_str = ss.str(); + LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; + } + } } else { // The user defined the type so no changes are necessary } @@ -317,10 +340,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info); auto g = graph_and_parameters.first; + TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler"); auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); @@ -357,11 +381,21 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = util::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); - if (cfg.partition_info.enabled) { + if (cfg.partition_info.enabled + && (cfg.lower_info.forced_fallback_modules.size() == 0 + && cfg.partition_info.forced_fallback_operators.size() == 0 + && conversion::VerifyConverterSupportForBlock(g->block(), true))) { + LOG_INFO("Skipping partitioning since model is fully supported"); + } + + if (cfg.partition_info.enabled + && !(cfg.lower_info.forced_fallback_modules.size() == 0 + && cfg.partition_info.forced_fallback_operators.size() == 0 + && conversion::VerifyConverterSupportForBlock(g->block(), false))) { auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params); new_g = graph_and_mapping.first; @@ -374,6 +408,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) return mod; } } else { + TRTORCH_CHECK(conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler"); auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); auto device_spec = cfg.convert_info.engine_settings.device; auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 248fad1e41..d62264cf7b 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -491,7 +491,7 @@ std::set ConvertableOpsInBlock(const torch::jit::Block* b) { return convertable_ops; } -bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { +bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) { auto unsupported_ops = GetUnsupportedOpsInBlock(b); if (unsupported_ops.size() != 0) { @@ -506,16 +506,20 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl; unsupported_msg << std::endl << "In Module:" << std::endl; - LOG_ERROR(unsupported_msg.str()); + if (suppress_errors) { + LOG_ERROR(unsupported_msg.str()); + } for (const auto n : b->nodes()) { auto schema = n->maybeSchema(); if (schema) { for (const auto& x : unsupported_ops) { if (x.first == schema->operator_name()) { - LOG_ERROR( - "Unsupported operator: " << *schema << std::endl - << trtorch::core::util::GetPyTorchSourceCode(n) << std::endl); + if (suppress_errors) { + LOG_ERROR( + "Unsupported operator: " << *schema << std::endl + << trtorch::core::util::GetPyTorchSourceCode(n) << std::endl); + } } } } @@ -531,7 +535,9 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) { unsupported_msg << "This may be because there are no operators that can be added to the TensorRT graph or all operators have a resolved compile time value." << std::endl; - LOG_ERROR(unsupported_msg.str()); + if (suppress_errors) { + LOG_ERROR(unsupported_msg.str()); + } return false; } diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index 2d60351edc..7d11478be7 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -25,7 +25,7 @@ std::string ConvertBlockToEngine( bool OpSupported(const torch::jit::Node* n); -bool VerifyConverterSupportForBlock(const torch::jit::Block* b); +bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors=false); c10::optional EvaluateNode( ConversionCtx* ctx, diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index 969b3b38c4..d3223bc4b6 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -45,6 +45,109 @@ std::vector get_tensor_inputs( return input_tensors; } +c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in) { + TRTORCH_ASSERT(in->owningGraph() == b->owningGraph(), "Provided input is not part of the provided graph"); + c10::optional dtype = {}; + + auto b_ins = b->inputs(); + std::unordered_set b_in_set(b_ins.begin(), b_ins.end()); + + TRTORCH_ASSERT( + in->type() == c10::TensorType::get(), "Input is not a tensor, cannot check for dtype based on calculation"); + + auto consumers = in->uses(); + auto search_list = std::vector(consumers.begin(), consumers.end()); + + for (auto& u : search_list) { + auto n = u.user; + LOG_GRAPH("Node we are looking at: " << util::node_info(n)); + auto ins = n->inputs(); + auto outs = n->outputs(); + + bool outputs_tensor = false; + for (auto o : outs) { + if (o->type() == c10::TensorType::get()) { + outputs_tensor = true; + break; + } + } + + if (!outputs_tensor) { + LOG_GRAPH("Node " << util::node_info(n) << " does not output a tensor, skipping"); + continue; + } + + LOG_GRAPH("Node " << util::node_info(n) << " outputs a tensor"); + + // If all input tensors are block inputs then this node will not give us useful type info so move to the next one + bool all_n_ins_are_b_ins = true; + for (auto in : ins) { + if (b_in_set.find(in) == b_in_set.end()) { + all_n_ins_are_b_ins = false; + break; + } + } + + if (all_n_ins_are_b_ins) { + LOG_GRAPH( + "All inputs to Node " << util::node_info(n) << " are graph inputs, cannot be used to determine input type"); + for (auto o : outs) { + if (o->type() == c10::TensorType::get()) { + auto o_uses = o->uses(); + search_list.insert(search_list.end(), o_uses.begin(), o_uses.end()); + } + } + continue; + } + + // If node outputs a Tensor it might be a result of tensor calcuation so check to see + // if any inputs to the calculation can give us hints + c10::optional const_tensor_n = {}; + + // Backtrace to constants which will immediately give us the Tensor type if possible + for (auto in : ins) { + LOG_GRAPH("Input to node: " << util::node_info(in->node())); + if (in->type()->isSubtypeOf(torch::jit::TensorType::get())) { + LOG_GRAPH("Input outputs a Tensor"); + if (in->node()->kind() == torch::jit::prim::Constant) { + LOG_GRAPH("Input is a constant"); + auto const_val = in->node()->t(c10::attr::value); + LOG_GRAPH("Found that constant tensor has type: " << const_val.scalar_type()); + dtype = {const_val.scalar_type()}; + goto exit_first_calc_dtype; + } + } + } + + // Add all tensor outputs to search list if we still dont know + for (auto o : outs) { + if (o->type() == c10::TensorType::get()) { + auto o_uses = o->uses(); + search_list.insert(search_list.end(), o_uses.begin(), o_uses.end()); + } + } + } +exit_first_calc_dtype: + if (dtype) { + LOG_GRAPH("Estimated input type is " << dtype.value()); + } else { + LOG_GRAPH("Cannot determine input types from graph"); + } + return dtype; +} + +TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) { + TypeMap types; + + for (auto i : b->inputs()) { + if (i->type() == c10::TensorType::get()) { + torch::jit::Value* in = i; + types.insert({in, get_value_first_calc_dtype_opt(b, i)}); + } + } + return types; +} + } // namespace ir } // namespace core } // namespace trtorch \ No newline at end of file diff --git a/core/ir/ir.h b/core/ir/ir.h index 7499ef794b..5d8bef78b6 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -52,6 +52,11 @@ std::vector get_tensor_inputs( std::shared_ptr& g, StaticParams& static_params); +using TypeMap = std::unordered_map>; + +c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); +ir::TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); + } // namespace ir } // namespace core } // namespace trtorch diff --git a/core/lowering/LowerInfo.cpp b/core/lowering/LowerInfo.cpp index e60701bc00..0e4c566c65 100644 --- a/core/lowering/LowerInfo.cpp +++ b/core/lowering/LowerInfo.cpp @@ -10,7 +10,7 @@ namespace lowering { std::ostream& operator<<(std::ostream& os, const LowerInfo& l) { os << "Settings requested for Lowering:" << std::endl; - os << " Forced Fallback Modules: [" << std::endl; + os << " torch_executed_modules: [" << std::endl; for (auto i : l.forced_fallback_modules) { os << " " << i << std::endl; } diff --git a/core/partitioning/PartitionInfo.cpp b/core/partitioning/PartitionInfo.cpp index a7d76928c2..0aa9d7d642 100644 --- a/core/partitioning/PartitionInfo.cpp +++ b/core/partitioning/PartitionInfo.cpp @@ -14,7 +14,7 @@ std::ostream& operator<<(std::ostream& os, const PartitionInfo& s) { if (s.enabled) { os << "True"; os << "\n \"min_block_size\": " << s.min_block_size \ - << "\n \"forced_fallback_operators\": ["; + << "\n \"torch_executed_operators\": ["; for (auto i : s.forced_fallback_operators) { os <<"\n " << i << ','; } diff --git a/core/util/BUILD b/core/util/BUILD index 5934bf3434..860cecfedc 100644 --- a/core/util/BUILD +++ b/core/util/BUILD @@ -27,9 +27,6 @@ cc_library( hdrs = [ "jit_util.h", ], - srcs = [ - "jit_util.cpp" - ], deps = [ ":macros" ] + select({ diff --git a/core/util/jit_util.h b/core/util/jit_util.h index 082441eeb1..b4145cd5ab 100644 --- a/core/util/jit_util.h +++ b/core/util/jit_util.h @@ -9,7 +9,6 @@ namespace trtorch { namespace core { namespace util { -using InputTypeMap = std::unordered_map>; inline std::string node_info(const torch::jit::Node* n) { std::stringstream ss; @@ -62,9 +61,6 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) { return source_code; } -c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); -InputTypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); - } // namespace util } // namespace core } // namespace trtorch diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index 0f7030193a..ddd2918159 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -125,9 +125,9 @@ namespace { TRTorchLogger& get_global_logger() { #ifndef NDEBUG - static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true); + static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true); #else - static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false); + static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kWARNING, false); #endif return global_logger; } diff --git a/cpp/bin/trtorchc/README.md b/cpp/bin/trtorchc/README.md index 22cc3bec77..cc483235d2 100644 --- a/cpp/bin/trtorchc/README.md +++ b/cpp/bin/trtorchc/README.md @@ -36,8 +36,8 @@ OPTIONS: --allow-gpu-fallback (Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA - --allow-torch-fallback Enable layers to run in torch if they - are not supported in TensorRT + --require-full-compilation Require that the model should be fully + compiled to TensorRT or throw an error --disable-tf32 Prevent Float32 layers from using the TF32 data format --sparse-weights Enable sparsity for weights of conv and @@ -63,18 +63,22 @@ OPTIONS: --calibration-cache-file=[file_path] Path to calibration cache file to use for post training quantization - --ffo=[forced_fallback_ops...], - --forced-fallback-op=[forced_fallback_ops...] + --teo=[torch-executed-ops...], + --torch-executed-ops=[torch-executed-ops...] (Repeatable) Operator in the graph that - should be forced to fallback to Pytorch - for execution (allow torch fallback must - be set) - --ffm=[forced_fallback_mods...], - --forced-fallback-mod=[forced_fallback_mods...] - (Repeatable) Module that should be - forced to fallback to Pytorch for - execution (allow torch fallback must be - set) + should always be run in PyTorch for + execution (partial compilation must be + enabled) + --tem=[torch-executed-mods...], + --torch-executed-mods=[torch-executed-mods...] + (Repeatable) Module that should always + be run in Pytorch for execution (partial + compilation must be enabled) + --mbs=[torch-executed-mods...], + --min-block-size=[torch-executed-mods...] + Minimum number of contiguous TensorRT + supported ops to compile a subgraph to + TensorRT --embed-engine Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec diff --git a/cpp/bin/trtorchc/main.cpp b/cpp/bin/trtorchc/main.cpp index ee5076e434..b4aa9ec4d7 100644 --- a/cpp/bin/trtorchc/main.cpp +++ b/cpp/bin/trtorchc/main.cpp @@ -237,11 +237,11 @@ int main(int argc, char** argv) { "(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA", {"allow-gpu-fallback"}); - args::Flag allow_torch_fallback( + args::Flag require_full_compilation( parser, - "allow-torch-fallback", - "Enable layers to run in torch if they are not supported in TensorRT", - {"allow-torch-fallback"}); + "require-full-compilation", + "Require that the model should be fully compiled to TensorRT or throw an error", + {"require-full-compilation"}); args::Flag disable_tf32( parser, "disable-tf32", "Prevent Float32 layers from using the TF32 data format", {"disable-tf32"}); @@ -276,17 +276,23 @@ int main(int argc, char** argv) { "Path to calibration cache file to use for post training quantization", {"calibration-cache-file"}); - args::ValueFlagList forced_fallback_ops( + args::ValueFlagList torch_executed_ops( parser, - "forced_fallback_ops", - "(Repeatable) Operator in the graph that should be forced to fallback to Pytorch for execution (allow torch fallback must be set)", - {"ffo", "forced-fallback-op"}); + "torch-executed-ops", + "(Repeatable) Operator in the graph that should always be run in PyTorch for execution (partial compilation must be enabled)", + {"teo", "torch-executed-ops"}); - args::ValueFlagList forced_fallback_mods( + args::ValueFlagList torch_executed_mods( parser, - "forced_fallback_mods", - "(Repeatable) Module that should be forced to fallback to Pytorch for execution (allow torch fallback must be set)", - {"ffm", "forced-fallback-mod"}); + "torch-executed-mods", + "(Repeatable) Module that should always be run in Pytorch for execution (partial compilation must be enabled)", + {"tem", "torch-executed-mods"}); + + args::ValueFlagList min_block_size( + parser, + "torch-executed-mods", + "Minimum number of contiguous TensorRT supported ops to compile a subgraph to TensorRT", + {"mbs", "min-block-size"}); args::Flag embed_engine( parser, @@ -478,24 +484,24 @@ int main(int argc, char** argv) { auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path); - if (allow_torch_fallback) { - compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback(true); - } + compile_settings.require_full_compilation = require_full_compilation; - if (forced_fallback_ops || forced_fallback_mods) { - if (!allow_torch_fallback) { + if (torch_executed_ops || torch_executed_mods) { + if (require_full_compilation) { trtorch::logging::log( trtorch::logging::Level::kERROR, - "Forced fallback ops provided but allow-torch-fallback is not set. Please use --allow-torch-fallback to enable automatic fallback of operators."); + "Ops or modules to run in torch were provided but full compilation was requested. Please remove --require-full-compilation to run specified ops and modules in torch."); exit(1); } - for (const auto fallback_op : args::get(forced_fallback_ops)) { - compile_settings.torch_fallback.forced_fallback_ops.push_back(fallback_op); + compile_settings.min_block_size = min_block_size; + + for (const auto _op : args::get(torch_executed_ops)) { + compile_settings.torch_executed_ops.push_back(_op); } - for (const auto fallback_mod : args::get(forced_fallback_mods)) { - compile_settings.torch_fallback.forced_fallback_modules.push_back(fallback_mod); + for (const auto _mod : args::get(torch_executed_mods)) { + compile_settings.torch_executed_modules.push_back(_mod); } } @@ -609,7 +615,7 @@ int main(int argc, char** argv) { return 1; } - if (!allow_torch_fallback) { + if (require_full_compilation) { if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) { trtorch::logging::log(trtorch::logging::Level::kERROR, "Module is not currently supported by TRTorch"); return 1; diff --git a/docsrc/tutorials/trtorchc.rst b/docsrc/tutorials/trtorchc.rst index ac6a248b8a..c98e84a64e 100644 --- a/docsrc/tutorials/trtorchc.rst +++ b/docsrc/tutorials/trtorchc.rst @@ -39,8 +39,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r --allow-gpu-fallback (Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA - --allow-torch-fallback Enable layers to run in torch if they - are not supported in TensorRT + --require-full-compilation Require that the model should be fully + compiled to TensorRT or throw an error --disable-tf32 Prevent Float32 layers from using the TF32 data format --sparse-weights Enable sparsity for weights of conv and @@ -66,18 +66,22 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r --calibration-cache-file=[file_path] Path to calibration cache file to use for post training quantization - --ffo=[forced_fallback_ops...], - --forced-fallback-op=[forced_fallback_ops...] + --teo=[torch-executed-ops...], + --torch-executed-ops=[torch-executed-ops...] (Repeatable) Operator in the graph that - should be forced to fallback to Pytorch - for execution (allow torch fallback must - be set) - --ffm=[forced_fallback_mods...], - --forced-fallback-mod=[forced_fallback_mods...] - (Repeatable) Module that should be - forced to fallback to Pytorch for - execution (allow torch fallback must be - set) + should always be run in PyTorch for + execution (partial compilation must be + enabled) + --tem=[torch-executed-mods...], + --torch-executed-mods=[torch-executed-mods...] + (Repeatable) Module that should always + be run in Pytorch for execution (partial + compilation must be enabled) + --mbs=[torch-executed-mods...], + --min-block-size=[torch-executed-mods...] + Minimum number of contiguous TensorRT + supported ops to compile a subgraph to + TensorRT --embed-engine Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 52f10751ac..898a9aee3c 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -25,7 +25,7 @@ def compile(module: torch.jit.ScriptModule, max_batch_size=0, calibrator=None, truncate_long_and_double=False, - require_full_compilation=True, + require_full_compilation=False, min_block_size=3, torch_executed_ops=[], torch_executed_modules=[]) -> torch.jit.ScriptModule: diff --git a/tests/core/BUILD b/tests/core/BUILD index fc5f788a1b..3235215cd1 100644 --- a/tests/core/BUILD +++ b/tests/core/BUILD @@ -16,6 +16,7 @@ cc_test( deps = [ "//tests/util", "//core", + "//core/ir", "//core/lowering", "//core/util:prelude", "@googletest//:gtest_main", diff --git a/tests/core/test_detecting_input_type.cpp b/tests/core/test_detecting_input_type.cpp index c7a279d38a..190697c52d 100644 --- a/tests/core/test_detecting_input_type.cpp +++ b/tests/core/test_detecting_input_type.cpp @@ -4,6 +4,7 @@ #include "torch/script.h" #include "core/util/prelude.h" #include "core/lowering/lowering.h" +#include "core/ir/ir.h" #include "trtorch/trtorch.h" TEST(CoreTest, DetectingInputTypeWorksCorrectFP32) { @@ -18,7 +19,7 @@ TEST(CoreTest, DetectingInputTypeWorksCorrectFP32) { auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); auto g = graph_and_parameters.first; - auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + auto input_types = trtorch::core::ir::get_block_first_calc_dtypes_opt(g->block()); for (auto in : input_types) { c10::optional& detected_type_opt = in.second; @@ -36,16 +37,16 @@ TEST(CoreTest, DetectingInputTypeWorksCorrectFP16) { ASSERT_TRUE(false); } - mod.to(at::kHalf); + mod.to(at::kHalf); - auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); + auto graph_and_parameters = trtorch::core::lowering::Lower(mod, "forward", {}); auto g = graph_and_parameters.first; - auto input_types = trtorch::core::util::get_block_first_calc_dtypes_opt(g->block()); + auto input_types = trtorch::core::ir::get_block_first_calc_dtypes_opt(g->block()); - for (auto in : input_types) { - c10::optional& detected_type_opt = in.second; - ASSERT_TRUE(detected_type_opt); - ASSERT_TRUE(detected_type_opt.value() == at::kHalf); - } + for (auto in : input_types) { + c10::optional& detected_type_opt = in.second; + ASSERT_TRUE(detected_type_opt); + ASSERT_TRUE(detected_type_opt.value() == at::kHalf); + } } diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 241f9a7609..af10399fd8 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -211,6 +211,7 @@ def test_input_respect_user_setting_fp32_weights_fp16_in(self): ts_model = torch.jit.script(self.model) trt_mod = trtorch.compile(ts_model, inputs=[self.input.half()], + require_full_compilation=True, enabled_precisions={torch.float, torch.half}) trt_mod(self.input.half()) @@ -221,6 +222,7 @@ def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self): trt_mod = trtorch.compile(ts_model, inputs=[input_spec], + require_full_compilation=True, enabled_precisions={torch.float, torch.half}) trt_mod(self.input.half()) @@ -253,6 +255,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self): trt_mod = trtorch.compile(half_mod, inputs=[self.input], + require_full_compilation=True, enabled_precisions={torch.float, torch.half}) trt_mod(self.input) @@ -265,6 +268,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): trt_mod = trtorch.compile(half_mod, inputs=[input_spec], + require_full_compilation=True, enabled_precisions={torch.float, torch.half}) trt_mod(self.input)