diff --git a/jlm/mlir/Makefile.sub b/jlm/mlir/Makefile.sub index d3403c616..0f6edd754 100644 --- a/jlm/mlir/Makefile.sub +++ b/jlm/mlir/Makefile.sub @@ -16,6 +16,7 @@ endif LIBMLIR_SRC = \ jlm/mlir/backend/mlirgen.cpp \ + jlm/mlir/frontend/rvsdggen.cpp \ .PHONY: libmlir-debug libmlir-debug: CXXFLAGS += $(CXXFLAGS_DEBUG) diff --git a/jlm/mlir/frontend/rvsdggen.cpp b/jlm/mlir/frontend/rvsdggen.cpp new file mode 100644 index 000000000..9f30a0a3a --- /dev/null +++ b/jlm/mlir/frontend/rvsdggen.cpp @@ -0,0 +1,237 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#ifdef MLIR_ENABLED + +#include "jlm/mlir/frontend/rvsdggen.hpp" + +#include +#include +#include +#include + +//#include +//#include +//#include +//#include +//#include +//#include + +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" + +namespace jlm::mlirrvsdg +{ + +std::unique_ptr +RVSDGGen::readRvsdgMlir(const util::filepath &filePath) +{ + // Configer the parser + mlir::ParserConfig config = mlir::ParserConfig(Context_.get()); + // Variable for storing the result + std::unique_ptr block = std::make_unique(); + // Read the input file + auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config); + if (result.failed()) + { + throw util::error("Parsing MLIR input file failed."); + } + return block; +} + +std::unique_ptr +RVSDGGen::convertMlir(std::unique_ptr & block) +{ + // Create RVSDG module + std::string dataLayout; + std::string targetTriple; + util::filepath sourceFileName(""); + auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout); + + // Get the root region + auto & graph = rvsdgModule->Rvsdg(); + auto root = graph.root(); + + // Convert the MLIR into an RVSDG graph + convertBlock(*block.get(), *root); + + return rvsdgModule; +} + +void +RVSDGGen::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) +{ + for (mlir::Operation & op : block.getOperations()) + { + ::llvm::outs() << "- Current operation " + << op.getName() << "\n"; + + if (op.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName()) + { + convertOmega(op, rvsdgRegion); + } + else if (op.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName()) + { + convertLambda(op, rvsdgRegion); + } + else if (op.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName()) + { + auto constant = static_cast(&op); + + // Need the type to know its width + auto type = constant.getType(); + JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID()); + auto *integerType = static_cast(&type); + + rvsdg::create_bitconstant(&rvsdgRegion, integerType->getWidth(), constant.value()); + } else { + ::llvm::outs() << "- Operation not implemented: " + << op.getName() << "\n"; +// throw util::error("Operation is not implemented:" + op.getName() + "\n"); + } + } +} + +void +RVSDGGen::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " - Converting region\n"; + for (mlir::Block &block : region.getBlocks()) + { + convertBlock(block, rvsdgRegion); + } +} + +void +RVSDGGen::convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " ** Converting Omega **\n"; + for (mlir::Value operand : omega.getOperands()) { + if (mlir::Operation *producer = operand.getDefiningOp()) { + ::llvm::outs() << " - Operand produced by operation '" + << producer->getName() << "'\n"; + } + else + { + // If there is no defining op, the Value is necessarily a Block + // argument. + auto blockArg = operand.cast(); + ::llvm::outs() << " - Operand produced by Block argument, number " + << blockArg.getArgNumber() << " " + << blockArg.getType() << "\n"; + } + } + + // TODO + // There should only exist one region in an omega + for (mlir::Region ®ion : omega.getRegions()) + { + convertRegion(region, rvsdgRegion); + } +} + +void +RVSDGGen::convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " ** Converting Lambda **\n"; + + // Get the name of the function + auto functionNameAttribute = lambda.getAttr(::llvm::StringRef("sym_name")); + JLM_ASSERT(functionNameAttribute != NULL); + mlir::StringAttr * functionName = static_cast(&functionNameAttribute); + ::llvm::outs() << "Function name: " << functionName->getValue().str() << "\n"; + + // A lambda node has only the function signature as the result + JLM_ASSERT(lambda.getNumResults() == 1); + + // Get the function signature + auto result = lambda.getResult(0).getType(); + ::llvm::outs() << " - Function signature: '" << result << "'\n"; + if (result.getTypeID() == mlir::rvsdg::LambdaRefType::getTypeID()) + { + std::vector arguments; + std::vector results; + mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast(&result); + for (auto argumentType : lambdaRefType->getParameterTypes()) + { + auto argument = convertType(argumentType); + ::llvm::outs() << " - Argument: '" << argument->debug_string() << "\n"; + arguments.push_back(argument); + } + for (auto returnType : lambdaRefType->getReturnTypes()) + { + auto result = convertType(returnType); + ::llvm::outs() << " - Result: '" << result->debug_string() << "\n"; + results.push_back(result); + } + llvm::FunctionType functionType(arguments, results); + auto rvsdgLambda = llvm::lambda::node::create( + &rvsdgRegion, + functionType, + functionName->getValue().str(), + llvm::linkage::external_linkage); + auto lambdaRegion = rvsdgLambda->subregion(); + + for (mlir::Value operand : lambda.getOperands()) { + if (mlir::Operation *producer = operand.getDefiningOp()) { + ::llvm::outs() << " - Operand produced by operation '" + << producer->getName() << "'\n"; + } + else + { + // If there is no defining op, the Value is necessarily a Block + // argument. + auto blockArg = operand.cast(); + ::llvm::outs() << " - Operand produced by Block argument, number " + << blockArg.getArgNumber() << " " + << blockArg.getType() << "\n"; + } + } + + for (mlir::Region ®ion : lambda.getRegions()) + { + convertRegion(region, *lambdaRegion); + } + + } + else + { + throw util::error("The result from lambda node is not a LambdaRefType\n"); + } +/* + FunctionType functionType({&vt}, {&vt}); + + auto lambda = lambda::node::create(rvsdgRegion, functionType, "f", linkage::external_linkage); + lambda->finalize({lambda->fctargument(0)}); + + std::vector functionArguments; + for (auto & argument : lambda->fctarguments()) + functionArguments.push_back(&argument); +*/ +} + +rvsdg::type * +RVSDGGen::convertType(mlir::Type & type) +{ + // TODO + // Fix memory leak + if (type.getTypeID() == mlir::IntegerType::getTypeID()) { + auto * intType = static_cast(&type); + return new rvsdg::bittype(intType->getWidth()); + } else if (type.getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID()) { + return new llvm::loopstatetype(); + } else if (type.getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID()) { + return new llvm::MemoryStateType(); + } else if (type.getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID()) { + return new llvm::iostatetype(); + } else { + throw util::error("Type conversion not implemented\n"); + } +} + +} // jlm::mlirrvsdg + +#endif // MLIR_ENABLED diff --git a/jlm/mlir/frontend/rvsdggen.hpp b/jlm/mlir/frontend/rvsdggen.hpp new file mode 100644 index 000000000..d55f2aa13 --- /dev/null +++ b/jlm/mlir/frontend/rvsdggen.hpp @@ -0,0 +1,93 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_MLIR_FRONTEND_RVSDGGEN_HPP +#define JLM_MLIR_FRONTEND_RVSDGGEN_HPP + +#ifdef MLIR_ENABLED + +//#include +//#include +//#include +//#include +//#include +//#include +//#include + +#include +#include + +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" + +namespace jlm::mlirrvsdg { + + class RVSDGGen final { + public: + RVSDGGen() { + Context_ = std::make_unique(); + // Load the RVSDG dialect + Context_->getOrLoadDialect(); + // Load the JLM dialect + Context_->getOrLoadDialect(); + // Load the Arith dialect + Context_->getOrLoadDialect(); + } + + RVSDGGen(const RVSDGGen &) = delete; + + RVSDGGen(RVSDGGen &&) = delete; + + RVSDGGen & + operator=(const RVSDGGen &) = delete; + + RVSDGGen & + operator=(RVSDGGen &&) = delete; + + std::unique_ptr + readRvsdgMlir( + const util::filepath &filePath); + + std::unique_ptr + convertMlir( + std::unique_ptr & block); + + private: + void + convertBlock( + mlir::Block & block, + rvsdg::region & rvsdgRegion); + + void + convertRegion( + mlir::Region & region, + rvsdg::region & rvsdgRegion); + + void + convertOmega( + mlir::Operation & omega, + rvsdg::region & rvsdgRegion); + + void + convertLambda( + mlir::Operation & lambda, + rvsdg::region & rvsdgRegion); + + rvsdg::type * + convertType( + mlir::Type & type); + + + std::unique_ptr Context_; + }; + +} // namespace jlm::mlirrvsdg + +#endif // MLIR_ENABLED + +#endif // JLM_MLIR_FRONTEND_RVSDGGEN_HPP \ No newline at end of file diff --git a/jlm/tooling/Command.cpp b/jlm/tooling/Command.cpp index 9839734da..8b6052578 100644 --- a/jlm/tooling/Command.cpp +++ b/jlm/tooling/Command.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -309,21 +310,12 @@ JlmOptCommand::ToString() const void JlmOptCommand::Run() const { - ::llvm::LLVMContext llvmContext; - auto llvmModule = ParseLlvmIrFile(CommandLineOptions_.GetInputFile(), llvmContext); + jlm::util::StatisticsCollector statisticsCollector(CommandLineOptions_.GetStatisticsCollectorSettings()); - auto interProceduralGraphModule = llvm::ConvertLlvmModule(*llvmModule); - - /* - * Dispose of Llvm module. It is no longer needed. - */ - llvmModule.reset(); - - jlm::util::StatisticsCollector statisticsCollector( - CommandLineOptions_.GetStatisticsCollectorSettings()); - - auto rvsdgModule = - llvm::ConvertInterProceduralGraphModule(*interProceduralGraphModule, statisticsCollector); + auto rvsdgModule = ParseInputFile( + CommandLineOptions_.GetInputFile(), + CommandLineOptions_.GetInputFormat(), + statisticsCollector); llvm::OptimizationSequence::CreateAndRun( *rvsdgModule, @@ -339,21 +331,63 @@ JlmOptCommand::Run() const statisticsCollector.PrintStatistics(); } -std::unique_ptr<::llvm::Module> -JlmOptCommand::ParseLlvmIrFile(const util::filepath & llvmIrFile, ::llvm::LLVMContext & llvmContext) - const +std::unique_ptr +JlmOptCommand::ParseInputFile( + const util::filepath & inputFile, + const JlmOptCommandLineOptions::InputFormat & inputFormat, + util::StatisticsCollector & statisticsCollector) + const { - ::llvm::SMDiagnostic diagnostic; - if (auto module = ::llvm::parseIRFile(llvmIrFile.to_str(), diagnostic, llvmContext)) + auto parseLlvmIrFile = [=] ( + const util::filepath & llvmIrFile, + util::StatisticsCollector & statisticsCollector) + -> std::unique_ptr { - return module; - } + ::llvm::LLVMContext llvmContext; + ::llvm::SMDiagnostic diagnostic; + if (auto llvmModule = ::llvm::parseIRFile(llvmIrFile.to_str(), diagnostic, llvmContext)) + { + auto interProceduralGraphModule = llvm::ConvertLlvmModule(*llvmModule); - std::string errors; - ::llvm::raw_string_ostream os(errors); - diagnostic.print(ProgramName_.c_str(), os); - throw util::error(errors); -} + // Dispose of Llvm module. It is no longer needed. + llvmModule.reset(); + + auto rvsdgModule = + llvm::ConvertInterProceduralGraphModule(*interProceduralGraphModule, statisticsCollector); + + return rvsdgModule; + } + + std::string errors; + ::llvm::raw_string_ostream os(errors); + diagnostic.print(ProgramName_.c_str(), os); + throw util::error(errors); + }; + + auto parseMlirIrFile = [] ( + const util::filepath & mlirIrFile, + util::StatisticsCollector & statisticsCollector) + -> std::unique_ptr + { +#ifdef MLIR_ENABLED + jlm::mlirrvsdg::RVSDGGen rvsdggen; + auto block = rvsdggen.readRvsdgMlir(mlirIrFile); + return rvsdggen.convertMlir(block); +#else + throw util::error("This version of jlm-opt has not been compiled with support for the MLIR backend\n"); +#endif + }; + + static std::unordered_map< + JlmOptCommandLineOptions::InputFormat, + std::function< + std::unique_ptr(const util::filepath &, util::StatisticsCollector &)>> + parsers({ { tooling::JlmOptCommandLineOptions::InputFormat::Llvm, parseLlvmIrFile }, + { tooling::JlmOptCommandLineOptions::InputFormat::Mlir, parseMlirIrFile }}); + + JLM_ASSERT(parsers.find(inputFormat) != parsers.end()); + return parsers[inputFormat](inputFile, statisticsCollector); +}; void JlmOptCommand::PrintRvsdgModule( diff --git a/jlm/tooling/Command.hpp b/jlm/tooling/Command.hpp index a3b9fc873..35954d513 100644 --- a/jlm/tooling/Command.hpp +++ b/jlm/tooling/Command.hpp @@ -370,8 +370,21 @@ class JlmOptCommand final : public Command } private: - std::unique_ptr<::llvm::Module> - ParseLlvmIrFile(const util::filepath & llvmIrFile, ::llvm::LLVMContext & llvmContext) const; + std::unique_ptr + ParseInputFile( + const util::filepath & inputFile, + const JlmOptCommandLineOptions::InputFormat & inputFormat, + util::StatisticsCollector & statisticsCollector) const; + + std::unique_ptr + ParseLlvmIrFile( + const util::filepath & llvmIrFile, + util::StatisticsCollector & statisticsCollector) const; + + std::unique_ptr + ParseMlirIrFile( + const util::filepath & rvsdgIrFile, + util::StatisticsCollector & statisticsCollector) const; static void PrintRvsdgModule( diff --git a/jlm/tooling/CommandGraphGenerator.cpp b/jlm/tooling/CommandGraphGenerator.cpp index e553f8da6..a44085731 100644 --- a/jlm/tooling/CommandGraphGenerator.cpp +++ b/jlm/tooling/CommandGraphGenerator.cpp @@ -125,6 +125,7 @@ JlcCommandGraphGenerator::GenerateCommandGraph(const JlcCommandLineOptions & com JlmOptCommandLineOptions jlmOptCommandLineOptions( CreateParserCommandOutputFile(compilation.InputFile()), + JlmOptCommandLineOptions::InputFormat::Llvm, CreateJlmOptCommandOutputFile(compilation.InputFile()), JlmOptCommandLineOptions::OutputFormat::Llvm, statisticsCollectorSettings, diff --git a/jlm/tooling/CommandLine.cpp b/jlm/tooling/CommandLine.cpp index c3c69825b..51a5bcd2b 100644 --- a/jlm/tooling/CommandLine.cpp +++ b/jlm/tooling/CommandLine.cpp @@ -270,6 +270,18 @@ JlmOptCommandLineOptions::ToCommandLineArgument(jlm::util::Statistics::Id statis throw util::error("Unknown statistics identifier"); } +const char * +JlmOptCommandLineOptions::ToCommandLineArgument(InputFormat inputFormat) +{ + static std::unordered_map map( + { { InputFormat::Llvm, "input-llvm" }, {InputFormat::Mlir, "input-mlir"} }); + + if (map.find(inputFormat) != map.end()) + return map[inputFormat]; + + throw util::error("Unknown input format"); +} + const char * JlmOptCommandLineOptions::ToCommandLineArgument(OutputFormat outputFormat) { @@ -871,6 +883,22 @@ JlmOptCommandLineParser::ParseCommandLineArguments(int argc, char ** argv) "Write theta-gamma inversion statistics to file.")), cl::desc("Write statistics")); + auto llvmInputFormat = JlmOptCommandLineOptions::InputFormat::Llvm; + auto mlirInputFormat = JlmOptCommandLineOptions::InputFormat::Mlir; + + cl::opt inputFormat( + cl::values( + ::clEnumValN( + llvmInputFormat, + JlmOptCommandLineOptions::ToCommandLineArgument(llvmInputFormat), + "Input LLVM IR [default]"), + ::clEnumValN( + mlirInputFormat, + JlmOptCommandLineOptions::ToCommandLineArgument(mlirInputFormat), + "Input MLIR")), + cl::init(llvmInputFormat), + cl::desc("Select input format")); + auto llvmOutputFormat = JlmOptCommandLineOptions::OutputFormat::Llvm; auto xmlOutputFormat = JlmOptCommandLineOptions::OutputFormat::Xml; auto mlirOutputFormat = JlmOptCommandLineOptions::OutputFormat::Mlir; @@ -1018,6 +1046,7 @@ JlmOptCommandLineParser::ParseCommandLineArguments(int argc, char ** argv) CommandLineOptions_ = JlmOptCommandLineOptions::Create( std::move(inputFilePath), + inputFormat, outputFile, outputFormat, std::move(statisticsCollectorSettings), diff --git a/jlm/tooling/CommandLine.hpp b/jlm/tooling/CommandLine.hpp index d750e9ee0..9056c350a 100644 --- a/jlm/tooling/CommandLine.hpp +++ b/jlm/tooling/CommandLine.hpp @@ -40,6 +40,12 @@ class optimization; class JlmOptCommandLineOptions final : public CommandLineOptions { public: + enum class InputFormat + { + Llvm, + Mlir, + }; + enum class OutputFormat { Llvm, @@ -116,11 +122,13 @@ class JlmOptCommandLineOptions final : public CommandLineOptions JlmOptCommandLineOptions( util::filepath inputFile, + InputFormat inputFormat, util::filepath outputFile, OutputFormat outputFormat, util::StatisticsCollectorSettings statisticsCollectorSettings, std::vector optimizations) : InputFile_(std::move(inputFile)), + InputFormat_(inputFormat), OutputFile_(std::move(outputFile)), OutputFormat_(outputFormat), StatisticsCollectorSettings_(std::move(statisticsCollectorSettings)), @@ -136,6 +144,12 @@ class JlmOptCommandLineOptions final : public CommandLineOptions return InputFile_; } + [[nodiscard]] InputFormat + GetInputFormat() const noexcept + { + return InputFormat_; + } + [[nodiscard]] const util::filepath & GetOutputFile() const noexcept { @@ -175,6 +189,9 @@ class JlmOptCommandLineOptions final : public CommandLineOptions static const char * ToCommandLineArgument(util::Statistics::Id statisticsId); + static const char * + ToCommandLineArgument(InputFormat inputFormat); + static const char * ToCommandLineArgument(OutputFormat outputFormat); @@ -184,6 +201,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions static std::unique_ptr Create( util::filepath inputFile, + InputFormat inputFormat, util::filepath outputFile, OutputFormat outputFormat, util::StatisticsCollectorSettings statisticsCollectorSettings, @@ -191,6 +209,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions { return std::make_unique( std::move(inputFile), + inputFormat, std::move(outputFile), outputFormat, std::move(statisticsCollectorSettings), @@ -199,6 +218,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions private: util::filepath InputFile_; + InputFormat InputFormat_; util::filepath OutputFile_; OutputFormat OutputFormat_; util::StatisticsCollectorSettings StatisticsCollectorSettings_; diff --git a/tests/jlm/tooling/TestJlmOptCommand.cpp b/tests/jlm/tooling/TestJlmOptCommand.cpp index a0b720399..648933d27 100644 --- a/tests/jlm/tooling/TestJlmOptCommand.cpp +++ b/tests/jlm/tooling/TestJlmOptCommand.cpp @@ -22,6 +22,7 @@ TestStatistics() JlmOptCommandLineOptions commandLineOptions( jlm::util::filepath("inputFile.ll"), + JlmOptCommandLineOptions::InputFormat::Llvm, jlm::util::filepath("outputFile.ll"), JlmOptCommandLineOptions::OutputFormat::Llvm, statisticsCollectorSettings,