From b79656e3f3299c63d4220463b376c9f81ffdf5b2 Mon Sep 17 00:00:00 2001 From: Nico Reissmann Date: Tue, 7 Jan 2025 19:56:53 +0100 Subject: [PATCH] Add FlattenBitConcatOperation normalization --- jlm/rvsdg/bitstring/concat.cpp | 35 +++++++++++++++++++++++++ jlm/rvsdg/bitstring/concat.hpp | 5 ++++ tests/jlm/rvsdg/bitstring/bitstring.cpp | 11 +++++++- 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/jlm/rvsdg/bitstring/concat.cpp b/jlm/rvsdg/bitstring/concat.cpp index 156b0e201..4187078b2 100644 --- a/jlm/rvsdg/bitstring/concat.cpp +++ b/jlm/rvsdg/bitstring/concat.cpp @@ -368,4 +368,39 @@ bitconcat_op::copy() const return std::make_unique(*this); } +static std::vector> +GetTypesFromOperands(const std::vector & args) +{ + std::vector> types; + for (const auto arg : args) + { + types.push_back(std::dynamic_pointer_cast(arg->Type())); + } + return types; +} + +std::optional> +FlattenBitConcatOperation(const bitconcat_op &, const std::vector & operands) +{ + JLM_ASSERT(!operands.empty()); + + const auto newOperands = base::detail::associative_flatten( + operands, + [](jlm::rvsdg::output * arg) + { + // FIXME: switch to comparing operator, not just typeid, after + // converting "concat" to not be a binary operator anymore + return is(output::GetNode(*arg)); + }); + + if (operands == newOperands) + { + JLM_ASSERT(newOperands.size() == 2); + return std::nullopt; + } + + JLM_ASSERT(newOperands.size() > 2); + return outputs(&CreateOpNode(newOperands, GetTypesFromOperands(newOperands))); +} + } diff --git a/jlm/rvsdg/bitstring/concat.hpp b/jlm/rvsdg/bitstring/concat.hpp index ae5ccb20d..df1198350 100644 --- a/jlm/rvsdg/bitstring/concat.hpp +++ b/jlm/rvsdg/bitstring/concat.hpp @@ -55,6 +55,11 @@ class bitconcat_op final : public BinaryOperation jlm::rvsdg::output * bitconcat(const std::vector & operands); +std::optional> +FlattenBitConcatOperation( + const bitconcat_op & operation, + const std::vector & operands); + } #endif diff --git a/tests/jlm/rvsdg/bitstring/bitstring.cpp b/tests/jlm/rvsdg/bitstring/bitstring.cpp index f2f589c07..6736be078 100644 --- a/tests/jlm/rvsdg/bitstring/bitstring.cpp +++ b/tests/jlm/rvsdg/bitstring/bitstring.cpp @@ -8,6 +8,7 @@ #include #include +#include #include static int @@ -1176,8 +1177,10 @@ ConcatFlattening() { using namespace jlm::rvsdg; - // Arrange & Act + // Arrange Graph graph; + const auto nf = graph.GetNodeNormalForm(typeid(bitconcat_op)); + nf->set_mutable(false); auto x = &jlm::tests::GraphImport::Create(graph, bittype::Create(8), "x"); auto y = &jlm::tests::GraphImport::Create(graph, bittype::Create(8), "y"); @@ -1189,6 +1192,12 @@ ConcatFlattening() auto & ex = jlm::tests::GraphExport::Create(*concatResult2, "dummy"); view(graph, stdout); + // Act + const auto concatNode = output::GetNode(*ex.origin()); + ReduceNode(FlattenBitConcatOperation, *concatNode); + + view(graph, stdout); + // Assert auto node = output::GetNode(*ex.origin()); assert(dynamic_cast(&node->GetOperation()));