From 825396db00d1e54e4f290b4bd41039785f7f080e Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Wed, 1 Jan 2025 09:50:58 +0000 Subject: [PATCH] Simplify error estimation by not storing m_ParamTypes --- include/clad/Differentiator/ErrorEstimator.h | 1 - lib/Differentiator/ErrorEstimator.cpp | 16 +++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/include/clad/Differentiator/ErrorEstimator.h b/include/clad/Differentiator/ErrorEstimator.h index 381fe42af..1a4bc3e89 100644 --- a/include/clad/Differentiator/ErrorEstimator.h +++ b/include/clad/Differentiator/ErrorEstimator.h @@ -48,7 +48,6 @@ class ErrorEstimationHandler : public ExternalRMVSource { std::stack m_ShouldEmit; ReverseModeVisitor* m_RMV; - llvm::SmallVectorImpl* m_ParamTypes = nullptr; llvm::SmallVectorImpl* m_Params = nullptr; public: diff --git a/lib/Differentiator/ErrorEstimator.cpp b/lib/Differentiator/ErrorEstimator.cpp index 368076483..9a1ae4adc 100644 --- a/lib/Differentiator/ErrorEstimator.cpp +++ b/lib/Differentiator/ErrorEstimator.cpp @@ -281,25 +281,23 @@ void ErrorEstimationHandler::ActBeforeCreatingDerivedFnParamTypes( void ErrorEstimationHandler::ActAfterCreatingDerivedFnParamTypes( llvm::SmallVectorImpl& paramTypes) { - m_ParamTypes = ¶mTypes; // If we are performing error estimation, our gradient function // will have an extra argument which will hold the final error value - paramTypes.push_back( - m_RMV->m_Context.getLValueReferenceType(m_RMV->m_Context.DoubleTy)); + ASTContext& C = m_RMV->m_Context; + paramTypes.push_back(C.getLValueReferenceType(C.DoubleTy)); } void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams( llvm::SmallVectorImpl& params) { m_Params = ¶ms; // If in error estimation mode, create the error parameter - ASTContext& context = m_RMV->m_Context; + ASTContext& C = m_RMV->m_Context; // Repeat the above but for the error ouput var "_final_error" + QualType LastParamTy = C.getLValueReferenceType(C.DoubleTy); ParmVarDecl* errorVarDecl = ParmVarDecl::Create( - context, m_RMV->m_Derivative, noLoc, noLoc, - &context.Idents.get("_final_error"), m_ParamTypes->back(), - context.getTrivialTypeSourceInfo(m_ParamTypes->back(), noLoc), - params.front()->getStorageClass(), - /*DefArg=*/nullptr); + C, m_RMV->m_Derivative, noLoc, noLoc, &C.Idents.get("_final_error"), + LastParamTy, C.getTrivialTypeSourceInfo(LastParamTy, noLoc), + params.front()->getStorageClass(), /*DefArg=*/nullptr); params.push_back(errorVarDecl); m_RMV->m_Sema.PushOnScopeChains(params.back(), m_RMV->getCurrentScope(), /*AddToContext=*/false);