Skip to content

Commit

Permalink
Narrow proc-state elements using sign-ext
Browse files Browse the repository at this point in the history
Proc state elements which start at -1 or other small negative numbers can only be narrowed if we keep track of the signed value. This adds support for doing that.

PiperOrigin-RevId: 698875794
  • Loading branch information
allight authored and copybara-github committed Nov 21, 2024
1 parent 6a8587b commit 2bdaded
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 43 deletions.
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,7 @@ cc_library(
"//xls/data_structures:leaf_type_tree",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:interval_ops",
"//xls/ir:op",
"//xls/ir:state_element",
"//xls/ir:ternary",
Expand Down
125 changes: 90 additions & 35 deletions xls/passes/proc_state_narrowing_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/ir/bits.h"
#include "xls/ir/interval_ops.h"
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
Expand All @@ -49,49 +50,87 @@ namespace {
// Struct which transforms a state element into a slice of its trailing bits.
struct ProcStateNarrowTransform : public Proc::StateElementTransformer {
public:
explicit ProcStateNarrowTransform(Bits known_leading)
: Proc::StateElementTransformer(),
known_leading_(std::move(known_leading)) {}
explicit ProcStateNarrowTransform(int64_t known_leading)
: Proc::StateElementTransformer(), known_leading_(known_leading) {}

int64_t known_leading() const { return known_leading_; }
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
StateRead* new_state_read,
Next* old_next) final {
XLS_RET_CHECK_EQ(
new_state_read->GetType()->GetFlatBitCount() + known_leading_,
old_next->state_read()->GetType()->GetFlatBitCount());
return proc->MakeNodeWithName<BitSlice>(
old_next->loc(), old_next->value(), /*start=*/0,
/*width=*/new_state_read->GetType()->GetFlatBitCount(),
absl::StrFormat("unexpand_for_%s", old_next->GetName()));
}

private:
int64_t known_leading_;
};

class ProcStateConcatNarrowTransform : public ProcStateNarrowTransform {
public:
explicit ProcStateConcatNarrowTransform(Bits leading_bits)
: ProcStateNarrowTransform(leading_bits.bit_count()),
leading_bits_(std::move(leading_bits)) {}

absl::StatusOr<Node*> TransformStateRead(Proc* proc,
StateRead* new_state_read,
StateRead* old_state_read) final {
XLS_RET_CHECK_EQ(new_state_read->GetType()->GetFlatBitCount() +
known_leading_.bit_count(),
old_state_read->GetType()->GetFlatBitCount());
XLS_RET_CHECK_EQ(
new_state_read->GetType()->GetFlatBitCount() + known_leading(),
old_state_read->GetType()->GetFlatBitCount());
XLS_ASSIGN_OR_RETURN(
Node * leading,
proc->MakeNodeWithName<Literal>(
old_state_read->loc(), Value(known_leading_),
old_state_read->loc(), Value(leading_bits_),
absl::StrFormat("leading_bits_%s",
old_state_read->state_element()->name())));
return proc->MakeNodeWithName<Concat>(
new_state_read->loc(), std::array<Node*, 2>{leading, new_state_read},
absl::StrFormat("extended_%s",
old_state_read->state_element()->name()));
}
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
StateRead* new_state_read,
Next* old_next) final {
XLS_RET_CHECK_EQ(new_state_read->GetType()->GetFlatBitCount() +
known_leading_.bit_count(),
old_next->state_read()->GetType()->GetFlatBitCount());
return proc->MakeNodeWithName<BitSlice>(
old_next->loc(), old_next->value(), /*start=*/0,
/*width=*/new_state_read->GetType()->GetFlatBitCount(),
absl::StrFormat("unexpand_for_%s", old_next->GetName()));
}

private:
Bits known_leading_;
Bits leading_bits_;
};

absl::Status RemoveLeadingBits(StateRead* state_read,
const Value& orig_init_value,
const Bits& known_leading) {
Value new_init_value(orig_init_value.bits().Slice(
0, orig_init_value.bits().bit_count() - known_leading.bit_count()));
ProcStateNarrowTransform transform(known_leading);
ProcStateConcatNarrowTransform transform(known_leading);
return state_read->function_base()
->AsProcOrDie()
->TransformStateElement(state_read, new_init_value, transform)
.status();
}

class ProcStateSignExtendNarrowTransform : public ProcStateNarrowTransform {
public:
explicit ProcStateSignExtendNarrowTransform(int64_t known_leading)
: ProcStateNarrowTransform(known_leading) {}

absl::StatusOr<Node*> TransformStateRead(Proc* proc,
StateRead* new_state_read,
StateRead* old_state_read) final {
return proc->MakeNodeWithName<ExtendOp>(
new_state_read->loc(), new_state_read, old_state_read->BitCountOrDie(),
Op::kSignExt,
absl::StrFormat("extended_%s",
old_state_read->state_element()->name()));
}
};

absl::Status RemoveSignBits(StateRead* state_read, const Value& orig_init_value,
int64_t real_size) {
Value new_init_value(orig_init_value.bits().Slice(0, real_size));
ProcStateSignExtendNarrowTransform transform(state_read->BitCountOrDie() -
real_size);
return state_read->function_base()
->AsProcOrDie()
->TransformStateElement(state_read, new_init_value, transform)
Expand Down Expand Up @@ -131,24 +170,40 @@ absl::StatusOr<bool> ProcStateNarrowingPass::RunOnProcInternal(
}
int64_t known_leading =
ternary_ops::ToKnownBits(ternary->Get({})).CountLeadingOnes();
if (known_leading == 0) {
if (known_leading != 0) {
// TODO(allight): We could also narrow internal/trailing bits.
VLOG(2) << "Unable to narrow " << state_element->name()
<< " due to finding that no leading bits are known.";
TernarySpan known_leading_tern =
absl::MakeConstSpan(ternary->Get({})).last(known_leading);
XLS_RET_CHECK(ternary_ops::IsFullyKnown(known_leading_tern));
Value orig_init_value = state_element->initial_value();
VLOG(2) << "Narrowing state_read " << state_read << " from "
<< state_read->BitCountOrDie() << " to "
<< (state_read->BitCountOrDie() - known_leading)
<< " bits (removing " << known_leading
<< " bits) using known-leading bits.";
XLS_RETURN_IF_ERROR(RemoveLeadingBits(
state_read, orig_init_value,
ternary_ops::ToKnownBitsValues(known_leading_tern)));
made_changes = true;
continue;
}
TernarySpan known_leading_tern =
absl::MakeConstSpan(ternary->Get({})).last(known_leading);
XLS_RET_CHECK(ternary_ops::IsFullyKnown(known_leading_tern));
Value orig_init_value = state_element->initial_value();
VLOG(2) << "Narrowing state element " << state_element->name() << " from "
<< state_element->type()->GetFlatBitCount() << " to "
<< (state_element->type()->GetFlatBitCount() - known_leading)
<< " bits (removing " << known_leading << " bits).";
XLS_RETURN_IF_ERROR(
RemoveLeadingBits(state_read, orig_init_value,
ternary_ops::ToKnownBitsValues(known_leading_tern)));
made_changes = true;
int64_t signed_bits = interval_ops::MinimumSignedBitCount(
qe.GetIntervals(state_read).Get({}));
int64_t signed_bits_removed = state_read->BitCountOrDie() - signed_bits;
if (signed_bits_removed != 0) {
Value orig_init_value = state_element->initial_value();
VLOG(2) << "Narrowing state_read " << state_read << " from "
<< state_read->BitCountOrDie() << " to "
<< (state_read->BitCountOrDie() - known_leading)
<< " bits (removing " << known_leading
<< " bits) using sign-extend.";
XLS_RETURN_IF_ERROR(RemoveSignBits(state_read, orig_init_value,
/*real_size=*/signed_bits));
made_changes = true;
}
VLOG(2) << "Unable to narrow " << state_read
<< " due to finding that no leading bits are known and signed "
"interval is not narrowable.";
}

return made_changes;
Expand Down
15 changes: 7 additions & 8 deletions xls/passes/proc_state_narrowing_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,6 @@ TEST_F(ProcStateNarrowingPassTest, StateExplorationWithPauses) {
}

TEST_F(ProcStateNarrowingPassTest, NegativeNumbersAreNotRemoved) {
// TODO(allight): Technically a valid transform would be to narrow this with a
// sign-extend. We don't have the ability to see this transformation in our
// analysis at the moment however.
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(
auto* chan, p->CreateStreamingChannel("test_chan", ChannelOps::kSendOnly,
Expand All @@ -343,23 +340,25 @@ TEST_F(ProcStateNarrowingPassTest, NegativeNumbersAreNotRemoved) {
pb.Send(chan, pb.Literal(Value::Token()), state);
// State just counts up 1 to 7 then goes from -7 to 7 repeating
// NB Limit is exactly 7 and comparison is LT so that however the transform is
// done the state fits in 3 bits.
// done the state fits in 4 bits.
// NB This is a signed comparison so naieve contextual narrowing will see
// range as [[0, 7], [INT_MIN, -1]].
// range as [[0, 7], [INT_MIN, -1]]. We need interval exploration to get the
// -7 lower bound.
auto in_loop = pb.SLt(state, pb.Literal(UBits(7, 32)));
pb.Next(state, pb.Add(state, pb.Literal(UBits(1, 32))), in_loop);
// If we aren't looping the value goes to -8
pb.Next(state, pb.Literal(SBits(-7, 32)), pb.Not(in_loop));

XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

solvers::z3::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/16,
solvers::z3::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/32,
/*include_state=*/false);
ScopedRecordIr sri(p.get());
EXPECT_THAT(RunPass(proc), IsOkAndHolds(false));
EXPECT_THAT(RunPass(proc), IsOkAndHolds(true));
EXPECT_THAT(RunProcStateCleanup(proc), IsOkAndHolds(true));

EXPECT_THAT(proc->StateElements(), UnorderedElementsAre(m::StateElement(
"the_state", p->GetBitsType(32))));
"the_state", p->GetBitsType(4))));
}

TEST_F(ProcStateNarrowingPassTest, StateExplorationWithPartialBackProp) {
Expand Down

0 comments on commit 2bdaded

Please sign in to comment.