Skip to content

Commit

Permalink
Improve DiffRequest and DynamicGraph printing.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 23, 2024
1 parent 36dbc6e commit c5f787b
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 19 deletions.
19 changes: 11 additions & 8 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#ifndef CLAD_DIFF_PLANNER_H
#define CLAD_DIFF_PLANNER_H

#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/RecursiveASTVisitor.h"

#include "llvm/Support/raw_ostream.h"

#include <iterator>
#include <set>

namespace clang {
class CallExpr;
class CompilerInstance;
Expand Down Expand Up @@ -132,15 +135,15 @@ struct DiffRequest {

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
std::to_string(CurrentDerivativeOrder) + "__mode_" +
DiffModeToString(Mode);
if (EnableTBRAnalysis)
res += "__TBR";
std::string res;
llvm::raw_string_ostream s(res);
print(s);
s.flush();

Check warning on line 142 in include/clad/Differentiator/DiffPlanner.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DiffPlanner.h#L139-L142

Added lines #L139 - L142 were not covered by tests
return res;
}
void print(llvm::raw_ostream& Out) const;
void dump() const { print(llvm::errs()); }

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#include "BuiltinDerivativesCUDA.cuh"
#endif
#include "CladConfig.h"
#include "DynamicGraph.h"
#include "FunctionTraits.h"
#include "Matrix.h"
#include "NumericalDiff.h"
#include "Tape.h"

#include <array>
#include <assert.h>
#include <stddef.h>
#include <cstring>
Expand Down
19 changes: 11 additions & 8 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,26 @@ template <typename T> class DynamicGraph {
const std::vector<T>& getNodes() const { return m_nodes; }
std::vector<T>& getNodes() { return m_nodes; }

/// Dump the nodes and edges.
void dump() const { print(std::cerr); }

/// Print the nodes and edges in the graph.
void print() {
void print(std::ostream& Out) const {
// First print the nodes with their insertion order.
for (const T& node : m_nodes) {
std::pair<bool, int> nodeInfo = m_nodeMap[node];
std::cout << (std::string)node << ": #" << nodeInfo.second;
std::pair<bool, int> nodeInfo = m_nodeMap.at(node);
Out << (std::string)node << ": #" << nodeInfo.second;

Check warning on line 120 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L119-L120

Added lines #L119 - L120 were not covered by tests
if (m_sources.find(nodeInfo.second) != m_sources.end())
std::cout << " (source)";
Out << " (source)";

Check warning on line 122 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L122

Added line #L122 was not covered by tests
if (nodeInfo.first)
std::cout << ", (done)\n";
Out << ", (done)\n";

Check warning on line 124 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L124

Added line #L124 was not covered by tests
else
std::cout << ", (unprocessed)\n";
Out << ", (unprocessed)\n";

Check warning on line 126 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L126

Added line #L126 was not covered by tests
}
// Then print the edges.
for (int i = 0; i < m_nodes.size(); i++)
for (size_t dest : m_adjList[i])
std::cout << i << " -> " << dest << "\n";
for (size_t dest : m_adjList.at(i))
Out << i << " -> " << dest << "\n";

Check warning on line 131 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L130-L131

Added lines #L130 - L131 were not covered by tests
}

/// Get the next node to be processed from the queue of nodes to be
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "clad/Differentiator/DiffPlanner.h"

#include "clad/Differentiator/DiffMode.h"

#include "ActivityAnalyzer.h"
#include "TBRAnalyzer.h"

Expand Down Expand Up @@ -601,6 +603,19 @@ namespace clad {
return;
}

void DiffRequest::print(llvm::raw_ostream& Out) const {
Out << '<';
PrintingPolicy Policy(Function->getASTContext().getLangOpts());
Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true);
Out << ">[name=" << BaseFunctionName << ", "
<< "order=" << CurrentDerivativeOrder << ", "
<< "mode=" << DiffModeToString(Mode);
if (EnableTBRAnalysis)
Out << ", tbr";
Out << ']';
Out.flush();
}

Check warning on line 617 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L606-L617

Added lines #L606 - L617 were not covered by tests

bool DiffRequest::shouldBeRecorded(Expr* E) const {
if (!EnableTBRAnalysis)
return true;
Expand Down
2 changes: 1 addition & 1 deletion tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ namespace clad {

// Print the graph of the diff requests.
llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n";
m_DiffRequestGraph.print();
m_DiffRequestGraph.print(std::cerr);

m_Multiplexer->PrintStats();
}
Expand Down
3 changes: 2 additions & 1 deletion unittests/Misc/DynamicGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
Expand Down Expand Up @@ -46,7 +47,7 @@ TEST(DynamicGraphTest, Printing) {
std::stringstream ss;
std::streambuf* coutbuf = std::cout.rdbuf();
std::cout.rdbuf(ss.rdbuf());
G.print();
G.print(std::cout);
std::cout.rdbuf(coutbuf);
std::string expectedOutput = "node0: #0 (source), (unprocessed)\n"
"node1: #1, (unprocessed)\n"
Expand Down

0 comments on commit c5f787b

Please sign in to comment.