From c5f787bb3016a9dfc9ef48d31d7f09b2db455372 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Sun, 22 Dec 2024 16:47:43 +0000 Subject: [PATCH] Improve DiffRequest and DynamicGraph printing. --- include/clad/Differentiator/DiffPlanner.h | 19 +++++++++++-------- include/clad/Differentiator/Differentiator.h | 2 +- include/clad/Differentiator/DynamicGraph.h | 19 +++++++++++-------- lib/Differentiator/DiffPlanner.cpp | 15 +++++++++++++++ tools/ClangPlugin.cpp | 2 +- unittests/Misc/DynamicGraph.cpp | 3 ++- 6 files changed, 41 insertions(+), 19 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 30b483b7e..54c4ee3eb 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -1,14 +1,17 @@ #ifndef CLAD_DIFF_PLANNER_H #define CLAD_DIFF_PLANNER_H -#include "clang/AST/RecursiveASTVisitor.h" -#include "llvm/ADT/SmallSet.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clang/AST/RecursiveASTVisitor.h" + +#include "llvm/Support/raw_ostream.h" + #include #include + namespace clang { class CallExpr; class CompilerInstance; @@ -132,15 +135,15 @@ struct DiffRequest { const clang::FunctionDecl* operator->() const { return Function; } - // String operator for printing the node. operator std::string() const { - std::string res = BaseFunctionName + "__order_" + - std::to_string(CurrentDerivativeOrder) + "__mode_" + - DiffModeToString(Mode); - if (EnableTBRAnalysis) - res += "__TBR"; + std::string res; + llvm::raw_string_ostream s(res); + print(s); + s.flush(); return res; } + void print(llvm::raw_ostream& Out) const; + void dump() const { print(llvm::errs()); } bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index c8aaaa286..5fc19781b 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -14,12 +14,12 @@ #include "BuiltinDerivativesCUDA.cuh" #endif #include "CladConfig.h" -#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" #include "Tape.h" +#include #include #include #include diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 2ef8cf992..343d832c2 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -109,23 +109,26 @@ template class DynamicGraph { const std::vector& getNodes() const { return m_nodes; } std::vector& getNodes() { return m_nodes; } + /// Dump the nodes and edges. + void dump() const { print(std::cerr); } + /// Print the nodes and edges in the graph. - void print() { + void print(std::ostream& Out) const { // First print the nodes with their insertion order. for (const T& node : m_nodes) { - std::pair nodeInfo = m_nodeMap[node]; - std::cout << (std::string)node << ": #" << nodeInfo.second; + std::pair nodeInfo = m_nodeMap.at(node); + Out << (std::string)node << ": #" << nodeInfo.second; if (m_sources.find(nodeInfo.second) != m_sources.end()) - std::cout << " (source)"; + Out << " (source)"; if (nodeInfo.first) - std::cout << ", (done)\n"; + Out << ", (done)\n"; else - std::cout << ", (unprocessed)\n"; + Out << ", (unprocessed)\n"; } // Then print the edges. for (int i = 0; i < m_nodes.size(); i++) - for (size_t dest : m_adjList[i]) - std::cout << i << " -> " << dest << "\n"; + for (size_t dest : m_adjList.at(i)) + Out << i << " -> " << dest << "\n"; } /// Get the next node to be processed from the queue of nodes to be diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index d2c39c1d9..cd90e0c01 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,5 +1,7 @@ #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DiffMode.h" + #include "ActivityAnalyzer.h" #include "TBRAnalyzer.h" @@ -601,6 +603,19 @@ namespace clad { return; } + void DiffRequest::print(llvm::raw_ostream& Out) const { + Out << '<'; + PrintingPolicy Policy(Function->getASTContext().getLangOpts()); + Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true); + Out << ">[name=" << BaseFunctionName << ", " + << "order=" << CurrentDerivativeOrder << ", " + << "mode=" << DiffModeToString(Mode); + if (EnableTBRAnalysis) + Out << ", tbr"; + Out << ']'; + Out.flush(); + } + bool DiffRequest::shouldBeRecorded(Expr* E) const { if (!EnableTBRAnalysis) return true; diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 9e1977e0b..ab995c2d5 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -526,7 +526,7 @@ namespace clad { // Print the graph of the diff requests. llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n"; - m_DiffRequestGraph.print(); + m_DiffRequestGraph.print(std::cerr); m_Multiplexer->PrintStats(); } diff --git a/unittests/Misc/DynamicGraph.cpp b/unittests/Misc/DynamicGraph.cpp index 6954a6698..04d8f90fa 100644 --- a/unittests/Misc/DynamicGraph.cpp +++ b/unittests/Misc/DynamicGraph.cpp @@ -1,3 +1,4 @@ +#include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/Differentiator.h" #include @@ -46,7 +47,7 @@ TEST(DynamicGraphTest, Printing) { std::stringstream ss; std::streambuf* coutbuf = std::cout.rdbuf(); std::cout.rdbuf(ss.rdbuf()); - G.print(); + G.print(std::cout); std::cout.rdbuf(coutbuf); std::string expectedOutput = "node0: #0 (source), (unprocessed)\n" "node1: #1, (unprocessed)\n"