Skip to content

Commit

Permalink
remove duplicated expr
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Nov 8, 2024
1 parent 845cbf5 commit 60a6ddd
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3155,6 +3155,8 @@ bool improveViaHerbie(

bool InitialValuesSet = false;

std::unordered_set<std::string> seenExprs;

for (const auto &BaseArgs : BaseArgsList) {
SmallString<32> tmpin, tmpout;

Expand Down Expand Up @@ -3235,6 +3237,11 @@ bool improveViaHerbie(
continue;
}

if (seenExprs.count(bestExpr.str()) != 0) {
continue; // Expression already seen, skip it
}
seenExprs.insert(bestExpr.str());

double bits = tests[0].getAsObject()->getNumber("bits").getValue();
json::Array &costAccuracy =
*tests[0].getAsObject()->getArray("cost-accuracy");
Expand Down Expand Up @@ -3265,9 +3272,16 @@ bool improveViaHerbie(
// Handle alternatives
for (size_t i = 0; i < alternatives.size(); ++i) {
json::Array &entry = *alternatives[i].getAsArray();
StringRef expr = entry[2].getAsString().getValue();

if (seenExprs.count(expr.str()) != 0) {
continue;
}
seenExprs.insert(expr.str());

double cost = entry[0].getAsNumber().getValue() / initialCostVal;
double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits;
StringRef expr = entry[2].getAsString().getValue();

RewriteCandidate candidate(cost, accuracy, expr.str());
candidate.CompCost =
getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
Expand Down Expand Up @@ -4327,9 +4341,13 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
if (EnzymePrintFPOpt && !opsToChange.empty()) {
llvm::errs() << "Created PrecisionChange for " << percent
<< "% of Funcs (" << numToChange << ")\n";
llvm::errs() << "Subset gradient range: ["
<< std::fabs(opsToChange.front()->grad) << ", "
<< std::fabs(opsToChange.back()->grad) << "]\n";
llvm::errs() << "Subset sensitivity score range: ["
<< std::fabs(opsToChange.front()->grad *
opsToChange.front()->geometricAvg)
<< ", "
<< std::fabs(opsToChange.back()->grad *
opsToChange.back()->geometricAvg)
<< "]\n";
}

for (auto prec : precTypes) {
Expand Down Expand Up @@ -4373,9 +4391,13 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) {
if (EnzymePrintFPOpt && !opsToChange.empty()) {
llvm::errs() << "Created PrecisionChange for " << percent
<< "% of all operations (" << numToChange << ")\n";
llvm::errs() << "Subset gradient range: ["
<< std::fabs(opsToChange.front()->grad) << ", "
<< std::fabs(opsToChange.back()->grad) << "]\n";
llvm::errs() << "Subset sensitivity score range: ["
<< std::fabs(opsToChange.front()->grad *
opsToChange.front()->geometricAvg)
<< ", "
<< std::fabs(opsToChange.back()->grad *
opsToChange.back()->geometricAvg)
<< "]\n";
}

for (auto prec : precTypes) {
Expand Down

0 comments on commit 60a6ddd

Please sign in to comment.