From d731ac81213e0946a7def4510fbae6ad1ed51104 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 25 Jan 2022 12:35:45 +0000 Subject: [PATCH 1/3] [microNPU] Refactor type inference data type checks Aims to improve readability, extendibility and error message unification for data type checks across NPU operators. A follow up for the comments in #9576. Change-Id: I83fb89a56677003f7abebb7985ad60d92cfa8df1 --- .../op/contrib/ethosu/binary_elementwise.cc | 102 ++++-------------- src/relay/op/contrib/ethosu/common.cc | 53 +++++++++ src/relay/op/contrib/ethosu/common.h | 36 +++++++ src/relay/op/contrib/ethosu/convolution.cc | 26 +---- src/relay/op/contrib/ethosu/depthwise.cc | 51 ++------- src/relay/op/contrib/ethosu/identity.cc | 10 +- src/relay/op/contrib/ethosu/pooling.cc | 18 ++-- .../op/contrib/ethosu/unary_elementwise.cc | 28 ++--- 8 files changed, 142 insertions(+), 182 deletions(-) diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index e7622452166c..8966540e0703 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -143,98 +143,34 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr."; - String operator_type = param->operator_type; - auto ifm_dtype = ifm->dtype; - auto ifm2_dtype = ifm2->dtype; - DataType ofm_dtype; + const String operator_name = "ethosu_binary_elementwise"; + const String operator_type = param->operator_type; + const DataType ifm_dtype = ifm->dtype; + const DataType ifm2_dtype = ifm2->dtype; + const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); - if (param->ofm_dtype == "int8") { - ofm_dtype = DataType::Int(8); - } else if (param->ofm_dtype == "uint8") { - ofm_dtype = DataType::UInt(8); - } else if (param->ofm_dtype == "int16") { - ofm_dtype = DataType::Int(16); - } else if (param->ofm_dtype == "int32") { - ofm_dtype = DataType::Int(32); - } - - if (ifm_dtype != ifm2_dtype) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << "type for ifm2 be the same of ifm but was " << ifm2_dtype - << " instead of " << ifm_dtype); - return false; - } + CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type); if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && - ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype); - return false; - } - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8), + DataType::Int(16), DataType::Int(32)}; + CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type); } else if (operator_type == "MIN" || operator_type == "MAX") { - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) for ifm but was " << ifm_dtype); - return false; - } - if (ifm_dtype != ofm_dtype) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type - << " type for ofm be the same of ifm but was " << ofm_dtype - << " instead of " << ifm_dtype); - return false; - } + std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8)}; + CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); + CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type); } else if (operator_type == "SHR") { - if (ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ifm but was " - << ifm_dtype); - return false; - } - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)}, + operator_name, "ofm", operator_type); } else if (operator_type == "SHL") { - if (ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ifm but was " - << ifm_dtype); - - return false; - } - if (ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ofm but was " - << ofm_dtype); - return false; - } + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type); } else { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or " + << "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or " << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type); return false; } diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index eac576257721..8e705b66bcb5 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -24,6 +24,9 @@ #include "common.h" +#include +#include + #include "../../op_common.h" namespace tvm { @@ -92,6 +95,56 @@ Array EthosuInferUpscaledInput(Array ifm_shape, String ifm return new_ifm_shape; } +DataType DataTypeFromString(const String& dtype) { + DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype); + return DataType(dl_dtype); +} + +void CheckDataType(const TypeReporter& reporter, const DataType& data_type, + const std::unordered_set& allowed_data_types, + const String& operator_name, const String& tensor_name, + const String& operator_type) { + if (allowed_data_types.find(data_type) != allowed_data_types.end()) { + return; + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != "") { + message << operator_type << " "; + } + message << "to have type in {"; + for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) { + message << *it; + if (std::next(it) != allowed_data_types.end()) { + message << ", "; + } + } + message << "}"; + message << " for " << tensor_name << " but was " << data_type << "."; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + +void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, + const DataType& data_type2, const String& operator_name, + const String& tensor_name, const String& tensor_name2, + const String& operator_type) { + if (data_type == data_type2) { + return; + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != " ") { + message << operator_type << " "; + } + message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was " + << data_type << " and " << data_type2; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index 001b596c0949..9238a7db95b1 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -27,6 +27,8 @@ #include +#include + namespace tvm { namespace relay { namespace op { @@ -65,6 +67,40 @@ Array EthosuInferKernelOutput(Array ifm_shape, String ifm_ */ Array EthosuInferUpscaledInput(Array ifm_shape, String ifm_layout); +/*! \brief Get data type from string representation. + * \param dtype Data type in lower case format followed by number of bits e.g. "int8". + */ +DataType DataTypeFromString(const String& dtype); + +/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a + * type inference error if not. + * \param reporter The infer type reporter. + * \param data_type The data ntype to check. + * \param allowed_data_types An unordered set of allowed data types. + * \param operator_name The name of the operator to report. + * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckDataType(const TypeReporter& reporter, const DataType& data_type, + const std::unordered_set& allowed_data_types, + const String& operator_name, const String& tensor_name, + const String& operator_type = ""); + +/*! \brief Check the data type matches that of the second data type provided. Raise a type inference + * error if not. + * \param reporter The infer type reporter. + * \param data_type The data type to check. + * \param data_type2 The second data type to check. + * \param operator_name The name of the operator to report. + * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". + * \param tensor_name2 The name of the second tensor to report e.g. "ifm2". + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, + const DataType& data_type2, const String& operator_name, + const String& tensor_name, const String& tensor_name2, + const String& operator_type = ""); + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 7b11f61acc12..4d9541ced816 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -131,28 +131,12 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr if (ifm == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr."; + const String operator_name = "ethosu_conv2d"; - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d input data type " - << "of type(uint8) or type(int8) but was " << ifm->dtype); - return false; - } - - if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d weight data type " - << "of type(uint8) or type(int8) but was " << weight->dtype); - return false; - } - - if (scale_bias->dtype != DataType::UInt(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d scale bias data type " - << "of type(uint8) but was " << scale_bias->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); + CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, + "weight"); + CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; if (upscale_methods.find(param->upscale) == upscale_methods.end()) { diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index c95385ad95d8..abfe0e3856a1 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -136,50 +136,17 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; - DataType ofm_dtype; + const String operator_name = "ethosu_depthwise_conv2d"; - if (param->ofm_dtype == "int8") { - ofm_dtype = DataType::Int(8); - } else if (param->ofm_dtype == "uint8") { - ofm_dtype = DataType::UInt(8); - } else if (param->ofm_dtype == "int16") { - ofm_dtype = DataType::Int(16); - } else if (param->ofm_dtype == "int32") { - ofm_dtype = DataType::Int(32); - } - - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d input data type " - << "of type(uint8) or type(int8) but was " << ifm->dtype); - return false; - } - - if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d weight data type " - << "of type(uint8) or type(int8) but was " << weight->dtype); - return false; - } - - if (scale_bias->dtype != DataType::UInt(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type " - << "of type(uint8) but was " << scale_bias->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); + CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, + "weight"); + CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d output data type " - << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); + std::unordered_set ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), + DataType::Int(32)}; + CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm"); // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; diff --git a/src/relay/op/contrib/ethosu/identity.cc b/src/relay/op/contrib/ethosu/identity.cc index c2b67477cfe9..350e8028f201 100644 --- a/src/relay/op/contrib/ethosu/identity.cc +++ b/src/relay/op/contrib/ethosu/identity.cc @@ -69,15 +69,11 @@ bool EthosuIdentityRel(const Array& types, int num_inputs, const Attrs& at if (ifm == nullptr) return false; const auto* param = attrs.as(); - ICHECK(param != nullptr) << "EthosuIdentityAttrs cannot be nullptr."; - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected type(uint8) or type(int8) for ifm but was " << ifm->dtype); - return false; - } + const String operator_name = "ethosu_identity"; + + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); if (ifm->shape.size() > 4) { reporter->GetDiagCtx().EmitFatal( diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index dc16c072ebe2..d9861954ac98 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -123,21 +123,17 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuPoolingAttrs cannot be nullptr."; + const String operator_name = "ethosu_pooling"; + if (param->pooling_type != "AVG" && param->pooling_type != "MAX") { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected pooling_type 'AVG' or 'MAX' but was " - << param->pooling_type); + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected " << operator_name + << " type 'AVG' or 'MAX' but was " << param->pooling_type); return false; } - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected pool type(uint8) or type(int8) for ifm but was " - << ifm->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", + param->pooling_type); const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; if (upscale_methods.find(param->upscale) == upscale_methods.end()) { diff --git a/src/relay/op/contrib/ethosu/unary_elementwise.cc b/src/relay/op/contrib/ethosu/unary_elementwise.cc index 9dc07e031d75..a346f095283c 100644 --- a/src/relay/op/contrib/ethosu/unary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/unary_elementwise.cc @@ -104,30 +104,22 @@ bool EthosuUnaryElementwiseRel(const Array& types, int num_inputs, const A const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuUnaryElementwiseAttrs cannot be nullptr."; - String operator_type = param->operator_type; + const String operator_name = "ethosu_unary_elementwise"; + const String operator_type = param->operator_type; if (operator_type != "ABS" && operator_type != "CLZ") { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise 'ABS' " - "or 'CLZ' for operator_type but was" + << "Invalid operator: expected << " << operator_name + << " 'ABS' or 'CLZ' for operator_type but was" << operator_type); return false; } - auto ifm_dtype = ifm->dtype; - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && operator_type == "ABS") { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise " - << operator_type << "input data type " - << "of type(uint8) or type(int8) but was " << ifm_dtype); - return false; - } - - if (ifm_dtype != DataType::Int(32) && operator_type == "CLZ") { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise CLZ input data type " - << "of type(int32) but was " << ifm_dtype); - return false; + const DataType ifm_dtype = ifm->dtype; + if (operator_type == "CLZ") { + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + } else { + CheckDataType(reporter, ifm_dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", + operator_type); } // Assign ofm type From 3536c4e0165b6ba3696886dbe5d2fd0c0253460d Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 7 Feb 2022 11:25:00 +0000 Subject: [PATCH 2/3] unordered_set -> initializer_list and use new format for upscale check Change-Id: Icf3d68d5cc7d5e1d5af42b1af193db89faea155e --- .../op/contrib/ethosu/binary_elementwise.cc | 6 +-- src/relay/op/contrib/ethosu/common.cc | 37 +++++++++++++++++-- src/relay/op/contrib/ethosu/common.h | 18 +++++++-- src/relay/op/contrib/ethosu/convolution.cc | 9 +---- src/relay/op/contrib/ethosu/depthwise.cc | 4 +- src/relay/op/contrib/ethosu/pooling.cc | 9 +---- 6 files changed, 55 insertions(+), 28 deletions(-) diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index 8966540e0703..258c84c70660 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -152,12 +152,12 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type); if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { - std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8), - DataType::Int(16), DataType::Int(32)}; + std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8), + DataType::Int(16), DataType::Int(32)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type); } else if (operator_type == "MIN" || operator_type == "MAX") { - std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8)}; + std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type); } else if (operator_type == "SHR") { diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index 8e705b66bcb5..a9fcd4301e81 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -101,11 +101,13 @@ DataType DataTypeFromString(const String& dtype) { } void CheckDataType(const TypeReporter& reporter, const DataType& data_type, - const std::unordered_set& allowed_data_types, + const std::initializer_list& allowed_data_types, const String& operator_name, const String& tensor_name, const String& operator_type) { - if (allowed_data_types.find(data_type) != allowed_data_types.end()) { - return; + for (const auto& i : allowed_data_types) { + if (data_type == i) { + return; + } } std::ostringstream message; @@ -126,6 +128,33 @@ void CheckDataType(const TypeReporter& reporter, const DataType& data_type, reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); } +void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method, + const std::initializer_list& allowed_upscale_methods, + const String& operator_name, const String& operator_type) { + for (const auto& i : allowed_upscale_methods) { + if (upscale_method == i) { + return; + } + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != "") { + message << operator_type << " "; + } + message << "to have upscale method in {"; + for (auto it = allowed_upscale_methods.begin(); it != allowed_upscale_methods.end(); ++it) { + message << *it; + if (std::next(it) != allowed_upscale_methods.end()) { + message << ", "; + } + } + message << "}"; + message << " but was " << upscale_method << "."; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, const DataType& data_type2, const String& operator_name, const String& tensor_name, const String& tensor_name2, @@ -136,7 +165,7 @@ void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, std::ostringstream message; message << "Invalid operator: expected " << operator_name << " "; - if (operator_type != " ") { + if (operator_type != "") { message << operator_type << " "; } message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was " diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index 9238a7db95b1..b993b78749ae 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -75,17 +75,29 @@ DataType DataTypeFromString(const String& dtype); /*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a * type inference error if not. * \param reporter The infer type reporter. - * \param data_type The data ntype to check. - * \param allowed_data_types An unordered set of allowed data types. + * \param data_type The data type to check. + * \param allowed_data_types An initializer list of allowed data types. * \param operator_name The name of the operator to report. * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. */ void CheckDataType(const TypeReporter& reporter, const DataType& data_type, - const std::unordered_set& allowed_data_types, + const std::initializer_list& allowed_data_types, const String& operator_name, const String& tensor_name, const String& operator_type = ""); +/*! \brief Check the upscale method matches one given in allowed_upscale_methods. Raise a type + * inference error if not. + * \param reporter The infer type reporter. + * \param upscale_method The upscale method string to check. + * \param allowed_upscale_methods An initializer list of allowed upscale methods. + * \param operator_name The name of the operator to report. + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method, + const std::initializer_list& allowed_upscale_methods, + const String& operator_name, const String& operator_type = ""); + /*! \brief Check the data type matches that of the second data type provided. Raise a type inference * error if not. * \param reporter The infer type reporter. diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 4d9541ced816..90bbf90d13c7 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -138,14 +138,7 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr "weight"); CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); - const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; - if (upscale_methods.find(param->upscale) == upscale_methods.end()) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected upsample method to be 'NONE', " - "'ZEROS' or 'NEAREST' but got " - << param->upscale); - return false; - } + CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); // The scale_bias should be provided as a tensor of size {ofm_channels, 10} reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8))); diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index abfe0e3856a1..b631f5b8e6a4 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -144,8 +144,8 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); - std::unordered_set ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), - DataType::Int(32)}; + std::initializer_list ofm_dtypes = {DataType::UInt(8), DataType::Int(8), + DataType::Int(16), DataType::Int(32)}; CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm"); // Collect the ifm, weight and ofm tensors for using in the inference function diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index d9861954ac98..3175e4ddffc4 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -135,14 +135,7 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", param->pooling_type); - const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; - if (upscale_methods.find(param->upscale) == upscale_methods.end()) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected upsample method to be 'NONE', " - "'ZEROS' or 'NEAREST' but got " - << param->upscale); - return false; - } + CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); Array ifm_shape = ifm->shape; if (param->upscale != "NONE") { From 86ac9c08b0b31368317f7559023804fb1c136809 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 7 Feb 2022 12:51:09 +0000 Subject: [PATCH 3/3] remove unused header and use auto for initializer type Change-Id: I10311b718c3abd0ed75dd88b5ec9de6e0742f047 --- src/relay/op/contrib/ethosu/binary_elementwise.cc | 6 +++--- src/relay/op/contrib/ethosu/common.cc | 1 - src/relay/op/contrib/ethosu/common.h | 2 -- src/relay/op/contrib/ethosu/depthwise.cc | 3 +-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index 258c84c70660..9a681b7cdc88 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -152,12 +152,12 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type); if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { - std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8), - DataType::Int(16), DataType::Int(32)}; + auto allowed_types = {DataType::Int(8), DataType::UInt(8), DataType::Int(16), + DataType::Int(32)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type); } else if (operator_type == "MIN" || operator_type == "MAX") { - std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8)}; + auto allowed_types = {DataType::Int(8), DataType::UInt(8)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type); } else if (operator_type == "SHR") { diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index a9fcd4301e81..5e957957bc1e 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -25,7 +25,6 @@ #include "common.h" #include -#include #include "../../op_common.h" diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index b993b78749ae..a399a2e53aa4 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -27,8 +27,6 @@ #include -#include - namespace tvm { namespace relay { namespace op { diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index b631f5b8e6a4..7e9fed5041be 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -144,8 +144,7 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); - std::initializer_list ofm_dtypes = {DataType::UInt(8), DataType::Int(8), - DataType::Int(16), DataType::Int(32)}; + auto ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), DataType::Int(32)}; CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm"); // Collect the ifm, weight and ofm tensors for using in the inference function