Skip to content

Commit

Permalink
fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Nov 8, 2024
1 parent bd013be commit f4f7335
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3165,7 +3165,7 @@ bool improveViaHerbie(
BaseArgsList.push_back(BaseArgs);
}

std::unordered_set<std::string> seenExprs;
std::vector<std::unordered_set<std::string>> seenExprs;
bool success = false;

for (const auto &BaseArgs : BaseArgsList) {
Expand Down Expand Up @@ -3270,30 +3270,31 @@ bool improveViaHerbie(
AO.initialHerbieCost = initialCost;
AO.initialHerbieAccuracy = initialAccuracy;

json::Array &best = *costAccuracy[1].getAsArray();
double bestCost = best[0].getAsNumber().getValue() / initialCostVal;
double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits;
if (seenExprs[i].count(bestExpr.str()) == 0) {
seenExprs[i].insert(bestExpr.str());

RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str());
bestCandidate.CompCost =
getCompCost(bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(bestCandidate);
json::Array &best = *costAccuracy[1].getAsArray();
double bestCost = best[0].getAsNumber().getValue() / initialCostVal;
double bestAccuracy = 1.0 - best[1].getAsNumber().getValue() / bits;

json::Array &alternatives = *costAccuracy[2].getAsArray();
RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str());
bestCandidate.CompCost = getCompCost(
bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap,
cast<Instruction>(AO.oldOutput)->getFastMathFlags());
AO.candidates.push_back(bestCandidate);
}

std::unordered_set<std::string> seenExprs;
seenExprs.insert(bestExpr.str());
json::Array &alternatives = *costAccuracy[2].getAsArray();

// Handle alternatives
for (size_t j = 0; j < alternatives.size(); ++j) {
json::Array &entry = *alternatives[j].getAsArray();
StringRef expr = entry[2].getAsString().getValue();

if (seenExprs.count(expr.str()) != 0) {
if (seenExprs[i].count(expr.str()) != 0) {
continue;
}
seenExprs.insert(expr.str());
seenExprs[i].insert(expr.str());

double cost = entry[0].getAsNumber().getValue() / initialCostVal;
double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits;
Expand Down

0 comments on commit f4f7335

Please sign in to comment.