Skip to content

Commit

Permalink
Narrow select elements using range analysis
Browse files Browse the repository at this point in the history
Sometimes we can narrow select elements more effectively using range analysis instead of the pure ternary analysis we were using. This allows us to recognize when a select is near signed-zero and perform signed narrowing.

Since range analysis is relatively expensive this is only performed once in the pipeline.

PiperOrigin-RevId: 699306623
  • Loading branch information
allight authored and copybara-github committed Nov 22, 2024
1 parent f47eb95 commit 74dad6c
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 115 deletions.
2 changes: 2 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ cc_library(
":optimization_pass_registry",
":pass_base",
":query_engine",
":range_query_engine",
":stateless_query_engine",
":ternary_query_engine",
":union_query_engine",
Expand All @@ -961,6 +962,7 @@ cc_library(
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
"//xls/ir:interval_ops",
"//xls/ir:node_util",
"//xls/ir:op",
"//xls/ir:ternary",
Expand Down
6 changes: 6 additions & 0 deletions xls/passes/optimization_pass_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ class PostInliningOptPassGroup : public OptimizationCompoundPass {

Add<CapOptLevel<3, FixedPointSimplificationPass>>();

// Range based select simplification is heavier so we only do it once.
Add<SelectRangeSimplificationPass>();
Add<DeadCodeEliminationPass>();

Add<CapOptLevel<3, FixedPointSimplificationPass>>();

Add<CapOptLevel<3, BddSimplificationPass>>();
Add<DeadCodeEliminationPass>();
Add<BddCsePass>();
Expand Down
26 changes: 26 additions & 0 deletions xls/passes/range_query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2546,6 +2546,32 @@ TEST_F(RangeQueryEngineTest, MultipleRangeGivenValue) {
BitsLTT(ltxyz.node(), {Interval::Precise(UBits(1, 1))}));
}

TEST_F(RangeQueryEngineTest, SmallNegativeSelectChain) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue inp = fb.Param("inp", u1);
// exp is [[-1], [0]]
BValue exp = fb.SignExtend(inp, 32);
BValue s1 = fb.Select(fb.Param("a", u1), exp, fb.Literal(UBits(0, 32)));
BValue s2 = fb.Select(fb.Param("b", u1), s1, fb.Literal(UBits(1, 32)));
BValue s3 = fb.Select(fb.Param("c", u1), s2, fb.Literal(UBits(2, 32)));
BValue s4 = fb.Select(fb.Param("d", u1), s3, fb.Literal(UBits(3, 32)));
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());
RangeQueryEngine engine;
XLS_ASSERT_OK(engine.Populate(f).status());

EXPECT_EQ(engine.GetIntervalSetTree(exp.node()),
BitsLTT(exp.node(), {Interval::Precise(SBits(-1, 32)),
Interval::Precise(UBits(0, 32))}));
EXPECT_EQ(engine.GetIntervalSetTree(s4.node()),
BitsLTT(s4.node(), {Interval::Precise(SBits(-1, 32)),
Interval(UBits(0, 32), UBits(3, 32))}));
EXPECT_EQ(interval_ops::MinimumSignedBitCount(
engine.GetIntervalSetTree(s4.node()).Get({})),
3);
}

TEST_F(RangeQueryEngineTest, MaxMinUnsignedValue) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Expand Down
135 changes: 107 additions & 28 deletions xls/passes/select_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
#include "xls/passes/select_simplification_pass.h"

#include <algorithm>
#include <array>
#include <cstdint>
#include <deque>
#include <functional>
#include <ios>
#include <iterator>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -50,6 +52,7 @@
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/function_base.h"
#include "xls/ir/interval_ops.h"
#include "xls/ir/lsb_or_msb.h"
#include "xls/ir/node.h"
#include "xls/ir/node_util.h"
Expand All @@ -65,6 +68,7 @@
#include "xls/passes/optimization_pass_registry.h"
#include "xls/passes/pass_base.h"
#include "xls/passes/query_engine.h"
#include "xls/passes/range_query_engine.h"
#include "xls/passes/stateless_query_engine.h"
#include "xls/passes/ternary_query_engine.h"
#include "xls/passes/union_query_engine.h"
Expand Down Expand Up @@ -217,7 +221,7 @@ absl::Status SqueezeSelect(SelectT* select, SqueezeF squeeze,
}
VLOG(2) << absl::StreamFormat("Squeezed select %s into %d bits",
select->ToString(),
squeezed_sel->BitCountOrDie());
squeezed_sel->GetType()->GetFlatBitCount());
std::optional<std::string> orig_name =
sel_node->HasAssignedName() ? std::make_optional(sel_node->GetName())
: std::nullopt;
Expand Down Expand Up @@ -263,8 +267,22 @@ template <typename SelectT>
std::is_same_v<SelectT, PrioritySelect>)
absl::StatusOr<bool> TrySqueezeSelect(SelectT* sel,
const QueryEngine& query_engine,
const BitProvenanceAnalysis& provenance) {
const BitProvenanceAnalysis& provenance,
bool with_range_analysis) {
Node* node = sel;
// If we have range analysis check for signed reduction. We want to see if the
// values are all around signed-zero
int64_t min_signed_size;
if (with_range_analysis) {
min_signed_size = interval_ops::MinimumSignedBitCount(
query_engine.GetIntervals(node).Get({}));
} else {
// If we don't have range analysis this will at best be the same as
// ternary-based constant leading-trailing bits analysis. Therefore, don't
// bother to do the expensive interval calculation.
min_signed_size = node->BitCountOrDie();
}
// Figure out common known constant MSB & LSB bits
auto is_squeezable_mux = [&](Bits* msb, Bits* lsb) {
std::optional<SharedLeafTypeTree<TernaryVector>> ternary =
query_engine.GetTernary(node);
Expand Down Expand Up @@ -309,27 +327,83 @@ absl::StatusOr<bool> TrySqueezeSelect(SelectT* sel,
return v;
});
if ((!const_squeezable || const_mux_width == node->BitCountOrDie()) &&
prov_width == node->BitCountOrDie()) {
prov_width == node->BitCountOrDie() &&
min_signed_size == node->BitCountOrDie()) {
// Can't narrow
return false;
}
// Technically we could do both (since these might narrow different bits) but
// Technically we could do all 3 (since these might narrow different bits) but
// (1) that makes whichever goes second more complicated and (2) this pass is
// run in fixed-point anyway so all it will do is save a pass-group run.
// run in fixed-point (for non-range analysis) anyway so all it will do is
// save a pass-group run. Hopefully one range based run will suffice since it
// is too slow to run in a fixed-point.
//
// const-mux is prioritized mostly to avoid having to rewrite all the tests
// for it.
if (const_mux_width <= prov_width) {
XLS_RETURN_IF_ERROR(SqueezeSelect(
sel,
SqueezeConstantBits{.const_msb = const_msb, .const_lsb = const_lsb},
UnsqueezeConstantBits{.const_msb = const_msb, .const_lsb = const_lsb},
MakeSelect<SelectT>));
} else {
XLS_RETURN_IF_ERROR(SqueezeSelect(
sel, RemoveUnchangedBits{.source = prov_bits, .changed_bit_src = sel},
RestoreUnchangedBits{.source = prov_bits, .changed_bit_src = sel},
MakeSelect<SelectT>));
// When multiple choices have the same net effect, const-mux is prioritized
// mostly to avoid having to rewrite all the tests for it. Next provenance is
// used finally we use sign-reduction.
enum class SqueezeType : int8_t {
kConstants = 1,
kBitProvenance = 2,
kSignExtend = 3
};
struct SqueezeOption {
int64_t bit_count;
SqueezeType type;
};
auto options = std::to_array<SqueezeOption>(
{{.bit_count = const_mux_width, .type = SqueezeType::kConstants},
{.bit_count = prov_width, .type = SqueezeType::kBitProvenance},
{.bit_count = min_signed_size, .type = SqueezeType::kSignExtend}});
SqueezeType option =
absl::c_min_element(options, [](const SqueezeOption& l,
const SqueezeOption& r) {
return l.bit_count < r.bit_count ||
(l.bit_count == r.bit_count &&
static_cast<int8_t>(l.type) < static_cast<int8_t>(r.type));
})->type;

VLOG(3) << "Options of squeeze for " << sel
<< " are: mux: " << const_mux_width << ", provenance: " << prov_width
<< " sign_ext: " << min_signed_size;
switch (option) {
case SqueezeType::kConstants:
VLOG(2) << "Squeezing select using constants : " << sel << " to "
<< const_mux_width << " bits";
XLS_RETURN_IF_ERROR(SqueezeSelect(
sel,
SqueezeConstantBits{.const_msb = const_msb, .const_lsb = const_lsb},
UnsqueezeConstantBits{.const_msb = const_msb, .const_lsb = const_lsb},
MakeSelect<SelectT>));
break;
case SqueezeType::kBitProvenance:
VLOG(2) << "Squeezing select using bit-prov: " << sel << " to "
<< prov_width << " bits";
XLS_RETURN_IF_ERROR(SqueezeSelect(
sel, RemoveUnchangedBits{.source = prov_bits, .changed_bit_src = sel},
RestoreUnchangedBits{.source = prov_bits, .changed_bit_src = sel},
MakeSelect<SelectT>));
break;
case SqueezeType::kSignExtend:
VLOG(2) << "Squeezing select using sign-ext: " << sel << " to "
<< min_signed_size << " bits";
XLS_RETURN_IF_ERROR(SqueezeSelect(
sel,
[&](Node* src) -> absl::StatusOr<Node*> {
return src->function_base()->MakeNodeWithName<BitSlice>(
src->loc(), src, /*start=*/0, /*width=*/min_signed_size,
src->HasAssignedName()
? absl::StrFormat("%s_squeezed", src->GetName())
: "");
},
[&](Node* src) -> absl::StatusOr<Node*> {
return src->function_base()->MakeNodeWithName<ExtendOp>(
src->loc(), src, node->BitCountOrDie(), Op::kSignExt,
src->HasAssignedName()
? absl::StrFormat("%s_unsqueezed", src->GetName())
: "");
},
MakeSelect<SelectT>));
break;
}
return true;
}
Expand Down Expand Up @@ -902,7 +976,7 @@ absl::StatusOr<bool> MaybeReorderSelect(Node* node,

absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
const BitProvenanceAnalysis& provenance,
int64_t opt_level) {
int64_t opt_level, bool range_analysis) {
// Select with a constant selector can be replaced with the respective
// case.
if (node->Is<Select>() &&
Expand Down Expand Up @@ -1754,17 +1828,17 @@ absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
if (node->GetType()->IsBits()) {
bool squeezed = false;
if (node->Is<Select>()) {
XLS_ASSIGN_OR_RETURN(
squeezed,
TrySqueezeSelect(node->As<Select>(), query_engine, provenance));
XLS_ASSIGN_OR_RETURN(squeezed,
TrySqueezeSelect(node->As<Select>(), query_engine,
provenance, range_analysis));
} else if (node->Is<OneHotSelect>()) {
XLS_ASSIGN_OR_RETURN(
squeezed, TrySqueezeSelect(node->As<OneHotSelect>(), query_engine,
provenance));
provenance, range_analysis));
} else if (node->Is<PrioritySelect>()) {
XLS_ASSIGN_OR_RETURN(
squeezed, TrySqueezeSelect(node->As<PrioritySelect>(), query_engine,
provenance));
provenance, range_analysis));
}
if (squeezed) {
return true;
Expand Down Expand Up @@ -2128,12 +2202,16 @@ absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,

} // namespace

absl::StatusOr<bool> SelectSimplificationPass::RunOnFunctionBaseInternal(
absl::StatusOr<bool> SelectSimplificationPassBase::RunOnFunctionBaseInternal(
FunctionBase* func, const OptimizationPassOptions& options,
PassResults* results) const {
std::vector<std::unique_ptr<QueryEngine>> query_engines;
query_engines.push_back(std::make_unique<StatelessQueryEngine>());
query_engines.push_back(std::make_unique<TernaryQueryEngine>());
if (range_analysis_) {
query_engines.push_back(std::make_unique<RangeQueryEngine>());
}
VLOG(2) << "Range analysis is " << std::boolalpha << range_analysis_;

UnionQueryEngine query_engine(std::move(query_engines));
XLS_RETURN_IF_ERROR(query_engine.Populate(func).status());
Expand All @@ -2143,9 +2221,9 @@ absl::StatusOr<bool> SelectSimplificationPass::RunOnFunctionBaseInternal(

bool changed = false;
for (Node* node : TopoSort(func)) {
XLS_ASSIGN_OR_RETURN(
bool node_changed,
SimplifyNode(node, query_engine, provenance, options.opt_level));
XLS_ASSIGN_OR_RETURN(bool node_changed,
SimplifyNode(node, query_engine, provenance,
options.opt_level, range_analysis_));
changed = changed || node_changed;
}

Expand Down Expand Up @@ -2189,5 +2267,6 @@ absl::StatusOr<bool> SelectSimplificationPass::RunOnFunctionBaseInternal(
}

REGISTER_OPT_PASS(SelectSimplificationPass);
REGISTER_OPT_PASS(SelectRangeSimplificationPass);

} // namespace xls
44 changes: 37 additions & 7 deletions xls/passes/select_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,49 @@

namespace xls {

// Pass which simplifies selects and one-hot-selects. Example optimizations
// include removing dead arms and eliminating selects with constant selectors.
class SelectSimplificationPass : public OptimizationFunctionBasePass {
// Base class which simplifies selects and one-hot-selects. Example
// optimizations include removing dead arms and eliminating selects with
// constant selectors.
class SelectSimplificationPassBase : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "select_simp";
explicit SelectSimplificationPass()
: OptimizationFunctionBasePass(kName, "Select Simplification") {}
~SelectSimplificationPass() override = default;
~SelectSimplificationPassBase() override = default;

protected:
explicit SelectSimplificationPassBase(std::string_view short_name,
std::string_view name,
bool with_range_analysis = false)
: OptimizationFunctionBasePass(short_name, name),
range_analysis_(with_range_analysis) {}

absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;

bool range_analysis_;
};

// Pass which simplifies selects and one-hot-selects. Example optimizations
// include removing dead arms and eliminating selects with constant selectors.
// Uses ternary analysis to determine possible values.
class SelectSimplificationPass : public SelectSimplificationPassBase {
public:
static constexpr std::string_view kName = "select_simp";
SelectSimplificationPass()
: SelectSimplificationPassBase(kName, "Select Simplification",
/*with_range_analysis=*/false) {}
~SelectSimplificationPass() override = default;
};

// Pass which simplifies selects and one-hot-selects. Example optimizations
// include removing dead arms and eliminating selects with constant selectors.
// Uses range analysis to determine possible values.
class SelectRangeSimplificationPass : public SelectSimplificationPassBase {
public:
static constexpr std::string_view kName = "select_range_simp";
SelectRangeSimplificationPass()
: SelectSimplificationPassBase(kName, "Select Range Simplification",
/*with_range_analysis=*/true) {}
~SelectRangeSimplificationPass() override = default;
};

} // namespace xls
Expand Down
Loading

0 comments on commit 74dad6c

Please sign in to comment.