diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index ec7f133e736c9..e4589a17b022a 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -23,6 +23,7 @@ JIT_TEST(DifferentiateWithRequiresGrad) JIT_TEST(FromQualString) JIT_TEST(InternedStrings) JIT_TEST(IValue) +JIT_TEST(TopologicalIndex) #define JIT_TEST_CUDA(name) \ TEST(JitTest, name##_CUDA) { \ diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 780e249fe727c..16be17e7fbaed 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -16,6 +16,15 @@ } catch (const std::exception& e) { \ ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ } +#define ASSERT_ANY_THROW(statement) \ + bool threw = false; \ + try { \ + (void)statement; \ + } catch (const std::exception& e) { \ + threw = true; \ + } \ + ASSERT_TRUE(threw); \ + #endif // defined(USE_GTEST) #include "torch/csrc/autograd/variable.h" @@ -1157,6 +1166,60 @@ void testSchemaParser() { } +void testTopologicalIndex() { + { + Graph graph; + auto node1 = graph.create(prim::Undefined); + auto node2 = graph.create(prim::Undefined); + auto node3 = graph.create(prim::Undefined); + auto node4 = graph.create(prim::Undefined); + + graph.appendNode(node4); + graph.prependNode(node1); + node2->insertAfter(node1); + node3->insertBefore(node4); + + // nodes should be in numerical order + ASSERT_TRUE(node1->isBefore(node2)); + ASSERT_TRUE(node1->isBefore(node3)); + ASSERT_TRUE(node1->isBefore(node4)); + ASSERT_TRUE(node2->isAfter(node1)); + ASSERT_TRUE(node2->isBefore(node3)); + ASSERT_TRUE(node2->isBefore(node4)); + ASSERT_FALSE(node3->isBefore(node1)); + ASSERT_FALSE(node3->isBefore(node2)); + ASSERT_FALSE(node3->isAfter(node4)); + + // make sure things don't blow up on deletions + node2->destroy(); + auto node2p = graph.create(prim::Undefined); + node2p->insertAfter(node1); + ASSERT_TRUE(node1->isBefore(node2p)); + ASSERT_TRUE(node2p->isBefore(node3)); + } + { + // Induce reindexing to test that path + Graph graph; + std::map nodes; + + auto anchor = graph.create(prim::Undefined); + graph.appendNode(anchor); + // Inserting to the same place a lot will trigger reindexing + for (auto i = 0; i < 100; ++i) { + auto n = graph.create(prim::Undefined); + n->insertAfter(anchor); + nodes[i] = n; + } + + // Nodes should be in reverse order + for (auto i = 0; i < 100; ++i) { + for (auto j = i + 1; j < 100; ++j) { + ASSERT_TRUE(nodes[i]->isAfter(nodes[j])); + } + } + } +} + } // namespace } // namespace jit } // namespace torch diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 9660f843e090e..f2048defaadfd 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -423,7 +423,7 @@ install(TARGETS torch LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}" ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}") -if (BUILD_TEST AND NOT MSVC AND NOT APPLE AND NOT USE_ROCM) +if (BUILD_TEST AND NOT MSVC AND NOT USE_ROCM) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) endif() diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 37f20df69efd1..5e5b2e42f6f82 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -18,6 +18,17 @@ #include namespace torch { namespace jit { +// Constants relating to maintaining the topological index of nodes. +// +// Lower and upper bounds of the index. Inclusive range. +static constexpr topo_position_t kLowerBound = INT64_MIN; +static constexpr topo_position_t kUpperBound = INT64_MAX; +static constexpr topo_position_t kMidPoint = 0; +// How far away to space nodes that are appended to the graph. +// should be 2^n, where: +// - n is the maximum number of repeated insertions without a re-index +// - 2^(64-n) is the maximum number of appends to the end without reindex +static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */; // Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char constexpr Symbol PythonOp::Kind; @@ -460,6 +471,27 @@ void LintGraph(std::shared_ptr& graph) { graph->lint(); } +Block::Block(Graph* graph_, Node* node_) + : graph_(graph_), + output_(initOutput(graph_->create(prim::Return, 0))), + input_(graph_->create(prim::Param, 0)), + owning_node_(node_) { + graph_->all_blocks.emplace(this); + output_->owning_block_ = this; + output_->topo_position_ = kUpperBound; + input_->owning_block_ = this; + input_->topo_position_ = kLowerBound; +} + +void Block::reIndexTopology() { + auto curPos = kLowerBound; + for (auto node : nodes()) { + AT_ASSERT(curPos <= (kUpperBound - kAppendInterval)); + curPos += kAppendInterval; + node->topo_position_ = curPos; + } +} + void Block::cloneFrom(Block * src, std::function value_map) { std::unordered_map local_map; auto env = [&](Value * v) { @@ -650,6 +682,62 @@ bool Node::isNondeterministic() const { return true; } +// Assign this node a topological position, to facilitate fast isBefore() and +// isAfter() queries. Must be called right after a node is inserted into the +// node list. +// +// The basic scheme is: assign every node a position (uint64_t). The common +// case (appending to the end of the graph) is made more efficient by advancing +// a fixed interval past the previous node and placing `this` there. Otherwise, +// assign `this` a position at the midpoint between its prev() and next() +// nodes. +// +// If we ever run out of space (by, e.g. inserting too much in place), we +// reindex by spreading out all the nodes again. +void Node::assignTopoPosition() { + auto returnNode = owningBlock()->return_node(); + const auto prevPos = prev()->topo_position_; + const auto nextPos = next()->topo_position_; + + // Append to the end of the graph + if (next() == returnNode) { + if (next() == prev()) { + // the node list is empty, assign the first position + topo_position_ = kMidPoint; + return; + } + + if (prevPos >= (kUpperBound - kAppendInterval)) { + // we're running off the edge + owningBlock()->reIndexTopology(); + return; + } + + topo_position_ = prevPos + kAppendInterval; + + // Prepend to the graph + } else if (prev() == returnNode) { + // next() is the first element in the block list + if (nextPos <= (kLowerBound + kAppendInterval)) { + // we're running off the edge + owningBlock()->reIndexTopology(); + return; + } + + topo_position_ = nextPos - kAppendInterval; + + // insert between two existing nodes + } else { + const auto posBetween = prevPos + (nextPos - prevPos) / 2; + if (posBetween == prevPos) { + // There was no room + owningBlock()->reIndexTopology(); + return; + } + topo_position_ = posBetween; + } +} + Node::Node(Graph * graph_, NodeKind kind_) : kind_(kind_), graph_(graph_), @@ -774,6 +862,19 @@ Value* Node::insertOutput(size_t i) { return outputs_.at(i); } +bool Node::isBefore(Node * n) const { + if (this == n) { + return false; + } + return !isAfter(n); +} + +bool Node::isAfter(Node * n) const { + JIT_ASSERT(this->owningBlock() == n->owningBlock()); + + return this->topo_position_ > n->topo_position_; +} + Node* Node::insertBefore(Node * n) { JIT_ASSERT(n->inBlockList()); insertAfter(n->prev()); @@ -789,6 +890,7 @@ Node* Node::insertAfter(Node * n) { this->prev() = n; this->next() = next; next->prev() = this; + assignTopoPosition(); return this; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index ec7e90ed964e7..1a6599e1bcbe7 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -161,6 +161,7 @@ using pyobj_list = std::vector; template using ArrayRef = at::ArrayRef; using NodeKind = Symbol; +using topo_position_t = int64_t; struct Value { TH_DISALLOW_COPY_AND_ASSIGN(Value); @@ -278,6 +279,7 @@ struct Node : public Attributes { // the schema. // note: mutable because schema_ is effectively a cache mutable const FunctionSchema* schema_; + topo_position_t topo_position_; protected: TORCH_API Node(Graph * graph_, NodeKind kind_); //defined after graph public: @@ -469,6 +471,12 @@ struct Node : public Attributes { return {blocks_.data(), blocks_.size()}; } + // Is 'this' before 'n' in the topological order? + TORCH_API bool isBefore(Node * n) const; + + // Is 'this' after 'n' in the topological order? + TORCH_API bool isAfter(Node * n) const; + // Insert unattached 'this' node before 'n' in the topological order. // Returns this (for chaining). // @@ -607,6 +615,9 @@ struct Node : public Attributes { TORCH_API void removeFromList(); TORCH_API void lint() const; + + void assignTopoPosition(); + protected: // subclasses must override // this function is used by createClone to initialize a new version @@ -628,7 +639,7 @@ struct Block { friend struct Node; friend struct Graph; TH_DISALLOW_COPY_AND_ASSIGN(Block); - Block(Graph * graph_, Node * node_); + TORCH_API Block(Graph * graph_, Node * node_); at::ArrayRef inputs() { return input_->outputs(); } @@ -707,6 +718,8 @@ struct Block { // in src to look up its corresponding value TORCH_API void cloneFrom(Block * src, std::function value_map); private: + void reIndexTopology(); + // should only be called in the constructor Node* initOutput(Node* p) { p->next() = p; @@ -977,16 +990,6 @@ inline const Graph * Value::owningGraph() const { return node()->owningGraph(); } -inline Block::Block(Graph* graph_, Node* node_) - : graph_(graph_), - output_(initOutput(graph_->create(prim::Return, 0))), - input_(graph_->create(prim::Param, 0)), - owning_node_(node_) { - graph_->all_blocks.emplace(this); - output_->owning_block_ = this; - input_->owning_block_ = this; -} - // Helper macros for constructing switch statements over Node types // instead of heavy-weight visitors // read 'between' these defines to see how they turn into a big switch diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index b959f6eb7ff29..0090d9e99e318 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -135,13 +135,6 @@ struct Device { struct GraphFuser { Block * block; - // Used to order nodes so we always consider producer-consumer fusions - // in reverse topological order. - // If topological_index[a] > topological_index[b] then a occurs after b. - // Because nodes can be added to this graph during optimization, this mapping is not bijective. - // Newly generated nodes will copy the location where they are inserted. - std::unordered_map topological_index; - GraphFuser(Block * block) : block(block) {} @@ -279,7 +272,7 @@ struct GraphFuser { auto defining_node = producer->node(); for(auto o : defining_node->outputs()) { for(auto u : o->uses()) { - if(u.user != consumer && topological_index.at(consumer) > topological_index.at(u.user)) + if(u.user != consumer && consumer->isAfter(u.user)) return false; } } @@ -480,7 +473,6 @@ struct GraphFuser { auto group = block->owningGraph()->createFusionGroup(getDevice(n).index()); // propogate position information for the new node so we can always // have a valid mapping - topological_index[group] = topological_index[n]; group->insertBefore(n); Node * mergedNode = mergeNodeIntoGroup(group,n); getSubgraph(group).registerOutput(mergedNode->output()); @@ -492,7 +484,6 @@ struct GraphFuser { } void insertAfter(Node * n, Node * after) { n->insertAfter(after); - topological_index[n] = topological_index[after]; } void insertAt(Node ** insertion_point, Node * n) { @@ -511,7 +502,6 @@ struct GraphFuser { fused_cat->insertBefore(list_construct); fused_cat->output()->copyMetadata(consumer->output()); consumer->output()->replaceAllUsesWith(fused_cat->output()); - topological_index[fused_cat] = topological_index[list_construct]; // NB: this deletes the fused_cat node from the original graph group = createSingletonFusionGroup(fused_cat); @@ -635,12 +625,11 @@ struct GraphFuser { for (auto i : inputs) { if (i->node()->owningBlock() == block) { result.push_back(i); - JIT_ASSERT(topological_index.count(i->node()) > 0); } } // Sort in reverse topological order std::sort(result.begin(), result.end(), [&](Value * a, Value * b) { - return topological_index.at(a->node()) > topological_index.at(b->node()); + return a->node()->isAfter(b->node()); }); return result; } @@ -663,17 +652,6 @@ struct GraphFuser { auto tensors = tensorInputs(node); auto new_tensors = SymbolicVariable::broadcast_tensors(fmap(tensors)); - // Fix up topological_index - Node * unpack_node = new_tensors.at(0).value()->node(); - JIT_ASSERT(unpack_node->kind() == prim::ListUnpack); - Node * broadcast_node = unpack_node->input()->node(); - JIT_ASSERT(broadcast_node->kind() == aten::broadcast_tensors); - Node * construct_node = broadcast_node->namedInput(attr::tensors)->node(); - JIT_ASSERT(construct_node->kind() == prim::ListConstruct); - topological_index[unpack_node] = topological_index[node]; - topological_index[broadcast_node] = topological_index[node]; - topological_index[construct_node] = topological_index[node]; - // Replace tensors inputs with broadcasted values auto new_tensors_it = new_tensors.begin(); for (size_t i = 0; i < node->inputs().size(); ++i) { @@ -821,15 +799,6 @@ struct GraphFuser { } void run() { - for(auto p : block->inputs()) { - topological_index[p->node()] = 0; - } - size_t i = 1; - for(auto consumer : block->nodes()) { - topological_index[consumer] = i++; - } - topological_index[block->return_node()] = i++; - // Run the pass until no changes are made. // This is neccessary, because the algorithm can miss out on certain fusion // opportunities if ran only once. Consider this graph: