Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Refactor type inference data type checks #10060

Merged
merged 3 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 19 additions & 83 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,98 +143,34 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr.";

String operator_type = param->operator_type;
auto ifm_dtype = ifm->dtype;
auto ifm2_dtype = ifm2->dtype;
DataType ofm_dtype;
const String operator_name = "ethosu_binary_elementwise";
const String operator_type = param->operator_type;
const DataType ifm_dtype = ifm->dtype;
const DataType ifm2_dtype = ifm2->dtype;
const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm_dtype != ifm2_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< "type for ifm2 be the same of ifm but was " << ifm2_dtype
<< " instead of " << ifm_dtype);
return false;
}
CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type);

if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
auto allowed_types = {DataType::Int(8), DataType::UInt(8), DataType::Int(16),
DataType::Int(32)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type);
} else if (operator_type == "MIN" || operator_type == "MAX") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) for ifm but was " << ifm_dtype);
return false;
}
if (ifm_dtype != ofm_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type
<< " type for ofm be the same of ifm but was " << ofm_dtype
<< " instead of " << ifm_dtype);
return false;
}
auto allowed_types = {DataType::Int(8), DataType::UInt(8)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type);
} else if (operator_type == "SHR") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)},
operator_name, "ofm", operator_type);
} else if (operator_type == "SHL") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);

return false;
}
if (ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ofm but was "
<< ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type);
} else {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or "
<< "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or "
<< "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type);
return false;
}
Expand Down
81 changes: 81 additions & 0 deletions src/relay/op/contrib/ethosu/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "common.h"

#include <sstream>

#include "../../op_common.h"

namespace tvm {
Expand Down Expand Up @@ -92,6 +94,85 @@ Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm
return new_ifm_shape;
}

DataType DataTypeFromString(const String& dtype) {
DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype);
return DataType(dl_dtype);
}

void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type) {
for (const auto& i : allowed_data_types) {
if (data_type == i) {
return;
}
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have type in {";
for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) {
message << *it;
if (std::next(it) != allowed_data_types.end()) {
message << ", ";
}
}
message << "}";
message << " for " << tensor_name << " but was " << data_type << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type) {
for (const auto& i : allowed_upscale_methods) {
if (upscale_method == i) {
return;
}
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have upscale method in {";
for (auto it = allowed_upscale_methods.begin(); it != allowed_upscale_methods.end(); ++it) {
message << *it;
if (std::next(it) != allowed_upscale_methods.end()) {
message << ", ";
}
}
message << "}";
message << " but was " << upscale_method << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type) {
if (data_type == data_type2) {
return;
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was "
<< data_type << " and " << data_type2;

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
46 changes: 46 additions & 0 deletions src/relay/op/contrib/ethosu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,52 @@ Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String ifm_
*/
Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm_layout);

/*! \brief Get data type from string representation.
* \param dtype Data type in lower case format followed by number of bits e.g. "int8".
*/
DataType DataTypeFromString(const String& dtype);

/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a
* type inference error if not.
* \param reporter The infer type reporter.
* \param data_type The data type to check.
* \param allowed_data_types An initializer list of allowed data types.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type = "");

/*! \brief Check the upscale method matches one given in allowed_upscale_methods. Raise a type
* inference error if not.
* \param reporter The infer type reporter.
* \param upscale_method The upscale method string to check.
* \param allowed_upscale_methods An initializer list of allowed upscale methods.
* \param operator_name The name of the operator to report.
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type = "");

/*! \brief Check the data type matches that of the second data type provided. Raise a type inference
* error if not.
* \param reporter The infer type reporter.
* \param data_type The data type to check.
* \param data_type2 The second data type to check.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param tensor_name2 The name of the second tensor to report e.g. "ifm2".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type = "");

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
35 changes: 6 additions & 29 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,37 +131,14 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (ifm == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<EthosuConv2DAttrs>();
CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr.";
const String operator_name = "ethosu_conv2d";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected upsample method to be 'NONE', "
"'ZEROS' or 'NEAREST' but got "
<< param->upscale);
return false;
}
CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);

// The scale_bias should be provided as a tensor of size {ofm_channels, 10}
reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8)));
Expand Down
50 changes: 8 additions & 42 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,50 +136,16 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";

DataType ofm_dtype;
const String operator_name = "ethosu_depthwise_conv2d";

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d output data type "
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
auto ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), DataType::Int(32)};
CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm");

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};
Expand Down
10 changes: 3 additions & 7 deletions src/relay/op/contrib/ethosu/identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,11 @@ bool EthosuIdentityRel(const Array<Type>& types, int num_inputs, const Attrs& at
if (ifm == nullptr) return false;

const auto* param = attrs.as<EthosuIdentityAttrs>();

ICHECK(param != nullptr) << "EthosuIdentityAttrs cannot be nullptr.";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected type(uint8) or type(int8) for ifm but was " << ifm->dtype);
return false;
}
const String operator_name = "ethosu_identity";

CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");

if (ifm->shape.size() > 4) {
reporter->GetDiagCtx().EmitFatal(
Expand Down
Loading