Skip to content

Commit

Permalink
Improve DiffRequest and DynamicGraph printing.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 22, 2024
1 parent 36dbc6e commit bb87743
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
19 changes: 8 additions & 11 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
#ifndef CLAD_DIFF_PLANNER_H
#define CLAD_DIFF_PLANNER_H

#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/RecursiveASTVisitor.h"

#include "llvm/ADT/SmallSet.h"
#include <llvm/Support/raw_ostream.h>

#include <iterator>
#include <set>

namespace clang {
class CallExpr;
class CompilerInstance;
Expand Down Expand Up @@ -132,15 +136,8 @@ struct DiffRequest {

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
std::to_string(CurrentDerivativeOrder) + "__mode_" +
DiffModeToString(Mode);
if (EnableTBRAnalysis)
res += "__TBR";
return res;
}
void print(llvm::raw_ostream& Out) const;
void dump() const { print(llvm::errs()); }

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
Expand Down
1 change: 0 additions & 1 deletion include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "BuiltinDerivativesCUDA.cuh"
#endif
#include "CladConfig.h"
#include "DynamicGraph.h"
#include "FunctionTraits.h"
#include "Matrix.h"
#include "NumericalDiff.h"
Expand Down
22 changes: 14 additions & 8 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H
#define CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H

#include "llvm/Support/raw_ostream.h"

#include <algorithm>
#include <functional>
#include <iostream>
Expand Down Expand Up @@ -109,23 +111,27 @@ template <typename T> class DynamicGraph {
const std::vector<T>& getNodes() const { return m_nodes; }
std::vector<T>& getNodes() { return m_nodes; }

/// Dump the nodes and edges.
void dump() const { print(llvm::errs()); }

/// Print the nodes and edges in the graph.
void print() {
void print(llvm::raw_ostream& Out) const {
// First print the nodes with their insertion order.
for (const T& node : m_nodes) {
std::pair<bool, int> nodeInfo = m_nodeMap[node];
std::cout << (std::string)node << ": #" << nodeInfo.second;
std::pair<bool, int> nodeInfo = m_nodeMap.at(node);
node.print(Out);
Out << ": #" << nodeInfo.second;
if (m_sources.find(nodeInfo.second) != m_sources.end())
std::cout << " (source)";
Out << " (source)";
if (nodeInfo.first)
std::cout << ", (done)\n";
Out << ", (done)\n";
else
std::cout << ", (unprocessed)\n";
Out << ", (unprocessed)\n";
}
// Then print the edges.
for (int i = 0; i < m_nodes.size(); i++)
for (size_t dest : m_adjList[i])
std::cout << i << " -> " << dest << "\n";
for (size_t dest : m_adjList.at(i))
Out << i << " -> " << dest << "\n";
}

/// Get the next node to be processed from the queue of nodes to be
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "clad/Differentiator/DiffPlanner.h"

#include "clad/Differentiator/DiffMode.h"

#include "ActivityAnalyzer.h"
#include "TBRAnalyzer.h"

Expand Down Expand Up @@ -601,6 +603,19 @@ namespace clad {
return;
}

void DiffRequest::print(llvm::raw_ostream& Out) const {
Out << '<';
PrintingPolicy Policy(Function->getASTContext().getLangOpts());
Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true);
Out << ">[name=" << BaseFunctionName << ", "
<< "order=" << CurrentDerivativeOrder << ", "
<< "mode=" << DiffModeToString(Mode);
if (EnableTBRAnalysis)
Out << ", tbr";
Out << ']';
Out.flush();
}

bool DiffRequest::shouldBeRecorded(Expr* E) const {
if (!EnableTBRAnalysis)
return true;
Expand Down
2 changes: 1 addition & 1 deletion tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ namespace clad {

// Print the graph of the diff requests.
llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n";
m_DiffRequestGraph.print();
m_DiffRequestGraph.print(llvm::errs());

m_Multiplexer->PrintStats();
}
Expand Down
15 changes: 10 additions & 5 deletions unittests/Misc/DynamicGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/DynamicGraph.h"

#include "llvm/Support/raw_ostream.h"

#include <iostream>
#include <string>
Expand All @@ -16,7 +19,9 @@ struct Node {
}

// String operator for printing the node.
operator std::string() const { return name + std::to_string(id); }
void print(llvm::raw_ostream &Out) const {
Out << name << std::to_string(id);
}
};

// Specialize std::hash for the Node type.
Expand Down Expand Up @@ -44,10 +49,10 @@ TEST(DynamicGraphTest, Printing) {

// Check the printed output.
std::stringstream ss;
std::streambuf* coutbuf = std::cout.rdbuf();
std::cout.rdbuf(ss.rdbuf());
G.print();
std::cout.rdbuf(coutbuf);
std::streambuf* coutbuf = std::cerr.rdbuf();
std::cerr.rdbuf(ss.rdbuf());
G.dump();
std::cerr.rdbuf(coutbuf);
std::string expectedOutput = "node0: #0 (source), (unprocessed)\n"
"node1: #1, (unprocessed)\n"
"node2: #2, (unprocessed)\n"
Expand Down

0 comments on commit bb87743

Please sign in to comment.