Skip to content

Commit

Permalink
Refactor opt-level to make it an OptimizationPassOption instead of se…
Browse files Browse the repository at this point in the history
…t by pass construction.

This makes it simpler to create custom pipelines.

PiperOrigin-RevId: 698418593
  • Loading branch information
allight authored and copybara-github committed Nov 20, 2024
1 parent f2b5f2c commit 7dad99e
Show file tree
Hide file tree
Showing 40 changed files with 399 additions and 277 deletions.
5 changes: 3 additions & 2 deletions xls/contrib/xlscc/unit_tests/translator_proc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7979,10 +7979,11 @@ TEST_F(TranslatorProcTestWithoutFSMParam, OpDuplicationAcrossIO) {
// Don't do cse so that the duplication shows
// bdd_cse pass wants a delay estimator
std::unique_ptr<xls::OptimizationCompoundPass> pipeline =
xls::GetOptimizationPipelineGenerator(/*opt_level=*/3)
xls::GetOptimizationPipelineGenerator()
.GeneratePipeline("inlining dce")
.value();
xls::OptimizationPassOptions options;
xls::OptimizationPassOptions options =
xls::OptimizationPassOptions().WithOptLevel(3);
xls::PassResults results;

XLS_ASSERT_OK(pipeline->Run(package_.get(), options, &results).status());
Expand Down
2 changes: 1 addition & 1 deletion xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)

Expand Down Expand Up @@ -872,7 +873,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
Expand Down
4 changes: 2 additions & 2 deletions xls/passes/arith_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ absl::StatusOr<bool> ArithSimplificationPass::RunOnFunctionBaseInternal(
}
XLS_ASSIGN_OR_RETURN(
bool node_changed,
MatchArithPatterns(opt_level_, n, StatelessQueryEngine()));
MatchArithPatterns(options.opt_level, n, StatelessQueryEngine()));
if (node_changed) {
pass_changed = true;
}
Expand All @@ -1549,6 +1549,6 @@ absl::StatusOr<bool> ArithSimplificationPass::RunOnFunctionBaseInternal(
return changed;
}

REGISTER_OPT_PASS(ArithSimplificationPass, pass_config::kOptLevel);
REGISTER_OPT_PASS(ArithSimplificationPass);

} // namespace xls
7 changes: 2 additions & 5 deletions xls/passes/arith_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ namespace xls {
class ArithSimplificationPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "arith_simp";
explicit ArithSimplificationPass(int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "Arithmetic Simplifications"),
opt_level_(opt_level) {}
explicit ArithSimplificationPass()
: OptimizationFunctionBasePass(kName, "Arithmetic Simplifications") {}
~ArithSimplificationPass() override = default;

protected:
int64_t opt_level_;

absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
11 changes: 6 additions & 5 deletions xls/passes/arith_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class ArithSimplificationPassTest : public IrTestBase {

absl::StatusOr<bool> Run(Package* p) {
PassResults results;
return ArithSimplificationPass(kMaxOptLevel)
.Run(p, OptimizationPassOptions(), &results);
return ArithSimplificationPass().Run(
p, OptimizationPassOptions().WithOptLevel(kMaxOptLevel), &results);
}

void CheckUnsignedDivide(int n, int divisor);
Expand Down Expand Up @@ -1905,9 +1905,10 @@ void UmulFuzz(const Bits& multiplicand, const Bits& result, int64_t var_width,
ScopedVerifyEquivalence sve(f);
ScopedRecordIr sri(&p);
PassResults results;
ASSERT_THAT(ArithSimplificationPass(kMaxOptLevel)
.Run(&p, OptimizationPassOptions(), &results),
absl_testing::IsOk());
ASSERT_THAT(
ArithSimplificationPass().Run(
&p, OptimizationPassOptions().WithOptLevel(kMaxOptLevel), &results),
absl_testing::IsOk());
}

FUZZ_TEST(ArithSimplificationPassFuzzTest, UmulFuzz)
Expand Down
10 changes: 5 additions & 5 deletions xls/passes/array_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
// transforming selects of array to array of selects, etc.
XLS_ASSIGN_OR_RETURN(
bool flatten_changed,
FlattenSequentialUpdates(func, query_engine, opt_level_));
FlattenSequentialUpdates(func, query_engine, options.opt_level));
if (flatten_changed) {
changed = true;
XLS_RETURN_IF_ERROR(query_engine.Populate(func).status());
Expand Down Expand Up @@ -1684,12 +1684,12 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
SimplifyResult result = {.changed = false, .new_worklist_nodes = {}};
if (node->Is<ArrayIndex>()) {
ArrayIndex* array_index = node->As<ArrayIndex>();
XLS_ASSIGN_OR_RETURN(
result, SimplifyArrayIndex(array_index, query_engine, opt_level_));
XLS_ASSIGN_OR_RETURN(result, SimplifyArrayIndex(array_index, query_engine,
options.opt_level));
} else if (node->Is<ArrayUpdate>()) {
XLS_ASSIGN_OR_RETURN(
result, SimplifyArrayUpdate(node->As<ArrayUpdate>(), query_engine,
opt_level_));
options.opt_level));
} else if (node->Is<Array>()) {
XLS_ASSIGN_OR_RETURN(result,
SimplifyArray(node->As<Array>(), query_engine));
Expand All @@ -1709,6 +1709,6 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
return changed;
}

REGISTER_OPT_PASS(ArraySimplificationPass, pass_config::kOptLevel);
REGISTER_OPT_PASS(ArraySimplificationPass);

} // namespace xls
7 changes: 2 additions & 5 deletions xls/passes/array_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef XLS_PASSES_ARRAY_SIMPLIFICATION_PASS_H_
#define XLS_PASSES_ARRAY_SIMPLIFICATION_PASS_H_

#include <cstdint>
#include <string_view>

#include "absl/status/statusor.h"
Expand All @@ -30,12 +29,10 @@ namespace xls {
class ArraySimplificationPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "array_simp";
explicit ArraySimplificationPass(int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "Array Simplification"),
opt_level_(opt_level) {}
explicit ArraySimplificationPass()
: OptimizationFunctionBasePass(kName, "Array Simplification") {}

protected:
int64_t opt_level_;
absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
13 changes: 6 additions & 7 deletions xls/passes/bdd_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ absl::StatusOr<bool> BddSimplificationPass::RunOnFunctionBaseInternal(
bool modified = false;
for (Node* node : TopoSort(f)) {
XLS_ASSIGN_OR_RETURN(bool node_modified,
SimplifyNode(node, query_engine, opt_level_));
SimplifyNode(node, query_engine, options.opt_level));
modified |= node_modified;
}

Expand All @@ -463,12 +463,11 @@ absl::StatusOr<bool> BddSimplificationPass::RunOnFunctionBaseInternal(
}

XLS_REGISTER_MODULE_INITIALIZER(bdd_simp, {
CHECK_OK(RegisterOptimizationPass<BddSimplificationPass>(
"bdd_simp", pass_config::kOptLevel));
CHECK_OK(RegisterOptimizationPass<BddSimplificationPass>(
"bdd_simp(2)", pass_config::CappedOptLevel{2}));
CHECK_OK(RegisterOptimizationPass<BddSimplificationPass>(
"bdd_simp(3)", pass_config::CappedOptLevel{3}));
CHECK_OK(RegisterOptimizationPass<BddSimplificationPass>("bdd_simp"));
CHECK_OK((RegisterOptimizationPass<CapOptLevel<2, BddSimplificationPass>>(
"bdd_simp(2)")));
CHECK_OK((RegisterOptimizationPass<CapOptLevel<3, BddSimplificationPass>>(
"bdd_simp(3)")));
});

} // namespace xls
9 changes: 2 additions & 7 deletions xls/passes/bdd_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef XLS_PASSES_BDD_SIMPLIFICATION_PASS_H_
#define XLS_PASSES_BDD_SIMPLIFICATION_PASS_H_

#include <cstdint>
#include <string_view>

#include "absl/status/statusor.h"
Expand All @@ -32,19 +31,15 @@ namespace xls {
class BddSimplificationPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "bdd_simp";
explicit BddSimplificationPass(int64_t opt_level)
: OptimizationFunctionBasePass(kName, "BDD-based Simplification"),
opt_level_(opt_level) {}
explicit BddSimplificationPass()
: OptimizationFunctionBasePass(kName, "BDD-based Simplification") {}
~BddSimplificationPass() override = default;

protected:
// Run all registered passes in order of registration.
absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;

private:
int64_t opt_level_;
};

} // namespace xls
Expand Down
7 changes: 4 additions & 3 deletions xls/passes/bdd_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ class BddSimplificationPassTest : public IrTestBase {
protected:
absl::StatusOr<bool> Run(Function* f, int64_t opt_level = kMaxOptLevel) {
PassResults results;
XLS_ASSIGN_OR_RETURN(bool changed,
BddSimplificationPass(opt_level).RunOnFunctionBase(
f, OptimizationPassOptions(), &results));
XLS_ASSIGN_OR_RETURN(
bool changed,
BddSimplificationPass().RunOnFunctionBase(
f, OptimizationPassOptions().WithOptLevel(opt_level), &results));
return changed;
}
};
Expand Down
9 changes: 5 additions & 4 deletions xls/passes/bit_slice_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ absl::StatusOr<bool> BitSliceSimplificationPass::RunOnFunctionBaseInternal(
bool changed = false;

XLS_ASSIGN_OR_RETURN(std::unique_ptr<QueryEngine> query_engine,
GetQueryEngine(f, opt_level_));
GetQueryEngine(f, options.opt_level));

// Iterating through these operations in reverse topological order makes sure
// we don't need to re-populate the query engine between nodes.
Expand Down Expand Up @@ -1098,14 +1098,15 @@ absl::StatusOr<bool> BitSliceSimplificationPass::RunOnFunctionBaseInternal(
while (!worklist.empty()) {
BitSlice* bit_slice = worklist.front();
worklist.pop_front();
XLS_ASSIGN_OR_RETURN(bool node_changed,
SimplifyBitSlice(bit_slice, opt_level_, &worklist));
XLS_ASSIGN_OR_RETURN(
bool node_changed,
SimplifyBitSlice(bit_slice, options.opt_level, &worklist));
changed = changed || node_changed;
}

return changed;
}

REGISTER_OPT_PASS(BitSliceSimplificationPass, pass_config::kOptLevel);
REGISTER_OPT_PASS(BitSliceSimplificationPass);

} // namespace xls
6 changes: 2 additions & 4 deletions xls/passes/bit_slice_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ namespace xls {
class BitSliceSimplificationPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "bitslice_simp";
explicit BitSliceSimplificationPass(int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "Bit-slice simplification"),
opt_level_(opt_level) {}
explicit BitSliceSimplificationPass()
: OptimizationFunctionBasePass(kName, "Bit-slice simplification") {}
~BitSliceSimplificationPass() override = default;

protected:
int64_t opt_level_;
absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
6 changes: 3 additions & 3 deletions xls/passes/concat_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,13 @@ absl::StatusOr<bool> ConcatSimplificationPass::RunOnFunctionBaseInternal(
Concat* concat = worklist.front();
worklist.pop_front();
XLS_ASSIGN_OR_RETURN(bool node_changed,
SimplifyConcat(concat, opt_level_, &worklist));
SimplifyConcat(concat, options.opt_level, &worklist));
changed = changed || node_changed;
}

// For optimizations which optimize around concats, just iterate through once
// and find all opportunities.
if (NarrowingEnabled(opt_level_)) {
if (options.narrowing_enabled()) {
for (Node* node : TopoSort(f)) {
if (OpIsBitWise(node->op())) {
XLS_ASSIGN_OR_RETURN(bool bitwise_changed,
Expand All @@ -597,6 +597,6 @@ absl::StatusOr<bool> ConcatSimplificationPass::RunOnFunctionBaseInternal(
return changed;
}

REGISTER_OPT_PASS(ConcatSimplificationPass, pass_config::kOptLevel);
REGISTER_OPT_PASS(ConcatSimplificationPass);

} // namespace xls
6 changes: 2 additions & 4 deletions xls/passes/concat_simplification_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ namespace xls {
class ConcatSimplificationPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "concat_simp";
explicit ConcatSimplificationPass(int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "Concat simplification"),
opt_level_(opt_level) {}
explicit ConcatSimplificationPass()
: OptimizationFunctionBasePass(kName, "Concat simplification") {}
~ConcatSimplificationPass() override = default;

protected:
int64_t opt_level_;
absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
6 changes: 3 additions & 3 deletions xls/passes/lut_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ absl::StatusOr<bool> SimplifyNode(
absl::StatusOr<bool> LutConversionPass::RunOnFunctionBaseInternal(
FunctionBase* func, const OptimizationPassOptions& options,
PassResults* results) const {
if (!NarrowingEnabled(opt_level_)) {
if (!options.narrowing_enabled()) {
return false;
}

Expand All @@ -267,13 +267,13 @@ absl::StatusOr<bool> LutConversionPass::RunOnFunctionBaseInternal(
continue;
}
XLS_ASSIGN_OR_RETURN(bool node_changed,
SimplifyNode(node, query_engine, opt_level_,
SimplifyNode(node, query_engine, options.opt_level,
dataflow_dominator_analysis));
changed = changed || node_changed;
}
return changed;
}

REGISTER_OPT_PASS(LutConversionPass, pass_config::kOptLevel);
REGISTER_OPT_PASS(LutConversionPass);

} // namespace xls
6 changes: 2 additions & 4 deletions xls/passes/lut_conversion_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ namespace xls {
class LutConversionPass : public OptimizationFunctionBasePass {
public:
static constexpr std::string_view kName = "lut_conversion";
explicit LutConversionPass(int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "LUT Conversion"),
opt_level_(opt_level) {}
explicit LutConversionPass()
: OptimizationFunctionBasePass(kName, "LUT Conversion") {}
~LutConversionPass() override = default;

protected:
int64_t opt_level_;
absl::StatusOr<bool> RunOnFunctionBaseInternal(
FunctionBase* f, const OptimizationPassOptions& options,
PassResults* results) const override;
Expand Down
14 changes: 5 additions & 9 deletions xls/passes/narrowing_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,7 @@ absl::StatusOr<bool> NarrowingPass::RunOnFunctionBaseInternal(
SpecializedQueryEngines sqe(RealAnalysis(options), pda, query_engine);

NarrowVisitor narrower(sqe, RealAnalysis(options), options,
SplitsEnabled(opt_level_));
options.splits_enabled());

for (Node* node : TopoSort(f)) {
// We specifically want gate ops to be eligible for being reduced to a
Expand Down Expand Up @@ -1883,18 +1883,14 @@ std::ostream& operator<<(std::ostream& os, NarrowingPass::AnalysisType a) {
XLS_REGISTER_MODULE_INITIALIZER(narrowing_pass, {
CHECK_OK(RegisterOptimizationPass<NarrowingPass>("narrow"));
CHECK_OK(RegisterOptimizationPass<NarrowingPass>(
"narrow(Ternary)", NarrowingPass::AnalysisType::kTernary,
pass_config::kOptLevel));
"narrow(Ternary)", NarrowingPass::AnalysisType::kTernary));
CHECK_OK(RegisterOptimizationPass<NarrowingPass>(
"narrow(Range)", NarrowingPass::AnalysisType::kRange,
pass_config::kOptLevel));
"narrow(Range)", NarrowingPass::AnalysisType::kRange));
CHECK_OK(RegisterOptimizationPass<NarrowingPass>(
"narrow(Context)", NarrowingPass::AnalysisType::kRangeWithContext,
pass_config::kOptLevel));
"narrow(Context)", NarrowingPass::AnalysisType::kRangeWithContext));
CHECK_OK(RegisterOptimizationPass<NarrowingPass>(
"narrow(OptionalContext)",
NarrowingPass::AnalysisType::kRangeWithOptionalContext,
pass_config::kOptLevel));
NarrowingPass::AnalysisType::kRangeWithOptionalContext));
});

} // namespace xls
8 changes: 2 additions & 6 deletions xls/passes/narrowing_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,12 @@ class NarrowingPass : public OptimizationFunctionBasePass {
kRangeWithOptionalContext,
};
static constexpr std::string_view kName = "narrow";
explicit NarrowingPass(AnalysisType analysis = AnalysisType::kRange,
int64_t opt_level = kMaxOptLevel)
: OptimizationFunctionBasePass(kName, "Narrowing"),
analysis_(analysis),
opt_level_(opt_level) {}
explicit NarrowingPass(AnalysisType analysis = AnalysisType::kRange)
: OptimizationFunctionBasePass(kName, "Narrowing"), analysis_(analysis) {}
~NarrowingPass() override = default;

protected:
AnalysisType analysis_;
int64_t opt_level_;

AnalysisType RealAnalysis(const OptimizationPassOptions& options) const;
absl::StatusOr<bool> RunOnFunctionBaseInternal(
Expand Down
Loading

0 comments on commit 7dad99e

Please sign in to comment.