Skip to content

Commit

Permalink
[XLS] Modernize to next-value nodes as soon as optimization starts
Browse files Browse the repository at this point in the history
Since proc inlining is the only remaining pass that cannot work with next-value nodes, we let it convert the procs it's inlining to use next-state elements, then convert them back to next_value nodes at the end.

This is the first step in our new plan for removing next-state element support:

1. (this change) Modernize to next-value nodes as soon as optimization starts, and let proc-inlining convert procs to next-state elements as needed
2. Switch all frontends to use next-value nodes natively
3. Make XLS IR verification reject nontrivial next-state elements, and remove the modernization pass from optimization
4. Switch proc-inlining from using next-state elements in procs to modeling them internally
5. Remove next-state element support from procs

PiperOrigin-RevId: 706026298
  • Loading branch information
ericastor authored and copybara-github committed Dec 13, 2024
1 parent 5e5e2e0 commit 82b203e
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 49 deletions.
14 changes: 7 additions & 7 deletions xls/flows/testdata/ir_wrapper_test_DslxProcsToIrOk.ir
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ chan test_package__output(bits[32], id=2, kind=streaming, ops=send_only, flow_co

proc __top__foo_0_next() {
tok: token = after_all(id=4)
receive.17: (token, bits[32]) = receive(tok, channel=test_package__in_0, id=17)
tok__1: token = tuple_index(receive.17, index=0, id=7, pos=[(0,12,13)])
receive.18: (token, bits[32]) = receive(tok__1, channel=test_package__in_1, id=18)
a: bits[32] = tuple_index(receive.17, index=1, id=8, pos=[(0,12,18)])
b: bits[32] = tuple_index(receive.18, index=1, id=12, pos=[(0,13,18)])
tok__2: token = tuple_index(receive.18, index=0, id=11, pos=[(0,13,13)])
receive.18: (token, bits[32]) = receive(tok, channel=test_package__in_0, id=18)
tok__1: token = tuple_index(receive.18, index=0, id=7, pos=[(0,12,13)])
receive.19: (token, bits[32]) = receive(tok__1, channel=test_package__in_1, id=19)
a: bits[32] = tuple_index(receive.18, index=1, id=8, pos=[(0,12,18)])
b: bits[32] = tuple_index(receive.19, index=1, id=12, pos=[(0,13,18)])
tok__2: token = tuple_index(receive.19, index=0, id=11, pos=[(0,13,13)])
add.13: bits[32] = add(a, b, id=13, pos=[(0,14,36)])
tok__3: token = send(tok__2, add.13, channel=test_package__output, id=19)
tok__3: token = send(tok__2, add.13, channel=test_package__output, id=20)
}
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1792,9 +1792,11 @@ cc_library(
"//xls/ir:node_util",
"//xls/ir:op",
"//xls/ir:source_location",
"//xls/ir:state_element",
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
5 changes: 3 additions & 2 deletions xls/passes/optimization_pass_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ PreInliningPassGroup::PreInliningPassGroup()
"pre-inlining passes") {
Add<DeadFunctionEliminationPass>();
Add<DeadCodeEliminationPass>();
// TODO: google/xls#1795 - Remove once full transition to next-op is complete.
Add<NextNodeModernizePass>();
// At this stage in the pipeline only optimizations up to level 2 should
// run. 'opt_level' is the maximum level of optimization which should be run
// in the entire pipeline so set the level of the simplification pass to the
Expand Down Expand Up @@ -225,6 +227,7 @@ class PostInliningOptPassGroup : public OptimizationCompoundPass {
Add<TokenDependencyPass>();
// Simplify the adapter procs before inlining.
Add<CapOptLevel<2, FixedPointSimplificationPass>>();

// TODO(allight): It might be worthwhile to split the pipeline here as well.
// Since proc-inlining is being phased out in favor of multi-proc codegen
// however this seems unnecessary.
Expand All @@ -236,8 +239,6 @@ class PostInliningOptPassGroup : public OptimizationCompoundPass {
Add<ProcStateFlatteningFixedPointPass>();
Add<IdentityRemovalPass>();
Add<DataflowSimplificationPass>();
// TODO(allight): Remove once full transition to next-op is complete.
Add<NextNodeModernizePass>();
Add<CapOptLevel<3, NextValueOptimizationPass>>();

Add<ProcStateNarrowingPass>();
Expand Down
101 changes: 101 additions & 0 deletions xls/passes/proc_inlining_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -49,6 +50,7 @@
#include "xls/ir/op.h"
#include "xls/ir/proc.h"
#include "xls/ir/source_location.h"
#include "xls/ir/state_element.h"
#include "xls/ir/topo_sort.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"
Expand Down Expand Up @@ -1782,6 +1784,95 @@ absl::Status SetProcState(Proc* proc,
return absl::OkStatus();
}

absl::Status ConvertToNextStateElements(Proc* proc) {
for (int64_t index = 0; index < proc->GetStateElementCount(); ++index) {
StateElement* state_element = proc->GetStateElement(index);
StateRead* state_read = proc->GetStateRead(index);
const absl::btree_set<Next*, Node::NodeIdLessThan>& nexts =
proc->next_values(state_read);
if (nexts.empty()) {
continue;
}

// Check that either all or none of the next_value nodes are predicated.
XLS_RET_CHECK(absl::c_all_of(nexts, [&](Next* next) {
return next->predicate().has_value() ==
(*nexts.begin())->predicate().has_value();
}));
const bool predicated =
!nexts.empty() && (*nexts.begin())->predicate().has_value();
if (!predicated) {
XLS_RET_CHECK_EQ(nexts.size(), 1);
}

std::vector<Node*> values;
std::optional<std::vector<Node*>> predicates;
values.reserve(nexts.size());
if (predicated) {
predicates.emplace();
predicates->reserve(nexts.size());
}
for (Next* next : nexts) {
values.push_back(next->value());
if (predicated) {
XLS_RET_CHECK(next->predicate().has_value());
predicates->push_back(*next->predicate());
}
}

Node* next_state;
if (predicated) {
SourceInfo loc = state_read->loc();
absl::c_reverse(*predicates);
XLS_ASSIGN_OR_RETURN(
Node * selector,
proc->MakeNodeWithName<Concat>(
loc, *predicates,
absl::StrCat(state_element->name(), "_next_selector")));
XLS_ASSIGN_OR_RETURN(
next_state,
proc->MakeNodeWithName<PrioritySelect>(
loc, selector, /*cases=*/values, /*default_value=*/state_read,
absl::StrCat(state_element->name(), "_next_state")));
} else {
next_state = values.front();
}

std::vector<Next*> nexts_to_remove(nexts.begin(), nexts.end());
for (Next* next : nexts_to_remove) {
XLS_RETURN_IF_ERROR(
next->ReplaceUsesWithNew<Tuple>(absl::Span<Node* const>{}).status());
XLS_RETURN_IF_ERROR(proc->RemoveNode(next));
}
XLS_RETURN_IF_ERROR(proc->SetNextStateElement(index, next_state));
}

return absl::OkStatus();
}

absl::Status ConvertToNextValueNodes(Proc* proc) {
for (int64_t index = 0; index < proc->GetStateElementCount(); ++index) {
StateRead* state_read = proc->GetStateRead(index);
if (proc->GetNextStateElement(index) == state_read) {
continue;
}

// Nontrivial next-state element; switch it to a next-value node, then
// remove it so we pass verification.
CHECK(proc->next_values(state_read).empty());
Node* next_value = proc->GetNextStateElement(index);
XLS_RETURN_IF_ERROR(
proc->MakeNodeWithName<Next>(
state_read->loc(), state_read,
/*value=*/next_value,
/*predicate=*/std::nullopt,
absl::StrCat(state_read->state_element()->name(), "_next"))
.status());
XLS_RETURN_IF_ERROR(proc->SetNextStateElement(index, state_read));
}
return absl::OkStatus();
}

} // namespace

absl::StatusOr<bool> ProcInliningPass::RunInternal(
Expand Down Expand Up @@ -1810,6 +1901,12 @@ absl::StatusOr<bool> ProcInliningPass::RunInternal(
procs_to_inline.push_back(proc.get());
}

for (Proc* proc : procs_to_inline) {
XLS_RETURN_IF_ERROR(ConvertToNextStateElements(proc));
}

VLOG(3) << "After switching to next-state elements:\n" << p->DumpIr();

{
int64_t top_ii = 1;
if (top_func_base->GetInitiationInterval().has_value()) {
Expand Down Expand Up @@ -1952,6 +2049,10 @@ absl::StatusOr<bool> ProcInliningPass::RunInternal(

VLOG(3) << "After deleting inlined I/O:\n" << p->DumpIr();

XLS_RETURN_IF_ERROR(ConvertToNextValueNodes(container_proc));

VLOG(3) << "After switching back to next-value nodes:\n" << p->DumpIr();

return true;
}

Expand Down
Loading

0 comments on commit 82b203e

Please sign in to comment.