Skip to content

Commit

Permalink
feat(trtorchc): Adding more dtype aliases
Browse files Browse the repository at this point in the history
Adding more aliases for f32 and f16 to make it easier on users

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 3, 2021
1 parent e1e7812 commit 652fb13
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions cpp/trtorchc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
std::transform(
dtype_str.begin(), dtype_str.end(), dtype_str.begin(), [](unsigned char c) { return std::tolower(c); });
if (dtype_str == "float" || dtype_str == "float32" || dtype_str == "f32") {
if (dtype_str == "float" || dtype_str == "float32" || dtype_str == "f32" || dtype_str == "fp32") {
return trtorch::CompileSpec::DataType::kFloat;
} else if (dtype_str == "half" || dtype_str == "float16" || dtype_str == "f16") {
} else if (dtype_str == "half" || dtype_str == "float16" || dtype_str == "f16" || dtype_str == "fp16") {
return trtorch::CompileSpec::DataType::kHalf;
} else if (dtype_str == "char" || dtype_str == "int8" || dtype_str == "i8") {
return trtorch::CompileSpec::DataType::kChar;
Expand All @@ -73,7 +73,7 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
} else {
trtorch::logging::log(
trtorch::logging::Level::kERROR,
"Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str);
"Invalid precision, options are [ float | float32 | fp32 | f32 | half | float16 | fp16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str);
return trtorch::CompileSpec::DataType::kUnknown;
}
}
Expand Down Expand Up @@ -214,7 +214,7 @@ int main(int argc, char** argv) {
args::ValueFlagList<std::string> enabled_precision(
parser,
"precision",
"(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float)",
"(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | int8 | i8 | char ] (default: float)",
{'p', "enabled-precison"});
args::ValueFlag<std::string> device_type(
parser,
Expand Down Expand Up @@ -434,7 +434,7 @@ int main(int argc, char** argv) {
}
} else {
std::stringstream ss;
ss << "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ], found: ";
ss << "Invalid precision given for enabled kernel precision, options are [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | char | int8 | i8 ], found: ";
ss << dtype;
trtorch::logging::log(
trtorch::logging::Level::kERROR, ss.str());
Expand Down

0 comments on commit 652fb13

Please sign in to comment.