Skip to content

Commit

Permalink
Covert mux normalizations to new normalization interface
Browse files Browse the repository at this point in the history
  • Loading branch information
phate committed Dec 19, 2024
1 parent 24cfa78 commit ba8e2a9
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 31 deletions.
22 changes: 22 additions & 0 deletions jlm/rvsdg/statemux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <jlm/rvsdg/graph.hpp>
#include <jlm/rvsdg/statemux.hpp>

#include <optional>

namespace jlm::rvsdg
{

Expand Down Expand Up @@ -195,6 +197,26 @@ mux_normal_form::set_multiple_origin_reducible(bool enable)
graph()->mark_denormalized();
}

std::optional<std::vector<rvsdg::output *>>
NormalizeMuxMux(const mux_op & operation, const std::vector<rvsdg::output *> & operands)
{
if (auto muxNode = is_mux_mux_reducible(operands))
return perform_mux_mux_reduction(operation, muxNode, operands);

return std::nullopt;
}

std::optional<std::vector<rvsdg::output *>>
NormalizeMuxDuplicateOperands(
const mux_op & operation,
const std::vector<rvsdg::output *> & operands)
{
if (is_multiple_origin_reducible(operands))
return perform_multiple_origin_reduction(operation, operands);

return std::nullopt;
}

}

namespace
Expand Down
38 changes: 38 additions & 0 deletions jlm/rvsdg/statemux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <jlm/rvsdg/simple-node.hpp>
#include <jlm/rvsdg/simple-normal-form.hpp>

#include <optional>

namespace jlm::rvsdg
{

Expand Down Expand Up @@ -126,6 +128,42 @@ create_state_split(
return create_state_mux(std::move(type), { operand }, nresults);
}

/**
* \brief Merges multiple mux operations into a single operation.
*
* so1 = mux_op si1 si2
* so2 = mux_op si3 si4
* so3 = mux_op so1 so2
* =>
* so3 = mux_op si1 si2 si3 si4
*
* @param operation The mux operation on which the transformation is performed.
* @param operands The operands of the mux node.
*
* @return If the normalization could be applied, then the results of the mux operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeMuxMux(const mux_op & operation, const std::vector<rvsdg::output *> & operands);

/**
* \brief Remove duplicated operands
*
* so1 = mux_op si1 si1 si2 si1 si3
* =>
* so1 = mux_op si1 si2 si3
*
* @param operation The mux operation on which the transformation is performed.
* @param operands The operands of the mux node.
*
* @return If the normalization could be applied, then the results of the mux operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeMuxDuplicateOperands(
const mux_op & operation,
const std::vector<rvsdg::output *> & operands);

}

#endif
72 changes: 41 additions & 31 deletions tests/jlm/rvsdg/test-statemux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,85 +8,95 @@
#include "test-registry.hpp"
#include "test-types.hpp"

#include <jlm/rvsdg/NodeNormalization.hpp>
#include <jlm/rvsdg/statemux.hpp>
#include <jlm/rvsdg/view.hpp>

static void
test_mux_mux_reduction()
#include <cassert>

static int
MuxMuxReduction()
{
using namespace jlm::rvsdg;

auto st = jlm::tests::statetype::Create();
auto stateType = jlm::tests::statetype::Create();

// Arrange
Graph graph;
auto nf = graph.node_normal_form(typeid(jlm::rvsdg::mux_op));
auto mnf = static_cast<jlm::rvsdg::mux_normal_form *>(nf);
mnf->set_mutable(false);
mnf->set_mux_mux_reducible(false);

auto x = &jlm::tests::GraphImport::Create(graph, st, "x");
auto y = &jlm::tests::GraphImport::Create(graph, st, "y");
auto z = &jlm::tests::GraphImport::Create(graph, st, "z");
auto x = &jlm::tests::GraphImport::Create(graph, stateType, "x");
auto y = &jlm::tests::GraphImport::Create(graph, stateType, "y");
auto z = &jlm::tests::GraphImport::Create(graph, stateType, "z");

auto mux1 = jlm::rvsdg::create_state_merge(st, { x, y });
auto mux2 = jlm::rvsdg::create_state_split(st, z, 2);
auto mux3 = jlm::rvsdg::create_state_merge(st, { mux1, mux2[0], mux2[1], z });
auto mux1 = jlm::rvsdg::create_state_merge(stateType, { x, y });
auto mux2 = jlm::rvsdg::create_state_split(stateType, z, 2);
auto mux3 = jlm::rvsdg::create_state_merge(stateType, { mux1, mux2[0], mux2[1], z });

auto & ex = jlm::tests::GraphExport::Create(*mux3, "m");

// jlm::rvsdg::view(graph.root(), stdout);
view(graph.root(), stdout);

mnf->set_mutable(true);
mnf->set_mux_mux_reducible(true);
graph.normalize();
graph.prune();
// Act
bool success = false;
do
{
auto muxNode = output::GetNode(*ex.origin());
success = ReduceNode<mux_op>(NormalizeMuxMux, *muxNode);
} while (success);

// jlm::rvsdg::view(graph.root(), stdout);
view(graph.root(), stdout);

// Assert
auto node = output::GetNode(*ex.origin());
assert(node->ninputs() == 4);
assert(node->input(0)->origin() == x);
assert(node->input(1)->origin() == y);
assert(node->input(2)->origin() == z);
assert(node->input(3)->origin() == z);

return 0;
}

static void
test_multiple_origin_reduction()
JLM_UNIT_TEST_REGISTER("jlm/rvsdg/test-statemux-MuxMuxReduction", MuxMuxReduction)

static int
DuplicateOperandReduction()
{
using namespace jlm::rvsdg;

auto st = jlm::tests::statetype::Create();
auto stateType = jlm::tests::statetype::Create();

// Arrange
Graph graph;
auto nf = graph.node_normal_form(typeid(jlm::rvsdg::mux_op));
auto mnf = static_cast<jlm::rvsdg::mux_normal_form *>(nf);
mnf->set_mutable(false);
mnf->set_multiple_origin_reducible(false);

auto x = &jlm::tests::GraphImport::Create(graph, st, "x");
auto mux1 = jlm::rvsdg::create_state_merge(st, { x, x });
auto x = &jlm::tests::GraphImport::Create(graph, stateType, "x");
auto mux1 = jlm::rvsdg::create_state_merge(stateType, { x, x });
auto & ex = jlm::tests::GraphExport::Create(*mux1, "m");

view(graph.root(), stdout);

mnf->set_mutable(true);
mnf->set_multiple_origin_reducible(true);
graph.normalize();
// Act
auto muxNode = output::GetNode(*ex.origin());
auto success = ReduceNode<mux_op>(NormalizeMuxDuplicateOperands, *muxNode);
graph.prune();

view(graph.root(), stdout);

// Assert
assert(success);
assert(output::GetNode(*ex.origin())->ninputs() == 1);
}

static int
test_main(void)
{
test_mux_mux_reduction();
test_multiple_origin_reduction();

return 0;
}

JLM_UNIT_TEST_REGISTER("jlm/rvsdg/test-statemux", test_main)
JLM_UNIT_TEST_REGISTER(
"jlm/rvsdg/test-statemux-DuplicateOperandReduction",
DuplicateOperandReduction)

0 comments on commit ba8e2a9

Please sign in to comment.