From 3536c4e0165b6ba3696886dbe5d2fd0c0253460d Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 7 Feb 2022 11:25:00 +0000 Subject: [PATCH] unordered_set -> initializer_list and use new format for upscale check Change-Id: Icf3d68d5cc7d5e1d5af42b1af193db89faea155e --- .../op/contrib/ethosu/binary_elementwise.cc | 6 +-- src/relay/op/contrib/ethosu/common.cc | 37 +++++++++++++++++-- src/relay/op/contrib/ethosu/common.h | 18 +++++++-- src/relay/op/contrib/ethosu/convolution.cc | 9 +---- src/relay/op/contrib/ethosu/depthwise.cc | 4 +- src/relay/op/contrib/ethosu/pooling.cc | 9 +---- 6 files changed, 55 insertions(+), 28 deletions(-) diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index 8966540e0703..258c84c70660 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -152,12 +152,12 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type); if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { - std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8), - DataType::Int(16), DataType::Int(32)}; + std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8), + DataType::Int(16), DataType::Int(32)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type); } else if (operator_type == "MIN" || operator_type == "MAX") { - std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8)}; + std::initializer_list allowed_types = {DataType::Int(8), DataType::UInt(8)}; CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type); } else if (operator_type == "SHR") { diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index 8e705b66bcb5..a9fcd4301e81 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -101,11 +101,13 @@ DataType DataTypeFromString(const String& dtype) { } void CheckDataType(const TypeReporter& reporter, const DataType& data_type, - const std::unordered_set& allowed_data_types, + const std::initializer_list& allowed_data_types, const String& operator_name, const String& tensor_name, const String& operator_type) { - if (allowed_data_types.find(data_type) != allowed_data_types.end()) { - return; + for (const auto& i : allowed_data_types) { + if (data_type == i) { + return; + } } std::ostringstream message; @@ -126,6 +128,33 @@ void CheckDataType(const TypeReporter& reporter, const DataType& data_type, reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); } +void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method, + const std::initializer_list& allowed_upscale_methods, + const String& operator_name, const String& operator_type) { + for (const auto& i : allowed_upscale_methods) { + if (upscale_method == i) { + return; + } + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != "") { + message << operator_type << " "; + } + message << "to have upscale method in {"; + for (auto it = allowed_upscale_methods.begin(); it != allowed_upscale_methods.end(); ++it) { + message << *it; + if (std::next(it) != allowed_upscale_methods.end()) { + message << ", "; + } + } + message << "}"; + message << " but was " << upscale_method << "."; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, const DataType& data_type2, const String& operator_name, const String& tensor_name, const String& tensor_name2, @@ -136,7 +165,7 @@ void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, std::ostringstream message; message << "Invalid operator: expected " << operator_name << " "; - if (operator_type != " ") { + if (operator_type != "") { message << operator_type << " "; } message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was " diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index 9238a7db95b1..b993b78749ae 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -75,17 +75,29 @@ DataType DataTypeFromString(const String& dtype); /*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a * type inference error if not. * \param reporter The infer type reporter. - * \param data_type The data ntype to check. - * \param allowed_data_types An unordered set of allowed data types. + * \param data_type The data type to check. + * \param allowed_data_types An initializer list of allowed data types. * \param operator_name The name of the operator to report. * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. */ void CheckDataType(const TypeReporter& reporter, const DataType& data_type, - const std::unordered_set& allowed_data_types, + const std::initializer_list& allowed_data_types, const String& operator_name, const String& tensor_name, const String& operator_type = ""); +/*! \brief Check the upscale method matches one given in allowed_upscale_methods. Raise a type + * inference error if not. + * \param reporter The infer type reporter. + * \param upscale_method The upscale method string to check. + * \param allowed_upscale_methods An initializer list of allowed upscale methods. + * \param operator_name The name of the operator to report. + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method, + const std::initializer_list& allowed_upscale_methods, + const String& operator_name, const String& operator_type = ""); + /*! \brief Check the data type matches that of the second data type provided. Raise a type inference * error if not. * \param reporter The infer type reporter. diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 4d9541ced816..90bbf90d13c7 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -138,14 +138,7 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr "weight"); CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); - const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; - if (upscale_methods.find(param->upscale) == upscale_methods.end()) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected upsample method to be 'NONE', " - "'ZEROS' or 'NEAREST' but got " - << param->upscale); - return false; - } + CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); // The scale_bias should be provided as a tensor of size {ofm_channels, 10} reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8))); diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index abfe0e3856a1..b631f5b8e6a4 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -144,8 +144,8 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); - std::unordered_set ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), - DataType::Int(32)}; + std::initializer_list ofm_dtypes = {DataType::UInt(8), DataType::Int(8), + DataType::Int(16), DataType::Int(32)}; CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm"); // Collect the ifm, weight and ofm tensors for using in the inference function diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index d9861954ac98..3175e4ddffc4 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -135,14 +135,7 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", param->pooling_type); - const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; - if (upscale_methods.find(param->upscale) == upscale_methods.end()) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected upsample method to be 'NONE', " - "'ZEROS' or 'NEAREST' but got " - << param->upscale); - return false; - } + CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); Array ifm_shape = ifm->shape; if (param->upscale != "NONE") {