forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#14 from Superjomn/init-graph-utils
init graph utils
- Loading branch information
Showing
11 changed files
with
304 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#pragma once | ||
|
||
#include "cinn/common/domain.h" | ||
#include "cinn/common/graph_utils.h" | ||
#include "cinn/common/pod_value.h" | ||
#include "cinn/common/shared.h" | ||
#include "cinn/common/type.h" | ||
|
||
namespace cinn { | ||
|
||
// export some general concepts. | ||
using common::Float; | ||
using common::Int; | ||
using common::Object; | ||
using common::Shared; | ||
using common::make_shared; | ||
|
||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,94 @@ | ||
#include "cinn/common/graph_utils.h" | ||
#include <glog/logging.h> | ||
#include <functional> | ||
#include <set> | ||
#include <stack> | ||
|
||
namespace cinn { | ||
namespace common {} // namespace common | ||
} // namespace cinn | ||
namespace common { | ||
|
||
namespace { | ||
|
||
void TopologicalSortUtil(GraphNode *node, | ||
std::set<GraphNode *> *visited, | ||
std::stack<GraphNode *> *stack, | ||
std::vector<GraphNode *> *order, | ||
std::vector<GraphEdge *> *edge_order) { | ||
node->VisitOnce(); | ||
if (!node->visited()) return; | ||
CHECK(!visited->count(node)) << "duplicate visit current node"; | ||
|
||
// Mark the current node as visited. | ||
visited->insert(node); | ||
order->push_back(node); | ||
|
||
for (auto &e : node->outlinks()) { | ||
if (!visited->count(e->sink())) { | ||
edge_order->push_back(e.get()); | ||
TopologicalSortUtil(e->sink(), visited, stack, order, edge_order); | ||
} | ||
} | ||
|
||
stack->push(node); | ||
} | ||
|
||
std::tuple<Graph::node_order_t, Graph::edge_order_t> TopologicalSort(const std::vector<GraphNode *> &nodes) { | ||
std::stack<GraphNode *> stack; | ||
std::set<GraphNode *> visited; // Tell whether a node is visited | ||
std::vector<GraphNode *> order; // nodes visited in order | ||
std::vector<GraphEdge *> edges; // edges visited in order | ||
|
||
for (auto *node : nodes) { | ||
if (!visited.count(node)) { | ||
TopologicalSortUtil(node, &visited, &stack, &order, &edges); | ||
} | ||
} | ||
return std::make_tuple(std::move(order), std::move(edges)); | ||
} | ||
|
||
void DFSSortUtil(const GraphNode *node, std::vector<GraphNode *> *order) {} | ||
|
||
std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {} | ||
|
||
} // namespace | ||
|
||
std::vector<const GraphNode *> Graph::nodes() const { | ||
std::vector<const GraphNode *> res; | ||
for (auto &s : nodes_) res.push_back(s.get()); | ||
return res; | ||
} | ||
std::vector<GraphNode *> Graph::nodes() { | ||
std::vector<GraphNode *> res; | ||
for (auto &s : nodes_) res.push_back(s.get()); | ||
return res; | ||
} | ||
|
||
std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topological_order() { | ||
return TopologicalSort(nodes()); | ||
} | ||
|
||
std::vector<GraphNode *> Graph::dfs_order() { return std::vector<GraphNode *>(); } | ||
|
||
std::vector<const GraphNode *> Graph::start_points() const { | ||
std::vector<const GraphNode *> res; | ||
for (auto *node : nodes()) { | ||
res.push_back(node); | ||
} | ||
return res; | ||
} | ||
|
||
void Graph::RegisterNode(size_t key, GraphNode *node) { | ||
registry_.emplace(key, node); | ||
nodes_.emplace_back(node); | ||
} | ||
void Graph::RegisterNode(const std::string &key, GraphNode *node) { RegisterNode(std::hash<std::string>{}(key), node); } | ||
|
||
GraphNode *Graph::RetriveNode(size_t key) const { | ||
auto it = registry_.find(key); | ||
return it == registry_.end() ? nullptr : it->second; | ||
} | ||
|
||
GraphNode *Graph::RetriveNode(const std::string &key) const { return RetriveNode(std::hash<std::string>()(key)); } | ||
|
||
} // namespace common | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#include <gtest/gtest.h> | ||
#include "cinn/common/common.h" | ||
|
||
namespace cinn { | ||
namespace common { | ||
|
||
struct GraphNodeWithName : public GraphNode { | ||
explicit GraphNodeWithName(std::string name) : name(name) {} | ||
|
||
std::string name; | ||
}; | ||
|
||
TEST(Graph, basic) { | ||
// Create nodes: A, B, C, D, E | ||
Graph graph; | ||
|
||
auto* A = make_shared<GraphNodeWithName>("A"); | ||
auto* B = make_shared<GraphNodeWithName>("B"); | ||
auto* C = make_shared<GraphNodeWithName>("C"); | ||
auto* D = make_shared<GraphNodeWithName>("D"); | ||
auto* E = make_shared<GraphNodeWithName>("E"); | ||
|
||
A->LinkTo(B); | ||
A->LinkTo(C); | ||
|
||
B->LinkTo(D); | ||
C->LinkTo(D); | ||
C->LinkTo(E); | ||
|
||
LOG(INFO) << "B: " << B->inlinks().size() << " -> " << B->outlinks().size(); | ||
|
||
graph.RegisterNode("A", A); | ||
graph.RegisterNode("B", B); | ||
graph.RegisterNode("C", C); | ||
graph.RegisterNode("D", D); | ||
graph.RegisterNode("E", E); | ||
|
||
Graph::node_order_t node_order; | ||
Graph::edge_order_t edge_order; | ||
std::tie(node_order, edge_order) = graph.topological_order(); | ||
|
||
for (auto* e : edge_order) { | ||
LOG(INFO) << "visit edge: " << e->source()->As<GraphNodeWithName>()->name << " -> " | ||
<< e->sink()->As<GraphNodeWithName>()->name; | ||
} | ||
|
||
for (auto* n : node_order) { | ||
LOG(INFO) << "visit node: " << n->As<GraphNodeWithName>()->name; | ||
} | ||
} | ||
|
||
} // namespace common | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.