Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve DiffRequest and DynamicGraph printing. #1183

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#ifndef CLAD_DIFF_PLANNER_H
#define CLAD_DIFF_PLANNER_H

#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

#include "clang/AST/RecursiveASTVisitor.h"

#include "llvm/Support/raw_ostream.h"

#include <iterator>
#include <set>

namespace clang {
class CallExpr;
class CompilerInstance;
Expand Down Expand Up @@ -132,15 +135,15 @@ struct DiffRequest {

const clang::FunctionDecl* operator->() const { return Function; }

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
std::to_string(CurrentDerivativeOrder) + "__mode_" +
DiffModeToString(Mode);
if (EnableTBRAnalysis)
res += "__TBR";
std::string res;
llvm::raw_string_ostream s(res);
print(s);
s.flush();
return res;
}
void print(llvm::raw_ostream& Out) const;
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
void dump() const { print(llvm::errs()); }
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
Expand Down
4 changes: 1 addition & 3 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
#include "BuiltinDerivativesCUDA.cuh"
#endif
#include "CladConfig.h"
#include "DynamicGraph.h"
#include "FunctionTraits.h"
#include "Matrix.h"
#include "NumericalDiff.h"
#include "Tape.h"

#include <assert.h>
#include <stddef.h>
#include <array>
#include <cstring>
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

namespace clad {
Expand Down
20 changes: 12 additions & 8 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <iostream>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -109,23 +110,26 @@
const std::vector<T>& getNodes() const { return m_nodes; }
std::vector<T>& getNodes() { return m_nodes; }

/// Dump the nodes and edges.
void dump() const { print(std::cerr); }

/// Print the nodes and edges in the graph.
void print() {
void print(std::ostream& Out) const {
// First print the nodes with their insertion order.
for (const T& node : m_nodes) {
std::pair<bool, int> nodeInfo = m_nodeMap[node];
std::cout << (std::string)node << ": #" << nodeInfo.second;
std::pair<bool, int> nodeInfo = m_nodeMap.at(node);
Out << (std::string)node << ": #" << nodeInfo.second;
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
if (m_sources.find(nodeInfo.second) != m_sources.end())
std::cout << " (source)";
Out << " (source)";
if (nodeInfo.first)
std::cout << ", (done)\n";
Out << ", (done)\n";
else
std::cout << ", (unprocessed)\n";
Out << ", (unprocessed)\n";

Check warning on line 127 in include/clad/Differentiator/DynamicGraph.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/DynamicGraph.h#L127

Added line #L127 was not covered by tests
}
// Then print the edges.
for (int i = 0; i < m_nodes.size(); i++)
for (size_t dest : m_adjList[i])
std::cout << i << " -> " << dest << "\n";
for (size_t dest : m_adjList.at(i))
Out << i << " -> " << dest << "\n";
}

/// Get the next node to be processed from the queue of nodes to be
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "clad/Differentiator/DiffPlanner.h"

#include "clad/Differentiator/DiffMode.h"

#include "ActivityAnalyzer.h"
#include "TBRAnalyzer.h"

Expand Down Expand Up @@ -601,6 +603,19 @@ namespace clad {
return;
}

void DiffRequest::print(llvm::raw_ostream& Out) const {
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
Out << '<';
PrintingPolicy Policy(Function->getASTContext().getLangOpts());
Function->getNameForDiagnostic(Out, Policy, /*Qualified=*/true);
Out << ">[name=" << BaseFunctionName << ", "
<< "order=" << CurrentDerivativeOrder << ", "
<< "mode=" << DiffModeToString(Mode);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved
if (EnableTBRAnalysis)
Out << ", tbr";
Out << ']';
Out.flush();
}

bool DiffRequest::shouldBeRecorded(Expr* E) const {
if (!EnableTBRAnalysis)
return true;
Expand Down
11 changes: 11 additions & 0 deletions test/Misc/TimingsReport.C
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
// RUN: %cladclang %s -I%S/../../include -oTimingsReport.out -ftime-report 2>&1 | %filecheck %s
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS %s
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -print-stats 2>&1 | %filecheck -check-prefix=CHECK_STATS_TBR %s

#include "clad/Differentiator/Differentiator.h"
// CHECK: Timers for Clad Funcs
// CHECK_STATS: *** INFORMATION ABOUT THE DIFF REQUESTS
// CHECK_STATS-NEXT: <test1>[name=test1, order=1, mode=forward]: #0 (source), (done)
// CHECK_STATS-NEXT: <test2>[name=test2, order=1, mode=reverse]: #1 (source), (done)
// CHECK_STATS-NEXT: <nested1>[name=nested1, order=1, mode=pushforward]: #2, (done)
// CHECK_STATS-NEXT: <nested2>[name=nested2, order=1, mode=pullback]: #3, (done)
// CHECK_STATS-NEXT: 0 -> 2
// CHECK_STATS-NEXT: 1 -> 3

// CHECK_STATS_TBR: <test1>[name=test1, order=1, mode=forward, tbr]: #0 (source), (done)

double nested1(double c){
return c*3*c;
Expand Down
3 changes: 2 additions & 1 deletion tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "clad/Differentiator/Compatibility.h"

#include <algorithm>
#include <iostream> // for std::cerr

using namespace clang;

Expand Down Expand Up @@ -526,7 +527,7 @@ namespace clad {

// Print the graph of the diff requests.
llvm::errs() << "\n*** INFORMATION ABOUT THE DIFF REQUESTS\n";
m_DiffRequestGraph.print();
m_DiffRequestGraph.print(std::cerr);
vgvassilev marked this conversation as resolved.
Show resolved Hide resolved

m_Multiplexer->PrintStats();
}
Expand Down
3 changes: 2 additions & 1 deletion unittests/Misc/DynamicGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
Expand Down Expand Up @@ -46,7 +47,7 @@ TEST(DynamicGraphTest, Printing) {
std::stringstream ss;
std::streambuf* coutbuf = std::cout.rdbuf();
std::cout.rdbuf(ss.rdbuf());
G.print();
G.print(std::cout);
std::cout.rdbuf(coutbuf);
std::string expectedOutput = "node0: #0 (source), (unprocessed)\n"
"node1: #1, (unprocessed)\n"
Expand Down
Loading