From fe0b354196d57a5adbc680861173fc15d59d055f Mon Sep 17 00:00:00 2001 From: Brant-Skywalker Date: Thu, 7 Nov 2024 22:28:12 -0600 Subject: [PATCH] accuracy cost evaluation: arithmetic avg --> geometric avg --- enzyme/Enzyme/Herbie.cpp | 143 ++++++++++++++++++++------------------- 1 file changed, 75 insertions(+), 68 deletions(-) diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 3a6ada459fa..a4a899a980f 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -2925,12 +2925,12 @@ void setUnifiedAccuracyCost( // llvm::errs() << "DEBUG AO real value: " << realVal << "\n"; if (!std::isnan(goldVal) && !std::isnan(realVal)) { - initAC += std::fabs(goldVal - realVal); + initAC += std::log1p(std::fabs(goldVal - realVal)); numValidSamples++; } } - AO.initialAccCost = initAC / numValidSamples * std::fabs(AO.grad); + AO.initialAccCost = std::expm1(initAC / numValidSamples) * std::fabs(AO.grad); // llvm::errs() << "DEBUG calculated AO initial accuracy cost: " // << AO.initialAccCost << "\n"; assert(numValidSamples && "No valid samples for AO -- try increasing the " @@ -2968,13 +2968,14 @@ void setUnifiedAccuracyCost( // llvm::errs() << "Real value: " << realVal << "\n"; double goldVal = goldVals[pair.index()]; if (!std::isnan(goldVal) && !std::isnan(realVal)) { - ac += std::fabs(goldVal - realVal); + ac += std::log1p(std::fabs(goldVal - realVal)); numValidSamples++; } } assert(numValidSamples && "No valid samples for AO -- try increasing the " "number of samples"); - candidate.accuracyCost = ac / numValidSamples * std::fabs(AO.grad); + candidate.accuracyCost = + std::expm1(ac / numValidSamples) * std::fabs(AO.grad); assert(!std::isnan(candidate.accuracyCost)); } } @@ -3024,7 +3025,7 @@ void setUnifiedAccuracyCost( double goldVal = goldVals[output][pair.index()]; if (!std::isnan(goldVal) && !std::isnan(result)) { double diff = std::fabs(goldVal - result); - ACC.perOutputInitialAccCost[output] += diff; + ACC.perOutputInitialAccCost[output] += std::log1p(diff); numValidSamplesPerOutput[output]++; } } @@ -3036,9 +3037,10 @@ void setUnifiedAccuracyCost( unsigned numValidSamples = numValidSamplesPerOutput[output]; assert(numValidSamples && "No valid samples for at least one output node " "-- try increasing the number of samples"); - ACC.perOutputInitialAccCost[output] /= numValidSamples; // Local error --> global error - ACC.perOutputInitialAccCost[output] *= std::fabs(output->grad); + ACC.perOutputInitialAccCost[output] = + std::expm1(ACC.perOutputInitialAccCost[output] / numValidSamples) * + std::fabs(output->grad); // llvm::errs() << "DEBUG calculated ACC per output initial accuracy cost: " // << ACC.perOutputInitialAccCost[output] << "\n"; ACC.initialAccCost += ACC.perOutputInitialAccCost[output]; @@ -3062,7 +3064,7 @@ void setUnifiedAccuracyCost( if (!std::isnan(goldVal) && !std::isnan(result)) { double diff = std::fabs(goldVal - result); // Sum up local errors - candidate.perOutputAccCost[output] += diff; + candidate.perOutputAccCost[output] += std::log1p(diff); numValidSamplesPerOutput[output]++; } } @@ -3074,9 +3076,10 @@ void setUnifiedAccuracyCost( unsigned numValidSamples = numValidSamplesPerOutput[output]; assert(numValidSamples && "No valid samples for output -- try increasing " "the number of samples"); - candidate.perOutputAccCost[output] /= numValidSamples; // Local error --> global error - candidate.perOutputAccCost[output] *= std::fabs(output->grad); + candidate.perOutputAccCost[output] = + std::expm1(candidate.perOutputAccCost[output] / numValidSamples) * + std::fabs(output->grad); // llvm::errs() // << "DEBUG calculated ACC per output candidate accuracy cost: " // << candidate.perOutputAccCost[output] << "\n"; @@ -3585,10 +3588,10 @@ bool accuracyDPSolver( InstructionCost newCompCost = currCompCost + candCompCost; double newAccCost = currAccCost + candAccCost; - if (EnzymePrintFPOpt) - llvm::errs() << "AO candidate " << i - << " has accuracy cost: " << candAccCost - << " and computation cost: " << candCompCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "AO candidate " << i + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; if (newCostToAccuracyMap.find(newCompCost) == newCostToAccuracyMap.end() || @@ -3596,10 +3599,10 @@ bool accuracyDPSolver( newCostToAccuracyMap[newCompCost] = newAccCost; newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; newCostToSolutionMap[newCompCost].emplace_back(&AO, i); - if (EnzymePrintFPOpt) - llvm::errs() << "Updating accuracy map (AO candidate " << i - << "): computation cost " << newCompCost - << " -> accuracy cost " << newAccCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "Updating accuracy map (AO candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; } } } @@ -3629,13 +3632,14 @@ bool accuracyDPSolver( otherCompCost.getValue().getValue()) && currAccCost - otherAccCost > std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { - if (EnzymePrintFPOpt) - llvm::errs() << "AO candidate with computation cost: " - << currCompCost - << " and accuracy cost: " << currAccCost - << " is dominated by candidate with computation cost:" - << otherCompCost - << " and accuracy cost: " << otherAccCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "AO candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; dominated = true; break; } @@ -3679,11 +3683,11 @@ bool accuracyDPSolver( InstructionCost newCompCost = currCompCost + candCompCost; double newAccCost = currAccCost + candAccCost; - if (EnzymePrintFPOpt) - llvm::errs() << "ACC candidate " << i << " (" - << candidate.value().desc - << ") has accuracy cost: " << candAccCost - << " and computation cost: " << candCompCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "ACC candidate " << i << " (" + // << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; if (newCostToAccuracyMap.find(newCompCost) == newCostToAccuracyMap.end() || @@ -3691,10 +3695,10 @@ bool accuracyDPSolver( newCostToAccuracyMap[newCompCost] = newAccCost; newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; newCostToSolutionMap[newCompCost].emplace_back(&ACC, i); - if (EnzymePrintFPOpt) - llvm::errs() << "Updating accuracy map (ACC candidate " << i - << "): computation cost " << newCompCost - << " -> accuracy cost " << newAccCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "Updating accuracy map (ACC candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; } } } @@ -3716,13 +3720,14 @@ bool accuracyDPSolver( otherCompCost.getValue().getValue()) && currAccCost - otherAccCost > std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { - if (EnzymePrintFPOpt) - llvm::errs() << "ACC candidate with computation cost: " - << currCompCost - << " and accuracy cost: " << currAccCost - << " is dominated by candidate with computation cost:" - << otherCompCost - << " and accuracy cost: " << otherAccCost << "\n"; + // if (EnzymePrintFPOpt) + // llvm::errs() << "ACC candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; dominated = true; break; } @@ -3740,33 +3745,35 @@ bool accuracyDPSolver( } if (EnzymePrintFPOpt) { - llvm::errs() << "\n*** DP Table ***\n"; - for (const auto &pair : costToAccuracyMap) { - llvm::errs() << "Computation cost: " << pair.first - << ", Accuracy cost: " << pair.second << "\n"; - llvm::errs() << "\tSolution steps: \n"; - for (const auto &step : costToSolutionMap[pair.first]) { - std::visit( - [&](auto *item) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - llvm::errs() - << "\t\t" << item->expr << " --(" << step.candidateIndex - << ")-> " << item->candidates[step.candidateIndex].expr - << "\n"; - } else if constexpr (std::is_same_v) { - llvm::errs() - << "\t\tACC: " << item->candidates[step.candidateIndex].desc - << " (#" << step.candidateIndex << ")\n"; - } else { - llvm_unreachable( - "accuracyDPSolver: Unexpected type of solution step"); - } - }, - step.item); - } - } - llvm::errs() << "*** End of DP Table ***\n\n"; + // llvm::errs() << "\n*** DP Table ***\n"; + // for (const auto &pair : costToAccuracyMap) { + // llvm::errs() << "Computation cost: " << pair.first + // << ", Accuracy cost: " << pair.second << "\n"; + // llvm::errs() << "\tSolution steps: \n"; + // for (const auto &step : costToSolutionMap[pair.first]) { + // std::visit( + // [&](auto *item) { + // using T = std::decay_t; + // if constexpr (std::is_same_v) { + // llvm::errs() + // << "\t\t" << item->expr << " --(" << + // step.candidateIndex + // << ")-> " << item->candidates[step.candidateIndex].expr + // << "\n"; + // } else if constexpr (std::is_same_v) { + // llvm::errs() + // << "\t\tACC: " << + // item->candidates[step.candidateIndex].desc + // << " (#" << step.candidateIndex << ")\n"; + // } else { + // llvm_unreachable( + // "accuracyDPSolver: Unexpected type of solution step"); + // } + // }, + // step.item); + // } + // } + // llvm::errs() << "*** End of DP Table ***\n\n"; llvm::errs() << "*** Critical Computation Costs ***\n"; // Just print all computation costs in the DP table for (const auto &pair : costToAccuracyMap) {