Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assign each question mark a unique negative integer value #1757

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/Dialect/Mlir/IndexExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ bool IndexExpr::canBeUsedInScope() const {

int64_t IndexExpr::getLiteral() const { return getObj().getLiteral(); }

int64_t IndexExpr::getQuestionmark() const {
return getObj().getQuestionmark();
}

AffineExpr IndexExpr::getAffineExpr() const { return getObj().getAffineExpr(); }

Value IndexExpr::getValue() const { return getObj().getValue(); }
Expand Down Expand Up @@ -399,15 +403,19 @@ void IndexExpr::debugPrint(
//===----------------------------------------------------------------------===//

/*static*/ void IndexExpr::getShape(SmallVectorImpl<IndexExpr> &indexExprList,
SmallVectorImpl<int64_t> &intDimList) {
SmallVectorImpl<int64_t> &intDimList, bool uniqueQuestionMark) {
intDimList.clear();
for (IndexExpr &expr : indexExprList) {
if (expr.isLiteral()) {
int64_t val = expr.getLiteral();
assert(val >= 0 && "expected positive values only");
intDimList.emplace_back(val);
} else
intDimList.emplace_back(-1);
} else {
if (uniqueQuestionMark)
intDimList.emplace_back(expr.getQuestionmark());
else
intDimList.emplace_back(-1);
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/Dialect/Mlir/IndexExpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class IndexExpr {
mlir::OpBuilder &getRewriter() const { return getScope().getRewriter(); }
mlir::Location getLoc() const { return getScope().getLoc(); }
int64_t getLiteral() const;
int64_t getQuestionmark() const;
mlir::AffineExpr getAffineExpr() const;
void getAffineMapAndOperands(
mlir::AffineMap &map, llvm::SmallVectorImpl<mlir::Value> &operands) const;
Expand All @@ -457,7 +458,8 @@ class IndexExpr {
// the (list of) Shape/Value/OpFoldResult corresponding to the original (list
// of) IndexExpr.
static void getShape(llvm::SmallVectorImpl<IndexExpr> &indexExprList,
llvm::SmallVectorImpl<int64_t> &intDimList);
llvm::SmallVectorImpl<int64_t> &intDimList,
bool uniqueQuestionMark = false);
static void getValues(mlir::ArrayRef<IndexExpr> indexExprArray,
llvm::SmallVectorImpl<mlir::Value> &valueList);
static void getOpOrFoldResults(
Expand Down
11 changes: 9 additions & 2 deletions src/Dialect/Mlir/IndexExprDetail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"

int64_t IndexExpr_gQuestionMarkCounter = -2;

using namespace mlir;

namespace onnx_mlir {
Expand All @@ -48,8 +50,8 @@ void IndexExprImpl::initAsUndefined() {
}

void IndexExprImpl::initAsQuestionmark() {
init(/*isDefined*/ true, /*literal*/ false, IndexExprKind::Questionmark, 0,
AffineExpr(nullptr), Value(nullptr));
init(/*isDefined*/ true, /*literal*/ false, IndexExprKind::Questionmark,
IndexExpr_gQuestionMarkCounter++, AffineExpr(nullptr), Value(nullptr));
}

void IndexExprImpl::initAsLiteral(int64_t const val, const IndexExprKind kind) {
Expand Down Expand Up @@ -243,6 +245,11 @@ int64_t IndexExprImpl::getLiteral() const {
return intLit;
}

int64_t IndexExprImpl::getQuestionmark() const {
assert(isQuestionmark() && "expected a question mark index expression");
return intLit;
}

//===----------------------------------------------------------------------===//
// IndexExprExpr transformational getters.
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 5 additions & 1 deletion src/Dialect/Mlir/IndexExprDetail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "src/Dialect/Mlir/IndexExpr.hpp"

extern int64_t IndexExpr_gQuestionMarkCounter;

namespace onnx_mlir {

// Implementation of the IndexExpr. In nearly all cases, the value described by
Expand Down Expand Up @@ -66,6 +68,7 @@ class IndexExprImpl {
mlir::Location getLoc() const { return getScope().getLoc(); }
IndexExprKind getKind() const;
int64_t getLiteral() const;
int64_t getQuestionmark() const;
mlir::AffineExpr getAffineExpr();
void getAffineMapAndOperands(
mlir::AffineMap &map, llvm::SmallVectorImpl<mlir::Value> &operands);
Expand All @@ -80,7 +83,8 @@ class IndexExprImpl {
bool literal;
// Type of IndexExpr. Literal are by default affine.
IndexExprKind kind;
// Integer value, valid when "literal" is true.
// Integer value, valid when "literal" or "question mark" is true. Negative
// value in case of question mark.
int64_t intLit;
// Affine expression, may be defined for literal, symbols, dims, or affine
// expr.
Expand Down