Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verif.bmc VerifToSMT lowering #7603

Merged
merged 11 commits into from
Sep 24, 2024
7 changes: 6 additions & 1 deletion include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,12 @@ def ConvertHWToSMT : Pass<"convert-hw-to-smt", "mlir::ModuleOp"> {

def ConvertVerifToSMT : Pass<"convert-verif-to-smt", "mlir::ModuleOp"> {
let summary = "Convert Verif ops to SMT ops";
let dependentDialects = ["smt::SMTDialect", "mlir::arith::ArithDialect"];
let dependentDialects = [
"smt::SMTDialect",
"mlir::arith::ArithDialect",
"mlir::scf::SCFDialect",
"mlir::func::FuncDialect"
];
}

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 3 additions & 1 deletion include/circt/Conversion/VerifToSMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
#include <memory>

namespace circt {
class Namespace;

#define GEN_PASS_DECL_CONVERTVERIFTOSMT
#include "circt/Conversion/Passes.h.inc"

/// Get the Verif to SMT conversion patterns.
void populateVerifToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
Namespace &names);

} // namespace circt

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/HWToSMT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_circt_conversion_library(CIRCTHWToSMT
LINK_LIBS PUBLIC
CIRCTHW
CIRCTSMT
CIRCTSeq
MLIRFuncDialect
MLIRTransforms
MLIRTransformUtils
Expand Down
24 changes: 23 additions & 1 deletion lib/Conversion/HWToSMT/HWToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "circt/Conversion/HWToSMT.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/SMT/SMTOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -92,6 +93,20 @@ struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
}
};

/// Remove redundant (seq::FromClock and seq::ToClock) ops.
template <typename OpTy>
struct ReplaceWithInput : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;

LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -118,6 +133,9 @@ void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
return std::nullopt;
return smt::BitVectorType::get(type.getContext(), type.getWidth());
});
converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
return smt::BitVectorType::get(type.getContext(), 1);
});

// Default target materialization to convert from illegal types to legal
// types, e.g., at the boundary of an inlined child block.
Expand Down Expand Up @@ -200,12 +218,16 @@ void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
InstanceOpConversion>(converter, patterns.getContext());
InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
ReplaceWithInput<seq::FromClockOp>>(converter,
patterns.getContext());
}

void ConvertHWToSMTPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<hw::HWDialect>();
target.addIllegalOp<seq::FromClockOp>();
target.addIllegalOp<seq::ToClockOp>();
target.addLegalDialect<smt::SMTDialect>();
target.addLegalDialect<mlir::func::FuncDialect>();

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/VerifToSMT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_circt_conversion_library(CIRCTVerifToSMT
CIRCTSMT
CIRCTVerif
MLIRArithDialect
MLIRSCFDialect
TaoBi22 marked this conversation as resolved.
Show resolved Hide resolved
MLIRTransforms
MLIRTransformUtils
MLIRReconcileUnrealizedCasts
Expand Down
244 changes: 237 additions & 7 deletions lib/Conversion/VerifToSMT/VerifToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
#include "circt/Conversion/VerifToSMT.h"
#include "circt/Conversion/HWToSMT.h"
#include "circt/Dialect/SMT/SMTOps.h"
#include "circt/Dialect/Seq/SeqTypes.h"
#include "circt/Dialect/Verif/VerifOps.h"
#include "circt/Support/Namespace.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"

namespace circt {
#define GEN_PASS_DEF_CONVERTVERIFTOSMT
Expand All @@ -30,7 +35,8 @@ using namespace hw;
//===----------------------------------------------------------------------===//

namespace {
/// Lower a verif::AssertOp operation with an i1 operand to a smt::AssertOp.
/// Lower a verif::AssertOp operation with an i1 operand to a smt::AssertOp,
/// negated to check for unsatisfiability.
TaoBi22 marked this conversation as resolved.
Show resolved Hide resolved
struct VerifAssertOpConversion : OpConversionPattern<verif::AssertOp> {
using OpConversionPattern<verif::AssertOp>::OpConversionPattern;

Expand All @@ -40,6 +46,22 @@ struct VerifAssertOpConversion : OpConversionPattern<verif::AssertOp> {
Value cond = typeConverter->materializeTargetConversion(
rewriter, op.getLoc(), smt::BoolType::get(getContext()),
adaptor.getProperty());
Value notCond = rewriter.create<smt::NotOp>(op.getLoc(), cond);
rewriter.replaceOpWithNewOp<smt::AssertOp>(op, notCond);
return success();
}
};

/// Lower a verif::AssumeOp operation with an i1 operand to a smt::AssertOp
struct VerifAssumeOpConversion : OpConversionPattern<verif::AssumeOp> {
using OpConversionPattern<verif::AssumeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(verif::AssumeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cond = typeConverter->materializeTargetConversion(
rewriter, op.getLoc(), smt::BoolType::get(getContext()),
adaptor.getProperty());
rewriter.replaceOpWithNewOp<smt::AssertOp>(op, cond);
return success();
}
Expand Down Expand Up @@ -146,6 +168,204 @@ struct LogicEquivalenceCheckingOpConversion
}
};

/// Lower a verif::BMCOp operation to an MLIR program that performs the bounded
/// model check
struct VerifBoundedModelCheckingOpConversion
: OpConversionPattern<verif::BoundedModelCheckingOp> {
using OpConversionPattern<verif::BoundedModelCheckingOp>::OpConversionPattern;

VerifBoundedModelCheckingOpConversion(TypeConverter &converter,
MLIRContext *context, Namespace &names)
: OpConversionPattern(converter, context), names(names) {}

LogicalResult
matchAndRewrite(verif::BoundedModelCheckingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Type> oldLoopInputTy(op.getLoop().getArgumentTypes());
SmallVector<Type> oldCircuitInputTy(op.getCircuit().getArgumentTypes());
// TODO: the init and loop regions should be able to be concrete instead of
// symbolic which is probably preferable - just need to convert back and
// forth
SmallVector<Type> loopInputTy, circuitInputTy, initOutputTy,
circuitOutputTy;
if (failed(typeConverter->convertTypes(oldLoopInputTy, loopInputTy)))
return failure();
if (failed(typeConverter->convertTypes(oldCircuitInputTy, circuitInputTy)))
return failure();
if (failed(typeConverter->convertTypes(
op.getInit().front().back().getOperandTypes(), initOutputTy)))
return failure();
if (failed(typeConverter->convertTypes(
op.getCircuit().front().back().getOperandTypes(), circuitOutputTy)))
return failure();
if (failed(rewriter.convertRegionTypes(&op.getInit(), *typeConverter)))
return failure();
if (failed(rewriter.convertRegionTypes(&op.getLoop(), *typeConverter)))
return failure();
if (failed(rewriter.convertRegionTypes(&op.getCircuit(), *typeConverter)))
return failure();

unsigned numRegs =
cast<IntegerAttr>(op->getAttr("num_regs")).getValue().getZExtValue();

auto initFuncTy = rewriter.getFunctionType({}, initOutputTy);
// Loop and init output types are necessarily the same, so just use init
// output types
auto loopFuncTy = rewriter.getFunctionType(loopInputTy, initOutputTy);
auto circuitFuncTy =
rewriter.getFunctionType(circuitInputTy, circuitOutputTy);

func::FuncOp initFuncOp, loopFuncOp, circuitFuncOp;

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(
op->getParentOfType<ModuleOp>().getBody());
initFuncOp = rewriter.create<func::FuncOp>(loc, names.newName("bmc_init"),
initFuncTy);
rewriter.inlineRegionBefore(op.getInit(), initFuncOp.getFunctionBody(),
initFuncOp.end());
loopFuncOp = rewriter.create<func::FuncOp>(loc, names.newName("bmc_loop"),
loopFuncTy);
rewriter.inlineRegionBefore(op.getLoop(), loopFuncOp.getFunctionBody(),
loopFuncOp.end());
circuitFuncOp = rewriter.create<func::FuncOp>(
loc, names.newName("bmc_circuit"), circuitFuncTy);
rewriter.inlineRegionBefore(op.getCircuit(),
circuitFuncOp.getFunctionBody(),
circuitFuncOp.end());
auto funcOps = {&initFuncOp, &loopFuncOp, &circuitFuncOp};
// initOutputTy is the same as loop output types
auto outputTys = {initOutputTy, initOutputTy, circuitOutputTy};
for (auto [funcOp, outputTy] : llvm::zip(funcOps, outputTys)) {
auto operands = funcOp->getBody().front().back().getOperands();
rewriter.eraseOp(&funcOp->getFunctionBody().front().back());
rewriter.setInsertionPointToEnd(&funcOp->getBody().front());
SmallVector<Value> toReturn;
for (unsigned i = 0; i < outputTy.size(); ++i)
toReturn.push_back(typeConverter->materializeTargetConversion(
rewriter, loc, outputTy[i], operands[i]));
rewriter.create<func::ReturnOp>(loc, toReturn);
}
}

auto solver =
rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
rewriter.createBlock(&solver.getBodyRegion());

// Call init func to get initial clock values
ValueRange initVals =
rewriter.create<func::CallOp>(loc, initFuncOp)->getResults();

// InputDecls order should be <circuit arguments> <state arguments>
// <wasViolated>
// Get list of clock indexes in circuit args
size_t initIndex = 0, curIndex = 0;
SmallVector<Value> inputDecls;
SmallVector<int> clockIndexes;
for (auto [oldTy, newTy] : llvm::zip(oldCircuitInputTy, circuitInputTy)) {
if (isa<seq::ClockType>(oldTy)) {
inputDecls.push_back(initVals[initIndex++]);
clockIndexes.push_back(curIndex);
} else {
inputDecls.push_back(rewriter.create<smt::DeclareFunOp>(loc, newTy));
}
curIndex++;
TaoBi22 marked this conversation as resolved.
Show resolved Hide resolved
}

auto numStateArgs = initVals.size() - initIndex;
// Add the rest of the init vals (state args)
for (; initIndex < initVals.size(); ++initIndex)
inputDecls.push_back(initVals[initIndex]);

Value lowerBound =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
Value step =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
Value upperBound =
rewriter.create<arith::ConstantOp>(loc, adaptor.getBoundAttr());
Value constFalse =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));
Value constTrue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
inputDecls.push_back(constFalse); // wasViolated?

// TODO: swapping to a whileOp here would allow early exit once the property
// is violated
// Perform model check up to the provided bound
auto forOp = rewriter.create<scf::ForOp>(
loc, lowerBound, upperBound, step, inputDecls,
[&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
// Execute the circuit
ValueRange circuitCallOuts =
builder
.create<func::CallOp>(
loc, circuitFuncOp,
iterArgs.take_front(circuitFuncOp.getNumArguments()))
->getResults();
auto checkOp =
rewriter.create<smt::CheckOp>(loc, builder.getI1Type());
{
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&checkOp.getSatRegion());
builder.create<smt::YieldOp>(loc, constTrue);
builder.createBlock(&checkOp.getUnknownRegion());
builder.create<smt::YieldOp>(loc, constTrue);
builder.createBlock(&checkOp.getUnsatRegion());
builder.create<smt::YieldOp>(loc, constFalse);
}

Value violated = builder.create<arith::OrIOp>(
loc, checkOp.getResult(0), iterArgs.back());

// Call loop func to update clock & state arg values
SmallVector<Value> loopCallInputs;
// Fetch clock values to feed to loop
for (auto index : clockIndexes) {
loopCallInputs.push_back(iterArgs[index]);
}
// Fetch state args to feed to loop
for (auto stateArg : iterArgs.drop_back().take_back(numStateArgs)) {
loopCallInputs.push_back(stateArg);
}
TaoBi22 marked this conversation as resolved.
Show resolved Hide resolved
ValueRange loopVals =
builder.create<func::CallOp>(loc, loopFuncOp, loopCallInputs)
->getResults();

size_t loopIndex = 0;
// Collect decls to yield at end of iteration
SmallVector<Value> newDecls;
for (auto [oldTy, newTy] :
llvm::zip(TypeRange(oldCircuitInputTy).drop_back(numRegs),
TypeRange(circuitInputTy).drop_back(numRegs))) {
if (isa<seq::ClockType>(oldTy))
newDecls.push_back(loopVals[loopIndex++]);
else
newDecls.push_back(builder.create<smt::DeclareFunOp>(loc, newTy));
}
newDecls.append(
SmallVector<Value>(circuitCallOuts.take_back(numRegs)));

// Add the rest of the loop state args
for (; loopIndex < loopVals.size(); ++loopIndex)
newDecls.push_back(loopVals[loopIndex]);

newDecls.push_back(violated);

builder.create<scf::YieldOp>(loc, newDecls);
});

Value res = rewriter.create<arith::XOrIOp>(loc, forOp->getResults().back(),
constTrue);
rewriter.create<smt::YieldOp>(loc, res);
rewriter.replaceOp(op, solver.getResults());
return success();
}

Namespace &names;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -160,21 +380,31 @@ struct ConvertVerifToSMTPass
} // namespace

void circt::populateVerifToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<VerifAssertOpConversion, LogicEquivalenceCheckingOpConversion>(
converter, patterns.getContext());
RewritePatternSet &patterns,
Namespace &names) {
patterns.add<VerifAssertOpConversion, VerifAssumeOpConversion,
LogicEquivalenceCheckingOpConversion>(converter,
patterns.getContext());
patterns.add<VerifBoundedModelCheckingOpConversion>(
converter, patterns.getContext(), names);
}

void ConvertVerifToSMTPass::runOnOperation() {
ConversionTarget target(getContext());
target.addIllegalDialect<verif::VerifDialect>();
target.addLegalDialect<smt::SMTDialect, arith::ArithDialect>();
target.addLegalDialect<smt::SMTDialect, arith::ArithDialect, scf::SCFDialect,
func::FuncDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();

RewritePatternSet patterns(&getContext());
TypeConverter converter;
populateHWToSMTTypeConverter(converter);
populateVerifToSMTConversionPatterns(converter, patterns);

SymbolCache symCache;
symCache.addDefinitions(getOperation());
Namespace names;
names.add(symCache);
populateVerifToSMTConversionPatterns(converter, patterns, names);

if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
Loading
Loading