Skip to content

Commit

Permalink
notifyMatchFailure for tosa.floor
Browse files Browse the repository at this point in the history
Signed-off-by: Philipp Braun <[email protected]>
  • Loading branch information
philippb-amd committed Jul 22, 2022
1 parent c334758 commit 49ad859
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
13 changes: 1 addition & 12 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@ using namespace mlir;

namespace onnx_mlir {

static bool isSignedInt(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

static bool isFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

void populateONNXToTOSAConversionPattern(ConversionTarget &target,
RewritePatternSet &patterns, TypeConverter &typeConverter,
MLIRContext *ctx) {
Expand Down Expand Up @@ -65,7 +54,7 @@ void FrontendToTosaLoweringPass::runOnOperation() {
// conversion failures. Quantized types are not supported right now.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) -> Optional<Type> {
if (isSignedInt(type) || isFloat(type))
if (isTOSASignedInt(type) || isTOSAFloat(type))
return type;
return llvm::None;
});
Expand Down
21 changes: 20 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// This file lowers ONNX element-wise operators to TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/TypeUtilities.h"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;
Expand All @@ -32,6 +32,25 @@ class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern<ONNXOpT> {
}
};

template <typename ONNXOpT>
class ONNXUnaryOpLoweringToTOSA<ONNXOpT, tosa::FloorOp>
: public OpConversionPattern<ONNXOpT> {
public:
using OpConversionPattern<ONNXOpT>::OpConversionPattern;
using OpAdaptor = typename ONNXOpT::Adaptor;
LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto scalarType = getElementTypeOrSelf(adaptor.X());
if (!isTOSAFloat(scalarType))
return rewriter.notifyMatchFailure(
op, "`tosa.floor` only supports float types");

rewriter.replaceOpWithNewOp<tosa::FloorOp>(op, op.getType(), adaptor.X());
return success();
}
};

class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down
15 changes: 15 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// Check for valid TOSA types.
//===----------------------------------------------------------------------===//

inline bool isTOSASignedInt(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && intType.isSigned() &&
(intWidth.find(intType.getWidth()) != intWidth.end());
}

inline bool isTOSAFloat(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 49ad859

Please sign in to comment.