diff --git a/core/ir/Input.cpp b/core/ir/Input.cpp index f1f9f694eb..8841f4109c 100644 --- a/core/ir/Input.cpp +++ b/core/ir/Input.cpp @@ -128,8 +128,6 @@ Input::Input(std::vector shape, nvinfer1::DataType dtype, nvinfer1::Ten max = util::toDims(shape); input_shape = util::toDims(shape); input_is_dynamic = false; - format = nvinfer1::TensorFormat::kLINEAR; - dtype = dtype; TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); this->dtype = dtype; @@ -156,8 +154,6 @@ Input::Input(std::vector min_shape, std::vector opt_shape, std min = util::toDims(min_shape); opt = util::toDims(opt_shape); max = util::toDims(max_shape); - format = nvinfer1::TensorFormat::kLINEAR; - dtype = nvinfer1::DataType::kFLOAT; std::vector dyn_shape; for (size_t i = 0; i < opt_shape.size(); i++) { diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 514c218aa1..ca873af2ab 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -66,10 +67,12 @@ struct TRTORCH_API CompileSpec { kHalf, /// INT8 kChar, - /// INT32 - kInt32, + /// INT + kInt, /// Bool kBool, + /// Sentinel value + kUnknown }; /** @@ -139,6 +142,7 @@ struct TRTORCH_API CompileSpec { } private: + friend std::ostream& operator<<(std::ostream& os, const DataType& dtype); Value value; }; @@ -278,6 +282,8 @@ struct TRTORCH_API CompileSpec { kContiguous, /// Channel Last / NHWC kChannelsLast, + /// Sentinel value + kUnknown, }; /** @@ -346,7 +352,9 @@ struct TRTORCH_API CompileSpec { return value != other; } + private: + friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format); Value value; }; @@ -472,6 +480,7 @@ struct TRTORCH_API CompileSpec { bool get_explicit_set_dtype() {return explicit_set_dtype;} private: + friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic; bool explicit_set_dtype; }; diff --git a/cpp/api/src/compile_spec.cpp b/cpp/api/src/compile_spec.cpp index 34a618e0ed..3e2bb1d728 100644 --- a/cpp/api/src/compile_spec.cpp +++ b/cpp/api/src/compile_spec.cpp @@ -9,13 +9,74 @@ namespace trtorch { +std::ostream& operator<<(std::ostream& os, const CompileSpec::DataType& dtype) { + switch (dtype) { + case CompileSpec::DataType::kChar: + os << "char"; + break; + case CompileSpec::DataType::kHalf: + os << "half"; + break; + case CompileSpec::DataType::kInt: + os << "int"; + break; + case CompileSpec::DataType::kBool: + os << "bool"; + break; + case CompileSpec::DataType::kFloat: + os << "float"; + break; + case CompileSpec::DataType::kUnknown: + default: + os << "unknown"; + break; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const CompileSpec::TensorFormat& format) { + switch (format) { + case CompileSpec::TensorFormat::kChannelsLast: + os << "channels last"; + break; + case CompileSpec::TensorFormat::kContiguous: + os << "contiguous"; + break; + case CompileSpec::TensorFormat::kUnknown: + default: + os << "unknown"; + break; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) { + auto vec_to_str = [](std::vector shape) -> std::string { + std::stringstream ss; + ss << '['; + for (auto i : shape) { + ss << i << ','; + } + ss << ']'; + return ss.str(); + }; + + if (!input.input_is_dynamic) { + os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')'; + } else { + os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape) << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')'; + } + return os; +} + + nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) { switch (value) { case CompileSpec::DataType::kChar: return nvinfer1::DataType::kINT8; case CompileSpec::DataType::kHalf: return nvinfer1::DataType::kHALF; - case CompileSpec::DataType::kInt32: + case CompileSpec::DataType::kInt: return nvinfer1::DataType::kINT32; case CompileSpec::DataType::kBool: return nvinfer1::DataType::kBOOL; @@ -47,7 +108,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) { value = DataType::kChar; break; case at::kInt: - value = DataType::kInt32; + value = DataType::kInt; break; case at::kBool: value = DataType::kBool; @@ -250,7 +311,6 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { /* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype for inputs they will follow PyTorch convetions */ for (size_t i = 0; i < external.inputs.size(); i++) { - std::cout << "EXPLICIT " << external.inputs[i].get_explicit_set_dtype() << std::endl; if (!external.inputs[i].get_explicit_set_dtype()) { auto& precisions = internal.convert_info.engine_settings.enabled_precisions; auto& internal_ins = internal.convert_info.inputs; @@ -261,9 +321,9 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { } else { internal_ins[i].dtype = nvinfer1::DataType::kFLOAT; } - std::cout << "internal type: " << internal_ins[i].dtype; } } + internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32; internal.convert_info.engine_settings.refit = external.refit; internal.convert_info.engine_settings.debug = external.debug; diff --git a/cpp/trtorchc/README.md b/cpp/trtorchc/README.md index 25a59efb27..3d065dd021 100644 --- a/cpp/trtorchc/README.md +++ b/cpp/trtorchc/README.md @@ -14,7 +14,7 @@ to standard TorchScript. Load with `torch.jit.load()` and run like you would run ``` trtorchc [input_file_path] [output_file_path] - [input_shapes...] {OPTIONS} + [input_specs...] {OPTIONS} TRTorch is a compiler for TorchScript, it will compile and optimize TorchScript programs to run on NVIDIA GPUs using TensorRT @@ -28,24 +28,29 @@ trtorchc [input_file_path] [output_file_path] -w, --warnings Disables warnings generated during compilation onto the console (warnings are on by default) - --info Dumps info messages generated during + --i, --info Dumps info messages generated during compilation onto the console --build-debuggable-engine Creates a debuggable engine --use-strict-types Restrict operating type to only use set - default operation precision - (op_precision) + operation precision --allow-gpu-fallback (Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA - -p[precision], - --default-op-precision=[precision] - Default operating precision for the - engine (Int8 requires a + --disable-tf32 Prevent Float32 layers from using the + TF32 data format + -p[precision...], + --enabled-precison=[precision...] (Repeatable) Enabling an operating + precision for kernels to use when + building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float) -d[type], --device-type=[type] The type of device the engine should be built for [ gpu | dla ] (default: gpu) + --gpu-id=[gpu_id] GPU id if running on multi-GPU platform + (defaults to 0) + --dla-core=[dla_core] DLACore id if running on available DLA + (defaults to 0) --engine-capability=[capability] The type of device the engine should be built for [ default | safe_gpu | safe_dla ] @@ -72,16 +77,21 @@ trtorchc [input_file_path] [output_file_path] input_file_path Path to input TorchScript file output_file_path Path for compiled TorchScript (or TensorRT engine) file - input_shapes... Sizes for inputs to engine, can either + input_specs... Specs for inputs to engine, can either be a single size or a range defined by Min, Optimal, Max sizes, e.g. "(N,..,C,H,W)" - "[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]" + "[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]". + Data Type and format can be specified by + adding an "@" followed by dtype and "%" + followed by format to the end of the + shape spec. e.g. "(3, 3, 32, + 32)@f16%NHWC" "--" can be used to terminate flag options and force all following arguments to be treated as positional options ``` e.g. ``` -trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16 +trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16 ``` \ No newline at end of file diff --git a/cpp/trtorchc/main.cpp b/cpp/trtorchc/main.cpp index 93dbe3382f..794238c7eb 100644 --- a/cpp/trtorchc/main.cpp +++ b/cpp/trtorchc/main.cpp @@ -42,6 +42,43 @@ bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) { return checkRtol(a - b, {a, b}, threshold); } +trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) { + std::transform( + str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); }); + + if (str == "linear" || str == "nchw" || str == "chw" || str == "contiguous") { + return trtorch::CompileSpec::TensorFormat::kContiguous; + } else if (str == "nhwc" || str == "hwc" || str == "channels_last") { + return trtorch::CompileSpec::TensorFormat::kChannelsLast; + } else { + trtorch::logging::log( + trtorch::logging::Level::kERROR, + "Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]"); + return trtorch::CompileSpec::TensorFormat::kUnknown; + } +} + +trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) { + std::transform( + dtype_str.begin(), dtype_str.end(), dtype_str.begin(), [](unsigned char c) { return std::tolower(c); }); + if (dtype_str == "float" || dtype_str == "float32" || dtype_str == "f32") { + return trtorch::CompileSpec::DataType::kFloat; + } else if (dtype_str == "half" || dtype_str == "float16" || dtype_str == "f16") { + return trtorch::CompileSpec::DataType::kHalf; + } else if (dtype_str == "char" || dtype_str == "int8" || dtype_str == "i8") { + return trtorch::CompileSpec::DataType::kChar; + } else if (dtype_str == "int" || dtype_str == "int32" || dtype_str == "i32") { + return trtorch::CompileSpec::DataType::kInt; + } else if (dtype_str == "bool" || dtype_str == "b") { + return trtorch::CompileSpec::DataType::kBool; + } else { + trtorch::logging::log( + trtorch::logging::Level::kERROR, + "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]"); + return trtorch::CompileSpec::DataType::kUnknown; + } +} + std::vector parseSingleDim(std::string shape_str) { std::vector shape; std::stringstream ss; @@ -71,7 +108,7 @@ std::vector parseSingleDim(std::string shape_str) { return {}; } -trtorch::CompileSpec::InputRange parseDynamicDim(std::string shape_str) { +std::vector> parseDynamicDim(std::string shape_str) { shape_str = shape_str.substr(1, shape_str.size() - 2); std::vector> shape; std::stringstream ss; @@ -96,7 +133,7 @@ trtorch::CompileSpec::InputRange parseDynamicDim(std::string shape_str) { exit(1); } - return trtorch::CompileSpec::InputRange(shape[0], shape[1], shape[2]); + return shape; } std::string get_cwd() { @@ -155,7 +192,7 @@ int main(int argc, char** argv) { args::Flag use_strict_types( parser, "use-strict-types", - "Restrict operating type to only use set default operation precision (op_precision)", + "Restrict operating type to only use set operation precision", {"use-strict-types"}); args::Flag allow_gpu_fallback( parser, @@ -166,11 +203,11 @@ int main(int argc, char** argv) { args::Flag disable_tf32( parser, "disable-tf32", "Prevent Float32 layers from using the TF32 data format", {"disable-tf32"}); - args::ValueFlag op_precision( + args::ValueFlagList enabled_precision( parser, "precision", - "Default operating precision for the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float)", - {'p', "default-op-precision"}); + "(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float)", + {'p', "enabled-precison"}); args::ValueFlag device_type( parser, "type", @@ -215,8 +252,8 @@ int main(int argc, char** argv) { parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file"); args::PositionalList input_shapes( parser, - "input_shapes", - "Sizes for inputs to engine, can either be a single size or a range defined by Min, Optimal, Max sizes, e.g. \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\""); + "input_specs", + "Specs for inputs to engine, can either be a single size or a range defined by Min, Optimal, Max sizes, e.g. \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\". Data Type and format can be specified by adding an \"@\" followed by dtype and \"%\" followed by format to the end of the shape spec. e.g. \"(3, 3, 32, 32)@f16\%NHWC\""); try { parser.ParseCLI(argc, argv); @@ -237,19 +274,103 @@ int main(int argc, char** argv) { trtorch::logging::set_reportable_log_level(trtorch::logging::Level::kERROR); } - std::vector ranges; - for (const auto shapes : args::get(input_shapes)) { - if (shapes.rfind("(", 0) == 0) { - ranges.push_back(trtorch::CompileSpec::InputRange(parseSingleDim(shapes))); - } else if (shapes.rfind("[", 0) == 0) { - ranges.push_back(parseDynamicDim(shapes)); + std::vector ranges; + const std::string spec_err_str = "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"\nTo specify input type append an @ followed by the precision\n e.g. \"(3,3,300,300)@f32\"\nTo specify input format append an \% followed by the format [contiguous | channel_last]\n e.g. \"(3,3,300,300)@f32\%channel_last\""; + for (const auto spec : args::get(input_shapes)) { + std::string shapes; + std::string dtype; + std::string format; + // THERE IS A SPEC FOR DTYPE + if (spec.find('@') != std::string::npos) { + // THERE IS ALSO A SPEC FOR FORMAT + if (spec.find('%') != std::string::npos) { + auto dtype_delim = spec.find('@'); + auto format_delim = spec.find('%'); + std::string shapes = spec.substr(0, dtype_delim); + std::string dtype = spec.substr(dtype_delim + 1, format_delim - (dtype_delim + 1)); + std::string format = spec.substr(format_delim + 1, spec.size()); + + auto parsed_dtype = parseDataType(dtype); + if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown) { + trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid datatype for input specification " + spec); + std::cerr << parser; + exit(1); + } + auto parsed_format = parseTensorFormat(format); + if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown) { + trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid format for input specification " + spec); + std::cerr << parser; + exit(1); + } + if (shapes.rfind("(", 0) == 0) { + ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(shapes), parsed_dtype, parsed_format)); + } else if (shapes.rfind("[", 0) == 0) { + auto dyn_shapes = parseDynamicDim(shapes); + ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format)); + } else { + trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str); + std::cerr << parser; + exit(1); + } + // THERE IS NO SPEC FOR FORMAT + } else { + std::string shapes = spec.substr(0, spec.find('@')); + std::string dtype = spec.substr(spec.find('@') + 1, spec.size()); + + auto parsed_dtype = parseDataType(dtype); + if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown) { + trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid datatype for input specification " + spec); + std::cerr << parser; + exit(1); + } + if (shapes.rfind("(", 0) == 0) { + ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(shapes), parsed_dtype)); + } else if (shapes.rfind("[", 0) == 0) { + auto dyn_shapes = parseDynamicDim(shapes); + ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype)); + } else { + trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str); + std::cerr << parser; + exit(1); + } + } + // THERE IS A SPEC FOR FORMAT BUT NOT DTYPE + } else if (spec.find('%') != std::string::npos) { + std::string shapes = spec.substr(0, spec.find('%')); + std::string format = spec.substr(spec.find('%') + 1, spec.size()); + + auto parsed_format = parseTensorFormat(format); + if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown) { + trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid format for input specification " + spec); + std::cerr << parser; + exit(1); + } + if (shapes.rfind("(", 0) == 0) { + ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(shapes), parsed_format)); + } else if (shapes.rfind("[", 0) == 0) { + auto dyn_shapes = parseDynamicDim(shapes); + ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_format)); + } else { + trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str); + std::cerr << parser; + exit(1); + } + // JUST SHAPE USE DEFAULT DTYPE } else { - trtorch::logging::log( - trtorch::logging::Level::kERROR, - "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\""); - std::cerr << parser; - exit(1); + if (spec.rfind("(", 0) == 0) { + ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(spec))); + } else if (spec.rfind("[", 0) == 0) { + auto dyn_shapes = parseDynamicDim(spec); + ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2])); + } else { + trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str); + std::cerr << parser; + exit(1); + } } + std::stringstream ss; + ss << "Parsed Input: " << ranges.back(); + trtorch::logging::log(trtorch::logging::Level::kDEBUG, ss.str()); } auto compile_settings = trtorch::CompileSpec(ranges); @@ -277,57 +398,58 @@ int main(int argc, char** argv) { auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path); - if (op_precision) { - auto precision = args::get(op_precision); - std::transform( - precision.begin(), precision.end(), precision.begin(), [](unsigned char c) { return std::tolower(c); }); - if (precision == "float" || precision == "float32" || precision == "f32") { - compile_settings.op_precision = torch::kF32; - } else if (precision == "half" || precision == "float16" || precision == "f16") { - compile_settings.op_precision = torch::kF16; - } else if (precision == "int8" || precision == "i8") { - compile_settings.op_precision = torch::kI8; - if (calibration_cache_file) { - compile_settings.ptq_calibrator = calibrator; + if (enabled_precision) { + for (const auto precision : args::get(enabled_precision)) { + auto dtype = parseDataType(precision); + if (dtype == trtorch::CompileSpec::DataType::kFloat) { + compile_settings.enabled_precisions.insert(torch::kF32); + } else if (dtype == trtorch::CompileSpec::DataType::kHalf) { + compile_settings.enabled_precisions.insert(torch::kF16); + } else if (dtype == trtorch::CompileSpec::DataType::kChar) { + compile_settings.enabled_precisions.insert(torch::kI8); + if (calibration_cache_file) { + compile_settings.ptq_calibrator = calibrator; + } else { + trtorch::logging::log( + trtorch::logging::Level::kERROR, + "If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided"); + std::cerr << parser; + return 1; + } } else { trtorch::logging::log( trtorch::logging::Level::kERROR, - "If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided"); + "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ]"); std::cerr << parser; return 1; } - } else { - trtorch::logging::log( - trtorch::logging::Level::kERROR, - "Invalid default operating precision, options are [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ]"); - std::cerr << parser; - return 1; } + } - if (device_type) { - auto device = args::get(device_type); - std::transform(device.begin(), device.end(), device.begin(), [](unsigned char c) { return std::tolower(c); }); + if (device_type) { + auto device = args::get(device_type); + std::transform(device.begin(), device.end(), device.begin(), [](unsigned char c) { return std::tolower(c); }); - if (gpu_id) { - compile_settings.device.gpu_id = args::get(gpu_id); - trtorch::set_device(compile_settings.device.gpu_id); - } + if (gpu_id) { + compile_settings.device.gpu_id = args::get(gpu_id); + trtorch::set_device(compile_settings.device.gpu_id); + } - if (device == "gpu") { - compile_settings.device.device_type = trtorch::CompileSpec::Device::DeviceType::kGPU; - } else if (device == "dla") { - compile_settings.device.device_type = trtorch::CompileSpec::Device::DeviceType::kDLA; - if (dla_core) { - compile_settings.device.dla_core = args::get(dla_core); - } - } else { - trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ]"); - std::cerr << parser; - return 1; + if (device == "gpu") { + compile_settings.device.device_type = trtorch::CompileSpec::Device::DeviceType::kGPU; + } else if (device == "dla") { + compile_settings.device.device_type = trtorch::CompileSpec::Device::DeviceType::kDLA; + if (dla_core) { + compile_settings.device.dla_core = args::get(dla_core); } + } else { + trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ]"); + std::cerr << parser; + return 1; } } + if (engine_capability) { auto capability = args::get(engine_capability); std::transform( @@ -388,7 +510,7 @@ int main(int argc, char** argv) { } else { auto trt_mod = trtorch::CompileGraph(mod, compile_settings); - if (compile_settings.op_precision == trtorch::CompileSpec::DataType::kFloat) { + if (compile_settings.enabled_precisions.size() == 1 && compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) != compile_settings.enabled_precisions.end()) { double threshold_val = 2e-5; if (threshold) { threshold_val = args::get(threshold); @@ -398,7 +520,7 @@ int main(int argc, char** argv) { std::vector trt_inputs_ivalues; for (auto i : ranges) { - auto in = at::randn(i.opt, {at::kCUDA}); + auto in = at::randn(i.opt_shape, {at::kCUDA}); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); }