Skip to content

Commit

Permalink
Cleaned up code and added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
sjalander committed Jan 16, 2024
1 parent dd72459 commit 1b13e77
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 153 deletions.
65 changes: 4 additions & 61 deletions jlm/mlir/backend/RvsdgToMlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ namespace jlm::rvsdgmlir
void
RvsdgToMlir::print(mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath)
{
// Verify the module
if (failed(mlir::verify(omega)))
{
omega.emitError("module verification error");
throw util::error("Verification of RVSDG-MLIR failed");
}
// Print the module
if (filePath == "")
{
::llvm::raw_os_ostream os(std::cout);
Expand All @@ -49,18 +47,13 @@ RvsdgToMlir::convertModule(const llvm::RvsdgModule & rvsdgModule)
mlir::rvsdg::OmegaNode
RvsdgToMlir::convertOmega(const rvsdg::graph & graph)
{
// Create the MLIR omega node
auto omega = Builder_->create<mlir::rvsdg::OmegaNode>(Builder_->getUnknownLoc());

// Create a block for the region
mlir::Region & region = omega.getRegion();
auto & omegaBlock = region.emplaceBlock();

// Convert the region of the omega
auto subregion = graph.root();
::llvm::SmallVector<mlir::Value> regionResults = convertSubregion(*subregion, omegaBlock);
::llvm::SmallVector<mlir::Value> regionResults = convertRegion(*subregion, omegaBlock);

// Handle the result of the omega
auto omegaResult =
Builder_->create<mlir::rvsdg::OmegaResult>(Builder_->getUnknownLoc(), regionResults);
omegaBlock.push_back(omegaResult);
Expand All @@ -69,9 +62,8 @@ RvsdgToMlir::convertOmega(const rvsdg::graph & graph)
}

::llvm::SmallVector<mlir::Value>
RvsdgToMlir::convertSubregion(rvsdg::region & region, mlir::Block & block)
RvsdgToMlir::convertRegion(rvsdg::region & region, mlir::Block & block)
{
// Handle arguments of the region
for (size_t i = 0; i < region.narguments(); ++i)
{
auto type = convertType(region.argument(i)->type());
Expand All @@ -92,15 +84,11 @@ RvsdgToMlir::convertSubregion(rvsdg::region & region, mlir::Block & block)
nodes[rvsdgNode] = convertNode(*rvsdgNode, block);
}

// Handle results of the region
::llvm::SmallVector<mlir::Value> results;
for (size_t i = 0; i < region.nresults(); ++i)
{
// Get the result of the RVSDG region
auto result = region.result(i);
// Get the output of the RVSDG node driving the result
auto output = result->origin();
// Get the RVSDG node that generates the output
rvsdg::node * outputNode = rvsdg::node_output::node(output);
if (outputNode == nullptr)
{
Expand All @@ -127,16 +115,6 @@ RvsdgToMlir::convertNode(const rvsdg::node & node, mlir::Block & block)
else if (auto lambda = dynamic_cast<const llvm::lambda::node *>(&node))
{
return convertLambda(*lambda, block);
/*
} else if (auto gamma = dynamic_cast<llvm::gamma_node *>(&node)) {
convertGamma(*gamma, block);
} else if (auto theta = dynamic_cast<llvm::theta_node *>(&node)) {
convertTheta(*theta, block);
} else if (auto delta = dynamic_cast<llvm::delta::node *>(&node)) {
convertDelta(*delta, block);
} else if (auto phi = dynamic_cast<llvm::phi::node *>(&node)) {
convertPhi(*phi, block);
*/
}
else
{
Expand Down Expand Up @@ -167,68 +145,45 @@ RvsdgToMlir::convertSimpleNode(const rvsdg::simple_node & node, mlir::Block & bl
mlir::Value
RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & block)
{

// Handle function arguments
::llvm::SmallVector<mlir::Type> arguments;
for (size_t i = 0; i < lambdaNode.nfctarguments(); ++i)
{
arguments.push_back(convertType(lambdaNode.fctargument(i)->type()));
}
::llvm::ArrayRef argumentsArray(arguments);

// Handle function results
::llvm::SmallVector<mlir::Type> results;
for (size_t i = 0; i < lambdaNode.nfctresults(); ++i)
{
results.push_back(convertType(lambdaNode.fctresult(i)->type()));
}
::llvm::ArrayRef resultsArray(results);

/*
// Context arguments
for (size_t i = 0; i < node->ncvarguments(); ++i) {
// s << print_input_origin(node.cvargument(i)->input()) << ": " <<
print_type(&ln.cvargument(i)->type()); throw util::error("Context arguments in convertLambda()
has not been implemented");
}
*/
// TODO
// Consider replacing the lambda ref creation with
// mlir::rvsdg::LambdaRefTyp::get();
// static LambdaRefType get(::mlir::MLIRContext *context, ::llvm::ArrayRef<mlir::Type>
// parameterTypes, ::llvm::ArrayRef<mlir::Type> returnTypes);

// LambdaNodes return a LambdaRefType
::llvm::SmallVector<mlir::Type> lambdaRef;
auto refType = Builder_->getType<::mlir::rvsdg::LambdaRefType>(argumentsArray, resultsArray);
lambdaRef.push_back(refType);

// TODO
// Add the inputs to the function
::llvm::SmallVector<mlir::Value> inputs;

// Add function attributes
// Add function attributes, e.g., the function name
::llvm::SmallVector<mlir::NamedAttribute> attributes;
auto attributeName = Builder_->getStringAttr("sym_name");
auto attributeValue = Builder_->getStringAttr(lambdaNode.name());
auto symbolName = Builder_->getNamedAttr(attributeName, attributeValue);
attributes.push_back(symbolName);
::llvm::ArrayRef<::mlir::NamedAttribute> attributesRef(attributes);

// Create the lambda node and add it to the region/block it resides in
auto lambda = Builder_->create<mlir::rvsdg::LambdaNode>(
Builder_->getUnknownLoc(),
lambdaRef,
inputs,
attributesRef);
block.push_back(lambda);

// Create a block for the region
mlir::Region & region = lambda.getRegion();
auto & lambdaBlock = region.emplaceBlock();

// Convert the region and get all the results generated by the region
auto regionResults = convertSubregion(*lambdaNode.subregion(), lambdaBlock);
auto regionResults = convertRegion(*lambdaNode.subregion(), lambdaBlock);
auto lambdaResult =
Builder_->create<mlir::rvsdg::LambdaResult>(Builder_->getUnknownLoc(), regionResults);
lambdaBlock.push_back(lambdaResult);
Expand All @@ -254,18 +209,6 @@ RvsdgToMlir::convertType(const rvsdg::type & type)
else if (dynamic_cast<const llvm::MemoryStateType *>(&type))
{
return Builder_->getType<::mlir::rvsdg::MemStateEdgeType>();
/*
} else if (auto varargType =dynamic_cast<const jlm::varargtype*>(&type)) {
s << "!jlm.varargList";
} else if (auto pointerType = dynamic_cast<const jlm::PointerType*>(&type)){
s << print_pointer_type(pointer_type);
} else if (auto arrayType = dynamic_cast<const jlm::arraytype*>(&type)){
s << print_array_type(array_type);
} else if (auto structType = dynamic_cast<const jlm::structtype*>(&type)){
s << print_struct_type(struct_type);
} else if (auto controlType = dynamic_cast<const jive::ctltype*>(&type)){
s << "!rvsdg.ctrl<" << control_type->nalternatives() << ">";
*/
}
else
{
Expand Down
52 changes: 50 additions & 2 deletions jlm/mlir/backend/RvsdgToMlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,76 @@ class RvsdgToMlir final
RvsdgToMlir &
operator=(RvsdgToMlir &&) = delete;

/**
* Prints MLIR RVSDG to a file.
* \param omega The MLIR RVSDG Omega node to be printed.
* \param filePath The path to the file to print the MLIR to.
*/
void
print(mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath);

/**
* Converts an RVSDG module to MLIR RVSDG.
* \param rvsdgModule The RVSDG module to be converted.
* \return An MLIR RVSDG OmegaNode containing the whole graph of the rvsdgModule. It is
* the responsibility of the caller to call ->destroy() on the returned omega, once it is no
* longer needed.
*/
mlir::rvsdg::OmegaNode
convertModule(const llvm::RvsdgModule & rvsdgModule);

private:
/**
* Converts an omega and all nodes in its (sub)region(s) to an MLIR RVSDG OmegaNode.
* \param graph The root RVSDG graph.
* \return An MLIR RVSDG OmegaNode.
*/
mlir::rvsdg::OmegaNode
convertOmega(const rvsdg::graph & graph);

/**
* Converts all nodes in an RVSDG region. Conversion of structural nodes cause their regions to
* also be converted.
* \param region The RVSDG region to be converted
* \param block The MLIR RVSDG block that corresponds to this RVSDG region, and to which
* converted nodes are insterted.
* \return A list of outputs of the converted region/block.
*/
::llvm::SmallVector<mlir::Value>
convertSubregion(rvsdg::region & region, mlir::Block & block);

convertRegion(rvsdg::region & region, mlir::Block & block);

/**
* Converts an RVSDG node to an MLIR RVSDG operation
* \param node The RVSDG node to be converted
* \param block The MLIR RVSDG block to insert the converted node.
* \return The converted MLIR RVSDG operation.
*/
mlir::Value
convertNode(const rvsdg::node & node, mlir::Block & block);

/**
* Converts an RVSDG simple_node to an MLIR RVSDG operation
* \param node The RVSDG node to be converted
* \param block The MLIR RVSDG block to insert the converted node.
* \return The converted MLIR RVSDG operation.
*/
mlir::Value
convertSimpleNode(const rvsdg::simple_node & node, mlir::Block & block);

/**
* Converts an RVSDG lambda node to an MLIR RVSDG LambdaNode..
* \param node The RVSDG lambda node to be converted
* \param block The MLIR RVSDG block to insert the lambda node.
* \return The converted MLIR RVSDG LambdaNode.
*/
mlir::Value
convertLambda(const llvm::lambda::node & node, mlir::Block & block);

/**
* Converts an RVSDG type to and MLIR RVSDG type
* \param type The RVSDG type to be converted.
* \result The corresponding MLIR RVSDG type.
*/
mlir::Type
convertType(const rvsdg::type & type);

Expand Down
39 changes: 9 additions & 30 deletions jlm/mlir/frontend/MlirToRvsdg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include <jlm/rvsdg/node.hpp>
#include <jlm/rvsdg/traverser.hpp>

// #include <jlm/llvm/ir/operators/GetElementPtr.hpp>
// #include <jlm/llvm/ir/operators/load.hpp>
// #include <jlm/llvm/ir/operators/operators.hpp>
// #include <jlm/llvm/ir/operators/sext.hpp>
// #include <jlm/llvm/ir/operators/store.hpp>
// #include <jlm/rvsdg/bitstring/type.hpp>

#include "llvm/Support/raw_os_ostream.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
Expand All @@ -28,11 +21,8 @@ namespace jlm::mlirrvsdg
std::unique_ptr<mlir::Block>
MlirToRvsdg::readRvsdgMlir(const util::filepath & filePath)
{
// Configer the parser
mlir::ParserConfig config = mlir::ParserConfig(Context_.get());
// Variable for storing the result
std::unique_ptr<mlir::Block> block = std::make_unique<mlir::Block>();
// Read the input file
auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config);
if (result.failed())
{
Expand All @@ -44,17 +34,13 @@ MlirToRvsdg::readRvsdgMlir(const util::filepath & filePath)
std::unique_ptr<llvm::RvsdgModule>
MlirToRvsdg::convertMlir(std::unique_ptr<mlir::Block> & block)
{
// Create RVSDG module
std::string dataLayout;
std::string targetTriple;
util::filepath sourceFileName("");
auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout);

// Get the root region
auto & graph = rvsdgModule->Rvsdg();
auto root = graph.root();

// Convert the MLIR into an RVSDG graph
convertBlock(*block.get(), *root);

return rvsdgModule;
Expand All @@ -72,15 +58,13 @@ MlirToRvsdg::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion)
std::unique_ptr<std::vector<jlm::rvsdg::output *>>
MlirToRvsdg::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion)
{
// Transform the block such that operations are in topological order
mlir::sortTopologically(&block);

// Create an RVSDG node for each MLIR operation and store each pair in a
// hash map for easy lookup of corresponding RVSDG nodes
std::unordered_map<mlir::Operation *, rvsdg::node *> operations;
for (mlir::Operation & mlirOp : block.getOperations())
{
// Get the inputs of the MLIR operation
std::vector<const jlm::rvsdg::output *> inputs;
for (mlir::Value operand : mlirOp.getOperands())
{
Expand All @@ -103,7 +87,7 @@ MlirToRvsdg::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion)
}
}

// The results of the block are encoded in the terminator operation
// The results of the region/block are encoded in the terminator operation
auto terminator = block.getTerminator();
std::unique_ptr<std::vector<jlm::rvsdg::output *>> results =
std::make_unique<std::vector<jlm::rvsdg::output *>>();
Expand Down Expand Up @@ -134,8 +118,8 @@ MlirToRvsdg::convertOperation(
if (mlirOperation.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName())
{
convertOmega(mlirOperation, rvsdgRegion);
// Omega doesn't have a corresponding RVSDG node so we return NULL
return NULL;
// Omega doesn't have a corresponding RVSDG node so we return nullptr
return nullptr;
}
else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName())
{
Expand All @@ -144,8 +128,6 @@ MlirToRvsdg::convertOperation(
else if (mlirOperation.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName())
{
auto constant = static_cast<mlir::arith::ConstantIntOp>(&mlirOperation);

// Need the type to know the width of the constant
auto type = constant.getType();
JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID());
auto * integerType = static_cast<mlir::IntegerType *>(&type);
Expand All @@ -156,12 +138,12 @@ MlirToRvsdg::convertOperation(
else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::LambdaResult::getOperationName())
{
// This is a terminating operation that doesn't have a corresponding RVSDG node
return NULL;
return nullptr;
}
else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::OmegaResult::getOperationName())
{
// This is a terminating operation that doesn't have a corresponding RVSDG node
return NULL;
return nullptr;
}
else
{
Expand All @@ -173,7 +155,7 @@ MlirToRvsdg::convertOperation(
void
MlirToRvsdg::convertOmega(mlir::Operation & mlirOmega, rvsdg::region & rvsdgRegion)
{
// The Omega consists of a single region
// The Omega has a single region
JLM_ASSERT(mlirOmega.getRegions().size() == 1);
convertRegion(mlirOmega.getRegion(0), rvsdgRegion);
}
Expand All @@ -183,13 +165,11 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe
{
// Get the name of the function
auto functionNameAttribute = mlirLambda.getAttr(::llvm::StringRef("sym_name"));
JLM_ASSERT(functionNameAttribute != NULL);
mlir::StringAttr * functionName = static_cast<mlir::StringAttr *>(&functionNameAttribute);
JLM_ASSERT(functionNameAttribute != nullptr);
auto * functionName = static_cast<mlir::StringAttr *>(&functionNameAttribute);

// A lambda node has only the function signature as the result
JLM_ASSERT(mlirLambda.getNumResults() == 1);

// Get the MLIR function signature
auto result = mlirLambda.getResult(0).getType();

if (result.getTypeID() != mlir::rvsdg::LambdaRefType::getTypeID())
Expand All @@ -198,7 +178,7 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe
}

// Create the RVSDG function signature
mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast<mlir::rvsdg::LambdaRefType *>(&result);
auto * lambdaRefType = static_cast<mlir::rvsdg::LambdaRefType *>(&result);
std::vector<std::unique_ptr<rvsdg::type>> argumentTypes;
for (auto argumentType : lambdaRefType->getParameterTypes())
{
Expand All @@ -217,7 +197,6 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe
functionName->getValue().str(),
llvm::linkage::external_linkage);

// Get the region and convert it
JLM_ASSERT(mlirLambda.getRegions().size() == 1);
auto lambdaRegion = rvsdgLambda->subregion();
auto regionResults = convertRegion(mlirLambda.getRegion(0), *lambdaRegion);
Expand Down
Loading

0 comments on commit 1b13e77

Please sign in to comment.