-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
453 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
/* | ||
* Copyright 2023 Magnus Sjalander <[email protected]> | ||
* See COPYING for terms of redistribution. | ||
*/ | ||
|
||
#ifdef MLIR_ENABLED | ||
|
||
#include "jlm/mlir/frontend/rvsdggen.hpp" | ||
|
||
#include <jlm/rvsdg/bitstring/comparison.hpp> | ||
#include <jlm/rvsdg/bitstring/constant.hpp> | ||
#include <jlm/rvsdg/node.hpp> | ||
#include <jlm/rvsdg/traverser.hpp> | ||
|
||
// #include <jlm/llvm/ir/operators/GetElementPtr.hpp> | ||
// #include <jlm/llvm/ir/operators/load.hpp> | ||
// #include <jlm/llvm/ir/operators/operators.hpp> | ||
// #include <jlm/llvm/ir/operators/sext.hpp> | ||
// #include <jlm/llvm/ir/operators/store.hpp> | ||
// #include <jlm/rvsdg/bitstring/type.hpp> | ||
|
||
#include "llvm/Support/raw_os_ostream.h" | ||
#include "mlir/IR/Verifier.h" | ||
#include "mlir/Parser/Parser.h" | ||
|
||
namespace jlm::mlirrvsdg | ||
{ | ||
|
||
std::unique_ptr<mlir::Block> | ||
RVSDGGen::readRvsdgMlir(const util::filepath & filePath) | ||
{ | ||
// Configer the parser | ||
mlir::ParserConfig config = mlir::ParserConfig(Context_.get()); | ||
// Variable for storing the result | ||
std::unique_ptr<mlir::Block> block = std::make_unique<mlir::Block>(); | ||
// Read the input file | ||
auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config); | ||
if (result.failed()) | ||
{ | ||
throw util::error("Parsing MLIR input file failed."); | ||
} | ||
return block; | ||
} | ||
|
||
std::unique_ptr<llvm::RvsdgModule> | ||
RVSDGGen::convertMlir(std::unique_ptr<mlir::Block> & block) | ||
{ | ||
// Create RVSDG module | ||
std::string dataLayout; | ||
std::string targetTriple; | ||
util::filepath sourceFileName(""); | ||
auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout); | ||
|
||
// Get the root region | ||
auto & graph = rvsdgModule->Rvsdg(); | ||
auto root = graph.root(); | ||
|
||
// Convert the MLIR into an RVSDG graph | ||
convertBlock(*block.get(), *root); | ||
|
||
return rvsdgModule; | ||
} | ||
|
||
void | ||
RVSDGGen::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) | ||
{ | ||
for (mlir::Operation & op : block.getOperations()) | ||
{ | ||
::llvm::outs() << "- Current operation " << op.getName() << "\n"; | ||
|
||
if (op.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName()) | ||
{ | ||
convertOmega(op, rvsdgRegion); | ||
} | ||
else if (op.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName()) | ||
{ | ||
convertLambda(op, rvsdgRegion); | ||
} | ||
else if (op.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName()) | ||
{ | ||
auto constant = static_cast<mlir::arith::ConstantIntOp>(&op); | ||
|
||
// Need the type to know its width | ||
auto type = constant.getType(); | ||
JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID()); | ||
auto * integerType = static_cast<mlir::IntegerType *>(&type); | ||
|
||
rvsdg::create_bitconstant(&rvsdgRegion, integerType->getWidth(), constant.value()); | ||
} | ||
else | ||
{ | ||
::llvm::outs() << "- Operation not implemented: " << op.getName() << "\n"; | ||
// throw util::error("Operation is not implemented:" + op.getName() + "\n"); | ||
} | ||
} | ||
} | ||
|
||
void | ||
RVSDGGen::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion) | ||
{ | ||
::llvm::outs() << " - Converting region\n"; | ||
for (mlir::Block & block : region.getBlocks()) | ||
{ | ||
convertBlock(block, rvsdgRegion); | ||
} | ||
} | ||
|
||
void | ||
RVSDGGen::convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion) | ||
{ | ||
::llvm::outs() << " ** Converting Omega **\n"; | ||
for (mlir::Value operand : omega.getOperands()) | ||
{ | ||
if (mlir::Operation * producer = operand.getDefiningOp()) | ||
{ | ||
::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n"; | ||
} | ||
else | ||
{ | ||
// If there is no defining op, the Value is necessarily a Block | ||
// argument. | ||
auto blockArg = operand.cast<mlir::BlockArgument>(); | ||
::llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber() | ||
<< " " << blockArg.getType() << "\n"; | ||
} | ||
} | ||
|
||
// TODO | ||
// There should only exist one region in an omega | ||
for (mlir::Region & region : omega.getRegions()) | ||
{ | ||
convertRegion(region, rvsdgRegion); | ||
} | ||
} | ||
|
||
void | ||
RVSDGGen::convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion) | ||
{ | ||
::llvm::outs() << " ** Converting Lambda **\n"; | ||
|
||
// Get the name of the function | ||
auto functionNameAttribute = lambda.getAttr(::llvm::StringRef("sym_name")); | ||
JLM_ASSERT(functionNameAttribute != NULL); | ||
mlir::StringAttr * functionName = static_cast<mlir::StringAttr *>(&functionNameAttribute); | ||
::llvm::outs() << "Function name: " << functionName->getValue().str() << "\n"; | ||
|
||
// A lambda node has only the function signature as the result | ||
JLM_ASSERT(lambda.getNumResults() == 1); | ||
|
||
// Get the function signature | ||
auto result = lambda.getResult(0).getType(); | ||
::llvm::outs() << " - Function signature: '" << result << "'\n"; | ||
if (result.getTypeID() == mlir::rvsdg::LambdaRefType::getTypeID()) | ||
{ | ||
std::vector<const jlm::rvsdg::type *> arguments; | ||
std::vector<const jlm::rvsdg::type *> results; | ||
mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast<mlir::rvsdg::LambdaRefType *>(&result); | ||
for (auto argumentType : lambdaRefType->getParameterTypes()) | ||
{ | ||
auto argument = convertType(argumentType); | ||
::llvm::outs() << " - Argument: '" << argument->debug_string() << "\n"; | ||
arguments.push_back(argument); | ||
} | ||
for (auto returnType : lambdaRefType->getReturnTypes()) | ||
{ | ||
auto result = convertType(returnType); | ||
::llvm::outs() << " - Result: '" << result->debug_string() << "\n"; | ||
results.push_back(result); | ||
} | ||
llvm::FunctionType functionType(arguments, results); | ||
auto rvsdgLambda = llvm::lambda::node::create( | ||
&rvsdgRegion, | ||
functionType, | ||
functionName->getValue().str(), | ||
llvm::linkage::external_linkage); | ||
auto lambdaRegion = rvsdgLambda->subregion(); | ||
|
||
for (mlir::Value operand : lambda.getOperands()) | ||
{ | ||
if (mlir::Operation * producer = operand.getDefiningOp()) | ||
{ | ||
::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n"; | ||
} | ||
else | ||
{ | ||
// If there is no defining op, the Value is necessarily a Block | ||
// argument. | ||
auto blockArg = operand.cast<mlir::BlockArgument>(); | ||
::llvm::outs() << " - Operand produced by Block argument, number " | ||
<< blockArg.getArgNumber() << " " << blockArg.getType() << "\n"; | ||
} | ||
} | ||
|
||
for (mlir::Region & region : lambda.getRegions()) | ||
{ | ||
convertRegion(region, *lambdaRegion); | ||
} | ||
} | ||
else | ||
{ | ||
throw util::error("The result from lambda node is not a LambdaRefType\n"); | ||
} | ||
/* | ||
FunctionType functionType({&vt}, {&vt}); | ||
auto lambda = lambda::node::create(rvsdgRegion, functionType, "f", linkage::external_linkage); | ||
lambda->finalize({lambda->fctargument(0)}); | ||
std::vector<jlm::rvsdg::argument*> functionArguments; | ||
for (auto & argument : lambda->fctarguments()) | ||
functionArguments.push_back(&argument); | ||
*/ | ||
} | ||
|
||
rvsdg::type * | ||
RVSDGGen::convertType(mlir::Type & type) | ||
{ | ||
// TODO | ||
// Fix memory leak | ||
if (type.getTypeID() == mlir::IntegerType::getTypeID()) | ||
{ | ||
auto * intType = static_cast<mlir::IntegerType *>(&type); | ||
return new rvsdg::bittype(intType->getWidth()); | ||
} | ||
else if (type.getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID()) | ||
{ | ||
return new llvm::loopstatetype(); | ||
} | ||
else if (type.getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID()) | ||
{ | ||
return new llvm::MemoryStateType(); | ||
} | ||
else if (type.getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID()) | ||
{ | ||
return new llvm::iostatetype(); | ||
} | ||
else | ||
{ | ||
throw util::error("Type conversion not implemented\n"); | ||
} | ||
} | ||
|
||
} // jlm::mlirrvsdg | ||
|
||
#endif // MLIR_ENABLED |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Copyright 2023 Magnus Sjalander <[email protected]> | ||
* See COPYING for terms of redistribution. | ||
*/ | ||
|
||
#ifndef JLM_MLIR_FRONTEND_RVSDGGEN_HPP | ||
#define JLM_MLIR_FRONTEND_RVSDGGEN_HPP | ||
|
||
#ifdef MLIR_ENABLED | ||
|
||
// #include <jlm/llvm/ir/operators/GetElementPtr.hpp> | ||
// #include <jlm/llvm/ir/operators/load.hpp> | ||
// #include <jlm/llvm/ir/operators/operators.hpp> | ||
// #include <jlm/llvm/ir/operators/sext.hpp> | ||
// #include <jlm/llvm/ir/operators/store.hpp> | ||
// #include <jlm/rvsdg/bitstring/comparison.hpp> | ||
// #include <jlm/rvsdg/bitstring/type.hpp> | ||
|
||
#include <jlm/llvm/ir/operators/lambda.hpp> | ||
#include <jlm/llvm/ir/RvsdgModule.hpp> | ||
|
||
#include <JLM/JLMDialect.h> | ||
#include <RVSDG/RVSDGDialect.h> | ||
#include <RVSDG/RVSDGPasses.h> | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
|
||
namespace jlm::mlirrvsdg | ||
{ | ||
|
||
class RVSDGGen final | ||
{ | ||
public: | ||
RVSDGGen() | ||
{ | ||
Context_ = std::make_unique<mlir::MLIRContext>(); | ||
// Load the RVSDG dialect | ||
Context_->getOrLoadDialect<mlir::rvsdg::RVSDGDialect>(); | ||
// Load the JLM dialect | ||
Context_->getOrLoadDialect<mlir::jlm::JLMDialect>(); | ||
// Load the Arith dialect | ||
Context_->getOrLoadDialect<mlir::arith::ArithDialect>(); | ||
} | ||
|
||
RVSDGGen(const RVSDGGen &) = delete; | ||
|
||
RVSDGGen(RVSDGGen &&) = delete; | ||
|
||
RVSDGGen & | ||
operator=(const RVSDGGen &) = delete; | ||
|
||
RVSDGGen & | ||
operator=(RVSDGGen &&) = delete; | ||
|
||
std::unique_ptr<mlir::Block> | ||
readRvsdgMlir(const util::filepath & filePath); | ||
|
||
std::unique_ptr<llvm::RvsdgModule> | ||
convertMlir(std::unique_ptr<mlir::Block> & block); | ||
|
||
private: | ||
void | ||
convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion); | ||
|
||
void | ||
convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion); | ||
|
||
void | ||
convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion); | ||
|
||
void | ||
convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion); | ||
|
||
rvsdg::type * | ||
convertType(mlir::Type & type); | ||
|
||
std::unique_ptr<mlir::MLIRContext> Context_; | ||
}; | ||
|
||
} // namespace jlm::mlirrvsdg | ||
|
||
#endif // MLIR_ENABLED | ||
|
||
#endif // JLM_MLIR_FRONTEND_RVSDGGEN_HPP |
Oops, something went wrong.