diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 11738c2363bf..07f262ca563c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,6 +24,7 @@ set(CIRCT_TEST_DEPENDS circt-capi-firtool-test circt-as circt-dis + circt-lec circt-opt circt-translate circt-reduce diff --git a/test/Tools/circt-lec/Inputs/a.mlir b/test/Tools/circt-lec/Inputs/a.mlir new file mode 100644 index 000000000000..61d1215bdaba --- /dev/null +++ b/test/Tools/circt-lec/Inputs/a.mlir @@ -0,0 +1,9 @@ +hw.module @foo(in %a : i8, out b : i8) { + %c1_i8 = hw.constant 1 : i8 + %add = comb.add %a, %c1_i8: i8 + hw.output %add : i8 +} +hw.module @top_a(in %a : i8, out b : i8) { + %foo.b = hw.instance "foo" @foo(a: %a: i8) -> (b: i8) + hw.output %foo.b : i8 +} diff --git a/test/Tools/circt-lec/Inputs/b.mlir b/test/Tools/circt-lec/Inputs/b.mlir new file mode 100644 index 000000000000..9067409c945b --- /dev/null +++ b/test/Tools/circt-lec/Inputs/b.mlir @@ -0,0 +1,9 @@ +hw.module @foo(in %a : i8, out b : i8) { + %c2_i8 = hw.constant 2 : i8 + %add = comb.add %a, %c2_i8: i8 + hw.output %add : i8 +} +hw.module @top_b(in %a : i8, out b : i8) { + %foo.b = hw.instance "foo" @foo(a: %a: i8) -> (b: i8) + hw.output %foo.b : i8 +} diff --git a/test/Tools/circt-lec/merge-inputs.mlir b/test/Tools/circt-lec/merge-inputs.mlir new file mode 100644 index 000000000000..97cae970d02f --- /dev/null +++ b/test/Tools/circt-lec/merge-inputs.mlir @@ -0,0 +1,18 @@ +// RUN: circt-lec %S/Inputs/a.mlir %S/Inputs/b.mlir --c1 top_a --c2 top_b --emit-mlir | FileCheck %s + +// Check that a.mlir and b.mlir are properly by comparing constants +// CHECK-LABEL: func.func @foo_0(%arg0: !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %c2_bv8 = smt.bv.constant #smt.bv<2> +// CHECK-NEXT: %0 = smt.bv.add %arg0, %c2_bv8 +// CHECK-NEXT: return %0 + +// CHECK-LABEL: func.func @foo(%arg0: !smt.bv<8>) -> !smt.bv<8> +// CHECK-NEXT: %c1_bv8 = smt.bv.constant #smt.bv<1> +// CHECK-NEXT: %0 = smt.bv.add %arg0, %c1_bv8 +// CHECK-NEXT: return %0 + +// CHECK-LABEL: func.func @top_a +// CHECK: %[[RESULT1:.+]] = func.call @foo(%[[ARG:.+]]) +// CHECK-NEXT: %[[RESULT2:.+]] = func.call @foo_0(%[[ARG]]) +// CHECK-NEXT: %[[VAL:.+]] = smt.distinct %[[RESULT1]], %[[RESULT2]] +// CHECK-NEXT: smt.assert %[[VAL]] diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 0340cc25f881..26b9a164dead 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -60,8 +60,8 @@ tools = [ 'arcilator', 'circt-as', 'circt-capi-ir-test', 'circt-capi-om-test', 'circt-capi-firrtl-test', 'circt-capi-firtool-test', 'circt-dis', - 'circt-reduce', 'circt-translate', 'firtool', 'hlstool', 'om-linker', - 'ibistool' + 'circt-lec', 'circt-reduce', 'circt-translate', 'firtool', 'hlstool', + 'om-linker', 'ibistool' ] if "CIRCT_OPT_CHECK_IR_ROUNDTRIP" in os.environ: diff --git a/tools/circt-lec/circt-lec.cpp b/tools/circt-lec/circt-lec.cpp index ec03638e7e2a..ce2b586d6ae3 100644 --- a/tools/circt-lec/circt-lec.cpp +++ b/tools/circt-lec/circt-lec.cpp @@ -71,9 +71,9 @@ static cl::opt secondModuleName( cl::desc("Specify a named module for the second circuit of the comparison"), cl::value_desc("module name"), cl::cat(mainCategory)); -static cl::opt inputFilename(cl::Positional, cl::Required, - cl::desc(""), - cl::cat(mainCategory)); +static cl::list inputFilenames(cl::Positional, cl::OneOrMore, + cl::desc(""), + cl::cat(mainCategory)); static cl::opt outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), @@ -122,6 +122,65 @@ static cl::opt outputFormat( // Tool implementation //===----------------------------------------------------------------------===// +// Move all operations in `src` to `dest`. Rename all symbols in `src` to avoid +// conflict. +static FailureOr mergeModules(ModuleOp dest, ModuleOp src, + StringAttr name) { + + SymbolTable destTable(dest), srcTable(src); + StringAttr renamed = {}; + for (auto &op : src.getOps()) { + if (SymbolOpInterface symbol = dyn_cast(op)) { + auto oldSymbol = symbol.getNameAttr(); + auto result = srcTable.renameToUnique(&op, {&destTable}); + if (failed(result)) + return src->emitError() << "failed to rename symbol " << oldSymbol; + + if (oldSymbol == name) { + assert(!renamed && "symbol must be unique"); + renamed = *result; + } + } + } + + if (!name) + return src->emitError() + << "module " << name << " was not found in the second module"; + + dest.getBody()->getOperations().splice(dest.getBody()->begin(), + src.getBody()->getOperations()); + return renamed; +} + +// Parse one or two MLIR modules and merge it into a single module. +static FailureOr> +parseAndMergeModules(MLIRContext &context, TimingScope &ts) { + auto parserTimer = ts.nest("Parse and merge MLIR input(s)"); + + if (inputFilenames.size() > 2) { + llvm::errs() << "more than 2 files are provided!\n"; + return failure(); + } + + auto module = parseSourceFile(inputFilenames[0], &context); + if (!module) + return failure(); + + if (inputFilenames.size() == 2) { + auto moduleOpt = parseSourceFile(inputFilenames[1], &context); + if (!moduleOpt) + return failure(); + auto result = mergeModules(module.get(), moduleOpt.get(), + StringAttr::get(&context, secondModuleName)); + if (failed(result)) + return failure(); + + secondModuleName.setValue(result->getValue().str()); + } + + return module; +} + /// This functions initializes the various components of the tool and /// orchestrates the work to be done. static LogicalResult executeLEC(MLIRContext &context) { @@ -130,15 +189,12 @@ static LogicalResult executeLEC(MLIRContext &context) { applyDefaultTimingManagerCLOptions(tm); auto ts = tm.getRootScope(); - OwningOpRef module; - { - auto parserTimer = ts.nest("Parse MLIR input"); - // Parse the provided input files. - module = parseSourceFile(inputFilename, &context); - } - if (!module) + auto parsedModule = parseAndMergeModules(context, ts); + if (failed(parsedModule)) return failure(); + OwningOpRef module = std::move(parsedModule.value()); + // Create the output directory or output file depending on our mode. std::optional> outputFile; std::string errorMessage;