From 4667983f0f35a7336f2556ce9eeb6511f6bf3ac0 Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 27 Apr 2018 22:23:57 -0700 Subject: [PATCH] Fixes for interpreter and ONNX export for translation (#7044) Fixes for interpreter and ONNX export for translation Address comments --- setup.py | 1 + ...stJit.test_shape_analysis_broadcast.expect | 8 +- ...Script.test_index_select_shape_prop.expect | 6 +- ...st_onnx_export_script_inline_params.expect | 4 +- ...st_onnx_export_script_inline_script.expect | 4 +- ...est_onnx_export_script_inline_trace.expect | 4 +- ...ript.test_onnx_export_script_module.expect | 4 +- ...t.test_onnx_export_script_module_if.expect | 10 +-- ...test_onnx_export_script_module_loop.expect | 12 +-- ...ript.test_onnx_export_shape_reshape.expect | 19 +++++ ...cript.test_onnx_export_speculate-f1.expect | 12 +-- ...cript.test_onnx_export_speculate-f2.expect | 4 +- test/test_jit.py | 50 ++++++++++--- tools/jit/templates/aten_dispatch.cpp | 2 +- torch/csrc/jit/export.cpp | 74 ++++++++++--------- torch/csrc/jit/init.cpp | 4 +- torch/csrc/jit/interned_strings.h | 3 +- torch/csrc/jit/interpreter.cpp | 29 ++++++++ torch/csrc/jit/passes/onnx.cpp | 2 + .../csrc/jit/passes/onnx/fixup_onnx_loop.cpp | 22 ++++++ torch/csrc/jit/passes/onnx/fixup_onnx_loop.h | 9 +++ torch/csrc/jit/passes/shape_analysis.cpp | 12 +++ torch/csrc/jit/python_ir.cpp | 3 +- torch/csrc/jit/script/compiler.cpp | 2 + torch/csrc/jit/script/init.cpp | 1 + torch/csrc/jit/script/module.h | 30 +++++++- torch/onnx/operators.py | 1 + torch/onnx/symbolic.py | 16 ++-- torch/onnx/utils.py | 24 ++++-- 29 files changed, 276 insertions(+), 96 deletions(-) create mode 100644 test/expect/TestScript.test_onnx_export_shape_reshape.expect create mode 100644 torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp create mode 100644 torch/csrc/jit/passes/onnx/fixup_onnx_loop.h diff --git a/setup.py b/setup.py index cdfb797e6b366..fa658b4c4efe2 100644 --- a/setup.py +++ b/setup.py @@ -651,6 +651,7 @@ def run(self): "torch/csrc/jit/passes/canonicalize.cpp", "torch/csrc/jit/passes/batch_mm.cpp", "torch/csrc/jit/passes/onnx/peephole.cpp", + "torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp", "torch/csrc/jit/generated/aten_dispatch.cpp", "torch/csrc/jit/script/lexer.cpp", "torch/csrc/jit/script/compiler.cpp", diff --git a/test/expect/TestJit.test_shape_analysis_broadcast.expect b/test/expect/TestJit.test_shape_analysis_broadcast.expect index 9af3289d267fc..868c268f24950 100644 --- a/test/expect/TestJit.test_shape_analysis_broadcast.expect +++ b/test/expect/TestJit.test_shape_analysis_broadcast.expect @@ -1,7 +1,7 @@ -graph(%a : Double(3, 1, 5) - %b : Double(4, 1, 8, 5)) { - %3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5]](%a) - %4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5]](%b) +graph(%a.1 : Double(3, 1, 5) + %b.1 : Double(4, 1, 8, 5)) { + %3 : Double(4!, 3!, 8!, 5) = aten::expand[size=[4, 3, 8, 5]](%a.1) + %4 : Double(4!, 3!, 8, 5) = aten::expand[size=[4, 3, 8, 5]](%b.1) %2 : Double(4, 3, 8, 5) = aten::add[alpha={1}](%3, %4) return (%2); } diff --git a/test/expect/TestScript.test_index_select_shape_prop.expect b/test/expect/TestScript.test_index_select_shape_prop.expect index 32a9d7744e52c..c3c114063bdec 100644 --- a/test/expect/TestScript.test_index_select_shape_prop.expect +++ b/test/expect/TestScript.test_index_select_shape_prop.expect @@ -1,5 +1,5 @@ -graph(%x : Double(2, 2) - %y : Long(4)) { - %2 : Double(2, 4) = aten::index_select[dim=1](%x, %y) +graph(%x.1 : Double(2, 2) + %y.1 : Long(4)) { + %2 : Double(2, 4) = aten::index_select[dim=1](%x.1, %y.1) return (%2); } diff --git a/test/expect/TestScript.test_onnx_export_script_inline_params.expect b/test/expect/TestScript.test_onnx_export_script_inline_params.expect index 2d6f3150b15ba..54853d73e061f 100644 --- a/test/expect/TestScript.test_onnx_export_script_inline_params.expect +++ b/test/expect/TestScript.test_onnx_export_script_inline_params.expect @@ -5,12 +5,12 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}] + inputs: [{name: "x.1", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}] outputs: [{name: "6", type:Tensor dims: 2 4}] initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]] nodes: [ Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, - Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]}, + Node {type: "Gemm", inputs: [x.1,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]}, Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'broadcast', type: int, value: 1}]} ] diff --git a/test/expect/TestScript.test_onnx_export_script_inline_script.expect b/test/expect/TestScript.test_onnx_export_script_inline_script.expect index dc43cdc7747bb..db62b86cd2654 100644 --- a/test/expect/TestScript.test_onnx_export_script_inline_script.expect +++ b/test/expect/TestScript.test_onnx_export_script_inline_script.expect @@ -5,11 +5,11 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 2 3}] + inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] outputs: [{name: "2", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ - Node {type: "Neg", inputs: [x], outputs: [1], attributes: []}, + Node {type: "Neg", inputs: [x.1], outputs: [1], attributes: []}, Node {type: "Add", inputs: [1,1], outputs: [2], attributes: []} ] } diff --git a/test/expect/TestScript.test_onnx_export_script_inline_trace.expect b/test/expect/TestScript.test_onnx_export_script_inline_trace.expect index dc43cdc7747bb..db62b86cd2654 100644 --- a/test/expect/TestScript.test_onnx_export_script_inline_trace.expect +++ b/test/expect/TestScript.test_onnx_export_script_inline_trace.expect @@ -5,11 +5,11 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 2 3}] + inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] outputs: [{name: "2", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ - Node {type: "Neg", inputs: [x], outputs: [1], attributes: []}, + Node {type: "Neg", inputs: [x.1], outputs: [1], attributes: []}, Node {type: "Add", inputs: [1,1], outputs: [2], attributes: []} ] } diff --git a/test/expect/TestScript.test_onnx_export_script_module.expect b/test/expect/TestScript.test_onnx_export_script_module.expect index ea3d1496c34b3..aa95eab88f810 100644 --- a/test/expect/TestScript.test_onnx_export_script_module.expect +++ b/test/expect/TestScript.test_onnx_export_script_module.expect @@ -5,11 +5,11 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 2 3}] + inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] outputs: [{name: "1", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ - Node {type: "Add", inputs: [x,x], outputs: [1], attributes: []} + Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []} ] } opset_import: [OperatorSetIdProto { domain: }], diff --git a/test/expect/TestScript.test_onnx_export_script_module_if.expect b/test/expect/TestScript.test_onnx_export_script_module_if.expect index 9af338c40b10c..ebece547e5fc5 100644 --- a/test/expect/TestScript.test_onnx_export_script_module_if.expect +++ b/test/expect/TestScript.test_onnx_export_script_module_if.expect @@ -5,21 +5,21 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 2 3}] + inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] outputs: [{name: "4", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ - Node {type: "Sum", inputs: [x], outputs: [1], attributes: []}, + Node {type: "Sum", inputs: [x.1], outputs: [1], attributes: []}, Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []}, Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export1" inputs: [] - outputs: [{name: "5", type:Tensor dims: 1 2 3}] + outputs: [{name: "5", type:Tensor dims: }] initializers: [] nodes: [ - Node {type: "Neg", inputs: [x], outputs: [5], attributes: []} + Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []} ] } @@ -27,7 +27,7 @@ ModelProto { GraphProto { name: "torch-jit-export2" inputs: [] - outputs: [{name: "x", type:Tensor dims: 1 2 3}] + outputs: [{name: "x.1", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ diff --git a/test/expect/TestScript.test_onnx_export_script_module_loop.expect b/test/expect/TestScript.test_onnx_export_script_module_loop.expect index e300e06b2fa5d..c3b982d1f6dfe 100644 --- a/test/expect/TestScript.test_onnx_export_script_module_loop.expect +++ b/test/expect/TestScript.test_onnx_export_script_module_loop.expect @@ -5,21 +5,21 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 2 3}] + inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] outputs: [{name: "3", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "Loop", inputs: [1,2,x], outputs: [3], attributes: [{ name: 'body', type: graph, value: + Node {type: "Loop", inputs: [1,2,x.1], outputs: [3], attributes: [{ name: 'body', type: graph, value: GraphProto { name: "torch-jit-export1" - inputs: [{name: "4", type:Tensor dims: },{name: "5", type:Tensor dims: 1 2 3}] - outputs: [{name: "7", type:Tensor dims: },{name: "6", type:Tensor dims: 1 2 3}] + inputs: [{name: "_", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "6", type:Tensor dims: }] + outputs: [{name: "8", type:Tensor dims: },{name: "7", type:Tensor dims: }] initializers: [] nodes: [ - Node {type: "Add", inputs: [5,5], outputs: [6], attributes: []}, - Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]} + Node {type: "Add", inputs: [6,6], outputs: [7], attributes: []}, + Node {type: "Constant", inputs: [], outputs: [8], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]} ] } diff --git a/test/expect/TestScript.test_onnx_export_shape_reshape.expect b/test/expect/TestScript.test_onnx_export_shape_reshape.expect new file mode 100644 index 0000000000000..05d657fdcf71d --- /dev/null +++ b/test/expect/TestScript.test_onnx_export_shape_reshape.expect @@ -0,0 +1,19 @@ +ModelProto { + producer_name: "pytorch" + domain: "" + doc_string: "" + graph: + GraphProto { + name: "torch-jit-export" + inputs: [{name: "0", type:Tensor dims: 1 2 3}] + outputs: [{name: "4", type:Tensor dims: 5 2 3}] + initializers: [] + nodes: [ + Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [3]}]}, + Node {type: "Tile", inputs: [0,1], outputs: [2], attributes: []}, + Node {type: "Shape", inputs: [2], outputs: [3], attributes: []}, + Node {type: "Reshape", inputs: [2,3], outputs: [4], attributes: []} + ] + } + opset_import: [OperatorSetIdProto { domain: }], +} diff --git a/test/expect/TestScript.test_onnx_export_speculate-f1.expect b/test/expect/TestScript.test_onnx_export_speculate-f1.expect index 449f0e417af5c..47f55eb41ccda 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f1.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f1.expect @@ -5,11 +5,11 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 10}] + inputs: [{name: "x.1", type:Tensor dims: 1 10}] outputs: [{name: "6", type:Tensor dims: 10 1}] initializers: [] nodes: [ - Node {type: "Add", inputs: [x,x], outputs: [1], attributes: []}, + Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []}, Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Transpose", inputs: [1], outputs: [3], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, Node {type: "Transpose", inputs: [1], outputs: [4], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, @@ -18,7 +18,7 @@ ModelProto { GraphProto { name: "torch-jit-export1" inputs: [] - outputs: [{name: "8", type:Tensor dims: 10 1}] + outputs: [{name: "8", type:Tensor dims: }] initializers: [] nodes: [ Node {type: "Constant", inputs: [], outputs: [7], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, @@ -26,7 +26,7 @@ ModelProto { GraphProto { name: "torch-jit-export2" inputs: [] - outputs: [{name: "3", type:Tensor dims: 10 1}] + outputs: [{name: "3", type:Tensor dims: }] initializers: [] nodes: [ @@ -37,7 +37,7 @@ ModelProto { GraphProto { name: "torch-jit-export3" inputs: [] - outputs: [{name: "4", type:Tensor dims: 10 1}] + outputs: [{name: "4", type:Tensor dims: }] initializers: [] nodes: [ @@ -52,7 +52,7 @@ ModelProto { GraphProto { name: "torch-jit-export4" inputs: [] - outputs: [{name: "5", type:Tensor dims: 10 1}] + outputs: [{name: "5", type:Tensor dims: }] initializers: [] nodes: [ diff --git a/test/expect/TestScript.test_onnx_export_speculate-f2.expect b/test/expect/TestScript.test_onnx_export_speculate-f2.expect index de3d721c1e8ed..34e7dadc2ad7a 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f2.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f2.expect @@ -5,11 +5,11 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}] + inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}] outputs: [{name: "5", type:Tensor dims: 1 20}] initializers: [TensorProto shape: [20 10],TensorProto shape: [20]] nodes: [ - Node {type: "Add", inputs: [x,x], outputs: [3], attributes: []}, + Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []}, Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { diff --git a/test/test_jit.py b/test/test_jit.py index bbf27db8f092c..aeab8448d6cad 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2695,8 +2695,10 @@ def forward(self, x): return x + x mte = ModuleToExport() + outputs = mte(torch.zeros(1, 2, 3)) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.zeros(1, 2, 3),), None, verbose=False)) + mte, (torch.zeros(1, 2, 3),), None, verbose=False, + example_outputs=outputs)) def test_onnx_export_script_python_fail(self): class ModuleToInline(torch.jit.ScriptModule): @@ -2717,9 +2719,11 @@ def forward(self, x): return y + y mte = ModuleToExport() + outputs = mte(torch.zeros(1, 2, 3)) f = io.BytesIO() with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"): - torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False) + torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False, + example_outputs=outputs) def test_onnx_export_script_inline_trace(self): class ModuleToInline(torch.jit.ScriptModule): @@ -2740,8 +2744,10 @@ def forward(self, x): return y + y mte = ModuleToExport() + outputs = mte(torch.zeros(1, 2, 3)) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.zeros(1, 2, 3),), None, verbose=False)) + mte, (torch.zeros(1, 2, 3),), None, verbose=False, + example_outputs=outputs)) def test_onnx_export_script_inline_script(self): class ModuleToInline(torch.jit.ScriptModule): @@ -2763,8 +2769,10 @@ def forward(self, x): return y + y mte = ModuleToExport() + outputs = mte(torch.zeros(1, 2, 3)) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.zeros(1, 2, 3),), None, verbose=False)) + mte, (torch.zeros(1, 2, 3),), None, verbose=False, + example_outputs=outputs)) def test_onnx_export_script_module_loop(self): class ModuleToExport(torch.jit.ScriptModule): @@ -2778,9 +2786,10 @@ def forward(self, x): return x mte = ModuleToExport() - f = io.BytesIO() + outputs = mte(torch.zeros(1, 2, 3)) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.zeros(1, 2, 3),), None, verbose=False)) + mte, (torch.zeros(1, 2, 3),), None, verbose=False, + example_outputs=outputs)) def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): @@ -2794,8 +2803,10 @@ def forward(self, x): return x mte = ModuleToExport() + outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long)) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.zeros(1, 2, 3),), None, verbose=False)) + mte, (torch.zeros(1, 2, 3),), None, verbose=False, + example_outputs=outputs)) def test_onnx_export_script_inline_params(self): class ModuleToInline(torch.jit.ScriptModule): @@ -2824,7 +2835,8 @@ def forward(self, x): reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)) self.assertEqual(result, reference) self.assertExpected(torch.onnx._export_to_pretty_string( - mte, (torch.ones(2, 3),), None, verbose=False)) + mte, (torch.ones(2, 3),), None, verbose=False, + example_outputs=result, propagate=True)) def test_trace_with_size(self): @torch.jit.trace(torch.zeros(1, 1)) @@ -2877,19 +2889,37 @@ def transpose(x): return x.t() f1 = Foo(transpose) + outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float)) f2 = Foo(linear) + outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float)) onnx_ish = torch.onnx._export_to_pretty_string( f1, (torch.ones(1, 10, dtype=torch.float), ), - None, verbose=False) + None, verbose=False, example_outputs=outputs_f1) self.assertExpected(onnx_ish, subname='f1') onnx_ish = torch.onnx._export_to_pretty_string( f2, (torch.ones(1, 10, dtype=torch.float), ), - None, verbose=False) + None, verbose=False, example_outputs=outputs_f2) self.assertExpected(onnx_ish, subname='f2') + def test_onnx_export_shape_reshape(self): + class Foo(torch.nn.Module): + def forward(self, x): + import torch.onnx.operators + x = x.repeat(5, 1, 1) + shape = torch.onnx.operators.shape_as_tensor(x) + reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape) + return reshaped + + foo = torch.jit.trace(torch.zeros(1, 2, 3))(Foo()) + outputs = foo(torch.zeros(1, 2, 3)) + f = io.BytesIO() + s = torch.onnx._export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f, + example_outputs=outputs) + self.assertExpected(s) + def test_shape_analysis_loop(self): def foo(a, b, x): c = a diff --git a/tools/jit/templates/aten_dispatch.cpp b/tools/jit/templates/aten_dispatch.cpp index 877d6b1117078..437b5f348f9f8 100644 --- a/tools/jit/templates/aten_dispatch.cpp +++ b/tools/jit/templates/aten_dispatch.cpp @@ -94,7 +94,7 @@ std::unordered_map constructors = { std::string getDescriptor(jit::Node* n) { std::stringstream s; - JIT_ASSERT(n->kind().is_aten()); + JIT_ASSERTM(n->kind().is_aten(), "%s is not an ATen op", n->kind().toDisplayString()); s << n->kind().toUnqualString(); if (tensor_vararg_fns.count(n->kind()) == 0) s << "-" << n->inputs().size(); diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index e5a5f700bfceb..1f9445cbbf635 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -51,6 +51,8 @@ void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor, p->add_dims(d); } onnx::DataType onnx_type; + // Most integral types and float16 need to be serialized as int32 + at::ScalarType cast_type = tensor.type().scalarType(); switch(tensor.type().scalarType()) { case at::kDouble: onnx_type = onnx::kDOUBLE; @@ -60,13 +62,16 @@ void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor, break; case at::kHalf: onnx_type = onnx::kFLOAT16; + cast_type = at::kInt; break; case at::kByte: case at::kChar: onnx_type = onnx::kINT8; + cast_type = at::kInt; break; case at::kShort: onnx_type = onnx::kINT16; + cast_type = at::kInt; break; case at::kInt: onnx_type = onnx::kINT32; @@ -80,7 +85,7 @@ void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor, } p->set_data_type(onnx_type); // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous() - auto t = tensor.contiguous().toBackend(at::kCPU); + auto t = tensor.contiguous().toBackend(at::kCPU).toType(cast_type); // Add a buffer to the raw_data_export_map for the caller to dump into an // external data store. If external_ref is not specified, we instead dump // the contiguous data into the protobuf itself @@ -158,40 +163,41 @@ void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, Export void encodeTypeProtoTensorType(onnx::TypeProtoTensor* tensor_type, Value* n) { onnx::TensorShapeProto* shape = tensor_type->mutable_shape(); - TensorType* node_type = n->type()->expect(); - const std::vector& sizes = node_type->sizes(); - for (std::int64_t s : sizes) { - shape->add_dim(s); - } - onnx::DataType onnx_type; - switch(node_type->scalarType()) { - case at::kDouble: - onnx_type = onnx::kDOUBLE; - break; - case at::kFloat: - onnx_type = onnx::kFLOAT; - break; - case at::kHalf: - onnx_type = onnx::kFLOAT16; - break; - case at::kByte: - case at::kChar: - onnx_type = onnx::kINT8; - break; - case at::kShort: - onnx_type = onnx::kINT16; - break; - case at::kInt: - onnx_type = onnx::kINT32; - break; - case at::kLong: - onnx_type = onnx::kINT64; - break; - default: - torch::barf("unexpected tensor scalar type"); - break; + if (TensorType* node_type = n->type()->cast()) { + const std::vector& sizes = node_type->sizes(); + for (std::int64_t s : sizes) { + shape->add_dim(s); + } + onnx::DataType onnx_type; + switch(node_type->scalarType()) { + case at::kDouble: + onnx_type = onnx::kDOUBLE; + break; + case at::kFloat: + onnx_type = onnx::kFLOAT; + break; + case at::kHalf: + onnx_type = onnx::kFLOAT16; + break; + case at::kByte: + case at::kChar: + onnx_type = onnx::kINT8; + break; + case at::kShort: + onnx_type = onnx::kINT16; + break; + case at::kInt: + onnx_type = onnx::kINT32; + break; + case at::kLong: + onnx_type = onnx::kINT64; + break; + default: + torch::barf("unexpected tensor scalar type"); + break; + } + tensor_type->set_data_type(onnx_type); } - tensor_type->set_data_type(onnx_type); } void encodeValueInfo(onnx::ValueInfoProto* v, Value* n) { diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 1e4078ab9d0e2..643dfd530352e 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -14,6 +14,7 @@ #include "torch/csrc/jit/passes/peephole.h" #include "torch/csrc/jit/passes/canonicalize.h" #include "torch/csrc/jit/passes/onnx/peephole.h" +#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/script/init.h" @@ -82,7 +83,8 @@ void initJITBindings(PyObject *module) { .def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) { return py::reinterpret_steal(python::unflatten(vars, desc)); }) - .def("_jit_pass_onnx_block", BlockToONNX); + .def("_jit_pass_onnx_block", BlockToONNX) + .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops); py::class_(m, "GraphExecutor") .def( diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 9d2581a3bf632..f59b9f4f31fea 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -104,7 +104,8 @@ _(onnx, Sub) \ _(onnx, Transpose) \ _(onnx, Unsqueeze) \ _(onnx, Loop) \ -_(onnx, If) +_(onnx, If) \ +_(onnx, Reshape) /* end */ // These symbols are attribute keys. They are shared between both ONNX and ATen diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index d5a152b190f2e..9ef3e24dcf251 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -736,6 +736,35 @@ Operation getOperation(jit::Node* node, bool values_are_variables) { return 0; }; IR_ELSE() + switch (node->kind()) { + case onnx::Reshape: { + return [=](Stack& stack) { + auto shape = pop(stack).contiguous(); + auto input = pop(stack); + JIT_ASSERT(shape.ndimension() == 1); + at::IntList shape_list(shape.data(), shape.size(0)); + stack.push_back(input.reshape(shape_list)); + return 0; + }; + } break; + case onnx::Shape: { + return [=](Stack& stack) { + auto t = pop(stack); + at::IntList sizes = t.sizes(); + auto sizes_tensor = at::CPU(at::kLong).tensor(sizes.size()); + auto accessor = sizes_tensor.accessor(); + for (size_t i=0; iappendNode(ctx.block->owningGraph()->createClone(node, envFn)); for(size_t i = 0; i < node->outputs().size(); i++) { + // n_->outputs()[i]->setType(node->outputs()[i]->type()); env[node->outputs()[i]] = n_->outputs()[i]; } }; @@ -217,6 +218,7 @@ void BlockToONNX(Block* old_block, Block* new_block, bool aten, std::unordered_m } for (auto output : old_block->outputs()) { ctx.block->registerOutput(env.at(output)); + env.at(output)->setType(output->type()); } // Copy stage from original graph diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp new file mode 100644 index 0000000000000..72ff528c97205 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp @@ -0,0 +1,22 @@ +#include "torch/csrc/jit/passes/onnx/fixup_onnx_loop.h" + +namespace torch { namespace jit { + +void FixupONNXLoops(Block *block) { + for (auto *node : block->nodes()) { + if (node->kind() == torch::jit::onnx::Loop) { + JIT_ASSERT(node->blocks().size() == 1); + auto *sub_block = node->blocks()[0]; + sub_block->insertInput(1, "cond"); + } + for (Block * block : node->blocks()) { + FixupONNXLoops(block); + } + } +} + +void FixupONNXLoops(std::shared_ptr& graph) { + FixupONNXLoops(graph->block()); +} + +}} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_loop.h b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.h new file mode 100644 index 0000000000000..fe311ac6f2946 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_loop.h @@ -0,0 +1,9 @@ +#pragma once + +#include "torch/csrc/jit/ir.h" + +namespace torch { namespace jit { + +void FixupONNXLoops(std::shared_ptr& graph); + +}} diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 7ee83cae5aa30..94fa31ca6bf2f 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -342,6 +342,18 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { setDynamicType(node); handled = true; } break; + case onnx::Shape: { + if (check_overload(/*num_inputs=*/1, /*num_outputs=*/1, {})) { + std::vector dim_vec = {(int64_t)types.at(0)->sizes().size()}; + at::IntList dims(dim_vec); + node->output()->setType( + std::make_shared(at::kLong, -1, dims)); + } + } break; + case onnx::Reshape: { + setDynamicType(node); + handled = true; + } default: { } break; } diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 828210a7f7906..442de7e00bc01 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -142,6 +142,7 @@ void initPythonIRBindings(PyObject * module_) { return node; }) .VS(copyMetadata) + .VS(isTensor) ; #undef VS @@ -284,7 +285,7 @@ void initPythonIRBindings(PyObject * module_) { ; #define TS(name) \ - def(#name,&Node :: name) + def(#name,&Type :: name) py::class_>(m,"Type") .def("__repr__",[](Type & t) { std::stringstream ss; diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 12e2b8c13cede..8858367cab36d 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -131,6 +131,8 @@ struct Environment { void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) { Value* as_simple_value = asSimple(value); + if (as_simple_value) + as_simple_value->setUniqueName(name); // prevent re-assignment involving any sugared values // any reassignment like: // a = ... diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 13ff958af4105..9c2dee00db1bd 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -388,6 +388,7 @@ void initJitScriptBindings(PyObject* module) { return m.graph(); }) .def("propagate_shapes", &Method::propagate_shapes) + .def("propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes) .def("params", &Method::params); m.def("_jit_script_compile", [](Def def, ResolutionCallback rcb) { diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index cd5f647b91770..1f02b1368d9cd 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -85,7 +85,35 @@ struct Method { for (auto inp : member_inputs) { inputs.push_back(*inp); } - PropagateInputShapes(*retval, ArgumentSpec(with_grad, variable_tensor_list(std::move(inputs)))); + PropagateInputShapes( + *retval, + ArgumentSpec(with_grad, variable_tensor_list(std::move(inputs)))); + return retval; + } + + std::shared_ptr propagate_and_assign_input_and_output_shapes(std::vector inputs, std::vector outputs, bool with_grad=false, bool propagate=true) { + auto retval = graph_->copy(); + for (auto inp : member_inputs) { + inputs.push_back(*inp); + } + if (propagate) { + auto inputs_copy = inputs; + PropagateInputShapes(*retval, ArgumentSpec(with_grad, variable_tensor_list(std::move(inputs_copy)))); + } + JIT_ASSERT(retval->inputs().size() == inputs.size()); + for (size_t i=0; i < retval->inputs().size(); ++i) { + auto scalar_type = inputs[i].type().scalarType(); + auto sizes = inputs[i].sizes(); + auto type = std::make_shared(scalar_type, -1, sizes); + retval->inputs()[i]->setType(type); + } + JIT_ASSERT(retval->outputs().size() == outputs.size()); + for (size_t i=0; i < retval->outputs().size(); ++i) { + auto scalar_type = outputs[i].type().scalarType(); + auto sizes = outputs[i].sizes(); + auto type = std::make_shared(scalar_type, -1, sizes); + retval->outputs()[i]->setType(type); + } return retval; } diff --git a/torch/onnx/operators.py b/torch/onnx/operators.py index b589996a83303..d0ecc2879ecb7 100644 --- a/torch/onnx/operators.py +++ b/torch/onnx/operators.py @@ -6,6 +6,7 @@ import torch import torch.onnx +import torch.onnx.utils def _shape_as_tensor(g, input): diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 5c550cb48c34c..80521ae81594a 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -277,9 +277,10 @@ def permute(g, self, dims): def view(g, self, size): - self_sizes = self.type().sizes() - if self_sizes and len(size) == 2 and self_sizes[0] == size[0]: - return g.op("Flatten", self, axis_i=1) + if self.isTensor(): + self_sizes = self.type().sizes() + if self_sizes and len(size) == 2 and self_sizes[0] == size[0]: + return g.op("Flatten", self, axis_i=1) shape = g.op("Constant", value_t=torch.LongTensor(size)) return g.op("Reshape", self, shape) @@ -700,10 +701,11 @@ def topk(g, self, k, dim=None, largest=True, sorted=True, out=None): def repeat(g, self, repeats): - sizes = self.type().sizes() - diff_dims = len(repeats) - len(sizes) - if diff_dims > 0: - self = view(g, self, [1] * diff_dims + sizes) + if self.isTensor(): + sizes = self.type().sizes() + diff_dims = len(repeats) - len(sizes) + if diff_dims > 0: + self = view(g, self, [1] * diff_dims + sizes) return g.op("Tile", self, g.op("Constant", value_t=torch.LongTensor(repeats))) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index bb96a2a11ed2e..e0ae4a559efbe 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -87,6 +87,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False, def _optimize_graph(graph, aten): # run dce first to eliminate dead parts of the graph that might have been # left behind by things like symbolic_override + torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) @@ -98,6 +99,8 @@ def _optimize_graph(graph, aten): torch._C._jit_pass_lint(graph) torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) + torch._C._jit_pass_fixup_onnx_loops(graph) + torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) return graph @@ -136,16 +139,21 @@ def _trace_and_get_graph_from_model(model, args, training): def _model_to_graph(model, args, f, verbose=False, training=False, - input_names=None, output_names=None, aten=False): + input_names=None, output_names=None, aten=False, + example_outputs=None, propagate=False): # Special case for common case of passing a single Variable if isinstance(args, torch.Tensor): args = (args, ) if isinstance(model, torch.jit.ScriptModule): torch_out = None + assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" + if isinstance(example_outputs, torch.Tensor): + example_outputs = [example_outputs] try: method = model.__getattr__('forward') - graph = method.propagate_shapes(args, False) + graph = method.propagate_and_assign_input_and_output_shapes( + args, example_outputs, False, propagate) params = method.params() except AttributeError: # TODO: just trace it @@ -164,10 +172,12 @@ def _model_to_graph(model, args, f, verbose=False, training=False, def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, - input_names=None, output_names=None, aten=False, export_type=ExportTypes.PROTOBUF_FILE): + input_names=None, output_names=None, aten=False, export_type=ExportTypes.PROTOBUF_FILE, + example_outputs=None, propagate=False): graph, params, torch_out = _model_to_graph(model, args, f, verbose, training, input_names, - output_names, aten) + output_names, aten, + example_outputs, propagate) from torch.onnx.symbolic import _onnx_opset_version return graph.prettyPrintExport(params, _onnx_opset_version, False) @@ -178,10 +188,12 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, # this output will be None, since we are not doing any tracing but rather # directly extracting the graph. def _export(model, args, f, export_params=True, verbose=False, training=False, - input_names=None, output_names=None, aten=False, export_type=ExportTypes.PROTOBUF_FILE): + input_names=None, output_names=None, aten=False, export_type=ExportTypes.PROTOBUF_FILE, + example_outputs=None, propagate=False): graph, params, torch_out = _model_to_graph(model, args, f, verbose, training, input_names, - output_names, aten) + output_names, aten, + example_outputs, propagate) # TODO: Don't allocate a in-memory string for the protobuf from torch.onnx.symbolic import _onnx_opset_version defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE