Skip to content

Commit

Permalink
Index to track topological order within a block (pytorch#12748)
Browse files Browse the repository at this point in the history
Summary:
Simple index to track topological order. Replaced `topological_index` in the graph fuser with this.
Pull Request resolved: pytorch#12748

Differential Revision: D10502983

Pulled By: michaelsuo

fbshipit-source-id: 5855e5add3c9742fe07e86d854260baa34beab3b
  • Loading branch information
suo authored and facebook-github-bot committed Oct 23, 2018
1 parent dd823cc commit 27af265
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 45 deletions.
1 change: 1 addition & 0 deletions test/cpp/jit/gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ JIT_TEST(DifferentiateWithRequiresGrad)
JIT_TEST(FromQualString)
JIT_TEST(InternedStrings)
JIT_TEST(IValue)
JIT_TEST(TopologicalIndex)

#define JIT_TEST_CUDA(name) \
TEST(JitTest, name##_CUDA) { \
Expand Down
63 changes: 63 additions & 0 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
#define ASSERT_ANY_THROW(statement) \
bool threw = false; \
try { \
(void)statement; \
} catch (const std::exception& e) { \
threw = true; \
} \
ASSERT_TRUE(threw); \

#endif // defined(USE_GTEST)

#include "torch/csrc/autograd/variable.h"
Expand Down Expand Up @@ -1157,6 +1166,60 @@ void testSchemaParser() {

}

void testTopologicalIndex() {
{
Graph graph;
auto node1 = graph.create(prim::Undefined);
auto node2 = graph.create(prim::Undefined);
auto node3 = graph.create(prim::Undefined);
auto node4 = graph.create(prim::Undefined);

graph.appendNode(node4);
graph.prependNode(node1);
node2->insertAfter(node1);
node3->insertBefore(node4);

// nodes should be in numerical order
ASSERT_TRUE(node1->isBefore(node2));
ASSERT_TRUE(node1->isBefore(node3));
ASSERT_TRUE(node1->isBefore(node4));
ASSERT_TRUE(node2->isAfter(node1));
ASSERT_TRUE(node2->isBefore(node3));
ASSERT_TRUE(node2->isBefore(node4));
ASSERT_FALSE(node3->isBefore(node1));
ASSERT_FALSE(node3->isBefore(node2));
ASSERT_FALSE(node3->isAfter(node4));

// make sure things don't blow up on deletions
node2->destroy();
auto node2p = graph.create(prim::Undefined);
node2p->insertAfter(node1);
ASSERT_TRUE(node1->isBefore(node2p));
ASSERT_TRUE(node2p->isBefore(node3));
}
{
// Induce reindexing to test that path
Graph graph;
std::map<size_t, Node*> nodes;

auto anchor = graph.create(prim::Undefined);
graph.appendNode(anchor);
// Inserting to the same place a lot will trigger reindexing
for (auto i = 0; i < 100; ++i) {
auto n = graph.create(prim::Undefined);
n->insertAfter(anchor);
nodes[i] = n;
}

// Nodes should be in reverse order
for (auto i = 0; i < 100; ++i) {
for (auto j = i + 1; j < 100; ++j) {
ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
}
}
}
}

} // namespace
} // namespace jit
} // namespace torch
2 changes: 1 addition & 1 deletion torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ install(TARGETS torch
LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}"
ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}")

if (BUILD_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM)
if (BUILD_TEST AND NOT MSVC AND NOT USE_ROCM)
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
endif()

Expand Down
102 changes: 102 additions & 0 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
#include <string>

namespace torch { namespace jit {
// Constants relating to maintaining the topological index of nodes.
//
// Lower and upper bounds of the index. Inclusive range.
static constexpr topo_position_t kLowerBound = INT64_MIN;
static constexpr topo_position_t kUpperBound = INT64_MAX;
static constexpr topo_position_t kMidPoint = 0;
// How far away to space nodes that are appended to the graph.
// should be 2^n, where:
// - n is the maximum number of repeated insertions without a re-index
// - 2^(64-n) is the maximum number of appends to the end without reindex
static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;

// Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
constexpr Symbol PythonOp::Kind;
Expand Down Expand Up @@ -460,6 +471,27 @@ void LintGraph(std::shared_ptr<Graph>& graph) {
graph->lint();
}

Block::Block(Graph* graph_, Node* node_)
: graph_(graph_),
output_(initOutput(graph_->create(prim::Return, 0))),
input_(graph_->create(prim::Param, 0)),
owning_node_(node_) {
graph_->all_blocks.emplace(this);
output_->owning_block_ = this;
output_->topo_position_ = kUpperBound;
input_->owning_block_ = this;
input_->topo_position_ = kLowerBound;
}

void Block::reIndexTopology() {
auto curPos = kLowerBound;
for (auto node : nodes()) {
AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
curPos += kAppendInterval;
node->topo_position_ = curPos;
}
}

void Block::cloneFrom(Block * src, std::function<Value*(Value*)> value_map) {
std::unordered_map<Value*, Value*> local_map;
auto env = [&](Value * v) {
Expand Down Expand Up @@ -650,6 +682,62 @@ bool Node::isNondeterministic() const {
return true;
}

// Assign this node a topological position, to facilitate fast isBefore() and
// isAfter() queries. Must be called right after a node is inserted into the
// node list.
//
// The basic scheme is: assign every node a position (uint64_t). The common
// case (appending to the end of the graph) is made more efficient by advancing
// a fixed interval past the previous node and placing `this` there. Otherwise,
// assign `this` a position at the midpoint between its prev() and next()
// nodes.
//
// If we ever run out of space (by, e.g. inserting too much in place), we
// reindex by spreading out all the nodes again.
void Node::assignTopoPosition() {
auto returnNode = owningBlock()->return_node();
const auto prevPos = prev()->topo_position_;
const auto nextPos = next()->topo_position_;

// Append to the end of the graph
if (next() == returnNode) {
if (next() == prev()) {
// the node list is empty, assign the first position
topo_position_ = kMidPoint;
return;
}

if (prevPos >= (kUpperBound - kAppendInterval)) {
// we're running off the edge
owningBlock()->reIndexTopology();
return;
}

topo_position_ = prevPos + kAppendInterval;

// Prepend to the graph
} else if (prev() == returnNode) {
// next() is the first element in the block list
if (nextPos <= (kLowerBound + kAppendInterval)) {
// we're running off the edge
owningBlock()->reIndexTopology();
return;
}

topo_position_ = nextPos - kAppendInterval;

// insert between two existing nodes
} else {
const auto posBetween = prevPos + (nextPos - prevPos) / 2;
if (posBetween == prevPos) {
// There was no room
owningBlock()->reIndexTopology();
return;
}
topo_position_ = posBetween;
}
}

Node::Node(Graph * graph_, NodeKind kind_) :
kind_(kind_),
graph_(graph_),
Expand Down Expand Up @@ -774,6 +862,19 @@ Value* Node::insertOutput(size_t i) {
return outputs_.at(i);
}

bool Node::isBefore(Node * n) const {
if (this == n) {
return false;
}
return !isAfter(n);
}

bool Node::isAfter(Node * n) const {
JIT_ASSERT(this->owningBlock() == n->owningBlock());

return this->topo_position_ > n->topo_position_;
}

Node* Node::insertBefore(Node * n) {
JIT_ASSERT(n->inBlockList());
insertAfter(n->prev());
Expand All @@ -789,6 +890,7 @@ Node* Node::insertAfter(Node * n) {
this->prev() = n;
this->next() = next;
next->prev() = this;
assignTopoPosition();
return this;
}

Expand Down
25 changes: 14 additions & 11 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ using pyobj_list = std::vector<THPObjectPtr>;
template<typename T>
using ArrayRef = at::ArrayRef<T>;
using NodeKind = Symbol;
using topo_position_t = int64_t;

struct Value {
TH_DISALLOW_COPY_AND_ASSIGN(Value);
Expand Down Expand Up @@ -278,6 +279,7 @@ struct Node : public Attributes<Node> {
// the schema.
// note: mutable because schema_ is effectively a cache
mutable const FunctionSchema* schema_;
topo_position_t topo_position_;
protected:
TORCH_API Node(Graph * graph_, NodeKind kind_); //defined after graph
public:
Expand Down Expand Up @@ -469,6 +471,12 @@ struct Node : public Attributes<Node> {
return {blocks_.data(), blocks_.size()};
}

// Is 'this' before 'n' in the topological order?
TORCH_API bool isBefore(Node * n) const;

// Is 'this' after 'n' in the topological order?
TORCH_API bool isAfter(Node * n) const;

// Insert unattached 'this' node before 'n' in the topological order.
// Returns this (for chaining).
//
Expand Down Expand Up @@ -607,6 +615,9 @@ struct Node : public Attributes<Node> {

TORCH_API void removeFromList();
TORCH_API void lint() const;

void assignTopoPosition();

protected:
// subclasses must override
// this function is used by createClone to initialize a new version
Expand All @@ -628,7 +639,7 @@ struct Block {
friend struct Node;
friend struct Graph;
TH_DISALLOW_COPY_AND_ASSIGN(Block);
Block(Graph * graph_, Node * node_);
TORCH_API Block(Graph * graph_, Node * node_);
at::ArrayRef<Value*> inputs() {
return input_->outputs();
}
Expand Down Expand Up @@ -707,6 +718,8 @@ struct Block {
// in src to look up its corresponding value
TORCH_API void cloneFrom(Block * src, std::function<Value*(Value*)> value_map);
private:
void reIndexTopology();

// should only be called in the constructor
Node* initOutput(Node* p) {
p->next() = p;
Expand Down Expand Up @@ -977,16 +990,6 @@ inline const Graph * Value::owningGraph() const {
return node()->owningGraph();
}

inline Block::Block(Graph* graph_, Node* node_)
: graph_(graph_),
output_(initOutput(graph_->create(prim::Return, 0))),
input_(graph_->create(prim::Param, 0)),
owning_node_(node_) {
graph_->all_blocks.emplace(this);
output_->owning_block_ = this;
input_->owning_block_ = this;
}

// Helper macros for constructing switch statements over Node types
// instead of heavy-weight visitors
// read 'between' these defines to see how they turn into a big switch
Expand Down
35 changes: 2 additions & 33 deletions torch/csrc/jit/passes/graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,6 @@ struct Device {
struct GraphFuser {
Block * block;

// Used to order nodes so we always consider producer-consumer fusions
// in reverse topological order.
// If topological_index[a] > topological_index[b] then a occurs after b.
// Because nodes can be added to this graph during optimization, this mapping is not bijective.
// Newly generated nodes will copy the location where they are inserted.
std::unordered_map<Node*,size_t> topological_index;

GraphFuser(Block * block)
: block(block) {}

Expand Down Expand Up @@ -279,7 +272,7 @@ struct GraphFuser {
auto defining_node = producer->node();
for(auto o : defining_node->outputs()) {
for(auto u : o->uses()) {
if(u.user != consumer && topological_index.at(consumer) > topological_index.at(u.user))
if(u.user != consumer && consumer->isAfter(u.user))
return false;
}
}
Expand Down Expand Up @@ -480,7 +473,6 @@ struct GraphFuser {
auto group = block->owningGraph()->createFusionGroup(getDevice(n).index());
// propogate position information for the new node so we can always
// have a valid mapping
topological_index[group] = topological_index[n];
group->insertBefore(n);
Node * mergedNode = mergeNodeIntoGroup(group,n);
getSubgraph(group).registerOutput(mergedNode->output());
Expand All @@ -492,7 +484,6 @@ struct GraphFuser {
}
void insertAfter(Node * n, Node * after) {
n->insertAfter(after);
topological_index[n] = topological_index[after];
}

void insertAt(Node ** insertion_point, Node * n) {
Expand All @@ -511,7 +502,6 @@ struct GraphFuser {
fused_cat->insertBefore(list_construct);
fused_cat->output()->copyMetadata(consumer->output());
consumer->output()->replaceAllUsesWith(fused_cat->output());
topological_index[fused_cat] = topological_index[list_construct];

// NB: this deletes the fused_cat node from the original graph
group = createSingletonFusionGroup(fused_cat);
Expand Down Expand Up @@ -635,12 +625,11 @@ struct GraphFuser {
for (auto i : inputs) {
if (i->node()->owningBlock() == block) {
result.push_back(i);
JIT_ASSERT(topological_index.count(i->node()) > 0);
}
}
// Sort in reverse topological order
std::sort(result.begin(), result.end(), [&](Value * a, Value * b) {
return topological_index.at(a->node()) > topological_index.at(b->node());
return a->node()->isAfter(b->node());
});
return result;
}
Expand All @@ -663,17 +652,6 @@ struct GraphFuser {
auto tensors = tensorInputs(node);
auto new_tensors = SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));

// Fix up topological_index
Node * unpack_node = new_tensors.at(0).value()->node();
JIT_ASSERT(unpack_node->kind() == prim::ListUnpack);
Node * broadcast_node = unpack_node->input()->node();
JIT_ASSERT(broadcast_node->kind() == aten::broadcast_tensors);
Node * construct_node = broadcast_node->namedInput(attr::tensors)->node();
JIT_ASSERT(construct_node->kind() == prim::ListConstruct);
topological_index[unpack_node] = topological_index[node];
topological_index[broadcast_node] = topological_index[node];
topological_index[construct_node] = topological_index[node];

// Replace tensors inputs with broadcasted values
auto new_tensors_it = new_tensors.begin();
for (size_t i = 0; i < node->inputs().size(); ++i) {
Expand Down Expand Up @@ -821,15 +799,6 @@ struct GraphFuser {
}

void run() {
for(auto p : block->inputs()) {
topological_index[p->node()] = 0;
}
size_t i = 1;
for(auto consumer : block->nodes()) {
topological_index[consumer] = i++;
}
topological_index[block->return_node()] = i++;

// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
// opportunities if ran only once. Consider this graph:
Expand Down

0 comments on commit 27af265

Please sign in to comment.