Skip to content

Commit

Permalink
mlir: Initial MLIR/RVSDG frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
sjalander committed Nov 29, 2023
1 parent 3a2fdd6 commit 945976d
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 26 deletions.
1 change: 1 addition & 0 deletions jlm/mlir/Makefile.sub
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ endif

LIBMLIR_SRC = \
jlm/mlir/backend/mlirgen.cpp \
jlm/mlir/frontend/rvsdggen.cpp \

.PHONY: libmlir-debug
libmlir-debug: CXXFLAGS += $(CXXFLAGS_DEBUG)
Expand Down
245 changes: 245 additions & 0 deletions jlm/mlir/frontend/rvsdggen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright 2023 Magnus Sjalander <[email protected]>
* See COPYING for terms of redistribution.
*/

#ifdef MLIR_ENABLED

#include "jlm/mlir/frontend/rvsdggen.hpp"

#include <jlm/rvsdg/bitstring/comparison.hpp>
#include <jlm/rvsdg/bitstring/constant.hpp>
#include <jlm/rvsdg/node.hpp>
#include <jlm/rvsdg/traverser.hpp>

// #include <jlm/llvm/ir/operators/GetElementPtr.hpp>
// #include <jlm/llvm/ir/operators/load.hpp>
// #include <jlm/llvm/ir/operators/operators.hpp>
// #include <jlm/llvm/ir/operators/sext.hpp>
// #include <jlm/llvm/ir/operators/store.hpp>
// #include <jlm/rvsdg/bitstring/type.hpp>

#include "llvm/Support/raw_os_ostream.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"

namespace jlm::mlirrvsdg
{

std::unique_ptr<mlir::Block>
RVSDGGen::readRvsdgMlir(const util::filepath & filePath)
{
// Configer the parser
mlir::ParserConfig config = mlir::ParserConfig(Context_.get());
// Variable for storing the result
std::unique_ptr<mlir::Block> block = std::make_unique<mlir::Block>();
// Read the input file
auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config);
if (result.failed())
{
throw util::error("Parsing MLIR input file failed.");
}
return block;
}

std::unique_ptr<llvm::RvsdgModule>
RVSDGGen::convertMlir(std::unique_ptr<mlir::Block> & block)
{
// Create RVSDG module
std::string dataLayout;
std::string targetTriple;
util::filepath sourceFileName("");
auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout);

// Get the root region
auto & graph = rvsdgModule->Rvsdg();
auto root = graph.root();

// Convert the MLIR into an RVSDG graph
convertBlock(*block.get(), *root);

return rvsdgModule;
}

void
RVSDGGen::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion)
{
for (mlir::Operation & op : block.getOperations())
{
::llvm::outs() << "- Current operation " << op.getName() << "\n";

if (op.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName())
{
convertOmega(op, rvsdgRegion);
}
else if (op.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName())
{
convertLambda(op, rvsdgRegion);
}
else if (op.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName())
{
auto constant = static_cast<mlir::arith::ConstantIntOp>(&op);

// Need the type to know its width
auto type = constant.getType();
JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID());
auto * integerType = static_cast<mlir::IntegerType *>(&type);

rvsdg::create_bitconstant(&rvsdgRegion, integerType->getWidth(), constant.value());
}
else
{
::llvm::outs() << "- Operation not implemented: " << op.getName() << "\n";
// throw util::error("Operation is not implemented:" + op.getName() + "\n");
}
}
}

void
RVSDGGen::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion)
{
::llvm::outs() << " - Converting region\n";
for (mlir::Block & block : region.getBlocks())
{
convertBlock(block, rvsdgRegion);
}
}

void
RVSDGGen::convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion)
{
::llvm::outs() << " ** Converting Omega **\n";
for (mlir::Value operand : omega.getOperands())
{
if (mlir::Operation * producer = operand.getDefiningOp())
{
::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n";
}
else
{
// If there is no defining op, the Value is necessarily a Block
// argument.
auto blockArg = operand.cast<mlir::BlockArgument>();
::llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber()
<< " " << blockArg.getType() << "\n";
}
}

// TODO
// There should only exist one region in an omega
for (mlir::Region & region : omega.getRegions())
{
convertRegion(region, rvsdgRegion);
}
}

void
RVSDGGen::convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion)
{
::llvm::outs() << " ** Converting Lambda **\n";

// Get the name of the function
auto functionNameAttribute = lambda.getAttr(::llvm::StringRef("sym_name"));
JLM_ASSERT(functionNameAttribute != NULL);
mlir::StringAttr * functionName = static_cast<mlir::StringAttr *>(&functionNameAttribute);
::llvm::outs() << "Function name: " << functionName->getValue().str() << "\n";

// A lambda node has only the function signature as the result
JLM_ASSERT(lambda.getNumResults() == 1);

// Get the function signature
auto result = lambda.getResult(0).getType();
::llvm::outs() << " - Function signature: '" << result << "'\n";
if (result.getTypeID() == mlir::rvsdg::LambdaRefType::getTypeID())
{
std::vector<const jlm::rvsdg::type *> arguments;
std::vector<const jlm::rvsdg::type *> results;
mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast<mlir::rvsdg::LambdaRefType *>(&result);
for (auto argumentType : lambdaRefType->getParameterTypes())
{
auto argument = convertType(argumentType);
::llvm::outs() << " - Argument: '" << argument->debug_string() << "\n";
arguments.push_back(argument);
}
for (auto returnType : lambdaRefType->getReturnTypes())
{
auto result = convertType(returnType);
::llvm::outs() << " - Result: '" << result->debug_string() << "\n";
results.push_back(result);
}
llvm::FunctionType functionType(arguments, results);
auto rvsdgLambda = llvm::lambda::node::create(
&rvsdgRegion,
functionType,
functionName->getValue().str(),
llvm::linkage::external_linkage);
auto lambdaRegion = rvsdgLambda->subregion();

for (mlir::Value operand : lambda.getOperands())
{
if (mlir::Operation * producer = operand.getDefiningOp())
{
::llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n";
}
else
{
// If there is no defining op, the Value is necessarily a Block
// argument.
auto blockArg = operand.cast<mlir::BlockArgument>();
::llvm::outs() << " - Operand produced by Block argument, number "
<< blockArg.getArgNumber() << " " << blockArg.getType() << "\n";
}
}

for (mlir::Region & region : lambda.getRegions())
{
convertRegion(region, *lambdaRegion);
}
}
else
{
throw util::error("The result from lambda node is not a LambdaRefType\n");
}
/*
FunctionType functionType({&vt}, {&vt});
auto lambda = lambda::node::create(rvsdgRegion, functionType, "f", linkage::external_linkage);
lambda->finalize({lambda->fctargument(0)});
std::vector<jlm::rvsdg::argument*> functionArguments;
for (auto & argument : lambda->fctarguments())
functionArguments.push_back(&argument);
*/
}

rvsdg::type *
RVSDGGen::convertType(mlir::Type & type)
{
// TODO
// Fix memory leak
if (type.getTypeID() == mlir::IntegerType::getTypeID())
{
auto * intType = static_cast<mlir::IntegerType *>(&type);
return new rvsdg::bittype(intType->getWidth());
}
else if (type.getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID())
{
return new llvm::loopstatetype();
}
else if (type.getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID())
{
return new llvm::MemoryStateType();
}
else if (type.getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID())
{
return new llvm::iostatetype();
}
else
{
throw util::error("Type conversion not implemented\n");
}
}

} // jlm::mlirrvsdg

#endif // MLIR_ENABLED
84 changes: 84 additions & 0 deletions jlm/mlir/frontend/rvsdggen.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2023 Magnus Sjalander <[email protected]>
* See COPYING for terms of redistribution.
*/

#ifndef JLM_MLIR_FRONTEND_RVSDGGEN_HPP
#define JLM_MLIR_FRONTEND_RVSDGGEN_HPP

#ifdef MLIR_ENABLED

// #include <jlm/llvm/ir/operators/GetElementPtr.hpp>
// #include <jlm/llvm/ir/operators/load.hpp>
// #include <jlm/llvm/ir/operators/operators.hpp>
// #include <jlm/llvm/ir/operators/sext.hpp>
// #include <jlm/llvm/ir/operators/store.hpp>
// #include <jlm/rvsdg/bitstring/comparison.hpp>
// #include <jlm/rvsdg/bitstring/type.hpp>

#include <jlm/llvm/ir/operators/lambda.hpp>
#include <jlm/llvm/ir/RvsdgModule.hpp>

#include <JLM/JLMDialect.h>
#include <RVSDG/RVSDGDialect.h>
#include <RVSDG/RVSDGPasses.h>

#include "mlir/Dialect/Arith/IR/Arith.h"

namespace jlm::mlirrvsdg
{

class RVSDGGen final
{
public:
RVSDGGen()
{
Context_ = std::make_unique<mlir::MLIRContext>();
// Load the RVSDG dialect
Context_->getOrLoadDialect<mlir::rvsdg::RVSDGDialect>();
// Load the JLM dialect
Context_->getOrLoadDialect<mlir::jlm::JLMDialect>();
// Load the Arith dialect
Context_->getOrLoadDialect<mlir::arith::ArithDialect>();
}

RVSDGGen(const RVSDGGen &) = delete;

RVSDGGen(RVSDGGen &&) = delete;

RVSDGGen &
operator=(const RVSDGGen &) = delete;

RVSDGGen &
operator=(RVSDGGen &&) = delete;

std::unique_ptr<mlir::Block>
readRvsdgMlir(const util::filepath & filePath);

std::unique_ptr<llvm::RvsdgModule>
convertMlir(std::unique_ptr<mlir::Block> & block);

private:
void
convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion);

void
convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion);

void
convertOmega(mlir::Operation & omega, rvsdg::region & rvsdgRegion);

void
convertLambda(mlir::Operation & lambda, rvsdg::region & rvsdgRegion);

rvsdg::type *
convertType(mlir::Type & type);

std::unique_ptr<mlir::MLIRContext> Context_;
};

} // namespace jlm::mlirrvsdg

#endif // MLIR_ENABLED

#endif // JLM_MLIR_FRONTEND_RVSDGGEN_HPP
Loading

0 comments on commit 945976d

Please sign in to comment.