Skip to content

Commit

Permalink
Set proper scope on nodes added by JIT (pytorch#12400)
Browse files Browse the repository at this point in the history
Summary:
In order to support tensorboardX and other visualization tools, we need to make sure a non-empty scope is set on all nodes added by the JIT. This attempts to do this, but is still a WIP.

This is a new version of pytorch#10749
Pull Request resolved: pytorch#12400

Reviewed By: ezyang

Differential Revision: D10224380

Pulled By: orionr

fbshipit-source-id: d1bccd0eee9ef7c4354112c6a39a5987bfac2994
  • Loading branch information
orionr authored and facebook-github-bot committed Oct 24, 2018
1 parent cf235e0 commit 046672e
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 91 deletions.
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/passes/pretty_print.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
namespace torch { namespace jit {

// IValue -> Constant node
Value* insertConstant(Graph& g, IValue val, c10::optional<SourceRange> loc) {
Value* insertConstant(
Graph& g,
IValue val,
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
Node * n = g.create(prim::Constant);
if(val.isTensor()) {
at::Tensor ref = std::move(val).toTensor();
Expand Down Expand Up @@ -53,6 +57,8 @@ Value* insertConstant(Graph& g, IValue val, c10::optional<SourceRange> loc) {
}
if(loc)
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
if(scope)
n->setScope(*scope);
return g.insertNode(n)->output();
}

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/constants.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "torch/csrc/jit/ivalue.h"
#include "torch/csrc/jit/scope.h"
#include "torch/csrc/jit/source_range.h"
#include "torch/csrc/WindowsTorchApiMacro.h"

Expand All @@ -22,7 +23,9 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error {
TORCH_API Value* insertConstant(
Graph& g,
IValue val,
c10::optional<SourceRange> loc = c10::nullopt);
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);


//////////////////////////////////////////////////////////////////////////////////
// Helper for retrieving constants
Expand Down
32 changes: 3 additions & 29 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,33 +196,6 @@ void Graph::dumpPretty() {
PrettyPrint(std::cout, *this);
}

ScopePtr Scope::push(Symbol name) {
return c10::make_intrusive<Scope>(intrusive_from_this(), name);
}

ScopePtr Scope::getRoot() {
ScopePtr current = intrusive_from_this();
while (current->parent_) {
current = current->parent_;
}
return current;
}

std::string Scope::namesFromRoot(const std::string& separator) const {
// TODO: I think the answer is we shouldn't have used Symbol here
std::string out = this->name_.toUnqualString();
if (this->isRoot()) {
return out;
}
ScopePtr parent = this->parent_;
while (!parent->isRoot()) {
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
out = std::string(parent->name_.toUnqualString()) + separator + out;
parent = parent->parent_;
}
return out;
}

static void checkSameDevice(const Node* node) {
bool has_device = false;
int device;
Expand Down Expand Up @@ -1125,8 +1098,9 @@ Node* Graph::createClone(Node * n, std::function<Value*(Value*)> value_map, bool

Value* Graph::insertConstant(
IValue val,
c10::optional<SourceRange> loc) {
return jit::insertConstant(*this, std::move(val), loc);
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
return jit::insertConstant(*this, std::move(val), loc, scope);
}

Value* Graph::insertDummyWorld() {
Expand Down
59 changes: 3 additions & 56 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch/csrc/jit/graph_node_list.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/resource_guard.h"
#include "torch/csrc/jit/scope.h"
#include "torch/csrc/jit/source_location.h"
#include "torch/csrc/jit/source_range.h"
#include "torch/csrc/jit/constants.h"
Expand Down Expand Up @@ -97,61 +98,6 @@ struct Use {
// If you are looking for "use induced by an input", it's best to use
// findUseForInput() to get it.


// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope;
using ScopePtr = c10::intrusive_ptr<Scope>;

struct TORCH_API Scope : public c10::intrusive_ptr_target {
private:
ScopePtr parent_;
Symbol name_;
ScopePtr intrusive_from_this() {
c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
// from a raw `this` pointer
// so we need to bump the refcount
// to account for this ownership
return c10::intrusive_ptr<Scope>::reclaim(this);
}
public:
Scope() {
name_ = Symbol::scope("");
}
Scope(ScopePtr parent, Symbol name) {
name_ = name;
parent_ = parent;
}
ScopePtr push(Symbol name);

ScopePtr parent() {
if (!parent_) {
throw std::runtime_error("Cannot get parent from Scope with no parent");
}
return parent_;
}
bool isRoot() const {
return !parent_;
}
bool isBlank() const {
static const Symbol blank = Symbol::scope("");
return isRoot() && name() == blank;
}

ScopePtr getRoot();

Symbol name() const {
return name_;
}

std::string namesFromRoot(const std::string& separator="/") const;
};

// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
Expand Down Expand Up @@ -868,7 +814,8 @@ friend struct Block;

TORCH_API Value* insertConstant(
IValue val,
c10::optional<SourceRange> loc = c10::nullopt);
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);

TORCH_API Value* insertDummyWorld();

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/erase_number_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ static void EraseNumberTypesOnBlock(Block* block) {
it->output()->type()->isSubtypeOf(BoolType::get())) {
auto s = *constant_as<at::Scalar>(it->output());
WithInsertPoint guard(*it);
Value* r = block->owningGraph()->insertConstant(scalar_to_tensor(s));
Value* r = block->owningGraph()->insertConstant(
scalar_to_tensor(s), c10::nullopt, it->scope());
it->output()->replaceAllUsesWith(r);
}
} break;
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
// Unfortunately, they are on the hook for all internal nodes
// (though in practice, the types are not computed.)
outputs[i]->setType(old->type());
// Copy over source location information to all nodes created by
// the symbolic
// Copy over source location and scope information to all nodes
// created by the symbolic
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
outputs[i]->node()->setScope(node->scope());
env[old] = outputs[i];
} else {
// Null output means that the ONNX op doesn't have outputs corresponding
Expand Down
59 changes: 59 additions & 0 deletions torch/csrc/jit/scope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "ir.h"


#include "torch/csrc/jit/operator.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/passes/pretty_print.h"

#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <set>
#include <stack>
#include <sstream>
#include <algorithm>
#include <string>

namespace torch { namespace jit {

ScopePtr Scope::push(Symbol name) {
return c10::make_intrusive<Scope>(intrusive_from_this(), name);
}

ScopePtr Scope::getRoot() {
ScopePtr current = intrusive_from_this();
while (current->parent_) {
current = current->parent_;
}
return current;
}

size_t Scope::getDepth() {
size_t d = 1;
ScopePtr current = intrusive_from_this();
while (current->parent_) {
current = current->parent_;
d += 1;
}
return d;
}

std::string Scope::namesFromRoot(const std::string& separator) const {
// TODO: I think the answer is we shouldn't have used Symbol here
std::string out = this->name_.toUnqualString();
if (this->isRoot()) {
return out;
}
ScopePtr parent = this->parent_;
while (!parent->isRoot()) {
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
out = std::string(parent->name_.toUnqualString()) + separator + out;
parent = parent->parent_;
}
return out;
}

}} // namespace torch::jit
69 changes: 69 additions & 0 deletions torch/csrc/jit/scope.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/WindowsTorchApiMacro.h"
#include "c10/macros/Macros.h"

#include <memory>

namespace torch {
namespace jit {

// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope;
using ScopePtr = c10::intrusive_ptr<Scope>;

struct TORCH_API Scope : public c10::intrusive_ptr_target {
private:
ScopePtr parent_;
Symbol name_;
ScopePtr intrusive_from_this() {
c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
// from a raw `this` pointer
// so we need to bump the refcount
// to account for this ownership
return c10::intrusive_ptr<Scope>::reclaim(this);
}
public:
Scope() {
name_ = Symbol::scope("");
}
Scope(ScopePtr parent, Symbol name) {
name_ = name;
parent_ = parent;
}
ScopePtr push(Symbol name);

ScopePtr parent() {
if (!parent_) {
throw std::runtime_error("Cannot get parent from Scope with no parent");
}
return parent_;
}
bool isRoot() const {
return !parent_;
}
bool isBlank() const {
static const Symbol blank = Symbol::scope("");
return isRoot() && name() == blank;
}

ScopePtr getRoot();

size_t getDepth();

Symbol name() const {
return name_;
}

std::string namesFromRoot(const std::string& separator="/") const;
};

} // namespace jit
} // namespace torch
13 changes: 12 additions & 1 deletion torch/csrc/jit/symbolic_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@ struct SymbolicVariable {
if(g == nullptr) {
g = inputs.at(0).value()->owningGraph();
}
Node * n = g->insertNode(g->create(kind, num_outputs));
Node* n = g->insertNode(g->create(kind, num_outputs));
size_t max_depth = 0;
ScopePtr s;
for(auto n : inputs) {
size_t d = n.value()->node()->scope()->getDepth();
if(d > max_depth) {
max_depth = d;
s = n.value()->node()->scope();
}
}
n->setScope(s);

for(auto i : inputs) {
n->addInput(i.value());
}
Expand Down

0 comments on commit 046672e

Please sign in to comment.