From e7373a8c8f267b487d9af3cb0de9da46e9e00c93 Mon Sep 17 00:00:00 2001 From: Simone Campanoni Date: Fri, 22 Nov 2024 12:33:15 -0800 Subject: [PATCH] Extended the select lifting pass to handle select nodes with no cases. PiperOrigin-RevId: 699255258 --- xls/passes/select_lifting_pass.cc | 22 ++++++++++++++++---- xls/passes/select_lifting_pass_test.cc | 28 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/xls/passes/select_lifting_pass.cc b/xls/passes/select_lifting_pass.cc index 458528662f..50fa8fc17b 100644 --- a/xls/passes/select_lifting_pass.cc +++ b/xls/passes/select_lifting_pass.cc @@ -212,9 +212,16 @@ absl::StatusOr> CanLiftSelect( // Only "select" nodes with specific properties can be optimized by this // transformation. // - // Shared property that must hold for all cases: - // Only "select" nodes with the same node type for all its inputs can - // be optimized. + // Shared properties that must hold for all cases: + // + // Property A: + // Only "select" nodes with at least one input case can be optimized. + // + // Property B: + // Only "select" nodes with the same node type for all its inputs can + // be optimized. + // + // // // There are more properties that must hold for the transformation to be // applicable. Such properties are specific to the node type of the inputs of @@ -229,7 +236,14 @@ absl::StatusOr> CanLiftSelect( absl::Span select_cases = GetCases(select_to_optimize); std::optional default_value = GetDefaultValue(select_to_optimize); - // Check the shared property + // Check the shared property A + if (select_cases.empty()) { + VLOG(3) << " The transformation is not applicable: the select does not " + "have input cases"; + return std::nullopt; + } + + // Check the shared property B std::optional shared_input_op = SharedOperation(select_cases, default_value); if (!shared_input_op) { diff --git a/xls/passes/select_lifting_pass_test.cc b/xls/passes/select_lifting_pass_test.cc index 6b89a8cb0f..f84f959f5f 100644 --- a/xls/passes/select_lifting_pass_test.cc +++ b/xls/passes/select_lifting_pass_test.cc @@ -14,6 +14,8 @@ #include "xls/passes/select_lifting_pass.h" +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/log/log.h" @@ -214,6 +216,32 @@ TEST_F(SelectLiftingPassTest, LiftSingleSelectWithIndicesOfDifferentBitwidth) { EXPECT_EQ(f->node_count(), 9); } +TEST_F(SelectLiftingPassTest, LiftSingleSelectWithNoCases) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + + // Fetch the types + Type* u32_type = p->GetBitsType(32); + + // Create the parameters of the IR function + BValue a = fb.Param("array", p->GetArrayType(16, u32_type)); + BValue c = fb.Param("condition", u32_type); + BValue i = fb.Param("first_index", u32_type); + + // Create the body of the IR function + BValue condition_constant = fb.Literal(UBits(10, 32)); + BValue selector = fb.AddCompareOp(Op::kUGt, c, condition_constant); + BValue array_index_i = fb.ArrayIndex(a, {i}); + std::vector cases; + BValue select_node = fb.Select(selector, cases, array_index_i); + + // Build the function + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(select_node)); + + // Set the expected outputs + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + } // namespace } // namespace xls