From 1b13e77ae74181d29587d66972c4341c7bf1d2b8 Mon Sep 17 00:00:00 2001 From: Magnus Sjalander Date: Tue, 16 Jan 2024 17:20:24 +0100 Subject: [PATCH] Cleaned up code and added documentation --- jlm/mlir/backend/RvsdgToMlir.cpp | 65 +--------------- jlm/mlir/backend/RvsdgToMlir.hpp | 52 ++++++++++++- jlm/mlir/frontend/MlirToRvsdg.cpp | 39 +++------- tests/jlm/mlir/backend/TestRvsdgToMlir.cpp | 27 +++---- tests/jlm/mlir/frontend/TestMlirToRvsdg.cpp | 84 +++++++++------------ 5 files changed, 114 insertions(+), 153 deletions(-) diff --git a/jlm/mlir/backend/RvsdgToMlir.cpp b/jlm/mlir/backend/RvsdgToMlir.cpp index 48eca0bf7..600a8b658 100644 --- a/jlm/mlir/backend/RvsdgToMlir.cpp +++ b/jlm/mlir/backend/RvsdgToMlir.cpp @@ -19,13 +19,11 @@ namespace jlm::rvsdgmlir void RvsdgToMlir::print(mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath) { - // Verify the module if (failed(mlir::verify(omega))) { omega.emitError("module verification error"); throw util::error("Verification of RVSDG-MLIR failed"); } - // Print the module if (filePath == "") { ::llvm::raw_os_ostream os(std::cout); @@ -49,18 +47,13 @@ RvsdgToMlir::convertModule(const llvm::RvsdgModule & rvsdgModule) mlir::rvsdg::OmegaNode RvsdgToMlir::convertOmega(const rvsdg::graph & graph) { - // Create the MLIR omega node auto omega = Builder_->create(Builder_->getUnknownLoc()); - - // Create a block for the region mlir::Region & region = omega.getRegion(); auto & omegaBlock = region.emplaceBlock(); - // Convert the region of the omega auto subregion = graph.root(); - ::llvm::SmallVector regionResults = convertSubregion(*subregion, omegaBlock); + ::llvm::SmallVector regionResults = convertRegion(*subregion, omegaBlock); - // Handle the result of the omega auto omegaResult = Builder_->create(Builder_->getUnknownLoc(), regionResults); omegaBlock.push_back(omegaResult); @@ -69,9 +62,8 @@ RvsdgToMlir::convertOmega(const rvsdg::graph & graph) } ::llvm::SmallVector -RvsdgToMlir::convertSubregion(rvsdg::region & region, mlir::Block & block) +RvsdgToMlir::convertRegion(rvsdg::region & region, mlir::Block & block) { - // Handle arguments of the region for (size_t i = 0; i < region.narguments(); ++i) { auto type = convertType(region.argument(i)->type()); @@ -92,15 +84,11 @@ RvsdgToMlir::convertSubregion(rvsdg::region & region, mlir::Block & block) nodes[rvsdgNode] = convertNode(*rvsdgNode, block); } - // Handle results of the region ::llvm::SmallVector results; for (size_t i = 0; i < region.nresults(); ++i) { - // Get the result of the RVSDG region auto result = region.result(i); - // Get the output of the RVSDG node driving the result auto output = result->origin(); - // Get the RVSDG node that generates the output rvsdg::node * outputNode = rvsdg::node_output::node(output); if (outputNode == nullptr) { @@ -127,16 +115,6 @@ RvsdgToMlir::convertNode(const rvsdg::node & node, mlir::Block & block) else if (auto lambda = dynamic_cast(&node)) { return convertLambda(*lambda, block); - /* - } else if (auto gamma = dynamic_cast(&node)) { - convertGamma(*gamma, block); - } else if (auto theta = dynamic_cast(&node)) { - convertTheta(*theta, block); - } else if (auto delta = dynamic_cast(&node)) { - convertDelta(*delta, block); - } else if (auto phi = dynamic_cast(&node)) { - convertPhi(*phi, block); - */ } else { @@ -167,8 +145,6 @@ RvsdgToMlir::convertSimpleNode(const rvsdg::simple_node & node, mlir::Block & bl mlir::Value RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & block) { - - // Handle function arguments ::llvm::SmallVector arguments; for (size_t i = 0; i < lambdaNode.nfctarguments(); ++i) { @@ -176,7 +152,6 @@ RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & } ::llvm::ArrayRef argumentsArray(arguments); - // Handle function results ::llvm::SmallVector results; for (size_t i = 0; i < lambdaNode.nfctresults(); ++i) { @@ -184,30 +159,13 @@ RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & } ::llvm::ArrayRef resultsArray(results); - /* - // Context arguments - for (size_t i = 0; i < node->ncvarguments(); ++i) { - // s << print_input_origin(node.cvargument(i)->input()) << ": " << - print_type(&ln.cvargument(i)->type()); throw util::error("Context arguments in convertLambda() - has not been implemented"); - } - */ - // TODO - // Consider replacing the lambda ref creation with - // mlir::rvsdg::LambdaRefTyp::get(); - // static LambdaRefType get(::mlir::MLIRContext *context, ::llvm::ArrayRef - // parameterTypes, ::llvm::ArrayRef returnTypes); - - // LambdaNodes return a LambdaRefType ::llvm::SmallVector lambdaRef; auto refType = Builder_->getType<::mlir::rvsdg::LambdaRefType>(argumentsArray, resultsArray); lambdaRef.push_back(refType); - // TODO - // Add the inputs to the function ::llvm::SmallVector inputs; - // Add function attributes + // Add function attributes, e.g., the function name ::llvm::SmallVector attributes; auto attributeName = Builder_->getStringAttr("sym_name"); auto attributeValue = Builder_->getStringAttr(lambdaNode.name()); @@ -215,7 +173,6 @@ RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & attributes.push_back(symbolName); ::llvm::ArrayRef<::mlir::NamedAttribute> attributesRef(attributes); - // Create the lambda node and add it to the region/block it resides in auto lambda = Builder_->create( Builder_->getUnknownLoc(), lambdaRef, @@ -223,12 +180,10 @@ RvsdgToMlir::convertLambda(const llvm::lambda::node & lambdaNode, mlir::Block & attributesRef); block.push_back(lambda); - // Create a block for the region mlir::Region & region = lambda.getRegion(); auto & lambdaBlock = region.emplaceBlock(); - // Convert the region and get all the results generated by the region - auto regionResults = convertSubregion(*lambdaNode.subregion(), lambdaBlock); + auto regionResults = convertRegion(*lambdaNode.subregion(), lambdaBlock); auto lambdaResult = Builder_->create(Builder_->getUnknownLoc(), regionResults); lambdaBlock.push_back(lambdaResult); @@ -254,18 +209,6 @@ RvsdgToMlir::convertType(const rvsdg::type & type) else if (dynamic_cast(&type)) { return Builder_->getType<::mlir::rvsdg::MemStateEdgeType>(); - /* - } else if (auto varargType =dynamic_cast(&type)) { - s << "!jlm.varargList"; - } else if (auto pointerType = dynamic_cast(&type)){ - s << print_pointer_type(pointer_type); - } else if (auto arrayType = dynamic_cast(&type)){ - s << print_array_type(array_type); - } else if (auto structType = dynamic_cast(&type)){ - s << print_struct_type(struct_type); - } else if (auto controlType = dynamic_cast(&type)){ - s << "!rvsdg.ctrl<" << control_type->nalternatives() << ">"; - */ } else { diff --git a/jlm/mlir/backend/RvsdgToMlir.hpp b/jlm/mlir/backend/RvsdgToMlir.hpp index 8e2da5c65..382760359 100644 --- a/jlm/mlir/backend/RvsdgToMlir.hpp +++ b/jlm/mlir/backend/RvsdgToMlir.hpp @@ -46,28 +46,76 @@ class RvsdgToMlir final RvsdgToMlir & operator=(RvsdgToMlir &&) = delete; + /** + * Prints MLIR RVSDG to a file. + * \param omega The MLIR RVSDG Omega node to be printed. + * \param filePath The path to the file to print the MLIR to. + */ void print(mlir::rvsdg::OmegaNode & omega, const util::filepath & filePath); + /** + * Converts an RVSDG module to MLIR RVSDG. + * \param rvsdgModule The RVSDG module to be converted. + * \return An MLIR RVSDG OmegaNode containing the whole graph of the rvsdgModule. It is + * the responsibility of the caller to call ->destroy() on the returned omega, once it is no + * longer needed. + */ mlir::rvsdg::OmegaNode convertModule(const llvm::RvsdgModule & rvsdgModule); private: + /** + * Converts an omega and all nodes in its (sub)region(s) to an MLIR RVSDG OmegaNode. + * \param graph The root RVSDG graph. + * \return An MLIR RVSDG OmegaNode. + */ mlir::rvsdg::OmegaNode convertOmega(const rvsdg::graph & graph); + /** + * Converts all nodes in an RVSDG region. Conversion of structural nodes cause their regions to + * also be converted. + * \param region The RVSDG region to be converted + * \param block The MLIR RVSDG block that corresponds to this RVSDG region, and to which + * converted nodes are insterted. + * \return A list of outputs of the converted region/block. + */ ::llvm::SmallVector - convertSubregion(rvsdg::region & region, mlir::Block & block); - + convertRegion(rvsdg::region & region, mlir::Block & block); + + /** + * Converts an RVSDG node to an MLIR RVSDG operation + * \param node The RVSDG node to be converted + * \param block The MLIR RVSDG block to insert the converted node. + * \return The converted MLIR RVSDG operation. + */ mlir::Value convertNode(const rvsdg::node & node, mlir::Block & block); + /** + * Converts an RVSDG simple_node to an MLIR RVSDG operation + * \param node The RVSDG node to be converted + * \param block The MLIR RVSDG block to insert the converted node. + * \return The converted MLIR RVSDG operation. + */ mlir::Value convertSimpleNode(const rvsdg::simple_node & node, mlir::Block & block); + /** + * Converts an RVSDG lambda node to an MLIR RVSDG LambdaNode.. + * \param node The RVSDG lambda node to be converted + * \param block The MLIR RVSDG block to insert the lambda node. + * \return The converted MLIR RVSDG LambdaNode. + */ mlir::Value convertLambda(const llvm::lambda::node & node, mlir::Block & block); + /** + * Converts an RVSDG type to and MLIR RVSDG type + * \param type The RVSDG type to be converted. + * \result The corresponding MLIR RVSDG type. + */ mlir::Type convertType(const rvsdg::type & type); diff --git a/jlm/mlir/frontend/MlirToRvsdg.cpp b/jlm/mlir/frontend/MlirToRvsdg.cpp index c11a4ef76..49a733312 100644 --- a/jlm/mlir/frontend/MlirToRvsdg.cpp +++ b/jlm/mlir/frontend/MlirToRvsdg.cpp @@ -10,13 +10,6 @@ #include #include -// #include -// #include -// #include -// #include -// #include -// #include - #include "llvm/Support/raw_os_ostream.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" @@ -28,11 +21,8 @@ namespace jlm::mlirrvsdg std::unique_ptr MlirToRvsdg::readRvsdgMlir(const util::filepath & filePath) { - // Configer the parser mlir::ParserConfig config = mlir::ParserConfig(Context_.get()); - // Variable for storing the result std::unique_ptr block = std::make_unique(); - // Read the input file auto result = mlir::parseSourceFile(filePath.to_str(), block.get(), config); if (result.failed()) { @@ -44,17 +34,13 @@ MlirToRvsdg::readRvsdgMlir(const util::filepath & filePath) std::unique_ptr MlirToRvsdg::convertMlir(std::unique_ptr & block) { - // Create RVSDG module std::string dataLayout; std::string targetTriple; util::filepath sourceFileName(""); auto rvsdgModule = llvm::RvsdgModule::Create(sourceFileName, targetTriple, dataLayout); - - // Get the root region auto & graph = rvsdgModule->Rvsdg(); auto root = graph.root(); - // Convert the MLIR into an RVSDG graph convertBlock(*block.get(), *root); return rvsdgModule; @@ -72,7 +58,6 @@ MlirToRvsdg::convertRegion(mlir::Region & region, rvsdg::region & rvsdgRegion) std::unique_ptr> MlirToRvsdg::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) { - // Transform the block such that operations are in topological order mlir::sortTopologically(&block); // Create an RVSDG node for each MLIR operation and store each pair in a @@ -80,7 +65,6 @@ MlirToRvsdg::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) std::unordered_map operations; for (mlir::Operation & mlirOp : block.getOperations()) { - // Get the inputs of the MLIR operation std::vector inputs; for (mlir::Value operand : mlirOp.getOperands()) { @@ -103,7 +87,7 @@ MlirToRvsdg::convertBlock(mlir::Block & block, rvsdg::region & rvsdgRegion) } } - // The results of the block are encoded in the terminator operation + // The results of the region/block are encoded in the terminator operation auto terminator = block.getTerminator(); std::unique_ptr> results = std::make_unique>(); @@ -134,8 +118,8 @@ MlirToRvsdg::convertOperation( if (mlirOperation.getName().getStringRef() == mlir::rvsdg::OmegaNode::getOperationName()) { convertOmega(mlirOperation, rvsdgRegion); - // Omega doesn't have a corresponding RVSDG node so we return NULL - return NULL; + // Omega doesn't have a corresponding RVSDG node so we return nullptr + return nullptr; } else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName()) { @@ -144,8 +128,6 @@ MlirToRvsdg::convertOperation( else if (mlirOperation.getName().getStringRef() == mlir::arith::ConstantIntOp::getOperationName()) { auto constant = static_cast(&mlirOperation); - - // Need the type to know the width of the constant auto type = constant.getType(); JLM_ASSERT(type.getTypeID() == mlir::IntegerType::getTypeID()); auto * integerType = static_cast(&type); @@ -156,12 +138,12 @@ MlirToRvsdg::convertOperation( else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::LambdaResult::getOperationName()) { // This is a terminating operation that doesn't have a corresponding RVSDG node - return NULL; + return nullptr; } else if (mlirOperation.getName().getStringRef() == mlir::rvsdg::OmegaResult::getOperationName()) { // This is a terminating operation that doesn't have a corresponding RVSDG node - return NULL; + return nullptr; } else { @@ -173,7 +155,7 @@ MlirToRvsdg::convertOperation( void MlirToRvsdg::convertOmega(mlir::Operation & mlirOmega, rvsdg::region & rvsdgRegion) { - // The Omega consists of a single region + // The Omega has a single region JLM_ASSERT(mlirOmega.getRegions().size() == 1); convertRegion(mlirOmega.getRegion(0), rvsdgRegion); } @@ -183,13 +165,11 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe { // Get the name of the function auto functionNameAttribute = mlirLambda.getAttr(::llvm::StringRef("sym_name")); - JLM_ASSERT(functionNameAttribute != NULL); - mlir::StringAttr * functionName = static_cast(&functionNameAttribute); + JLM_ASSERT(functionNameAttribute != nullptr); + auto * functionName = static_cast(&functionNameAttribute); // A lambda node has only the function signature as the result JLM_ASSERT(mlirLambda.getNumResults() == 1); - - // Get the MLIR function signature auto result = mlirLambda.getResult(0).getType(); if (result.getTypeID() != mlir::rvsdg::LambdaRefType::getTypeID()) @@ -198,7 +178,7 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe } // Create the RVSDG function signature - mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast(&result); + auto * lambdaRefType = static_cast(&result); std::vector> argumentTypes; for (auto argumentType : lambdaRefType->getParameterTypes()) { @@ -217,7 +197,6 @@ MlirToRvsdg::convertLambda(mlir::Operation & mlirLambda, rvsdg::region & rvsdgRe functionName->getValue().str(), llvm::linkage::external_linkage); - // Get the region and convert it JLM_ASSERT(mlirLambda.getRegions().size() == 1); auto lambdaRegion = rvsdgLambda->subregion(); auto regionResults = convertRegion(mlirLambda.getRegion(0), *lambdaRegion); diff --git a/tests/jlm/mlir/backend/TestRvsdgToMlir.cpp b/tests/jlm/mlir/backend/TestRvsdgToMlir.cpp index 70b265d81..8cec7dd3b 100644 --- a/tests/jlm/mlir/backend/TestRvsdgToMlir.cpp +++ b/tests/jlm/mlir/backend/TestRvsdgToMlir.cpp @@ -14,6 +14,7 @@ static void TestLambda() { using namespace jlm::llvm; + using namespace mlir::rvsdg; auto rvsdgModule = RvsdgModule::Create(jlm::util::filepath(""), "", ""); auto graph = &rvsdgModule->Rvsdg(); @@ -45,41 +46,41 @@ TestLambda() auto omega = mlirgen.convertModule(*rvsdgModule); // Validate the generated MLIR - mlir::Region & omegaRegion = omega.getRegion(); + auto & omegaRegion = omega.getRegion(); assert(omegaRegion.getBlocks().size() == 1); - mlir::Block & omegaBlock = omegaRegion.front(); + auto & omegaBlock = omegaRegion.front(); // Lamda + terminating operation assert(omegaBlock.getOperations().size() == 2); auto & mlirLambda = omegaBlock.front(); - assert(mlirLambda.getName().getStringRef() == mlir::rvsdg::LambdaNode::getOperationName()); + assert(mlirLambda.getName().getStringRef() == LambdaNode::getOperationName()); // Verify function name auto functionNameAttribute = mlirLambda.getAttr(::llvm::StringRef("sym_name")); - mlir::StringAttr * functionName = static_cast(&functionNameAttribute); - std::string string = functionName->getValue().str(); + auto * functionName = static_cast(&functionNameAttribute); + auto string = functionName->getValue().str(); assert(string == "test"); // Verify function signature auto result = mlirLambda.getResult(0).getType(); - assert(result.getTypeID() == mlir::rvsdg::LambdaRefType::getTypeID()); - mlir::rvsdg::LambdaRefType * lambdaRefType = static_cast(&result); + assert(result.getTypeID() == LambdaRefType::getTypeID()); + auto * lambdaRefType = static_cast(&result); std::vector arguments; for (auto argumentType : lambdaRefType->getParameterTypes()) { arguments.push_back(argumentType); } - assert(arguments[0].getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID()); - assert(arguments[1].getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID()); - assert(arguments[2].getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID()); + assert(arguments[0].getTypeID() == IOStateEdgeType::getTypeID()); + assert(arguments[1].getTypeID() == MemStateEdgeType::getTypeID()); + assert(arguments[2].getTypeID() == LoopStateEdgeType::getTypeID()); std::vector results; for (auto returnType : lambdaRefType->getReturnTypes()) { results.push_back(returnType); } assert(results[0].getTypeID() == mlir::IntegerType::getTypeID()); - assert(results[1].getTypeID() == mlir::rvsdg::IOStateEdgeType::getTypeID()); - assert(results[2].getTypeID() == mlir::rvsdg::MemStateEdgeType::getTypeID()); - assert(results[3].getTypeID() == mlir::rvsdg::LoopStateEdgeType::getTypeID()); + assert(results[1].getTypeID() == IOStateEdgeType::getTypeID()); + assert(results[2].getTypeID() == MemStateEdgeType::getTypeID()); + assert(results[3].getTypeID() == LoopStateEdgeType::getTypeID()); auto & lambdaRegion = mlirLambda.getRegion(0); auto & lambdaBlock = lambdaRegion.front(); diff --git a/tests/jlm/mlir/frontend/TestMlirToRvsdg.cpp b/tests/jlm/mlir/frontend/TestMlirToRvsdg.cpp index 34465afaa..836771424 100644 --- a/tests/jlm/mlir/frontend/TestMlirToRvsdg.cpp +++ b/tests/jlm/mlir/frontend/TestMlirToRvsdg.cpp @@ -16,42 +16,39 @@ static void TestLambda() { { + using namespace mlir::rvsdg; + using namespace mlir::jlm; + + // Setup MLIR Context and load dialects auto context = std::make_unique(); - // Load the RVSDG dialect - context->getOrLoadDialect(); - // Load the JLM dialect - context->getOrLoadDialect(); - // Load the Arith dialect + context->getOrLoadDialect(); + context->getOrLoadDialect(); context->getOrLoadDialect(); auto builder = std::make_unique(context.get()); - // Create the MLIR omega node - mlir::rvsdg::OmegaNode omega = - builder->create(builder->getUnknownLoc()); - - // Create a block for the region as this is currently not done automatically - mlir::Region & omegaRegion = omega.getRegion(); - mlir::Block * omegaBlock = new mlir::Block; + auto omega = builder->create(builder->getUnknownLoc()); + auto & omegaRegion = omega.getRegion(); + auto * omegaBlock = new mlir::Block; omegaRegion.push_back(omegaBlock); // Handle function arguments ::llvm::SmallVector arguments; - arguments.push_back(builder->getType<::mlir::rvsdg::IOStateEdgeType>()); - arguments.push_back(builder->getType<::mlir::rvsdg::MemStateEdgeType>()); - arguments.push_back(builder->getType<::mlir::rvsdg::LoopStateEdgeType>()); + arguments.push_back(builder->getType()); + arguments.push_back(builder->getType()); + arguments.push_back(builder->getType()); ::llvm::ArrayRef argumentsArray(arguments); // Handle function results ::llvm::SmallVector results; results.push_back(builder->getIntegerType(32)); - results.push_back(builder->getType<::mlir::rvsdg::IOStateEdgeType>()); - results.push_back(builder->getType<::mlir::rvsdg::MemStateEdgeType>()); - results.push_back(builder->getType<::mlir::rvsdg::LoopStateEdgeType>()); + results.push_back(builder->getType()); + results.push_back(builder->getType()); + results.push_back(builder->getType()); ::llvm::ArrayRef resultsArray(results); // LambdaNodes return a LambdaRefType ::llvm::SmallVector lambdaRef; - auto refType = builder->getType<::mlir::rvsdg::LambdaRefType>(argumentsArray, resultsArray); + auto refType = builder->getType(argumentsArray, resultsArray); lambdaRef.push_back(refType); // Add function attributes @@ -62,32 +59,21 @@ TestLambda() attributes.push_back(symbolName); ::llvm::ArrayRef<::mlir::NamedAttribute> attributesRef(attributes); - // Add the inputs to the function + // Add inputs to the function ::llvm::SmallVector inputs; // Create the lambda node and add it to the region/block it resides in - auto lambda = builder->create( - builder->getUnknownLoc(), - lambdaRef, - inputs, - attributesRef); + auto lambda = + builder->create(builder->getUnknownLoc(), lambdaRef, inputs, attributesRef); omegaBlock->push_back(lambda); - - // Create a block for the region as this is not done automatically - mlir::Region & lambdaRegion = lambda.getRegion(); - mlir::Block * lambdaBlock = new mlir::Block; + auto & lambdaRegion = lambda.getRegion(); + auto * lambdaBlock = new mlir::Block; lambdaRegion.push_back(lambdaBlock); // Add arguments to the region - lambdaBlock->addArgument( - builder->getType<::mlir::rvsdg::IOStateEdgeType>(), - builder->getUnknownLoc()); - lambdaBlock->addArgument( - builder->getType<::mlir::rvsdg::MemStateEdgeType>(), - builder->getUnknownLoc()); - lambdaBlock->addArgument( - builder->getType<::mlir::rvsdg::LoopStateEdgeType>(), - builder->getUnknownLoc()); + lambdaBlock->addArgument(builder->getType(), builder->getUnknownLoc()); + lambdaBlock->addArgument(builder->getType(), builder->getUnknownLoc()); + lambdaBlock->addArgument(builder->getType(), builder->getUnknownLoc()); auto constOp = builder->create(builder->getUnknownLoc(), 1, 32); lambdaBlock->push_back(constOp); @@ -99,15 +85,13 @@ TestLambda() regionResults.push_back(lambdaBlock->getArgument(2)); // Handle the result of the lambda - auto lambdaResult = - builder->create(builder->getUnknownLoc(), regionResults); + auto lambdaResult = builder->create(builder->getUnknownLoc(), regionResults); lambdaBlock->push_back(lambdaResult); // Handle the result of the omega ::llvm::SmallVector omegaRegionResults; omegaRegionResults.push_back(lambda); - auto omegaResult = - builder->create(builder->getUnknownLoc(), omegaRegionResults); + auto omegaResult = builder->create(builder->getUnknownLoc(), omegaRegionResults); omegaBlock->push_back(omegaResult); std::unique_ptr rootBlock = std::make_unique(); @@ -116,14 +100,20 @@ TestLambda() // Convert the MLIR to RVSDG jlm::mlirrvsdg::MlirToRvsdg rvsdggen; auto rvsdgModule = rvsdggen.convertMlir(rootBlock); - auto & graph = rvsdgModule->Rvsdg(); auto region = graph.root(); - assert(region->nnodes() == 1); - assert(jlm::rvsdg::region::Contains(*region, false)); - assert(!jlm::rvsdg::region::Contains(*region, false)); - assert(jlm::rvsdg::region::Contains(*region, true)); + { + using namespace jlm::rvsdg; + + assert(region->nnodes() == 1); + auto convertedLambda = + jlm::util::AssertedCast(region->nodes.first()); + assert(is(convertedLambda)); + + assert(convertedLambda->subregion()->nnodes() == 1); + assert(is(convertedLambda->subregion()->nodes.first())); + } } }