From 60a6ddd15e9ee3e98d6e7779238abf3b1e4dcd7a Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Thu, 7 Nov 2024 19:38:05 -0600 Subject: [PATCH] remove duplicated expr --- enzyme/Enzyme/Herbie.cpp | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 1f80181a85b..3a6ada459fa 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -3155,6 +3155,8 @@ bool improveViaHerbie( bool InitialValuesSet = false; + std::unordered_set seenExprs; + for (const auto &BaseArgs : BaseArgsList) { SmallString<32> tmpin, tmpout; @@ -3235,6 +3237,11 @@ bool improveViaHerbie( continue; } + if (seenExprs.count(bestExpr.str()) != 0) { + continue; // Expression already seen, skip it + } + seenExprs.insert(bestExpr.str()); + double bits = tests[0].getAsObject()->getNumber("bits").getValue(); json::Array &costAccuracy = *tests[0].getAsObject()->getArray("cost-accuracy"); @@ -3265,9 +3272,16 @@ bool improveViaHerbie( // Handle alternatives for (size_t i = 0; i < alternatives.size(); ++i) { json::Array &entry = *alternatives[i].getAsArray(); + StringRef expr = entry[2].getAsString().getValue(); + + if (seenExprs.count(expr.str()) != 0) { + continue; + } + seenExprs.insert(expr.str()); + double cost = entry[0].getAsNumber().getValue() / initialCostVal; double accuracy = 1.0 - entry[1].getAsNumber().getValue() / bits; - StringRef expr = entry[2].getAsString().getValue(); + RewriteCandidate candidate(cost, accuracy, expr.str()); candidate.CompCost = getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, @@ -4327,9 +4341,13 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { if (EnzymePrintFPOpt && !opsToChange.empty()) { llvm::errs() << "Created PrecisionChange for " << percent << "% of Funcs (" << numToChange << ")\n"; - llvm::errs() << "Subset gradient range: [" - << std::fabs(opsToChange.front()->grad) << ", " - << std::fabs(opsToChange.back()->grad) << "]\n"; + llvm::errs() << "Subset sensitivity score range: [" + << std::fabs(opsToChange.front()->grad * + opsToChange.front()->geometricAvg) + << ", " + << std::fabs(opsToChange.back()->grad * + opsToChange.back()->geometricAvg) + << "]\n"; } for (auto prec : precTypes) { @@ -4373,9 +4391,13 @@ bool fpOptimize(Function &F, const TargetTransformInfo &TTI) { if (EnzymePrintFPOpt && !opsToChange.empty()) { llvm::errs() << "Created PrecisionChange for " << percent << "% of all operations (" << numToChange << ")\n"; - llvm::errs() << "Subset gradient range: [" - << std::fabs(opsToChange.front()->grad) << ", " - << std::fabs(opsToChange.back()->grad) << "]\n"; + llvm::errs() << "Subset sensitivity score range: [" + << std::fabs(opsToChange.front()->grad * + opsToChange.front()->geometricAvg) + << ", " + << std::fabs(opsToChange.back()->grad * + opsToChange.back()->geometricAvg) + << "]\n"; } for (auto prec : precTypes) {