From 0ac4805788532968df3ee816f2db50090561252e Mon Sep 17 00:00:00 2001 From: Nico Reissmann Date: Sat, 21 Dec 2024 10:23:13 +0100 Subject: [PATCH] Convert binary operation normalizations to new normalization interface (#693) --- jlm/rvsdg/binary.cpp | 63 ++++++++ jlm/rvsdg/binary.hpp | 39 +++++ tests/jlm/rvsdg/test-binary.cpp | 267 +++++++++++++++++++++++++++++++- 3 files changed, 364 insertions(+), 5 deletions(-) diff --git a/jlm/rvsdg/binary.cpp b/jlm/rvsdg/binary.cpp index 67407691e..60c961e5d 100644 --- a/jlm/rvsdg/binary.cpp +++ b/jlm/rvsdg/binary.cpp @@ -319,6 +319,69 @@ binary_op::flags() const noexcept return jlm::rvsdg::binary_op::flags::none; } +std::optional> +FlattenAssociativeBinaryOperation( + const binary_op & operation, + const std::vector & operands) +{ + JLM_ASSERT(!operands.empty()); + auto region = operands[0]->region(); + + if (!operation.is_associative()) + { + return std::nullopt; + } + + auto newOperands = base::detail::associative_flatten( + operands, + [&operation](rvsdg::output * operand) + { + auto node = TryGetOwnerNode(*operand); + if (node == nullptr) + return false; + + auto flattenedBinaryOperation = + dynamic_cast(&node->GetOperation()); + return node->GetOperation() == operation + || (flattenedBinaryOperation && flattenedBinaryOperation->bin_operation() == operation); + }); + + if (operands == newOperands) + { + JLM_ASSERT(newOperands.size() == 2); + return std::nullopt; + } + + JLM_ASSERT(newOperands.size() > 2); + auto flattenedBinaryOperation = + std::make_unique(operation, newOperands.size()); + return outputs(SimpleNode::create(region, *flattenedBinaryOperation, newOperands)); +} + +std::optional> +NormalizeBinaryOperation(const binary_op & operation, const std::vector & operands) +{ + JLM_ASSERT(!operands.empty()); + auto region = operands[0]->region(); + + auto newOperands = reduce_operands(operation, operands); + + if (newOperands == operands) + { + // The operands did not change, which means that none of the normalizations triggered. + return std::nullopt; + } + + if (newOperands.size() == 1) + { + // The operands could be reduced to a single value by applying constant folding. + return newOperands; + } + + JLM_ASSERT(newOperands.size() == 2); + return outputs(SimpleNode::create(region, operation, newOperands)); +} + /* flattened binary operator */ flattened_binary_op::~flattened_binary_op() noexcept diff --git a/jlm/rvsdg/binary.hpp b/jlm/rvsdg/binary.hpp index be2ef06c2..07f0fbe74 100644 --- a/jlm/rvsdg/binary.hpp +++ b/jlm/rvsdg/binary.hpp @@ -12,6 +12,8 @@ #include #include +#include + namespace jlm::rvsdg { @@ -167,6 +169,43 @@ class binary_op : public SimpleOperation } }; +/** + * \brief Flattens a cascade of the same binary operations into a single flattened binary operation. + * + * o1 = binaryNode i1 i2 + * o2 = binaryNode o1 i3 + * => + * o2 = flattenedBinaryNode i1 i2 i3 + * + * \pre The binary operation must be associative. + * + * @param operation The binary operation on which the transformation is performed. + * @param operands The operands of the binary node. + * @return If the normalization could be applied, then the results of the binary operation after + * the transformation. Otherwise, std::nullopt. + */ +std::optional> +FlattenAssociativeBinaryOperation( + const binary_op & operation, + const std::vector & operands); + +/** + * \brief Applies the reductions implemented in the binary operations reduction functions. + * + * @param operation The binary operation on which the transformation is performed. + * @param operands The operands of the binary node. + * + * @return If the normalization could be applied, then the results of the binary operation after + * the transformation. Otherwise, std::nullopt. + * + * \see binary_op::can_reduce_operand_pair() + * \see binary_op::reduce_operand_pair() + */ +std::optional> +NormalizeBinaryOperation( + const binary_op & operation, + const std::vector & operands); + class flattened_binary_op final : public SimpleOperation { public: diff --git a/tests/jlm/rvsdg/test-binary.cpp b/tests/jlm/rvsdg/test-binary.cpp index eee0cd3ff..64ba71182 100644 --- a/tests/jlm/rvsdg/test-binary.cpp +++ b/tests/jlm/rvsdg/test-binary.cpp @@ -7,10 +7,80 @@ #include "test-registry.hpp" #include "test-types.hpp" +#include #include -static void -test_flattened_binary_reduction() +class BinaryOperation final : public jlm::rvsdg::binary_op +{ +public: + BinaryOperation( + const std::shared_ptr operandType, + const std::shared_ptr resultType, + const enum jlm::rvsdg::binary_op::flags & flags) + : jlm::rvsdg::binary_op({ operandType, operandType }, resultType), + Flags_(flags) + {} + + jlm::rvsdg::binop_reduction_path_t + can_reduce_operand_pair(const jlm::rvsdg::output * operand1, const jlm::rvsdg::output * operand2) + const noexcept override + { + auto n1 = jlm::rvsdg::TryGetOwnerNode(*operand1); + auto n2 = jlm::rvsdg::TryGetOwnerNode(*operand2); + + if (jlm::rvsdg::is(n1) && jlm::rvsdg::is(n2)) + { + return 1; + } + + return 0; + } + + jlm::rvsdg::output * + reduce_operand_pair( + jlm::rvsdg::unop_reduction_path_t path, + jlm::rvsdg::output * op1, + jlm::rvsdg::output * op2) const override + { + + if (path == 1) + { + return op2; + } + + return nullptr; + } + + [[nodiscard]] enum jlm::rvsdg::binary_op::flags + flags() const noexcept override + { + return Flags_; + } + + bool + operator==(const Operation & other) const noexcept override + { + JLM_UNREACHABLE("Not implemented."); + } + + [[nodiscard]] std::string + debug_string() const override + { + return "BinaryOperation"; + } + + [[nodiscard]] std::unique_ptr + copy() const override + { + return std::make_unique(this->argument(0), this->result(0), Flags_); + } + +private: + enum jlm::rvsdg::binary_op::flags Flags_; +}; + +static int +FlattenedBinaryReduction() { using namespace jlm::rvsdg; @@ -84,14 +154,201 @@ test_flattened_binary_reduction() auto node2 = output::GetNode(*node1->input(0)->origin()); assert(is(node2)); } + + return 0; +} + +JLM_UNIT_TEST_REGISTER("jlm/rvsdg/test-binary-FlattenedBinaryReduction", FlattenedBinaryReduction) + +static int +FlattenAssociativeBinaryOperation_NotAssociativeBinary() +{ + using namespace jlm::rvsdg; + + // Arrange + auto valueType = jlm::tests::valuetype::Create(); + + Graph graph; + auto i0 = &jlm::tests::GraphImport::Create(graph, valueType, "i0"); + auto i1 = &jlm::tests::GraphImport::Create(graph, valueType, "i1"); + auto i2 = &jlm::tests::GraphImport::Create(graph, valueType, "i2"); + + jlm::tests::binary_op binaryOperation(valueType, valueType, binary_op::flags::none); + auto o1 = SimpleNode::create(graph.root(), binaryOperation, { i0, i1 }); + auto o2 = SimpleNode::create(graph.root(), binaryOperation, { o1->output(0), i2 }); + + auto & ex = jlm::tests::GraphExport::Create(*o2->output(0), "o2"); + + jlm::rvsdg::view(graph, stdout); + + // Act + auto node = TryGetOwnerNode(*ex.origin()); + auto success = ReduceNode(FlattenAssociativeBinaryOperation, *node); + + jlm::rvsdg::view(graph, stdout); + + // Assert + assert(success == false); + assert(TryGetOwnerNode(*ex.origin()) == node); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/rvsdg/test-binary-FlattenAssociatedBinaryOperation_NotAssociativeBinary", + FlattenAssociativeBinaryOperation_NotAssociativeBinary) + +static int +FlattenAssociativeBinaryOperation_NoNewOperands() +{ + using namespace jlm::rvsdg; + + // Arrange + auto valueType = jlm::tests::valuetype::Create(); + + Graph graph; + auto i0 = &jlm::tests::GraphImport::Create(graph, valueType, "i0"); + auto i1 = &jlm::tests::GraphImport::Create(graph, valueType, "i1"); + + jlm::tests::unary_op unaryOperation(valueType, valueType); + jlm::tests::binary_op binaryOperation(valueType, valueType, binary_op::flags::associative); + auto u1 = SimpleNode::create(graph.root(), unaryOperation, { i0 }); + auto u2 = SimpleNode::create(graph.root(), unaryOperation, { i1 }); + auto b2 = SimpleNode::create(graph.root(), binaryOperation, { u1->output(0), u2->output(0) }); + + auto & ex = jlm::tests::GraphExport::Create(*b2->output(0), "o2"); + + jlm::rvsdg::view(graph, stdout); + + // Act + auto node = TryGetOwnerNode(*ex.origin()); + auto success = ReduceNode(FlattenAssociativeBinaryOperation, *node); + + jlm::rvsdg::view(graph, stdout); + + // Assert + assert(success == false); + assert(TryGetOwnerNode(*ex.origin()) == node); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/rvsdg/test-binary-FlattenAssociatedBinaryOperation_NoNewOperands", + FlattenAssociativeBinaryOperation_NoNewOperands) + +static int +FlattenAssociativeBinaryOperation_Success() +{ + using namespace jlm::rvsdg; + + // Arrange + auto valueType = jlm::tests::valuetype::Create(); + + Graph graph; + auto i0 = &jlm::tests::GraphImport::Create(graph, valueType, "i0"); + auto i1 = &jlm::tests::GraphImport::Create(graph, valueType, "i1"); + auto i2 = &jlm::tests::GraphImport::Create(graph, valueType, "i2"); + + jlm::tests::binary_op binaryOperation(valueType, valueType, binary_op::flags::associative); + auto o1 = SimpleNode::create(graph.root(), binaryOperation, { i0, i1 }); + auto o2 = SimpleNode::create(graph.root(), binaryOperation, { o1->output(0), i2 }); + + auto & ex = jlm::tests::GraphExport::Create(*o2->output(0), "o2"); + + jlm::rvsdg::view(graph, stdout); + + // Act + auto node = TryGetOwnerNode(*ex.origin()); + auto success = ReduceNode(FlattenAssociativeBinaryOperation, *node); + + jlm::rvsdg::view(graph, stdout); + + // Assert + assert(success); + auto flattenedBinaryNode = TryGetOwnerNode(*ex.origin()); + assert(is(flattenedBinaryNode)); + assert(flattenedBinaryNode->ninputs() == 3); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/rvsdg/test-binary-FlattenAssociatedBinaryOperation_Success", + FlattenAssociativeBinaryOperation_Success) + +static int +NormalizeBinaryOperation_NoNewOperands() +{ + using namespace jlm::rvsdg; + + // Arrange + auto valueType = jlm::tests::valuetype::Create(); + + Graph graph; + auto i0 = &jlm::tests::GraphImport::Create(graph, valueType, "i0"); + auto i1 = &jlm::tests::GraphImport::Create(graph, valueType, "i1"); + + jlm::tests::binary_op binaryOperation(valueType, valueType, binary_op::flags::associative); + auto o1 = SimpleNode::create(graph.root(), binaryOperation, { i0, i1 }); + + auto & ex = jlm::tests::GraphExport::Create(*o1->output(0), "o2"); + + jlm::rvsdg::view(graph, stdout); + + // Act + auto node = TryGetOwnerNode(*ex.origin()); + auto success = ReduceNode(NormalizeBinaryOperation, *node); + + jlm::rvsdg::view(graph, stdout); + + // Assert + assert(success == false); + + return 0; } +JLM_UNIT_TEST_REGISTER( + "jlm/rvsdg/test-binary-NormalizeBinaryOperation_NoNewOperands", + NormalizeBinaryOperation_NoNewOperands) + static int -test_main() +NormalizeBinaryOperation_SingleOperand() { - test_flattened_binary_reduction(); + using namespace jlm::rvsdg; + + // Arrange + auto valueType = jlm::tests::valuetype::Create(); + + jlm::tests::unary_op unaryOperation(valueType, valueType); + BinaryOperation binaryOperation(valueType, valueType, binary_op::flags::none); + + Graph graph; + auto s0 = &jlm::tests::GraphImport::Create(graph, valueType, "s0"); + auto s1 = &jlm::tests::GraphImport::Create(graph, valueType, "s1"); + + auto u1 = SimpleNode::create(graph.root(), unaryOperation, { s0 }); + auto u2 = SimpleNode::create(graph.root(), unaryOperation, { s1 }); + + auto o1 = SimpleNode::create(graph.root(), binaryOperation, { u1->output(0), u2->output(0) }); + + auto & ex = jlm::tests::GraphExport::Create(*o1->output(0), "ex"); + + jlm::rvsdg::view(graph, stdout); + + // Act + auto node = TryGetOwnerNode(*ex.origin()); + auto success = ReduceNode(NormalizeBinaryOperation, *node); + + jlm::rvsdg::view(graph, stdout); + + // Assert + assert(success == true); + assert(ex.origin() == u2->output(0)); return 0; } -JLM_UNIT_TEST_REGISTER("jlm/rvsdg/test-binary", test_main) +JLM_UNIT_TEST_REGISTER( + "jlm/rvsdg/test-binary-NormalizeBinaryOperation_SingleOperand", + NormalizeBinaryOperation_SingleOperand)