Skip to content

Commit

Permalink
Convert store normalizations to new normalization interface (#688)
Browse files Browse the repository at this point in the history
  • Loading branch information
phate authored Dec 19, 2024
1 parent 0c35186 commit 24cfa78
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 52 deletions.
44 changes: 44 additions & 0 deletions jlm/llvm/ir/operators/Store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,50 @@ store_normal_form::set_multiple_origin_reducible(bool enable)
graph()->mark_denormalized();
}

std::optional<std::vector<rvsdg::output *>>
NormalizeStoreMux(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands)
{
if (is_store_mux_reducible(operands))
return perform_store_mux_reduction(operation, operands);

return std::nullopt;
}

std::optional<std::vector<rvsdg::output *>>
NormalizeStoreStore(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands)
{
if (is_store_store_reducible(operation, operands))
return perform_store_store_reduction(operation, operands);

return std::nullopt;
}

std::optional<std::vector<rvsdg::output *>>
NormalizeStoreAlloca(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands)
{
if (is_store_alloca_reducible(operands))
return perform_store_alloca_reduction(operation, operands);

return std::nullopt;
}

std::optional<std::vector<rvsdg::output *>>
NormalizeStoreDuplicateState(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands)
{
if (is_multiple_origin_reducible(operands))
return perform_multiple_origin_reduction(operation, operands);

return std::nullopt;
}

}

namespace
Expand Down
85 changes: 85 additions & 0 deletions jlm/llvm/ir/operators/Store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <jlm/rvsdg/simple-node.hpp>
#include <jlm/rvsdg/simple-normal-form.hpp>

#include <optional>

namespace jlm::llvm
{

Expand Down Expand Up @@ -569,6 +571,89 @@ class StoreVolatileNode final : public StoreNode
}
};

/**
* \brief Swaps a memory state merge operation and a store operation.
*
* sx1 = MemStateMerge si1 ... siM
* sl1 = StoreNonVolatile a v sx1
* =>
* sl1 ... slM = StoreNonVolatile a v si1 ... siM
* sx1 = MemStateMerge sl1 ... slM
*
* FIXME: The reduction can be generalized: A store node can have multiple operands from different
* merge nodes.
*
* @param operation The operation of the StoreNonVolatile node.
* @param operands The operands of the StoreNonVolatile node.
*
* @return If the normalization could be applied, then the results of the store operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeStoreMux(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands);

/**
* \brief Removes a duplicated store to the same address.
*
* so1 so2 = StoreNonVolatile a v1 si1 si2
* sx1 sx2 = StoreNonVolatile a v2 so1 so2
* =>
* sx1 sx2 = StoreNonVolatile a v2 si1 si2
*
* @param operation The operation of the StoreNonVolatile node.
* @param operands The operands of the StoreNonVolatile node.
*
* @return If the normalization could be applied, then the results of the store operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeStoreStore(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands);

/**
* \brief Removes unnecessary state from a store node when its address originates directly from an
* alloca node.
*
* a s = Alloca b
* so1 so2 = StoreNonVolatile a v s si1 si2
* ... = AnyOp so1 so2
* =>
* a s = Alloca b
* so1 = StoreNonVolatile a v s
* ... = AnyOp so1 so1
*
* @param operation The operation of the StoreNonVolatile node.
* @param operands The operands of the StoreNonVolatile node.
*
* @return If the normalization could be applied, then the results of the store operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeStoreAlloca(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands);

/**
* \brief Remove duplicated state operands
*
* so1 so2 so3 = StoreNonVolatile a v si1 si1 si1
* =>
* so1 = StoreNonVolatile a v si1
*
* @param operation The load operation on which the transformation is performed.
* @param operands The operands of the load node.
*
* @return If the normalization could be applied, then the results of the load operation after
* the transformation. Otherwise, std::nullopt.
*/
std::optional<std::vector<rvsdg::output *>>
NormalizeStoreDuplicateState(
const StoreNonVolatileOperation & operation,
const std::vector<rvsdg::output *> & operands);

}

#endif
104 changes: 52 additions & 52 deletions tests/jlm/llvm/ir/operators/StoreTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
#include <test-registry.hpp>
#include <test-types.hpp>

#include <jlm/rvsdg/bitstring/type.hpp>
#include <jlm/rvsdg/view.hpp>

#include <jlm/llvm/ir/operators/alloca.hpp>
#include <jlm/llvm/ir/operators/MemoryStateOperations.hpp>
#include <jlm/llvm/ir/operators/Store.hpp>
#include <jlm/llvm/ir/RvsdgModule.hpp>
#include <jlm/rvsdg/bitstring/type.hpp>
#include <jlm/rvsdg/NodeNormalization.hpp>
#include <jlm/rvsdg/view.hpp>

static int
StoreNonVolatileOperationEquality()
Expand Down Expand Up @@ -209,7 +209,7 @@ TestCopy()
JLM_UNIT_TEST_REGISTER("jlm/llvm/ir/operators/StoreTests-TestCopy", TestCopy)

static int
TestStoreMuxReduction()
TestStoreMuxNormalization()
{
using namespace jlm::llvm;

Expand All @@ -219,10 +219,9 @@ TestStoreMuxReduction()
auto mt = MemoryStateType::Create();

jlm::rvsdg::Graph graph;
auto nf = graph.node_normal_form(typeid(StoreNonVolatileOperation));
auto snf = static_cast<jlm::llvm::store_normal_form *>(nf);
snf->set_mutable(false);
snf->set_store_mux_reducible(false);
auto nf = StoreNonVolatileOperation::GetNormalForm(&graph);
nf->set_mutable(false);
nf->set_store_mux_reducible(false);

auto a = &jlm::tests::GraphImport::Create(graph, pt, "a");
auto v = &jlm::tests::GraphImport::Create(graph, vt, "v");
Expand All @@ -231,27 +230,26 @@ TestStoreMuxReduction()
auto s3 = &jlm::tests::GraphImport::Create(graph, mt, "s3");

auto mux = MemoryStateMergeOperation::Create({ s1, s2, s3 });
auto state = StoreNonVolatileNode::Create(a, v, { mux }, 4);
auto & storeNode = StoreNonVolatileNode::CreateNode(*a, *v, { mux }, 4);

auto & ex = GraphExport::Create(*state[0], "s");
auto & ex = GraphExport::Create(*storeNode.output(0), "s");

// jlm::rvsdg::view(graph.root(), stdout);
jlm::rvsdg::view(graph.root(), stdout);

// Act
snf->set_mutable(true);
snf->set_store_mux_reducible(true);
graph.normalize();
auto success = jlm::rvsdg::ReduceNode<StoreNonVolatileOperation>(NormalizeStoreMux, storeNode);
graph.prune();

// jlm::rvsdg::view(graph.root(), stdout);
jlm::rvsdg::view(graph.root(), stdout);

// Assert
auto muxnode = jlm::rvsdg::output::GetNode(*ex.origin());
assert(is<MemoryStateMergeOperation>(muxnode));
assert(muxnode->ninputs() == 3);
auto n0 = jlm::rvsdg::output::GetNode(*muxnode->input(0)->origin());
auto n1 = jlm::rvsdg::output::GetNode(*muxnode->input(1)->origin());
auto n2 = jlm::rvsdg::output::GetNode(*muxnode->input(2)->origin());
assert(success);
auto muxNode = jlm::rvsdg::output::GetNode(*ex.origin());
assert(is<MemoryStateMergeOperation>(muxNode));
assert(muxNode->ninputs() == 3);
auto n0 = jlm::rvsdg::output::GetNode(*muxNode->input(0)->origin());
auto n1 = jlm::rvsdg::output::GetNode(*muxNode->input(1)->origin());
auto n2 = jlm::rvsdg::output::GetNode(*muxNode->input(2)->origin());
assert(jlm::rvsdg::is<StoreNonVolatileOperation>(n0->GetOperation()));
assert(jlm::rvsdg::is<StoreNonVolatileOperation>(n1->GetOperation()));
assert(jlm::rvsdg::is<StoreNonVolatileOperation>(n2->GetOperation()));
Expand All @@ -260,8 +258,8 @@ TestStoreMuxReduction()
}

JLM_UNIT_TEST_REGISTER(
"jlm/llvm/ir/operators/StoreTests-TestStoreMuxReduction",
TestStoreMuxReduction)
"jlm/llvm/ir/operators/StoreTests-TestStoreMuxNormalization",
TestStoreMuxNormalization)

static int
TestDuplicateStateReduction()
Expand All @@ -284,25 +282,25 @@ TestDuplicateStateReduction()
auto s2 = &jlm::tests::GraphImport::Create(graph, memoryStateType, "s2");
auto s3 = &jlm::tests::GraphImport::Create(graph, memoryStateType, "s3");

auto states = StoreNonVolatileNode::Create(a, v, { s1, s2, s1, s2, s3 }, 4);
auto & storeNode = StoreNonVolatileNode::CreateNode(*a, *v, { s1, s2, s1, s2, s3 }, 4);

auto & exS1 = GraphExport::Create(*states[0], "exS1");
auto & exS2 = GraphExport::Create(*states[1], "exS2");
auto & exS3 = GraphExport::Create(*states[2], "exS3");
auto & exS4 = GraphExport::Create(*states[3], "exS4");
auto & exS5 = GraphExport::Create(*states[4], "exS5");
auto & exS1 = GraphExport::Create(*storeNode.output(0), "exS1");
auto & exS2 = GraphExport::Create(*storeNode.output(1), "exS2");
auto & exS3 = GraphExport::Create(*storeNode.output(2), "exS3");
auto & exS4 = GraphExport::Create(*storeNode.output(3), "exS4");
auto & exS5 = GraphExport::Create(*storeNode.output(4), "exS5");

view(graph.root(), stdout);

// Act
nf->set_mutable(true);
nf->set_multiple_origin_reducible(true);
graph.normalize();
auto success =
jlm::rvsdg::ReduceNode<StoreNonVolatileOperation>(NormalizeStoreDuplicateState, storeNode);
graph.prune();

view(graph.root(), stdout);

// Assert
assert(success);
auto node = jlm::rvsdg::output::GetNode(*exS1.origin());
assert(is<StoreNonVolatileOperation>(node));
assert(node->ninputs() == 5);
Expand Down Expand Up @@ -331,35 +329,38 @@ TestStoreAllocaReduction()
auto bt = jlm::rvsdg::bittype::Create(32);

jlm::rvsdg::Graph graph;
auto nf = graph.node_normal_form(typeid(StoreNonVolatileOperation));
auto snf = static_cast<jlm::llvm::store_normal_form *>(nf);
snf->set_mutable(false);
snf->set_store_alloca_reducible(false);
auto nf = StoreNonVolatileOperation::GetNormalForm(&graph);
nf->set_mutable(false);
nf->set_store_alloca_reducible(false);

auto size = &jlm::tests::GraphImport::Create(graph, bt, "size");
auto value = &jlm::tests::GraphImport::Create(graph, vt, "value");
auto s = &jlm::tests::GraphImport::Create(graph, mt, "s");

auto alloca1 = alloca_op::create(vt, size, 4);
auto alloca2 = alloca_op::create(vt, size, 4);
auto states1 = StoreNonVolatileNode::Create(alloca1[0], value, { alloca1[1], alloca2[1], s }, 4);
auto states2 = StoreNonVolatileNode::Create(alloca2[0], value, states1, 4);
auto & storeNode1 =
StoreNonVolatileNode::CreateNode(*alloca1[0], *value, { alloca1[1], alloca2[1], s }, 4);
auto & storeNode2 =
StoreNonVolatileNode::CreateNode(*alloca2[0], *value, outputs(&storeNode1), 4);

GraphExport::Create(*states2[0], "s1");
GraphExport::Create(*states2[1], "s2");
GraphExport::Create(*states2[2], "s3");
GraphExport::Create(*storeNode2.output(0), "s1");
GraphExport::Create(*storeNode2.output(1), "s2");
GraphExport::Create(*storeNode2.output(2), "s3");

// jlm::rvsdg::view(graph.root(), stdout);
view(graph.root(), stdout);

// Act
snf->set_mutable(true);
snf->set_store_alloca_reducible(true);
graph.normalize();
auto success1 =
jlm::rvsdg::ReduceNode<StoreNonVolatileOperation>(NormalizeStoreAlloca, storeNode1);
auto success2 =
jlm::rvsdg::ReduceNode<StoreNonVolatileOperation>(NormalizeStoreAlloca, storeNode2);
graph.prune();

// jlm::rvsdg::view(graph.root(), stdout);
view(graph.root(), stdout);

// Assert
assert(success1 && success2);
bool has_add_import = false;
for (size_t n = 0; n < graph.root()->nresults(); n++)
{
Expand Down Expand Up @@ -391,22 +392,21 @@ TestStoreStoreReduction()
auto v2 = &jlm::tests::GraphImport::Create(graph, vt, "value");
auto s = &jlm::tests::GraphImport::Create(graph, mt, "state");

auto s1 = StoreNonVolatileNode::Create(a, v1, { s }, 4)[0];
auto s2 = StoreNonVolatileNode::Create(a, v2, { s1 }, 4)[0];
auto & storeNode1 = StoreNonVolatileNode::CreateNode(*a, *v1, { s }, 4);
auto & storeNode2 = StoreNonVolatileNode::CreateNode(*a, *v2, outputs(&storeNode1), 4);

auto & ex = GraphExport::Create(*s2, "state");
auto & ex = GraphExport::Create(*storeNode2.output(0), "state");

jlm::rvsdg::view(graph.root(), stdout);

// Act
auto nf = StoreNonVolatileOperation::GetNormalForm(&graph);
nf->set_store_store_reducible(true);
graph.normalize();
auto success = jlm::rvsdg::ReduceNode<StoreNonVolatileOperation>(NormalizeStoreStore, storeNode2);
graph.prune();

jlm::rvsdg::view(graph.root(), stdout);

// Assert
assert(success);
assert(graph.root()->nnodes() == 1);
assert(jlm::rvsdg::output::GetNode(*ex.origin())->input(1)->origin() == v2);

Expand Down

0 comments on commit 24cfa78

Please sign in to comment.