Skip to content

Commit

Permalink
Fix issue llvm#1145 - error lowering onnx.Cast
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Tiotto <[email protected]>
  • Loading branch information
Ettore Tiotto authored and etiotto committed Feb 4, 2022
1 parent 6297a58 commit de92188
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
61 changes: 29 additions & 32 deletions src/Dialect/ONNX/MLIRDialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,45 +120,39 @@ Value MathBuilder::max(Value lhs, Value rhs) const {
}

Value MathBuilder::sgt(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::sgt);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::sgt);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::OGT);
}

Value MathBuilder::sge(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::sge);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::sge);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::OGE);
}

Value MathBuilder::slt(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::slt);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::slt);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::OLT);
}

Value MathBuilder::sle(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::sle);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLE, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::sle);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::OLE);
}

Value MathBuilder::eq(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::eq);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::eq);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::OEQ);
}

Value MathBuilder::neq(Value lhs, Value rhs) const {
assert(lhs.getType() == rhs.getType() && "expected same type");
if (lhs.getType().isa<IntegerType>() || lhs.getType().isa<IndexType>())
return createArithICmp(lhs, rhs, arith::CmpIPredicate::ne);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, lhs, rhs);
return createArithCmp(lhs, rhs, arith::CmpIPredicate::ne);
return createArithCmp(lhs, rhs, arith::CmpFPredicate::ONE);
}

Value MathBuilder::select(Value cmp, Value lhs, Value rhs) const {
Expand Down Expand Up @@ -206,21 +200,24 @@ Value MathBuilder::constantIndex(int64_t val) const {
return b.create<arith::ConstantOp>(loc, constantAttr);
}

Value MathBuilder::createArithICmp(
Value MathBuilder::createArithCmp(
Value lhs, Value rhs, arith::CmpIPredicate pred) const {
Type type = lhs.getType();
assert(type == rhs.getType() && "operands should have the same type");
assert((type.isa<IntegerType>() || type.isa<IndexType>()) &&
"Expecting IntegerType or IndexType");

if (type.isa<IndexType>())
return b.create<arith::CmpIOp>(loc, pred, lhs, rhs);

assert(type.isSignlessInteger() &&
"arith::CmpIOp requires signless integer types.");
assert(type == rhs.getType() && "Operands should have the same type");
assert(((type.isa<IntegerType>() && type.isSignlessInteger()) ||
type.isa<IndexType>()) &&
"Expecting a signless IntegerType or an IndexType");
return b.create<arith::CmpIOp>(loc, pred, lhs, rhs);
}

Value MathBuilder::createArithCmp(
Value lhs, Value rhs, arith::CmpFPredicate pred) const {
Type type = lhs.getType();
assert(type == rhs.getType() && "Operands should have the same type");
assert(type.isa<FloatType>() && "Expecting a FloatType");
return b.create<arith::CmpFOp>(loc, pred, lhs, rhs);
}

// Several operations in the arith dialect require signless integers. This
// cast remove the sign of integer types for successful processing, to the
// best of my understanding.
Expand Down Expand Up @@ -261,8 +258,8 @@ Value MathBuilder::cast(Type destType, Value src) const {
destIsIndex = true;
}

// Only support Integer or Float type at this stage. Index were transformed to
// signless int.
// Only support Integer or Float type at this stage. Index were transformed
// to signless int.
// TODO: add support for shaped tensor (MemRef, Vector, Tensor?) if needed.
assert((srcType.isa<IntegerType>() || srcType.isa<FloatType>()) &&
"support only float or int");
Expand Down Expand Up @@ -342,8 +339,8 @@ Value MathBuilder::cast(Type destType, Value src) const {
// Int to int conversion.
if (srcType.isa<IntegerType>() && destType.isa<IntegerType>()) {
if (srcType.isUnsignedInteger()) {
// Unsigned to unsigned conversion. Has to convert to signless first, and
// recovert output to unsigned.
// Unsigned to unsigned conversion. Has to convert to signless first,
// and recovert output to unsigned.
assert(destType.isUnsignedInteger() && "no unsigned/signed conversion");
assert((bitExtend || bitTrunc) && "expected extend or trunc");
Value cast = castToSignless(src, srcWidth);
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/MLIRDialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ struct MathBuilder final : DialectBuilder {
Value castToIndex(Value val) const;

private:
Value createArithICmp(Value lhs, Value rhs, arith::CmpIPredicate pred) const;
Value createArithCmp(Value lhs, Value rhs, arith::CmpIPredicate pred) const;
Value createArithCmp(Value lhs, Value rhs, arith::CmpFPredicate pred) const;
Value castToSignless(Value source, int64_t width) const;
Value castToUnsigned(Value source, int64_t width) const;
};
Expand Down

0 comments on commit de92188

Please sign in to comment.