-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moving DerivedFnCollector out of ClangPlugin
- Loading branch information
1 parent
962e5d9
commit d879f1b
Showing
13 changed files
with
262 additions
and
209 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H | ||
#define CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H | ||
|
||
#include "clad/Differentiator/DerivedFnInfo.h" | ||
|
||
#include "clang/AST/Decl.h" | ||
|
||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
|
||
namespace clad { | ||
/// This class is designed to store collection of `DerivedFnInfo` objects. | ||
/// It's purpose is to avoid repeated generation of same derivatives by | ||
/// making it possible to reuse previously computed derivatives. | ||
class DerivedFnCollector { | ||
using DerivedFns = llvm::SmallVector<DerivedFnInfo, 16>; | ||
/// Mapping to efficiently find out information about all the derivatives of | ||
/// a function. | ||
llvm::DenseMap<const clang::FunctionDecl*, DerivedFns> | ||
m_DerivedFnInfoCollection; | ||
|
||
public: | ||
/// Adds a derived function to the collection. | ||
void Add(const DerivedFnInfo& DFI); | ||
|
||
/// Finds a `DerivedFnInfo` object in the collection that satisfies the | ||
/// given differentiation request. | ||
DerivedFnInfo Find(const DiffRequest& request) const; | ||
|
||
bool IsDerivative(const clang::FunctionDecl* FD) const; | ||
|
||
private: | ||
/// Returns true if the collection already contains a `DerivedFnInfo` | ||
/// object that represents the same derivative object as the provided | ||
/// argument `DFI`. | ||
bool AlreadyExists(const DerivedFnInfo& DFI) const; | ||
}; | ||
} // namespace clad | ||
|
||
#endif // CLAD_DIFFERENTIATOR_DERIVEDFNCOLLECTOR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#ifndef CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H | ||
#define CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H | ||
|
||
#include "clang/AST/Decl.h" | ||
#include "clad/Differentiator/DiffMode.h" | ||
#include "clad/Differentiator/ParseDiffArgsTypes.h" | ||
|
||
namespace clad { | ||
struct DiffRequest; | ||
|
||
/// `DerivedFnInfo` is designed to effectively store information about a | ||
/// derived function. | ||
struct DerivedFnInfo { | ||
const clang::FunctionDecl* m_OriginalFn = nullptr; | ||
clang::FunctionDecl* m_DerivedFn = nullptr; | ||
clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; | ||
DiffMode m_Mode = DiffMode::unknown; | ||
unsigned m_DerivativeOrder = 0; | ||
DiffInputVarsInfo m_DiffVarsInfo; | ||
bool m_UsesEnzyme = false; | ||
bool m_DeclarationOnly = false; | ||
|
||
DerivedFnInfo() = default; | ||
DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, | ||
clang::FunctionDecl* overloadedDerivedFn); | ||
|
||
/// Returns true if the derived function represented by the object, | ||
/// satisfies the requirements of the given differentiation request. | ||
bool SatisfiesRequest(const DiffRequest& request) const; | ||
|
||
/// Returns true if the object represents any derived function; otherwise | ||
/// returns false. | ||
bool IsValid() const; | ||
|
||
const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } | ||
clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } | ||
clang::FunctionDecl* OverloadedDerivedFn() const { | ||
return m_OverloadedDerivedFn; | ||
} | ||
bool DeclarationOnly() const { return m_DeclarationOnly; } | ||
|
||
/// Returns true if `lhs` and `rhs` represents same derivative. | ||
/// Here derivative is any function derived by clad. | ||
static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, | ||
const DerivedFnInfo& rhs); | ||
}; | ||
} // namespace clad | ||
|
||
#endif // CLAD_DIFFERENTIATOR_DERIVEDFNINFO_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include "clad/Differentiator/DerivedFnCollector.h" | ||
#include "clad/Differentiator/DiffPlanner.h" | ||
|
||
namespace clad { | ||
void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { | ||
assert(!AlreadyExists(DFI) && | ||
"We are generating same derivative more than once, or calling " | ||
"`DerivedFnCollector::Add` more than once for the same derivative " | ||
". Ideally, we shouldn't do either."); | ||
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); | ||
} | ||
|
||
bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { | ||
auto subCollectionIt = m_DerivedFnInfoCollection.find(DFI.OriginalFn()); | ||
if (subCollectionIt == m_DerivedFnInfoCollection.end()) | ||
return false; | ||
const auto& subCollection = subCollectionIt->second; | ||
const auto* it = | ||
std::find_if(subCollection.begin(), subCollection.end(), | ||
[&DFI](const DerivedFnInfo& info) { | ||
return DerivedFnInfo::RepresentsSameDerivative(DFI, info); | ||
}); | ||
return it != subCollection.end(); | ||
} | ||
|
||
DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const { | ||
auto subCollectionIt = m_DerivedFnInfoCollection.find(request.Function); | ||
if (subCollectionIt == m_DerivedFnInfoCollection.end()) | ||
return DerivedFnInfo(); | ||
const auto& subCollection = subCollectionIt->second; | ||
const auto* it = std::find_if(subCollection.begin(), subCollection.end(), | ||
[&request](const DerivedFnInfo& DFI) { | ||
return DFI.SatisfiesRequest(request); | ||
}); | ||
if (it == subCollection.end()) | ||
return DerivedFnInfo(); | ||
return *it; | ||
} | ||
} // namespace clad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.