From 945976d9465c9dc0ba90472c865bd3da546098e8 Mon Sep 17 00:00:00 2001 From: Magnus Sjalander Date: Fri, 13 Oct 2023 16:44:50 +0200 Subject: [PATCH] mlir: Initial MLIR/RVSDG frontend --- jlm/mlir/Makefile.sub | 1 + jlm/mlir/frontend/rvsdggen.cpp | 245 ++++++++++++++++++++++++ jlm/mlir/frontend/rvsdggen.hpp | 84 ++++++++ jlm/tooling/Command.cpp | 81 +++++--- jlm/tooling/Command.hpp | 17 +- jlm/tooling/CommandGraphGenerator.cpp | 1 + jlm/tooling/CommandLine.cpp | 29 +++ jlm/tooling/CommandLine.hpp | 20 ++ tests/jlm/tooling/TestJlmOptCommand.cpp | 1 + 9 files changed, 453 insertions(+), 26 deletions(-) create mode 100644 jlm/mlir/frontend/rvsdggen.cpp create mode 100644 jlm/mlir/frontend/rvsdggen.hpp diff --git a/jlm/mlir/Makefile.sub b/jlm/mlir/Makefile.sub index d3403c616..0f6edd754 100644 --- a/jlm/mlir/Makefile.sub +++ b/jlm/mlir/Makefile.sub @@ -16,6 +16,7 @@ endif LIBMLIR_SRC = \ jlm/mlir/backend/mlirgen.cpp \ + jlm/mlir/frontend/rvsdggen.cpp \ .PHONY: libmlir-debug libmlir-debug: CXXFLAGS += $(CXXFLAGS_DEBUG) diff --git a/jlm/mlir/frontend/rvsdggen.cpp b/jlm/mlir/frontend/rvsdggen.cpp new file mode 100644 index 000000000..15c9af786 --- /dev/null +++ b/jlm/mlir/frontend/rvsdggen.cpp @@ -0,0 +1,245 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#ifdef MLIR_ENABLED + +#include "jlm/mlir/frontend/rvsdggen.hpp" + +#include +#include +#include +#include + +// #include +// #include +// #include +// #include +// #include +// #include + +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" + +namespace jlm::mlirrvsdg +{ + +std::unique_ptr +RVSDGGen::readRvsdgMlir(const util::filepath & filePath) +{ + // Configer the parser + mlir::ParserConfig config = mlir::ParserConfig(Context_.get()); + // Variable for storing the result + std::unique_ptr block = std::make_unique(); + // Read the input file + auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config); + if (result.failed()) + { + throw util::error("Parsing MLIR input file failed."); + } + return block; +} + +std::unique_ptr +RVSDGGen::convertMlir(std::unique_ptr & block) +{ + // Create RVSDG module + std::string dataLayout; + std::string targetTriple; + util::filepath sourceFileName(""); + auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout); + + // Get the root region + auto & graph = rvsdgModule->Rvsdg(); + auto root = graph.root(); + + // Convert the MLIR into an RVSDG graph + convertBlock(*block.get(), *root); + + return rvsdgModule; +} + +void +RVSDGGen::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) +{ + for (mlir::Operation & op : block.getOperations()) + { + ::llvm::outs() << "- Current operation " << op.getName() << "\n"; + + if (op.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName()) + { + convertOmega(op, rvsdgRegion); + } + else if (op.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName()) + { + convertLambda(op, rvsdgRegion); + } + else if (op.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName()) + { + auto constant = static_cast(&op); + + // Need the type to know its width + auto type = constant.getType(); + JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID()); + auto * integerType = static_cast(&type); + + rvsdg::create_bitconstant(&rvsdgRegion, integerType->getWidth(), constant.value()); + } + else + { + ::llvm::outs() << "- Operation not implemented: " << op.getName() << "\n"; + // throw util::error("Operation is not implemented:" + op.getName() + "\n"); + } + } +} + +void +RVSDGGen::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " - Converting region\n"; + for (mlir::Block & block : region.getBlocks()) + { + convertBlock(block, rvsdgRegion); + } +} + +void +RVSDGGen::convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " ** Converting Omega **\n"; + for (mlir::Value operand : omega.getOperands()) + { + if (mlir::Operation * producer = operand.getDefiningOp()) + { + ::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n"; + } + else + { + // If there is no defining op, the Value is necessarily a Block + // argument. + auto blockArg = operand.cast(); + ::llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber() + << " " << blockArg.getType() << "\n"; + } + } + + // TODO + // There should only exist one region in an omega + for (mlir::Region & region : omega.getRegions()) + { + convertRegion(region, rvsdgRegion); + } +} + +void +RVSDGGen::convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion) +{ + ::llvm::outs() << " ** Converting Lambda **\n"; + + // Get the name of the function + auto functionNameAttribute = lambda.getAttr(::llvm::StringRef("sym_name")); + JLM_ASSERT(functionNameAttribute != NULL); + mlir::StringAttr * functionName = static_cast(&functionNameAttribute); + ::llvm::outs() << "Function name: " << functionName->getValue().str() << "\n"; + + // A lambda node has only the function signature as the result + JLM_ASSERT(lambda.getNumResults() == 1); + + // Get the function signature + auto result = lambda.getResult(0).getType(); + ::llvm::outs() << " - Function signature: '" << result << "'\n"; + if (result.getTypeID() == mlir::rvsdg::LambdaRefType::getTypeID()) + { + std::vector arguments; + std::vector results; + mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast(&result); + for (auto argumentType : lambdaRefType->getParameterTypes()) + { + auto argument = convertType(argumentType); + ::llvm::outs() << " - Argument: '" << argument->debug_string() << "\n"; + arguments.push_back(argument); + } + for (auto returnType : lambdaRefType->getReturnTypes()) + { + auto result = convertType(returnType); + ::llvm::outs() << " - Result: '" << result->debug_string() << "\n"; + results.push_back(result); + } + llvm::FunctionType functionType(arguments, results); + auto rvsdgLambda = llvm::lambda::node::create( + &rvsdgRegion, + functionType, + functionName->getValue().str(), + llvm::linkage::external_linkage); + auto lambdaRegion = rvsdgLambda->subregion(); + + for (mlir::Value operand : lambda.getOperands()) + { + if (mlir::Operation * producer = operand.getDefiningOp()) + { + ::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n"; + } + else + { + // If there is no defining op, the Value is necessarily a Block + // argument. + auto blockArg = operand.cast(); + ::llvm::outs() << " - Operand produced by Block argument, number " + << blockArg.getArgNumber() << " " << blockArg.getType() << "\n"; + } + } + + for (mlir::Region & region : lambda.getRegions()) + { + convertRegion(region, *lambdaRegion); + } + } + else + { + throw util::error("The result from lambda node is not a LambdaRefType\n"); + } + /* + FunctionType functionType({&vt}, {&vt}); + + auto lambda = lambda::node::create(rvsdgRegion, functionType, "f", linkage::external_linkage); + lambda->finalize({lambda->fctargument(0)}); + + std::vector functionArguments; + for (auto & argument : lambda->fctarguments()) + functionArguments.push_back(&argument); + */ +} + +rvsdg::type * +RVSDGGen::convertType(mlir::Type & type) +{ + // TODO + // Fix memory leak + if (type.getTypeID() == mlir::IntegerType::getTypeID()) + { + auto * intType = static_cast(&type); + return new rvsdg::bittype(intType->getWidth()); + } + else if (type.getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID()) + { + return new llvm::loopstatetype(); + } + else if (type.getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID()) + { + return new llvm::MemoryStateType(); + } + else if (type.getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID()) + { + return new llvm::iostatetype(); + } + else + { + throw util::error("Type conversion not implemented\n"); + } +} + +} // jlm::mlirrvsdg + +#endif // MLIR_ENABLED diff --git a/jlm/mlir/frontend/rvsdggen.hpp b/jlm/mlir/frontend/rvsdggen.hpp new file mode 100644 index 000000000..6a4f17cea --- /dev/null +++ b/jlm/mlir/frontend/rvsdggen.hpp @@ -0,0 +1,84 @@ +/* + * Copyright 2023 Magnus Sjalander + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_MLIR_FRONTEND_RVSDGGEN_HPP +#define JLM_MLIR_FRONTEND_RVSDGGEN_HPP + +#ifdef MLIR_ENABLED + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +#include +#include + +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" + +namespace jlm::mlirrvsdg +{ + +class RVSDGGen final +{ +public: + RVSDGGen() + { + Context_ = std::make_unique(); + // Load the RVSDG dialect + Context_->getOrLoadDialect(); + // Load the JLM dialect + Context_->getOrLoadDialect(); + // Load the Arith dialect + Context_->getOrLoadDialect(); + } + + RVSDGGen(const RVSDGGen &) = delete; + + RVSDGGen(RVSDGGen &&) = delete; + + RVSDGGen & + operator=(const RVSDGGen &) = delete; + + RVSDGGen & + operator=(RVSDGGen &&) = delete; + + std::unique_ptr + readRvsdgMlir(const util::filepath & filePath); + + std::unique_ptr + convertMlir(std::unique_ptr & block); + +private: + void + convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion); + + void + convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion); + + void + convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion); + + void + convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion); + + rvsdg::type * + convertType(mlir::Type & type); + + std::unique_ptr Context_; +}; + +} // namespace jlm::mlirrvsdg + +#endif // MLIR_ENABLED + +#endif // JLM_MLIR_FRONTEND_RVSDGGEN_HPP diff --git a/jlm/tooling/Command.cpp b/jlm/tooling/Command.cpp index 223212ff1..7197d1ad0 100644 --- a/jlm/tooling/Command.cpp +++ b/jlm/tooling/Command.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -309,21 +310,13 @@ JlmOptCommand::ToString() const void JlmOptCommand::Run() const { - ::llvm::LLVMContext llvmContext; - auto llvmModule = ParseLlvmIrFile(CommandLineOptions_.GetInputFile(), llvmContext); - - auto interProceduralGraphModule = llvm::ConvertLlvmModule(*llvmModule); - - /* - * Dispose of Llvm module. It is no longer needed. - */ - llvmModule.reset(); - jlm::util::StatisticsCollector statisticsCollector( CommandLineOptions_.GetStatisticsCollectorSettings()); - auto rvsdgModule = - llvm::ConvertInterProceduralGraphModule(*interProceduralGraphModule, statisticsCollector); + auto rvsdgModule = ParseInputFile( + CommandLineOptions_.GetInputFile(), + CommandLineOptions_.GetInputFormat(), + statisticsCollector); llvm::OptimizationSequence::CreateAndRun( *rvsdgModule, @@ -339,21 +332,61 @@ JlmOptCommand::Run() const statisticsCollector.PrintStatistics(); } -std::unique_ptr<::llvm::Module> -JlmOptCommand::ParseLlvmIrFile(const util::filepath & llvmIrFile, ::llvm::LLVMContext & llvmContext) - const +std::unique_ptr +JlmOptCommand::ParseInputFile( + const util::filepath & inputFile, + const JlmOptCommandLineOptions::InputFormat & inputFormat, + util::StatisticsCollector & statisticsCollector) const { - ::llvm::SMDiagnostic diagnostic; - if (auto module = ::llvm::parseIRFile(llvmIrFile.to_str(), diagnostic, llvmContext)) + auto parseLlvmIrFile = + [=](const util::filepath & llvmIrFile, + util::StatisticsCollector & statisticsCollector) -> std::unique_ptr { - return module; - } + ::llvm::LLVMContext llvmContext; + ::llvm::SMDiagnostic diagnostic; + if (auto llvmModule = ::llvm::parseIRFile(llvmIrFile.to_str(), diagnostic, llvmContext)) + { + auto interProceduralGraphModule = llvm::ConvertLlvmModule(*llvmModule); - std::string errors; - ::llvm::raw_string_ostream os(errors); - diagnostic.print(ProgramName_.c_str(), os); - throw util::error(errors); -} + // Dispose of Llvm module. It is no longer needed. + llvmModule.reset(); + + auto rvsdgModule = + llvm::ConvertInterProceduralGraphModule(*interProceduralGraphModule, statisticsCollector); + + return rvsdgModule; + } + + std::string errors; + ::llvm::raw_string_ostream os(errors); + diagnostic.print(ProgramName_.c_str(), os); + throw util::error(errors); + }; + + auto parseMlirIrFile = + [](const util::filepath & mlirIrFile, + util::StatisticsCollector & statisticsCollector) -> std::unique_ptr + { +#ifdef MLIR_ENABLED + jlm::mlirrvsdg::RVSDGGen rvsdggen; + auto block = rvsdggen.readRvsdgMlir(mlirIrFile); + return rvsdggen.convertMlir(block); +#else + throw util::error( + "This version of jlm-opt has not been compiled with support for the MLIR backend\n"); +#endif + }; + + static std::unordered_map< + JlmOptCommandLineOptions::InputFormat, + std::function< + std::unique_ptr(const util::filepath &, util::StatisticsCollector &)>> + parsers({ { tooling::JlmOptCommandLineOptions::InputFormat::Llvm, parseLlvmIrFile }, + { tooling::JlmOptCommandLineOptions::InputFormat::Mlir, parseMlirIrFile } }); + + JLM_ASSERT(parsers.find(inputFormat) != parsers.end()); + return parsers[inputFormat](inputFile, statisticsCollector); +}; void JlmOptCommand::PrintRvsdgModule( diff --git a/jlm/tooling/Command.hpp b/jlm/tooling/Command.hpp index a3b9fc873..8aedde90e 100644 --- a/jlm/tooling/Command.hpp +++ b/jlm/tooling/Command.hpp @@ -370,8 +370,21 @@ class JlmOptCommand final : public Command } private: - std::unique_ptr<::llvm::Module> - ParseLlvmIrFile(const util::filepath & llvmIrFile, ::llvm::LLVMContext & llvmContext) const; + std::unique_ptr + ParseInputFile( + const util::filepath & inputFile, + const JlmOptCommandLineOptions::InputFormat & inputFormat, + util::StatisticsCollector & statisticsCollector) const; + + std::unique_ptr + ParseLlvmIrFile( + const util::filepath & llvmIrFile, + util::StatisticsCollector & statisticsCollector) const; + + std::unique_ptr + ParseMlirIrFile( + const util::filepath & rvsdgIrFile, + util::StatisticsCollector & statisticsCollector) const; static void PrintRvsdgModule( diff --git a/jlm/tooling/CommandGraphGenerator.cpp b/jlm/tooling/CommandGraphGenerator.cpp index e553f8da6..a44085731 100644 --- a/jlm/tooling/CommandGraphGenerator.cpp +++ b/jlm/tooling/CommandGraphGenerator.cpp @@ -125,6 +125,7 @@ JlcCommandGraphGenerator::GenerateCommandGraph(const JlcCommandLineOptions & com JlmOptCommandLineOptions jlmOptCommandLineOptions( CreateParserCommandOutputFile(compilation.InputFile()), + JlmOptCommandLineOptions::InputFormat::Llvm, CreateJlmOptCommandOutputFile(compilation.InputFile()), JlmOptCommandLineOptions::OutputFormat::Llvm, statisticsCollectorSettings, diff --git a/jlm/tooling/CommandLine.cpp b/jlm/tooling/CommandLine.cpp index feeb6c382..e27be9c88 100644 --- a/jlm/tooling/CommandLine.cpp +++ b/jlm/tooling/CommandLine.cpp @@ -270,6 +270,18 @@ JlmOptCommandLineOptions::ToCommandLineArgument(jlm::util::Statistics::Id statis throw util::error("Unknown statistics identifier"); } +const char * +JlmOptCommandLineOptions::ToCommandLineArgument(InputFormat inputFormat) +{ + static std::unordered_map map( + { { InputFormat::Llvm, "input-llvm" }, { InputFormat::Mlir, "input-mlir" } }); + + if (map.find(inputFormat) != map.end()) + return map[inputFormat]; + + throw util::error("Unknown input format"); +} + const char * JlmOptCommandLineOptions::ToCommandLineArgument(OutputFormat outputFormat) { @@ -872,6 +884,22 @@ JlmOptCommandLineParser::ParseCommandLineArguments(int argc, char ** argv) "Write theta-gamma inversion statistics to file.")), cl::desc("Write statistics")); + auto llvmInputFormat = JlmOptCommandLineOptions::InputFormat::Llvm; + auto mlirInputFormat = JlmOptCommandLineOptions::InputFormat::Mlir; + + cl::opt inputFormat( + cl::values( + ::clEnumValN( + llvmInputFormat, + JlmOptCommandLineOptions::ToCommandLineArgument(llvmInputFormat), + "Input LLVM IR [default]"), + ::clEnumValN( + mlirInputFormat, + JlmOptCommandLineOptions::ToCommandLineArgument(mlirInputFormat), + "Input MLIR")), + cl::init(llvmInputFormat), + cl::desc("Select input format")); + auto llvmOutputFormat = JlmOptCommandLineOptions::OutputFormat::Llvm; auto xmlOutputFormat = JlmOptCommandLineOptions::OutputFormat::Xml; auto mlirOutputFormat = JlmOptCommandLineOptions::OutputFormat::Mlir; @@ -1019,6 +1047,7 @@ JlmOptCommandLineParser::ParseCommandLineArguments(int argc, char ** argv) CommandLineOptions_ = JlmOptCommandLineOptions::Create( std::move(inputFilePath), + inputFormat, outputFile, outputFormat, std::move(statisticsCollectorSettings), diff --git a/jlm/tooling/CommandLine.hpp b/jlm/tooling/CommandLine.hpp index d750e9ee0..9056c350a 100644 --- a/jlm/tooling/CommandLine.hpp +++ b/jlm/tooling/CommandLine.hpp @@ -40,6 +40,12 @@ class optimization; class JlmOptCommandLineOptions final : public CommandLineOptions { public: + enum class InputFormat + { + Llvm, + Mlir, + }; + enum class OutputFormat { Llvm, @@ -116,11 +122,13 @@ class JlmOptCommandLineOptions final : public CommandLineOptions JlmOptCommandLineOptions( util::filepath inputFile, + InputFormat inputFormat, util::filepath outputFile, OutputFormat outputFormat, util::StatisticsCollectorSettings statisticsCollectorSettings, std::vector optimizations) : InputFile_(std::move(inputFile)), + InputFormat_(inputFormat), OutputFile_(std::move(outputFile)), OutputFormat_(outputFormat), StatisticsCollectorSettings_(std::move(statisticsCollectorSettings)), @@ -136,6 +144,12 @@ class JlmOptCommandLineOptions final : public CommandLineOptions return InputFile_; } + [[nodiscard]] InputFormat + GetInputFormat() const noexcept + { + return InputFormat_; + } + [[nodiscard]] const util::filepath & GetOutputFile() const noexcept { @@ -175,6 +189,9 @@ class JlmOptCommandLineOptions final : public CommandLineOptions static const char * ToCommandLineArgument(util::Statistics::Id statisticsId); + static const char * + ToCommandLineArgument(InputFormat inputFormat); + static const char * ToCommandLineArgument(OutputFormat outputFormat); @@ -184,6 +201,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions static std::unique_ptr Create( util::filepath inputFile, + InputFormat inputFormat, util::filepath outputFile, OutputFormat outputFormat, util::StatisticsCollectorSettings statisticsCollectorSettings, @@ -191,6 +209,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions { return std::make_unique( std::move(inputFile), + inputFormat, std::move(outputFile), outputFormat, std::move(statisticsCollectorSettings), @@ -199,6 +218,7 @@ class JlmOptCommandLineOptions final : public CommandLineOptions private: util::filepath InputFile_; + InputFormat InputFormat_; util::filepath OutputFile_; OutputFormat OutputFormat_; util::StatisticsCollectorSettings StatisticsCollectorSettings_; diff --git a/tests/jlm/tooling/TestJlmOptCommand.cpp b/tests/jlm/tooling/TestJlmOptCommand.cpp index a0b720399..648933d27 100644 --- a/tests/jlm/tooling/TestJlmOptCommand.cpp +++ b/tests/jlm/tooling/TestJlmOptCommand.cpp @@ -22,6 +22,7 @@ TestStatistics() JlmOptCommandLineOptions commandLineOptions( jlm::util::filepath("inputFile.ll"), + JlmOptCommandLineOptions::InputFormat::Llvm, jlm::util::filepath("outputFile.ll"), JlmOptCommandLineOptions::OutputFormat::Llvm, statisticsCollectorSettings,