Skip to content

Commit

Permalink
Extend conditional specialization pass to forward through OneHotSelec…
Browse files Browse the repository at this point in the history
…ts & logical ops too

We already supported forwarding cases through case-at-a-time selects where the context guaranteed which case would be selected.

We now extend this logic to also forward inputs through OneHotSelects & ANDs/ORs/XORs where the context guarantees that one input passes identically through the operation. Applied to AND/OR/XOR, this should give us a chance to see through various forms of operations that amount to selects or gates in context.

PiperOrigin-RevId: 648509568
  • Loading branch information
ericastor authored and copybara-github committed Jul 1, 2024
1 parent 329dcda commit 8a9f598
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 22 deletions.
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1545,13 +1545,15 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//xls/common:module_initializer",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:op",
"//xls/ir:ternary",
"//xls/ir:value",
"//xls/ir:value_utils",
],
)

Expand Down
135 changes: 126 additions & 9 deletions xls/passes/conditional_specialization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <sstream>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/btree_set.h"
Expand All @@ -34,6 +35,7 @@
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xls/common/module_initializer.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
Expand All @@ -44,6 +46,7 @@
#include "xls/ir/ternary.h"
#include "xls/ir/topo_sort.h"
#include "xls/ir/value.h"
#include "xls/ir/value_utils.h"
#include "xls/passes/bdd_function.h"
#include "xls/passes/bdd_query_engine.h"
#include "xls/passes/optimization_pass.h"
Expand Down Expand Up @@ -215,6 +218,18 @@ class ConditionMap {
edge_conditions_.insert({key, std::move(condition_set)});
}

// Returns the conditions which can be assumed along the edge to `node` from
// its operand index `operand_no`.
const ConditionSet& GetEdgeConditionSet(Node* node, int64_t operand_no) {
std::pair<Node*, int64_t> key = {node, operand_no};
if (!edge_conditions_.contains(key)) {
// There are no special conditions for this edge. Return the conditions on
// the target of the edge which necessarily hold on the edge as well.
return node_conditions_.at(node);
}
return edge_conditions_.at(key);
}

// Returns the conditions which can be assumed along the edge(s) from node to
// user. This interface is asymmetric to SetEdgeCondition (which takes a node
// and operand number) to make it easier to use because at a particular node
Expand All @@ -238,13 +253,7 @@ class ConditionMap {
}
CHECK(operand_index.has_value()) << absl::StreamFormat(
"%s is not a user of %s", user->GetName(), node->GetName());
std::pair<Node*, int64_t> key = {user, operand_index.value()};
if (!edge_conditions_.contains(key)) {
// There are no special conditions for this edge. Return the conditions on
// the target of the edge which necessarily hold on the edge as well.
return node_conditions_.at(user);
}
return edge_conditions_.at(key);
return GetEdgeConditionSet(user, *operand_index);
}

std::string ToString() const {
Expand Down Expand Up @@ -368,6 +377,27 @@ std::optional<Node*> GetSelectedCase(PrioritySelect* select,
return select->default_value();
}

struct ZeroValue : std::monostate {};
std::optional<std::variant<Node*, ZeroValue>> GetSelectedCase(
OneHotSelect* ohs, const TernaryVector& selector_value) {
if (!ternary_ops::IsFullyKnown(selector_value)) {
// We can't be sure which case is selected.
return std::nullopt;
}
Bits selector_bits = ternary_ops::ToKnownBitsValues(selector_value);
if (selector_bits.PopCount() > 1) {
// We aren't selecting just one state.
return std::nullopt;
}
for (int64_t i = 0; i < selector_value.size(); ++i) {
if (selector_bits.Get(i)) {
return ohs->get_case(i);
}
}
// All bits of the selector are zero.
return ZeroValue{};
}

} // namespace

absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
Expand Down Expand Up @@ -643,10 +673,15 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
// It may be possible to bypass multiple selects so walk the edge up the
// graph as far as possible. For example, in the diagram above `b` may
// also be a select with a selector whose value is implied by `s`.
if (operand->Is<Select>() || operand->Is<PrioritySelect>()) {
//
// This also applies to ANDs, ORs, and XORs, if the condition set implies
// that all but one operand is the identity for the operation.
if (operand->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel, Op::kAnd,
Op::kOr, Op::kXor})) {
std::optional<Node*> replacement;
Node* src = operand;
while (src->Is<Select>() || src->Is<PrioritySelect>()) {
while (src->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel, Op::kAnd,
Op::kOr, Op::kXor})) {
if (src->Is<Select>()) {
Select* select = src->As<Select>();
if (select->selector()->Is<Literal>()) {
Expand Down Expand Up @@ -690,6 +725,88 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
xls::ToString(*implied_selector));
src = *implied_case;
replacement = src;
} else if (src->Is<OneHotSelect>()) {
XLS_RET_CHECK(src->Is<OneHotSelect>());
OneHotSelect* ohs = src->As<OneHotSelect>();
if (ohs->selector()->Is<Literal>()) {
break;
}
std::optional<TernaryVector> implied_selector =
ImpliedNodeTernary(edge_set, ohs->selector(), query_engine);
if (!implied_selector.has_value()) {
break;
}
for (int64_t case_no = 0; case_no < ohs->cases().size();
++case_no) {
if (implied_selector.value()[case_no] ==
TernaryValue::kKnownZero) {
continue;
}

// This case could be selected - but if it's definitely zero when
// selected, then we can ignore it.
std::optional<Bits> implied_case =
ImpliedNodeValue(condition_map.GetEdgeConditionSet(
ohs, /*operand_no=*/case_no + 1),
ohs->cases()[case_no], query_engine);
if (implied_case.has_value() && implied_case->IsZero()) {
implied_selector.value()[case_no] = TernaryValue::kKnownZero;
}
}
std::optional<std::variant<Node*, ZeroValue>> implied_case =
GetSelectedCase(ohs, *implied_selector);
if (!implied_case.has_value()) {
break;
}
VLOG(3) << absl::StreamFormat(
"Conditions for edge (%s, %s) imply selector %s of select %s "
"has value %s",
operand->GetName(), node->GetName(), ohs->selector()->GetName(),
ohs->GetName(), xls::ToString(*implied_selector));
if (std::holds_alternative<Node*>(*implied_case)) {
src = std::get<Node*>(*implied_case);
} else {
XLS_RET_CHECK(std::holds_alternative<ZeroValue>(*implied_case));
XLS_ASSIGN_OR_RETURN(
src,
f->MakeNode<Literal>(src->loc(), ZeroOfType(src->GetType())));
}
replacement = src;
} else {
XLS_RET_CHECK(src->OpIn({Op::kAnd, Op::kOr, Op::kXor}));
auto is_identity = [&](const Bits& b) {
if (operand->op() == Op::kAnd) {
return b.IsAllOnes();
}
return b.IsZero();
};
NaryOp* bitwise_op = src->As<NaryOp>();
std::optional<Node*> nonidentity_operand = std::nullopt;
for (Node* potential_src : bitwise_op->operands()) {
XLS_RET_CHECK(potential_src->GetType()->IsBits());
std::optional<Bits> implied_src =
ImpliedNodeValue(edge_set, potential_src, query_engine);
if (implied_src.has_value() && is_identity(*implied_src)) {
continue;
}
if (nonidentity_operand.has_value()) {
// There's more than one potentially-non-zero operand; we're
// done, there's nothing to do.
nonidentity_operand = std::nullopt;
break;
}
nonidentity_operand = potential_src;
}
if (!nonidentity_operand.has_value()) {
break;
}
VLOG(3) << absl::StreamFormat(
"Conditions for edge (%s, %s) imply that bitwise operation "
"%s has only one non-identity operand: %s",
operand->GetName(), node->GetName(), bitwise_op->GetName(),
nonidentity_operand.value()->GetName());
src = *nonidentity_operand;
replacement = src;
}
}
if (replacement.has_value()) {
Expand Down
Loading

0 comments on commit 8a9f598

Please sign in to comment.