Skip to content

Commit

Permalink
unordered_set -> initializer_list and use new format for upscale check
Browse files Browse the repository at this point in the history
Change-Id: Icf3d68d5cc7d5e1d5af42b1af193db89faea155e
  • Loading branch information
lhutton1 committed Feb 7, 2022
1 parent d731ac8 commit 3536c4e
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type);

if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
std::unordered_set<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8),
DataType::Int(16), DataType::Int(32)};
std::initializer_list<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8),
DataType::Int(16), DataType::Int(32)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type);
} else if (operator_type == "MIN" || operator_type == "MAX") {
std::unordered_set<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8)};
std::initializer_list<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type);
} else if (operator_type == "SHR") {
Expand Down
37 changes: 33 additions & 4 deletions src/relay/op/contrib/ethosu/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ DataType DataTypeFromString(const String& dtype) {
}

void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::unordered_set<DataType>& allowed_data_types,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type) {
if (allowed_data_types.find(data_type) != allowed_data_types.end()) {
return;
for (const auto& i : allowed_data_types) {
if (data_type == i) {
return;
}
}

std::ostringstream message;
Expand All @@ -126,6 +128,33 @@ void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type) {
for (const auto& i : allowed_upscale_methods) {
if (upscale_method == i) {
return;
}
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have upscale method in {";
for (auto it = allowed_upscale_methods.begin(); it != allowed_upscale_methods.end(); ++it) {
message << *it;
if (std::next(it) != allowed_upscale_methods.end()) {
message << ", ";
}
}
message << "}";
message << " but was " << upscale_method << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
Expand All @@ -136,7 +165,7 @@ void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != " ") {
if (operator_type != "") {
message << operator_type << " ";
}
message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was "
Expand Down
18 changes: 15 additions & 3 deletions src/relay/op/contrib/ethosu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,29 @@ DataType DataTypeFromString(const String& dtype);
/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a
* type inference error if not.
* \param reporter The infer type reporter.
* \param data_type The data ntype to check.
* \param allowed_data_types An unordered set of allowed data types.
* \param data_type The data type to check.
* \param allowed_data_types An initializer list of allowed data types.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::unordered_set<DataType>& allowed_data_types,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type = "");

/*! \brief Check the upscale method matches one given in allowed_upscale_methods. Raise a type
* inference error if not.
* \param reporter The infer type reporter.
* \param upscale_method The upscale method string to check.
* \param allowed_upscale_methods An initializer list of allowed upscale methods.
* \param operator_name The name of the operator to report.
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type = "");

/*! \brief Check the data type matches that of the second data type provided. Raise a type inference
* error if not.
* \param reporter The infer type reporter.
Expand Down
9 changes: 1 addition & 8 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,7 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected upsample method to be 'NONE', "
"'ZEROS' or 'NEAREST' but got "
<< param->upscale);
return false;
}
CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);

// The scale_bias should be provided as a tensor of size {ofm_channels, 10}
reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8)));
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
std::unordered_set<DataType> ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16),
DataType::Int(32)};
std::initializer_list<DataType> ofm_dtypes = {DataType::UInt(8), DataType::Int(8),
DataType::Int(16), DataType::Int(32)};
CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm");

// Collect the ifm, weight and ofm tensors for using in the inference function
Expand Down
9 changes: 1 addition & 8 deletions src/relay/op/contrib/ethosu/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,7 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm",
param->pooling_type);

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected upsample method to be 'NONE', "
"'ZEROS' or 'NEAREST' but got "
<< param->upscale);
return false;
}
CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);

Array<IndexExpr> ifm_shape = ifm->shape;
if (param->upscale != "NONE") {
Expand Down

0 comments on commit 3536c4e

Please sign in to comment.