diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 46193e27d9..aee142ec31 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1728,6 +1728,7 @@ cc_library( "//xls/data_structures:leaf_type_tree", "//xls/ir", "//xls/ir:bits", + "//xls/ir:interval_ops", "//xls/ir:op", "//xls/ir:state_element", "//xls/ir:ternary", diff --git a/xls/passes/proc_state_narrowing_pass.cc b/xls/passes/proc_state_narrowing_pass.cc index f19cdf44eb..202483df9a 100644 --- a/xls/passes/proc_state_narrowing_pass.cc +++ b/xls/passes/proc_state_narrowing_pass.cc @@ -30,6 +30,7 @@ #include "xls/common/status/status_macros.h" #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/bits.h" +#include "xls/ir/interval_ops.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" @@ -49,20 +50,42 @@ namespace { // Struct which transforms a state element into a slice of its trailing bits. struct ProcStateNarrowTransform : public Proc::StateElementTransformer { public: - explicit ProcStateNarrowTransform(Bits known_leading) - : Proc::StateElementTransformer(), - known_leading_(std::move(known_leading)) {} + explicit ProcStateNarrowTransform(int64_t known_leading) + : Proc::StateElementTransformer(), known_leading_(known_leading) {} + + int64_t known_leading() const { return known_leading_; } + absl::StatusOr TransformNextValue(Proc* proc, + StateRead* new_state_read, + Next* old_next) final { + XLS_RET_CHECK_EQ( + new_state_read->GetType()->GetFlatBitCount() + known_leading_, + old_next->state_read()->GetType()->GetFlatBitCount()); + return proc->MakeNodeWithName( + old_next->loc(), old_next->value(), /*start=*/0, + /*width=*/new_state_read->GetType()->GetFlatBitCount(), + absl::StrFormat("unexpand_for_%s", old_next->GetName())); + } + + private: + int64_t known_leading_; +}; + +class ProcStateConcatNarrowTransform : public ProcStateNarrowTransform { + public: + explicit ProcStateConcatNarrowTransform(Bits leading_bits) + : ProcStateNarrowTransform(leading_bits.bit_count()), + leading_bits_(std::move(leading_bits)) {} absl::StatusOr TransformStateRead(Proc* proc, StateRead* new_state_read, StateRead* old_state_read) final { - XLS_RET_CHECK_EQ(new_state_read->GetType()->GetFlatBitCount() + - known_leading_.bit_count(), - old_state_read->GetType()->GetFlatBitCount()); + XLS_RET_CHECK_EQ( + new_state_read->GetType()->GetFlatBitCount() + known_leading(), + old_state_read->GetType()->GetFlatBitCount()); XLS_ASSIGN_OR_RETURN( Node * leading, proc->MakeNodeWithName( - old_state_read->loc(), Value(known_leading_), + old_state_read->loc(), Value(leading_bits_), absl::StrFormat("leading_bits_%s", old_state_read->state_element()->name()))); return proc->MakeNodeWithName( @@ -70,20 +93,9 @@ struct ProcStateNarrowTransform : public Proc::StateElementTransformer { absl::StrFormat("extended_%s", old_state_read->state_element()->name())); } - absl::StatusOr TransformNextValue(Proc* proc, - StateRead* new_state_read, - Next* old_next) final { - XLS_RET_CHECK_EQ(new_state_read->GetType()->GetFlatBitCount() + - known_leading_.bit_count(), - old_next->state_read()->GetType()->GetFlatBitCount()); - return proc->MakeNodeWithName( - old_next->loc(), old_next->value(), /*start=*/0, - /*width=*/new_state_read->GetType()->GetFlatBitCount(), - absl::StrFormat("unexpand_for_%s", old_next->GetName())); - } private: - Bits known_leading_; + Bits leading_bits_; }; absl::Status RemoveLeadingBits(StateRead* state_read, @@ -91,7 +103,34 @@ absl::Status RemoveLeadingBits(StateRead* state_read, const Bits& known_leading) { Value new_init_value(orig_init_value.bits().Slice( 0, orig_init_value.bits().bit_count() - known_leading.bit_count())); - ProcStateNarrowTransform transform(known_leading); + ProcStateConcatNarrowTransform transform(known_leading); + return state_read->function_base() + ->AsProcOrDie() + ->TransformStateElement(state_read, new_init_value, transform) + .status(); +} + +class ProcStateSignExtendNarrowTransform : public ProcStateNarrowTransform { + public: + explicit ProcStateSignExtendNarrowTransform(int64_t known_leading) + : ProcStateNarrowTransform(known_leading) {} + + absl::StatusOr TransformStateRead(Proc* proc, + StateRead* new_state_read, + StateRead* old_state_read) final { + return proc->MakeNodeWithName( + new_state_read->loc(), new_state_read, old_state_read->BitCountOrDie(), + Op::kSignExt, + absl::StrFormat("extended_%s", + old_state_read->state_element()->name())); + } +}; + +absl::Status RemoveSignBits(StateRead* state_read, const Value& orig_init_value, + int64_t real_size) { + Value new_init_value(orig_init_value.bits().Slice(0, real_size)); + ProcStateSignExtendNarrowTransform transform(state_read->BitCountOrDie() - + real_size); return state_read->function_base() ->AsProcOrDie() ->TransformStateElement(state_read, new_init_value, transform) @@ -131,24 +170,40 @@ absl::StatusOr ProcStateNarrowingPass::RunOnProcInternal( } int64_t known_leading = ternary_ops::ToKnownBits(ternary->Get({})).CountLeadingOnes(); - if (known_leading == 0) { + if (known_leading != 0) { // TODO(allight): We could also narrow internal/trailing bits. - VLOG(2) << "Unable to narrow " << state_element->name() - << " due to finding that no leading bits are known."; + TernarySpan known_leading_tern = + absl::MakeConstSpan(ternary->Get({})).last(known_leading); + XLS_RET_CHECK(ternary_ops::IsFullyKnown(known_leading_tern)); + Value orig_init_value = state_element->initial_value(); + VLOG(2) << "Narrowing state_read " << state_read << " from " + << state_read->BitCountOrDie() << " to " + << (state_read->BitCountOrDie() - known_leading) + << " bits (removing " << known_leading + << " bits) using known-leading bits."; + XLS_RETURN_IF_ERROR(RemoveLeadingBits( + state_read, orig_init_value, + ternary_ops::ToKnownBitsValues(known_leading_tern))); + made_changes = true; continue; } - TernarySpan known_leading_tern = - absl::MakeConstSpan(ternary->Get({})).last(known_leading); - XLS_RET_CHECK(ternary_ops::IsFullyKnown(known_leading_tern)); - Value orig_init_value = state_element->initial_value(); - VLOG(2) << "Narrowing state element " << state_element->name() << " from " - << state_element->type()->GetFlatBitCount() << " to " - << (state_element->type()->GetFlatBitCount() - known_leading) - << " bits (removing " << known_leading << " bits)."; - XLS_RETURN_IF_ERROR( - RemoveLeadingBits(state_read, orig_init_value, - ternary_ops::ToKnownBitsValues(known_leading_tern))); - made_changes = true; + int64_t signed_bits = interval_ops::MinimumSignedBitCount( + qe.GetIntervals(state_read).Get({})); + int64_t signed_bits_removed = state_read->BitCountOrDie() - signed_bits; + if (signed_bits_removed != 0) { + Value orig_init_value = state_element->initial_value(); + VLOG(2) << "Narrowing state_read " << state_read << " from " + << state_read->BitCountOrDie() << " to " + << (state_read->BitCountOrDie() - known_leading) + << " bits (removing " << known_leading + << " bits) using sign-extend."; + XLS_RETURN_IF_ERROR(RemoveSignBits(state_read, orig_init_value, + /*real_size=*/signed_bits)); + made_changes = true; + } + VLOG(2) << "Unable to narrow " << state_read + << " due to finding that no leading bits are known and signed " + "interval is not narrowable."; } return made_changes; diff --git a/xls/passes/proc_state_narrowing_pass_test.cc b/xls/passes/proc_state_narrowing_pass_test.cc index 631a080700..3217a04358 100644 --- a/xls/passes/proc_state_narrowing_pass_test.cc +++ b/xls/passes/proc_state_narrowing_pass_test.cc @@ -331,9 +331,6 @@ TEST_F(ProcStateNarrowingPassTest, StateExplorationWithPauses) { } TEST_F(ProcStateNarrowingPassTest, NegativeNumbersAreNotRemoved) { - // TODO(allight): Technically a valid transform would be to narrow this with a - // sign-extend. We don't have the ability to see this transformation in our - // analysis at the moment however. auto p = CreatePackage(); XLS_ASSERT_OK_AND_ASSIGN( auto* chan, p->CreateStreamingChannel("test_chan", ChannelOps::kSendOnly, @@ -343,9 +340,10 @@ TEST_F(ProcStateNarrowingPassTest, NegativeNumbersAreNotRemoved) { pb.Send(chan, pb.Literal(Value::Token()), state); // State just counts up 1 to 7 then goes from -7 to 7 repeating // NB Limit is exactly 7 and comparison is LT so that however the transform is - // done the state fits in 3 bits. + // done the state fits in 4 bits. // NB This is a signed comparison so naieve contextual narrowing will see - // range as [[0, 7], [INT_MIN, -1]]. + // range as [[0, 7], [INT_MIN, -1]]. We need interval exploration to get the + // -7 lower bound. auto in_loop = pb.SLt(state, pb.Literal(UBits(7, 32))); pb.Next(state, pb.Add(state, pb.Literal(UBits(1, 32))), in_loop); // If we aren't looping the value goes to -8 @@ -353,13 +351,14 @@ TEST_F(ProcStateNarrowingPassTest, NegativeNumbersAreNotRemoved) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); - solvers::z3::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/16, + solvers::z3::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/32, /*include_state=*/false); ScopedRecordIr sri(p.get()); - EXPECT_THAT(RunPass(proc), IsOkAndHolds(false)); + EXPECT_THAT(RunPass(proc), IsOkAndHolds(true)); + EXPECT_THAT(RunProcStateCleanup(proc), IsOkAndHolds(true)); EXPECT_THAT(proc->StateElements(), UnorderedElementsAre(m::StateElement( - "the_state", p->GetBitsType(32)))); + "the_state", p->GetBitsType(4)))); } TEST_F(ProcStateNarrowingPassTest, StateExplorationWithPartialBackProp) {