diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 30b483b7e..917a75b4f 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -132,15 +132,8 @@ struct DiffRequest { const clang::FunctionDecl* operator->() const { return Function; } - // String operator for printing the node. - operator std::string() const { - std::string res = BaseFunctionName + "__order_" + - std::to_string(CurrentDerivativeOrder) + "__mode_" + - DiffModeToString(Mode); - if (EnableTBRAnalysis) - res += "__TBR"; - return res; - } + void print(llvm::raw_ostream& Out) const; + void dump() const { print(llvm::errs()); } bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index c8aaaa286..60e4b014c 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -14,7 +14,6 @@ #include "BuiltinDerivativesCUDA.cuh" #endif #include "CladConfig.h" -#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 2ef8cf992..e8a095ca8 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -109,23 +109,27 @@ template class DynamicGraph { const std::vector& getNodes() const { return m_nodes; } std::vector& getNodes() { return m_nodes; } + /// Dump the nodes and edges. + void dump() const { print(llvm::errs()); } + /// Print the nodes and edges in the graph. - void print() { + void print(llvm::raw_ostream& Out) const { // First print the nodes with their insertion order. for (const T& node : m_nodes) { - std::pair nodeInfo = m_nodeMap[node]; - std::cout << (std::string)node << ": #" << nodeInfo.second; + std::pair nodeInfo = m_nodeMap.at(node); + node.print(Out); + Out << ": #" << nodeInfo.second; if (m_sources.find(nodeInfo.second) != m_sources.end()) - std::cout << " (source)"; + Out << " (source)"; if (nodeInfo.first) - std::cout << ", (done)\n"; + Out << ", (done)\n"; else - std::cout << ", (unprocessed)\n"; + Out << ", (unprocessed)\n"; } // Then print the edges. for (int i = 0; i < m_nodes.size(); i++) - for (size_t dest : m_adjList[i]) - std::cout << i << " -> " << dest << "\n"; + for (size_t dest : m_adjList.at(i)) + Out << i << " -> " << dest << "\n"; } /// Get the next node to be processed from the queue of nodes to be diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d2c39c1d9..40fddabef 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -601,6 +601,19 @@ namespace clad { return; } + void DiffRequest::print(llvm::raw_ostream& Out) const { + Out << '<'; + PrintingPolicy Policy(Function->getASTContext().getLangOpts()); + Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true); + Out << ">[name=" << BaseFunctionName << ", " + << "order=" << CurrentDerivativeOrder << ", " + << "mode=" << DiffModeToString(Mode); + if (EnableTBRAnalysis) + Out << ", tbr"; + Out << ']'; + Out.flush(); + } + bool DiffRequest::shouldBeRecorded(Expr* E) const { if (!EnableTBRAnalysis) return true; diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 9e1977e0b..d446d9d49 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -526,7 +526,7 @@ namespace clad { // Print the graph of the diff requests. llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n"; - m_DiffRequestGraph.print(); + m_DiffRequestGraph.print(llvm::errs()); m_Multiplexer->PrintStats(); }