Skip to content

Commit

Permalink
Fixes for interpreter and ONNX export for translation (pytorch#7044)
Browse files Browse the repository at this point in the history
Fixes for interpreter and ONNX export for translation

Address comments
  • Loading branch information
James Reed authored Apr 28, 2018
1 parent fc6a846 commit 4667983
Show file tree
Hide file tree
Showing 29 changed files with 276 additions and 96 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def run(self):
"torch/csrc/jit/passes/canonicalize.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp",
"torch/csrc/jit/generated/aten_dispatch.cpp",
"torch/csrc/jit/script/lexer.cpp",
"torch/csrc/jit/script/compiler.cpp",
Expand Down
8 changes: 4 additions & 4 deletions test/expect/TestJit.test_shape_analysis_broadcast.expect
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
graph(%a : Double(3, 1, 5)
%b : Double(4, 1, 8, 5)) {
%3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5]](%a)
%4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5]](%b)
graph(%a.1 : Double(3, 1, 5)
%b.1 : Double(4, 1, 8, 5)) {
%3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5]](%a.1)
%4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5]](%b.1)
%2 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%3, %4)
return (%2);
}
6 changes: 3 additions & 3 deletions test/expect/TestScript.test_index_select_shape_prop.expect
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
graph(%x : Double(2, 2)
%y : Long(4)) {
%2 : Double(2, 4) = aten::index_select[dim=1](%x, %y)
graph(%x.1 : Double(2, 2)
%y.1 : Long(4)) {
%2 : Double(2, 4) = aten::index_select[dim=1](%x.1, %y.1)
return (%2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}]
inputs: [{name: "x.1", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}]
outputs: [{name: "6", type:Tensor dims: 2 4}]
initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]]
nodes: [
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]},
Node {type: "Gemm", inputs: [x.1,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 2 3}]
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "2", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Neg", inputs: [x], outputs: [1], attributes: []},
Node {type: "Neg", inputs: [x.1], outputs: [1], attributes: []},
Node {type: "Add", inputs: [1,1], outputs: [2], attributes: []}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 2 3}]
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "2", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Neg", inputs: [x], outputs: [1], attributes: []},
Node {type: "Neg", inputs: [x.1], outputs: [1], attributes: []},
Node {type: "Add", inputs: [1,1], outputs: [2], attributes: []}
]
}
Expand Down
4 changes: 2 additions & 2 deletions test/expect/TestScript.test_onnx_export_script_module.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 2 3}]
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "1", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Add", inputs: [x,x], outputs: [1], attributes: []}
Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []}
]
}
opset_import: [OperatorSetIdProto { domain: }],
Expand Down
10 changes: 5 additions & 5 deletions test/expect/TestScript.test_onnx_export_script_module_if.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 2 3}]
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "4", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Sum", inputs: [x], outputs: [1], attributes: []},
Node {type: "Sum", inputs: [x.1], outputs: [1], attributes: []},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []},
Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
outputs: [{name: "5", type:Tensor dims: 1 2 3}]
outputs: [{name: "5", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Neg", inputs: [x], outputs: [5], attributes: []}
Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []}
]
}

},{ name: 'else_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
outputs: [{name: "x", type:Tensor dims: 1 2 3}]
outputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 2 3}]
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
outputs: [{name: "3", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [1,2,x], outputs: [3], attributes: [{ name: 'body', type: graph, value:
Node {type: "Loop", inputs: [1,2,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "4", type:Tensor dims: },{name: "5", type:Tensor dims: 1 2 3}]
outputs: [{name: "7", type:Tensor dims: },{name: "6", type:Tensor dims: 1 2 3}]
inputs: [{name: "_", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "6", type:Tensor dims: }]
outputs: [{name: "8", type:Tensor dims: },{name: "7", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Add", inputs: [5,5], outputs: [6], attributes: []},
Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []},
Node {type: "Constant", inputs: [], outputs: [8], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}
]
}

Expand Down
19 changes: 19 additions & 0 deletions test/expect/TestScript.test_onnx_export_shape_reshape.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
ModelProto {
producer_name: "pytorch"
domain: ""
doc_string: ""
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "0", type:Tensor dims: 1 2 3}]
outputs: [{name: "4", type:Tensor dims: 5 2 3}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [3]}]},
Node {type: "Tile", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Shape", inputs: [2], outputs: [3], attributes: []},
Node {type: "Reshape", inputs: [2,3], outputs: [4], attributes: []}
]
}
opset_import: [OperatorSetIdProto { domain: }],
}
12 changes: 6 additions & 6 deletions test/expect/TestScript.test_onnx_export_speculate-f1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 10}]
inputs: [{name: "x.1", type:Tensor dims: 1 10}]
outputs: [{name: "6", type:Tensor dims: 10 1}]
initializers: []
nodes: [
Node {type: "Add", inputs: [x,x], outputs: [1], attributes: []},
Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Transpose", inputs: [1], outputs: [3], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "Transpose", inputs: [1], outputs: [4], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Expand All @@ -18,15 +18,15 @@ ModelProto {
GraphProto {
name: "torch-jit-export1"
inputs: []
outputs: [{name: "8", type:Tensor dims: 10 1}]
outputs: [{name: "8", type:Tensor dims: }]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
outputs: [{name: "3", type:Tensor dims: 10 1}]
outputs: [{name: "3", type:Tensor dims: }]
initializers: []
nodes: [

Expand All @@ -37,7 +37,7 @@ ModelProto {
GraphProto {
name: "torch-jit-export3"
inputs: []
outputs: [{name: "4", type:Tensor dims: 10 1}]
outputs: [{name: "4", type:Tensor dims: }]
initializers: []
nodes: [

Expand All @@ -52,7 +52,7 @@ ModelProto {
GraphProto {
name: "torch-jit-export4"
inputs: []
outputs: [{name: "5", type:Tensor dims: 10 1}]
outputs: [{name: "5", type:Tensor dims: }]
initializers: []
nodes: [

Expand Down
4 changes: 2 additions & 2 deletions test/expect/TestScript.test_onnx_export_speculate-f2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
outputs: [{name: "5", type:Tensor dims: 1 20}]
initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
nodes: [
Node {type: "Add", inputs: [x,x], outputs: [3], attributes: []},
Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
Expand Down
50 changes: 40 additions & 10 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,8 +2695,10 @@ def forward(self, x):
return x + x

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False))
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))

def test_onnx_export_script_python_fail(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -2717,9 +2719,11 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
example_outputs=outputs)

def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -2740,8 +2744,10 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False))
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))

def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -2763,8 +2769,10 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False))
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))

def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -2778,9 +2786,10 @@ def forward(self, x):
return x

mte = ModuleToExport()
f = io.BytesIO()
outputs = mte(torch.zeros(1, 2, 3))
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False))
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))

def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -2794,8 +2803,10 @@ def forward(self, x):
return x

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False))
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))

def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand Down Expand Up @@ -2824,7 +2835,8 @@ def forward(self, x):
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
self.assertEqual(result, reference)
self.assertExpected(torch.onnx._export_to_pretty_string(
mte, (torch.ones(2, 3),), None, verbose=False))
mte, (torch.ones(2, 3),), None, verbose=False,
example_outputs=result, propagate=True))

def test_trace_with_size(self):
@torch.jit.trace(torch.zeros(1, 1))
Expand Down Expand Up @@ -2877,19 +2889,37 @@ def transpose(x):
return x.t()

f1 = Foo(transpose)
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
f2 = Foo(linear)
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))

onnx_ish = torch.onnx._export_to_pretty_string(
f1,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False)
None, verbose=False, example_outputs=outputs_f1)
self.assertExpected(onnx_ish, subname='f1')
onnx_ish = torch.onnx._export_to_pretty_string(
f2,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False)
None, verbose=False, example_outputs=outputs_f2)
self.assertExpected(onnx_ish, subname='f2')

def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
def forward(self, x):
import torch.onnx.operators
x = x.repeat(5, 1, 1)
shape = torch.onnx.operators.shape_as_tensor(x)
reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
return reshaped

foo = torch.jit.trace(torch.zeros(1, 2, 3))(Foo())
outputs = foo(torch.zeros(1, 2, 3))
f = io.BytesIO()
s = torch.onnx._export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
example_outputs=outputs)
self.assertExpected(s)

def test_shape_analysis_loop(self):
def foo(a, b, x):
c = a
Expand Down
2 changes: 1 addition & 1 deletion tools/jit/templates/aten_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ std::unordered_map<std::string, operator_constructor> constructors = {

std::string getDescriptor(jit::Node* n) {
std::stringstream s;
JIT_ASSERT(n->kind().is_aten());
JIT_ASSERTM(n->kind().is_aten(), "%s is not an ATen op", n->kind().toDisplayString());
s << n->kind().toUnqualString();
if (tensor_vararg_fns.count(n->kind()) == 0)
s << "-" << n->inputs().size();
Expand Down
Loading

0 comments on commit 4667983

Please sign in to comment.