Skip to content

Commit

Permalink
Migrate from Std to Func dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
aqjune committed Mar 28, 2022
1 parent d04da45 commit 23832af
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 40 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ option(USE_LIBC "Use libc++ in case the MLIR (and CVC5) is linked against libc++
set(MLIR_INC_DIR "${MLIR_DIR}/include")
set(MLIR_LIB_DIR "${MLIR_DIR}/lib")
set(Z3_INC_DIR "${Z3_DIR}/include")
set(Z3_LIB_DIR "${Z3_DIR}/lib")
set(Z3_LIB_DIR "${Z3_DIR}/bin")
set(CVC5_INC_DIR "${CVC5_DIR}/include")
set(CVC5_LIB_DIR "${CVC5_DIR}/lib")

Expand Down Expand Up @@ -107,7 +107,7 @@ else()
MLIRIR MLIRDialect MLIRDialectUtils MLIRLinalg MLIRAffine MLIRMemRef
MLIRShape MLIRMath MLIRSparseTensor MLIRSCF MLIRArithmetic MLIRBufferization
MLIRComplex MLIRArithmeticUtils MLIRQuantUtils
MLIRStandard MLIRMemRefUtils MLIRTensor MLIRTosa MLIRQuant MLIRParser MLIRSupport MLIRControlFlow
MLIRFunc MLIRMemRefUtils MLIRTensor MLIRTosa MLIRQuant MLIRParser MLIRSupport MLIRControlFlow
LLVMSupport LLVMDemangle pthread m curses)
if (APPLE) # Apple LLD does not support 'group' flags
target_link_libraries(${PROJECT_LIB} PUBLIC ${LIB_LIST})
Expand Down
2 changes: 1 addition & 1 deletion src/analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ void analyzeBlock(
}
}

AnalysisResult analyze(mlir::FuncOp &fn) {
AnalysisResult analyze(mlir::func::FuncOp &fn) {
AnalysisResult res;

auto &region = fn.getRegion();
Expand Down
3 changes: 2 additions & 1 deletion src/analysis.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "utils.h"
Expand Down Expand Up @@ -29,4 +30,4 @@ struct AnalysisResult {
bool isElementwiseFPOps = true;
};

AnalysisResult analyze(mlir::FuncOp &fn);
AnalysisResult analyze(mlir::func::FuncOp &fn);
19 changes: 10 additions & 9 deletions src/encode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -813,12 +813,6 @@ void encodeOp(State &st, mlir::AffineApplyOp op, bool) {
st.regs.add(op, Index(move(*res)));
}

template<>
void encodeOp(State &st, mlir::ReturnOp op, bool) {
for (unsigned i = 0; i < op.getNumOperands(); ++i)
st.retValues.push_back(st.regs.findOrCrash(op.getOperand(i)));
}

template<>
void encodeOp(State &st, mlir::arith::SelectOp op, bool) {
auto condTy = op.getCondition().getType();
Expand Down Expand Up @@ -888,6 +882,12 @@ void encodeOp(State &st, mlir::arith::SelectOp op, bool) {
}
}

template<>
void encodeOp(State &st, mlir::func::ReturnOp op, bool) {
for (unsigned i = 0; i < op.getNumOperands(); ++i)
st.retValues.push_back(st.regs.findOrCrash(op.getOperand(i)));
}

template<>
void encodeOp(State &st, mlir::shape::ShapeOfOp op, bool) {
if (!op.getType().isa<mlir::TensorType>())
Expand Down Expand Up @@ -3388,7 +3388,6 @@ static void encodeBlock(
// Encode ops. Alphabetically sorted.
ENCODE(st, op, mlir::AffineApplyOp, encodeMemWriteOps);
ENCODE(st, op, mlir::arith::SelectOp, encodeMemWriteOps);
ENCODE(st, op, mlir::ReturnOp, encodeMemWriteOps);

ENCODE(st, op, mlir::arith::AddFOp, encodeMemWriteOps);
ENCODE(st, op, mlir::arith::AddIOp, encodeMemWriteOps);
Expand All @@ -3414,6 +3413,8 @@ static void encodeBlock(
ENCODE(st, op, mlir::bufferization::ToMemrefOp, encodeMemWriteOps);
ENCODE(st, op, mlir::bufferization::ToTensorOp, encodeMemWriteOps);

ENCODE(st, op, mlir::func::ReturnOp, encodeMemWriteOps);

ENCODE(st, op, mlir::math::AbsOp, encodeMemWriteOps);
ENCODE(st, op, mlir::math::ExpOp, encodeMemWriteOps);

Expand Down Expand Up @@ -3495,7 +3496,7 @@ static void encodeBlock(
llvm::outs() << "\n";
}

void encode(State &st, mlir::FuncOp &fn, bool printOps) {
void encode(State &st, mlir::func::FuncOp &fn, bool printOps) {
auto &region = fn.getRegion();
if (!llvm::hasSingleElement(region))
throw UnsupportedException(
Expand Down
3 changes: 2 additions & 1 deletion src/encode.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "state.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

#include <optional>
#include <string>

// encode can throw UnsupportedException.
void encode(State &st, mlir::FuncOp &fn, bool printOps);
void encode(State &st, mlir::func::FuncOp &fn, bool printOps);
9 changes: 4 additions & 5 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Parser.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include <string>

Expand Down Expand Up @@ -69,13 +68,13 @@ static unsigned validateBuffer(unique_ptr<llvm::MemoryBuffer> srcBuffer,
src_sourceMgr.AddNewSourceBuffer(move(srcBuffer), llvm::SMLoc());
tgt_sourceMgr.AddNewSourceBuffer(move(tgtBuffer), llvm::SMLoc());

auto ir_before = parseSourceFile(src_sourceMgr, context);
auto ir_before = parseSourceFile<ModuleOp>(src_sourceMgr, context);
if (!ir_before) {
llvm::errs() << "Cannot parse source file\n";
return 81;
}

auto ir_after = parseSourceFile(tgt_sourceMgr, context);
auto ir_after = parseSourceFile<ModuleOp>(tgt_sourceMgr, context);
if (!ir_after) {
llvm::errs() << "Cannot parse target file\n";
return 82;
Expand Down Expand Up @@ -110,10 +109,10 @@ int main(int argc, char* argv[]) {
DialectRegistry registry;
// NOTE: we cannot use mlir::registerAllDialects because IREE does not have
// dependency on some of those dialects
registry.insert<StandardOpsDialect>();
registry.insert<AffineDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<bufferization::BufferizationDialect>();
registry.insert<func::FuncDialect>();
registry.insert<linalg::LinalgDialect>();
registry.insert<math::MathDialect>();
registry.insert<memref::MemRefDialect>();
Expand Down
12 changes: 7 additions & 5 deletions src/print.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"

#include "abstractops.h"
#include "opts.h"
#include "print.h"
Expand All @@ -16,7 +18,7 @@ static string intToStr(Expr e) {
}
}

static void printInputs(Model m, mlir::FuncOp src, const State &st_src) {
static void printInputs(Model m, mlir::func::FuncOp src, const State &st_src) {
unsigned n = src.getNumArguments();
for (unsigned i = 0; i < n; ++i) {
auto argsrc = src.getArgument(i);
Expand Down Expand Up @@ -66,7 +68,7 @@ static Expr evalFromModel(Model m, Expr e) {
return wb;
}

void printOperations(Model m, mlir::FuncOp fn, const State &st) {
void printOperations(Model m, mlir::func::FuncOp fn, const State &st) {
for (auto &op: fn.getRegion().front()) {
llvm::outs() << "\t" << op << "\n";

Expand Down Expand Up @@ -94,8 +96,8 @@ void printOperations(Model m, mlir::FuncOp fn, const State &st) {
}

void printCounterEx(
Model m, const vector<Expr> &params, mlir::FuncOp src,
mlir::FuncOp tgt, const State &st_src, const State &st_tgt,
Model m, const vector<Expr> &params, mlir::func::FuncOp src,
mlir::func::FuncOp tgt, const State &st_src, const State &st_tgt,
VerificationStep step, unsigned retvalidx, optional<mlir::Type> memElemTy) {
llvm::outs() << "<Inputs>\n";
printInputs(m, src, st_src);
Expand All @@ -108,7 +110,7 @@ void printCounterEx(


if (step == VerificationStep::RetValue) {
if (src.getType().getResult(retvalidx).isa<mlir::TensorType>()) {
if (src.getResultTypes()[retvalidx].isa<mlir::TensorType>()) {
llvm::outs() << "\n<Returned tensor>\n";

auto t_src = get<Tensor>(st_src.retValues[retvalidx]).eval(m);
Expand Down
8 changes: 5 additions & 3 deletions src/print.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#pragma once

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"

#include "smt.h"
#include "state.h"
#include "vcgen.h"

#include <vector>
#include "mlir/IR/BuiltinOps.h"

void printOperations(smt::Model m, mlir::FuncOp fn, const State &st);
void printOperations(smt::Model m, mlir::func::FuncOp fn, const State &st);

void printCounterEx(
smt::Model model, const std::vector<smt::Expr> &params,
mlir::FuncOp src, mlir::FuncOp tgt,
mlir::func::FuncOp src, mlir::func::FuncOp tgt,
const State &st_src, const State &st_tgt,
VerificationStep step, unsigned retvalidx = -1,
std::optional<mlir::Type> memElemTy = std::nullopt);
21 changes: 11 additions & 10 deletions src/vcgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Defer {

class ValidationInput {
public:
mlir::FuncOp src, tgt;
mlir::func::FuncOp src, tgt;
string dumpSMTPath;

TypeMap<size_t> numBlocksPerType;
Expand Down Expand Up @@ -147,7 +147,7 @@ llvm::cl::opt<string> arg_verify_fn_name("compare-fn-name",
llvm::cl::cat(MlirTvCategory));

static optional<string> checkFunctionSignatures(
mlir::FuncOp src, mlir::FuncOp tgt) {
mlir::func::FuncOp src, mlir::func::FuncOp tgt) {
if (src.getNumArguments() != tgt.getNumArguments())
return "The source and target program have different number of arguments.";

Expand All @@ -167,7 +167,7 @@ static optional<string> checkFunctionSignatures(
}

static State createInputState(
mlir::FuncOp fn, std::unique_ptr<Memory> &&initMem,
mlir::func::FuncOp fn, std::unique_ptr<Memory> &&initMem,
ArgInfo &args, vector<Expr> &preconds) {
State s(move(initMem));
unsigned n = fn.getNumArguments();
Expand Down Expand Up @@ -294,8 +294,8 @@ static Results checkRefinement(
const ValidationInput &vinput,
const State &st_src, const State &st_tgt, Expr &&precond,
bool useAllLogic, int64_t &elapsedMillisec) {
mlir::FuncOp src = vinput.src;
mlir::FuncOp tgt = vinput.tgt;
mlir::func::FuncOp src = vinput.src;
mlir::func::FuncOp tgt = vinput.tgt;
auto fnname = src.getName().str();

auto printErrorMsg = [&](Solver &s, CheckResult res, const char *msg,
Expand Down Expand Up @@ -433,7 +433,7 @@ static void printUnsupported(const UnsupportedException &ue) {
static State encodeFinalState(
const ValidationInput &vinput, unique_ptr<Memory> &&initMem,
bool printOps, bool issrc, ArgInfo &args, vector<Expr> &preconds) {
mlir::FuncOp fn = issrc ? vinput.src : vinput.tgt;
mlir::func::FuncOp fn = issrc ? vinput.src : vinput.tgt;

State st = createInputState(fn, move(initMem), args, preconds);

Expand Down Expand Up @@ -504,7 +504,7 @@ static Results tryValidation(
static void checkIsSrcAlwaysUB(
const ValidationInput &vinput, bool wasSuccess, bool useAllLogic,
int64_t &elapsedMillisec) {
mlir::FuncOp src = vinput.src;
mlir::func::FuncOp src = vinput.src;
string fnname = src.getName().str();

// Set the abstract level to be as concrete as possible because we may not
Expand Down Expand Up @@ -745,9 +745,10 @@ static vector<mlir::memref::GlobalOp> mergeGlobals(
Results validate(
mlir::OwningOpRef<mlir::ModuleOp> &src,
mlir::OwningOpRef<mlir::ModuleOp> &tgt) {
map<llvm::StringRef, mlir::FuncOp> srcfns, tgtfns;
auto fillFns = [](map<llvm::StringRef, mlir::FuncOp> &m, mlir::Operation &op) {
auto fnop = mlir::dyn_cast<mlir::FuncOp>(op);
map<llvm::StringRef, mlir::func::FuncOp> srcfns, tgtfns;
auto fillFns = [](map<llvm::StringRef, mlir::func::FuncOp> &m,
mlir::Operation &op) {
auto fnop = mlir::dyn_cast<mlir::func::FuncOp>(op);
if (fnop && !fnop.isDeclaration()) {
m[fnop.getName()] = fnop;
}
Expand Down
6 changes: 3 additions & 3 deletions unittests/state_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
}
)"";

std::vector<mlir::FuncOp> parseIR(std::string IR, mlir::MLIRContext *ctx) {
std::vector<mlir::func::FuncOp> parseIR(std::string IR, mlir::MLIRContext *ctx) {
src = mlir::parseSourceString(sourceIR, ctx);
std::vector<mlir::FuncOp> srcFns;
llvm::for_each(*src, [&](auto &op) { srcFns.push_back(mlir::dyn_cast<mlir::FuncOp>(op)); });
std::vector<mlir::func::FuncOp> srcFns;
llvm::for_each(*src, [&](auto &op) { srcFns.push_back(mlir::dyn_cast<mlir::func::FuncOp>(op)); });
return srcFns;
}

Expand Down

0 comments on commit 23832af

Please sign in to comment.