Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#14 from Superjomn/init-graph-utils
Browse files Browse the repository at this point in the history
init graph utils
  • Loading branch information
Superjomn authored Feb 5, 2020
2 parents 06105cc + 62e46de commit 2818c10
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 25 deletions.
1 change: 1 addition & 0 deletions cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ cc_library(common

cc_test(test_pod_value SRCS pod_value_test.cc DEPS common)
cc_test(test_shared SRCS shared_test.cc DEPS common)
cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS common)
18 changes: 18 additions & 0 deletions cinn/common/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "cinn/common/domain.h"
#include "cinn/common/graph_utils.h"
#include "cinn/common/pod_value.h"
#include "cinn/common/shared.h"
#include "cinn/common/type.h"

namespace cinn {

// export some general concepts.
using common::Float;
using common::Int;
using common::Object;
using common::Shared;
using common::make_shared;

} // namespace cinn
93 changes: 91 additions & 2 deletions cinn/common/graph_utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,94 @@
#include "cinn/common/graph_utils.h"
#include <glog/logging.h>
#include <functional>
#include <set>
#include <stack>

namespace cinn {
namespace common {} // namespace common
} // namespace cinn
namespace common {

namespace {

void TopologicalSortUtil(GraphNode *node,
std::set<GraphNode *> *visited,
std::stack<GraphNode *> *stack,
std::vector<GraphNode *> *order,
std::vector<GraphEdge *> *edge_order) {
node->VisitOnce();
if (!node->visited()) return;
CHECK(!visited->count(node)) << "duplicate visit current node";

// Mark the current node as visited.
visited->insert(node);
order->push_back(node);

for (auto &e : node->outlinks()) {
if (!visited->count(e->sink())) {
edge_order->push_back(e.get());
TopologicalSortUtil(e->sink(), visited, stack, order, edge_order);
}
}

stack->push(node);
}

std::tuple<Graph::node_order_t, Graph::edge_order_t> TopologicalSort(const std::vector<GraphNode *> &nodes) {
std::stack<GraphNode *> stack;
std::set<GraphNode *> visited; // Tell whether a node is visited
std::vector<GraphNode *> order; // nodes visited in order
std::vector<GraphEdge *> edges; // edges visited in order

for (auto *node : nodes) {
if (!visited.count(node)) {
TopologicalSortUtil(node, &visited, &stack, &order, &edges);
}
}
return std::make_tuple(std::move(order), std::move(edges));
}

void DFSSortUtil(const GraphNode *node, std::vector<GraphNode *> *order) {}

std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {}

} // namespace

std::vector<const GraphNode *> Graph::nodes() const {
std::vector<const GraphNode *> res;
for (auto &s : nodes_) res.push_back(s.get());
return res;
}
std::vector<GraphNode *> Graph::nodes() {
std::vector<GraphNode *> res;
for (auto &s : nodes_) res.push_back(s.get());
return res;
}

std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topological_order() {
return TopologicalSort(nodes());
}

std::vector<GraphNode *> Graph::dfs_order() { return std::vector<GraphNode *>(); }

std::vector<const GraphNode *> Graph::start_points() const {
std::vector<const GraphNode *> res;
for (auto *node : nodes()) {
res.push_back(node);
}
return res;
}

void Graph::RegisterNode(size_t key, GraphNode *node) {
registry_.emplace(key, node);
nodes_.emplace_back(node);
}
void Graph::RegisterNode(const std::string &key, GraphNode *node) { RegisterNode(std::hash<std::string>{}(key), node); }

GraphNode *Graph::RetriveNode(size_t key) const {
auto it = registry_.find(key);
return it == registry_.end() ? nullptr : it->second;
}

GraphNode *Graph::RetriveNode(const std::string &key) const { return RetriveNode(std::hash<std::string>()(key)); }

} // namespace common
} // namespace cinn
93 changes: 83 additions & 10 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
@@ -1,50 +1,123 @@
#pragma once
//! \file This file contains the utilities of graph.

#include <glog/logging.h>
#include <list>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/common/object.h"
#include "cinn/common/shared.h"

namespace cinn {
namespace common {

class GraphNode;

/**
* Edge in the graph, which can hold some attributes.
*/
class GraphEdge : public Object {
public:
GraphEdge(GraphNode* source, GraphNode* sink) : source_(source), sink_(sink) {}

GraphNode* source() const { return source_; }
GraphNode* sink() const { return sink_; }
const char* type_info() const override { return "graph_edge"; }

private:
//! Source of this edge.
GraphNode* source_{};
//! End of this edge.
GraphNode* sink_{};
};

/**
* @brief The base class of all node of graph.
* This is used to normalize and share the graph operations.
*/
class GraphNode {
class GraphNode : public Object {
public:
GraphNode() = default;
//! Links from this to other.
template <typename EdgeT = GraphEdge>
std::tuple<EdgeT*, EdgeT*> LinkTo(GraphNode* other) {
CHECK_NE(other, this) << "cannot link to itself";
other->inlinks_.push_back(make_shared<GraphEdge>(other, this));
outlinks_.push_back(make_shared<GraphEdge>(this, other));
return std::make_tuple(static_cast<EdgeT*>(outlinks_.back().get()),
static_cast<EdgeT*>(other->inlinks().back().get()));
}

//! Get the input links of the node.
virtual std::list<GraphNode*> inlinks() const { return inlinks_; }
virtual std::list<Shared<GraphEdge>> inlinks() const { return inlinks_; }
//! Get the output links of the node.
virtual std::list<GraphNode*> outlinks() const { return outlinks_; }
virtual std::list<Shared<GraphEdge>> outlinks() const { return outlinks_; }
//! Get a derived pointer.
template <typename Derived>
Derived* As() {
static_assert(std::is_base_of<GraphNode, Derived>::value);
return static_cast<Derived*>(this);
}

//! Reset graph traversal meta info.
void ResetVisitMeta() { visited_time_ = 0; }
void VisitOnce() const { visited_time_++; }
bool visited() const { return inlinks_.empty() || visited_time_ == inlinks_.size(); }

const char* type_info() const override { return "graph_node"; }

GraphNode() = default;

protected:
//! The input links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid cycle reference.
std::list<GraphNode*> inlinks_;
std::list<common::Shared<GraphEdge>> inlinks_;
//! The output links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid cycle reference.
std::list<GraphNode*> outlinks_;
std::list<common::Shared<GraphEdge>> outlinks_;

mutable int visited_time_{};
};

/**
* @brief The base class of all the graph.
*/
class Graph {
public:
size_t num_nodes() const { return nodes_.size(); }
using node_order_t = std::vector<GraphNode*>;
using edge_order_t = std::vector<GraphEdge*>;

//! Add a node to the graph.
//! @{
void RegisterNode(size_t key, GraphNode* node);
void RegisterNode(const std::string& key, GraphNode* node);
//! @}

//! Return the graph's topological order.
std::vector<GraphNode*> topological_order() const;
//! Retrive a node.
//! @{
GraphNode* RetriveNode(size_t key) const;
GraphNode* RetriveNode(const std::string& key) const;
//! @}

//! Get the start point of the graph (the nodes those has no inlinks).
std::vector<const GraphNode*> start_points() const;

//! Return the graph's nodes and edges(visited) in topological order.
std::tuple<std::vector<GraphNode*>, std::vector<GraphEdge*>> topological_order();

//! Return the graph's DFS order.
std::vector<GraphNode*> dfs_order() const;
std::vector<GraphNode*> dfs_order();

std::vector<const GraphNode*> nodes() const;
std::vector<GraphNode*> nodes();

size_t num_nodes() const { return nodes_.size(); }

protected:
//! A lookup table that map from hash key to graph node, note that it doesn't own the graph node.
std::unordered_map<size_t, GraphNode*> registry_;
//! A list owns the graph nodes.
std::vector<Shared<GraphNode>> nodes_;
};

Expand Down
53 changes: 53 additions & 0 deletions cinn/common/graph_utils_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <gtest/gtest.h>
#include "cinn/common/common.h"

namespace cinn {
namespace common {

struct GraphNodeWithName : public GraphNode {
explicit GraphNodeWithName(std::string name) : name(name) {}

std::string name;
};

TEST(Graph, basic) {
// Create nodes: A, B, C, D, E
Graph graph;

auto* A = make_shared<GraphNodeWithName>("A");
auto* B = make_shared<GraphNodeWithName>("B");
auto* C = make_shared<GraphNodeWithName>("C");
auto* D = make_shared<GraphNodeWithName>("D");
auto* E = make_shared<GraphNodeWithName>("E");

A->LinkTo(B);
A->LinkTo(C);

B->LinkTo(D);
C->LinkTo(D);
C->LinkTo(E);

LOG(INFO) << "B: " << B->inlinks().size() << " -> " << B->outlinks().size();

graph.RegisterNode("A", A);
graph.RegisterNode("B", B);
graph.RegisterNode("C", C);
graph.RegisterNode("D", D);
graph.RegisterNode("E", E);

Graph::node_order_t node_order;
Graph::edge_order_t edge_order;
std::tie(node_order, edge_order) = graph.topological_order();

for (auto* e : edge_order) {
LOG(INFO) << "visit edge: " << e->source()->As<GraphNodeWithName>()->name << " -> "
<< e->sink()->As<GraphNodeWithName>()->name;
}

for (auto* n : node_order) {
LOG(INFO) << "visit node: " << n->As<GraphNodeWithName>()->name;
}
}

} // namespace common
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/poly/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ cc_library(poly SRCS
DEPS common)

cc_test(test_poly_element SRCS element_test.cc DEPS poly)
cc_test(test_schedule SRCS schedule_test.cc DEPS poly)
3 changes: 2 additions & 1 deletion cinn/poly/isl_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cinn/poly/isl_utils.h"
#include <isl/cpp.h>

namespace cinn {
namespace poly {
Expand All @@ -11,7 +12,7 @@ std::vector<std::string> GetDimNames(const isl::set &x) {
return res;
}

std::vector<std::string> poly::GetDimNames(const isl::map &x, isl_dim_type dim_type) {
std::vector<std::string> GetDimNames(const isl::map &x, isl_dim_type dim_type) {
std::vector<std::string> res;
for (int i = 0; i < isl_map_dim(x.get(), dim_type); i++) {
res.push_back(isl_map_get_dim_name(x.get(), dim_type, i));
Expand Down
41 changes: 33 additions & 8 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
#include "cinn/poly/schedule.h"
#include "cinn/utils/string.h"

#include <sstream>
#include "cinn/common/graph_utils.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace poly {

/**
* Node in the schedule graph.
*/
struct ScheduleGraphNode : public common::GraphNode {
TimeSchedule time_schedule;

explicit ScheduleGraphNode(const std::vector<std::string> &dims) : time_schedule(dims) {}
};

struct ScheduleGraphEdge : public common::GraphEdge {
ScheduleGraphEdge(common::GraphNode *a, common::GraphNode *b) : common::GraphEdge(a, b) {}

//! Dependency level.
int level;
};

std::string TimeSchedule::__str__() const {
CHECK(!time_dims.empty());

Expand Down Expand Up @@ -36,22 +52,31 @@ void Scheduler::RegisterElement(const Element &x) {
// Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is
// like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range.
TimeSchedule schedule(GetDimNames(x.schedule(), isl_dim_out));
schedule_.emplace(x.id(), std::move(schedule));
schedule_graph_.RegisterNode(x.id(), common::make_shared<ScheduleGraphNode>(GetDimNames(x.schedule(), isl_dim_out)));
}

void Scheduler::FinalizeRegistration() {
CHECK_GT(space_size_, 0) << "No valid dimension is collected, use RegisterElement to collect some elements";
CHECK(!schedule_.empty()) << "No valid dimension is collected, use RegisterElement to collect some elements";
registration_finalized_ = false;
CHECK(!schedule_graph_.nodes().empty())
<< "No node is registered to the graph, use RegisterElement to collect some elements";
registration_finalized_ = true;

for (auto &item : schedule_) {
item.second.ResizeTimeSpace(space_size_);
for (auto &item : schedule_graph_.nodes()) {
item->As<ScheduleGraphNode>()->time_schedule.ResizeTimeSpace(space_size_);
}
}

Scheduler &Scheduler::After(const Element &a, const Element &b, int level) {
CHECK_LT(level, space_size_);
depend_flow_graph_[b.id()].depend_level[a.id()] = level;
auto *a_node = schedule_graph_.RetriveNode(a.id())->As<ScheduleGraphNode>();
auto *b_node = schedule_graph_.RetriveNode(a.id())->As<ScheduleGraphNode>();
CHECK(a_node) << "no node called " << a.id() << " registered in the graph";
CHECK(b_node) << "no node called " << b.id() << " registered in the graph";

common::GraphEdge *a_edge, *b_edge;
std::tie(a_edge, b_edge) = a_node->LinkTo<ScheduleGraphEdge>(b_node);
a_edge->As<ScheduleGraphEdge>()->level = level;
b_edge->As<ScheduleGraphEdge>()->level = level;
return *this;
}

Expand Down
Loading

0 comments on commit 2818c10

Please sign in to comment.