From 66b427a9b1925476a2c5be5d61bf74f0c3181ef2 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 12 May 2022 06:07:59 -0700 Subject: [PATCH] Revert "[ONNX] Support optional type (#68793) (#73284)" This reverts commit 679fc90cdb126223e159a1ba3baeeb30fcc34b03. --- aten/src/ATen/core/interned_strings.h | 3 - caffe2/python/onnx/tests/onnx_backend_test.py | 3 - test/onnx/test_models.py | 18 +- test/onnx/test_pytorch_common.py | 15 +- test/onnx/test_pytorch_onnx_caffe2.py | 13 +- test/onnx/test_pytorch_onnx_onnxruntime.py | 814 +++++++----------- .../test_pytorch_onnx_onnxruntime_cuda.py | 10 +- test/onnx/test_utility_funs.py | 2 +- torch/_C/__init__.pyi.in | 2 +- .../passes/onnx/fixup_onnx_controlflow.cpp | 230 +---- torch/csrc/jit/passes/onnx/peephole.h | 2 +- .../jit/passes/onnx/scalar_type_analysis.cpp | 7 +- .../jit/passes/onnx/shape_type_inference.cpp | 108 ++- .../jit/passes/onnx/shape_type_inference.h | 27 +- torch/csrc/jit/python/python_arg_flatten.cpp | 6 +- torch/csrc/jit/python/python_ir.cpp | 10 - torch/csrc/jit/python/script_init.cpp | 48 +- torch/csrc/jit/serialization/export.cpp | 172 +--- torch/csrc/jit/serialization/onnx.cpp | 12 - torch/csrc/onnx/init.cpp | 10 +- torch/onnx/symbolic_opset15.py | 35 - torch/onnx/symbolic_opset9.py | 9 +- torch/onnx/utils.py | 144 ++-- 23 files changed, 533 insertions(+), 1167 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index e4c95fe43a33c..0b307552d160a 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -279,9 +279,6 @@ namespace c10 { _(onnx, Range) \ _(onnx, Tile) \ _(onnx, Where) \ - _(onnx, Optional) \ - _(onnx, OptionalGetElement) \ - _(onnx, OptionalHasElement) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ _(attr, ReverseSubgraph) \ diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index 42262d269695d..50a350a6faaa5 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -165,9 +165,6 @@ '|test_optional_.*' '|test_shape_end_.*' '|test_shape_start_.*' - '|test_identity_opt_*' - '|test_loop16_seq_none_*' - '|test_if_opt_*' ')') # Unsupported ops in opset 16 diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index dc849528842a8..391f6309b46e8 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -20,7 +20,7 @@ run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, - skipScriptTest, + disableScriptTest, ) from torchvision.models import shufflenet_v2_x1_0 from torchvision.models.alexnet import alexnet @@ -82,7 +82,7 @@ def test_prelu(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(PReluNet(), x) - @skipScriptTest() + @disableScriptTest() def test_concat(self): input_a = Variable(torch.randn(BATCH_SIZE, 3)) input_b = Variable(torch.randn(BATCH_SIZE, 3)) @@ -93,12 +93,12 @@ def test_permute(self): x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12)) self.exportTest(PermuteNet(), x) - @skipScriptTest() + @disableScriptTest() def test_embedding_sequential_1(self): x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) self.exportTest(EmbeddingNetwork1(), x) - @skipScriptTest() + @disableScriptTest() def test_embedding_sequential_2(self): x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3))) self.exportTest(EmbeddingNetwork2(), x) @@ -152,7 +152,7 @@ def test_resnet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(resnet50()), toC(x), atol=1e-6) - @skipScriptTest(min_opset_version=15) # None type in outputs + @disableScriptTest() # None type in outputs def test_inception(self): x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299)) self.exportTest(toC(inception_v3()), toC(x)) @@ -175,14 +175,14 @@ def test_densenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5) - @skipScriptTest() + @disableScriptTest() def test_dcgan_netD(self): netD = _netD(1) netD.apply(weights_init) input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1)) self.exportTest(toC(netD), toC(input)) - @skipScriptTest() + @disableScriptTest() def test_dcgan_netG(self): netG = _netG(1) netG.apply(weights_init) @@ -239,7 +239,7 @@ def test_qat_resnet_per_channel(self): self.exportTest(toC(qat_resnet50), toC(x)) - @skipScriptTest(min_opset_version=15) # None type in outputs + @disableScriptTest() # None type in outputs def test_googlenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5) @@ -252,7 +252,7 @@ def test_mobilenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5) - @skipScriptTest() # prim_data + @disableScriptTest() # prim_data def test_shufflenet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5) diff --git a/test/onnx/test_pytorch_common.py b/test/onnx/test_pytorch_common.py index 44ccc303cff74..c4d270ba29411 100644 --- a/test/onnx/test_pytorch_common.py +++ b/test/onnx/test_pytorch_common.py @@ -91,12 +91,21 @@ def wrapper(self): return skip_dec +# Enables tests for scripting, instead of only tracing the model. +def enableScriptTest(): + def script_dec(func): + def wrapper(self): + self.is_script_test_enabled = True + return func(self) + return wrapper + return script_dec + -# skips tests for scripting. -def skipScriptTest(min_opset_version=float("inf")): +# Disable tests for scripting. +def disableScriptTest(): def script_dec(func): def wrapper(self): - self.is_script_test_enabled = self.opset_version >= min_opset_version + self.is_script_test_enabled = False return func(self) return wrapper diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 79ae0a36f37b8..41228100affcd 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -2573,14 +2573,11 @@ def forward(self, lstm_in): bias=has_bias, num_layers=num_layers, ) - lstm_in = ( - [ - torch.from_numpy(inputs), - torch.from_numpy(hx), - torch.from_numpy(hx), - ] - + [param.detach() for param in torch_lstm._flat_weights], - ) + lstm_in = [ + torch.from_numpy(inputs), + torch.from_numpy(hx), + torch.from_numpy(hx), + ] + [param.detach() for param in torch_lstm._flat_weights] self.run_model_test( MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index e03566453f8f8..8c32d5fca781b 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -12,29 +12,21 @@ import model_defs.word_language_model as word_language_model import numpy as np import onnx -import onnxruntime -import torchvision -from model_defs.lstm_flattening_result import ( - LstmFlatteningResultWithoutSeqLength, - LstmFlatteningResultWithSeqLength, -) -from model_defs.rnn_model_with_packed_sequence import ( - RnnModelWithPackedSequence, - RnnModelWithPackedSequenceWithoutState, - RnnModelWithPackedSequenceWithState, -) -from test_pytorch_common import ( - BATCH_SIZE, - RNN_BATCH_SIZE, - RNN_HIDDEN_SIZE, - RNN_INPUT_SIZE, - RNN_SEQUENCE_LENGTH, - skipIfNoLapack, - skipIfUnsupportedMaxOpsetVersion, - skipIfUnsupportedMinOpsetVersion, - skipIfUnsupportedOpsetVersion, - skipScriptTest, -) + +import torch.nn.functional as F +from torch.nn.utils import rnn as rnn_utils +from model_defs.lstm_flattening_result import (LstmFlatteningResultWithSeqLength, + LstmFlatteningResultWithoutSeqLength) +from model_defs.rnn_model_with_packed_sequence import (RnnModelWithPackedSequence, + RnnModelWithPackedSequenceWithState, + RnnModelWithPackedSequenceWithoutState) +from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion, + skipIfNoLapack, disableScriptTest, skipIfUnsupportedMaxOpsetVersion) +from test_pytorch_common import BATCH_SIZE +from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE +from typing import List, Tuple, Optional, Dict, Union +from torch import Tensor + from torchvision import ops from torchvision.models.detection.faster_rcnn import ( FastRCNNPredictor, @@ -66,22 +58,22 @@ def flatten_tuples(elem): - flattened = [] + tup = [] for t in elem: - if isinstance(t, tuple): - flattened.extend(flatten_tuples(t)) + if isinstance(t, (tuple)): + tup += flatten_tuples(t) else: - flattened.append(t) - return flattened + tup += [t] + return tup def to_numpy(elem): - if isinstance(elem, Tensor): + if isinstance(elem, torch.Tensor): if elem.requires_grad: return elem.detach().cpu().numpy() else: return elem.cpu().numpy() - elif isinstance(elem, (list, tuple)): + elif isinstance(elem, list) or isinstance(elem, tuple): return [to_numpy(inp) for inp in elem] elif isinstance(elem, bool): return np.array(elem, dtype=bool) @@ -90,11 +82,12 @@ def to_numpy(elem): elif isinstance(elem, float): return np.array(elem, dtype=float) elif isinstance(elem, dict): - flattened = [] + dict_ = [] for k in elem: - flattened += [to_numpy(k)] + [to_numpy(elem[k])] - return flattened - return elem + dict_ += [to_numpy(k)] + [to_numpy(elem[k])] + return dict_ + else: + return RuntimeError("Input has unknown type.") def convert_to_onnx( @@ -144,30 +137,17 @@ def inline_flatten_list(inputs, res_list): return res_list -def unpack_to_numpy(values): +def unpack_to_numpy(value): value_unpacked = [] - for value in values: - value_unpacked.extend(unpack_quantized_tensor(value)) - return [to_numpy(v) for v in value_unpacked] - - -def run_ort(ort_sess, inputs): - kw_inputs = {} - if inputs and isinstance(inputs[-1], dict): - kw_inputs = inputs[-1] - inputs = inputs[:-1] - inputs = unpack_to_numpy(flatten_tuples(inputs)) - ort_inputs = {} - for input_name, input in kw_inputs.items(): - ort_inputs[input_name] = to_numpy(input) - inputs = to_numpy(inputs) - ort_sess_inputs = ort_sess.get_inputs() - for i, input in enumerate(inputs): - if i == len(ort_sess_inputs) or ort_sess_inputs[i].name in ort_inputs: - raise ValueError( - f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}" - ) - ort_inputs[ort_sess_inputs[i].name] = input + for value_ in value: + value_unpacked.extend(unpack_quantized_tensor(value_)) + value_final = [to_numpy(v) for v in value_unpacked] + return value_final + + +def run_ort(ort_sess, input): + input = unpack_to_numpy(flatten_tuples(input)) + ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(input)) ort_outs = ort_sess.run(None, ort_inputs) return inline_flatten_list(ort_outs, []) @@ -214,10 +194,12 @@ def run_model_test( if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) with torch.no_grad(): - if isinstance(input, (Tensor, dict)): + if isinstance(input, torch.Tensor): input = (input,) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. + if isinstance(input, dict): + input = (input,) input_args = copy.deepcopy(input) input_kwargs = {} if dict_check and isinstance(input_args[-1], dict): @@ -228,7 +210,7 @@ def run_model_test( output = model_copy(*input_args, **input_kwargs) except Exception: output = model(*input_args, **input_kwargs) - if isinstance(output, Tensor): + if isinstance(output, torch.Tensor): output = (output,) if not dict_check and isinstance(input[-1], dict): @@ -257,9 +239,7 @@ def run_model_test( input_copy = copy.deepcopy(input) if flatten: input_copy, _ = torch.jit._flatten(input_copy) - elif input_copy and input_copy[-1] == {}: - # Handle empty kwargs (normally removed by flatten). - input_copy = input_copy[:-1] + ort_outs = run_ort(ort_sess, input_copy) ort_compare_with_pytorch(ort_outs, output, rtol, atol) @@ -267,11 +247,11 @@ def run_model_test( # model with these inputs and check the outputs if test_with_inputs is not None: for test_input in test_with_inputs: - if isinstance(test_input, Tensor): + if isinstance(test_input, torch.Tensor): test_input = (test_input,) test_input_copy = copy.deepcopy(test_input) output = model(*test_input_copy) - if isinstance(output, Tensor): + if isinstance(output, torch.Tensor): output = (output,) if remained_onnx_input_idx is not None: test_input_onnx = [] @@ -461,22 +441,17 @@ def _run_test(m, remained_onnx_input_idx, flatten=True): ) if isinstance(remained_onnx_input_idx, dict): - scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"] - tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"] + scripting_remained_onnx_input_idx = remained_onnx_input_idx['scripting'] + tracing_remained_onnx_input_idx = remained_onnx_input_idx['tracing'] else: scripting_remained_onnx_input_idx = remained_onnx_input_idx tracing_remained_onnx_input_idx = remained_onnx_input_idx - is_script = isinstance( - model, (torch.jit.ScriptModule, torch.jit.ScriptFunction) - ) - - if self.is_script_test_enabled: - script_model = model if is_script else torch.jit.script(model) + if self.is_script_test_enabled and not isinstance(model, torch.jit.ScriptModule): + script_model = torch.jit.script(model) _run_test(script_model, scripting_remained_onnx_input_idx, flatten=False) - if not is_script: - _run_test(model, tracing_remained_onnx_input_idx) + _run_test(model, tracing_remained_onnx_input_idx) def run_model_test_with_external_data( self, @@ -499,13 +474,13 @@ def run_model_test_with_external_data( elif training is None or training == torch.onnx.TrainingMode.EVAL: model.eval() with torch.no_grad(): - if isinstance(input, Tensor): + if isinstance(input, torch.Tensor): input = (input,) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. input_copy = copy.deepcopy(input) output = model(*input_copy) - if isinstance(output, Tensor): + if isinstance(output, torch.Tensor): output = (output,) # export the model to ONNX @@ -692,14 +667,9 @@ def forward(self, x): model = Fuse() x = torch.randn(2, 5, 9, requires_grad=True) - self.run_test( - torch.jit.script(model), - (x,), - input_names=["x"], - dynamic_axes={"x": [0, 2]}, - rtol=1e-3, - atol=1e-6, - ) + self.run_test(torch.jit.script(model), (x,), + input_names=['x'], dynamic_axes={'x': [0, 2]}, + rtol=1e-3, atol=1e-6) def test_conv_tbc(self): from torch.nn.modules.utils import _single @@ -713,9 +683,9 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0): self.padding = _single(padding) self.weight = torch.nn.Parameter( - Tensor(self.kernel_size[0], in_channels, out_channels) + torch.Tensor(self.kernel_size[0], in_channels, out_channels) ) - self.bias = torch.nn.Parameter(Tensor(out_channels)) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self): @@ -776,7 +746,7 @@ def run_word_language_model(self, model_name): # Only support CPU version, since tracer is not working in GPU RNN. self.run_test(model, (x, model.hidden)) - def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor: + def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: import os from PIL import Image @@ -788,14 +758,12 @@ def get_image(self, rel_path: str, size: Tuple[int, int]) -> Tensor: return transforms.ToTensor()(image) - def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]: - return ( - [self.get_image("grace_hopper_517x606.jpg", (100, 320))], - [self.get_image("rgb_pytorch.png", (250, 380))], - ) + def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + return ([self.get_image("grace_hopper_517x606.jpg", (100, 320))], + [self.get_image("rgb_pytorch.png", (250, 380))]) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() # Faster RCNN model is not scriptable + @disableScriptTest() # Faster RCNN model is not scriptable def test_faster_rcnn(self): model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn( pretrained=False, pretrained_backbone=True, min_size=200, max_size=300 @@ -870,7 +838,7 @@ def test_paste_mask_in_image(self): assert torch.all(out2.eq(out_trace2)) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_mask_rcnn(self): model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn( pretrained=False, pretrained_backbone=True, min_size=200, max_size=300 @@ -954,7 +922,7 @@ def test_heatmaps_to_keypoints(self): @unittest.skip("Failing, see https://github.com/pytorch/pytorch/issues/66528") @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_keypoint_rcnn(self): model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn( pretrained=False, pretrained_backbone=False, min_size=200, max_size=300 @@ -993,7 +961,7 @@ def test_keypoint_rcnn(self): ) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_shufflenet_v2_dynamic_axes(self): model = torchvision.models.shufflenet_v2_x0_5(pretrained=False) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) @@ -1012,7 +980,7 @@ def test_shufflenet_v2_dynamic_axes(self): atol=1e-5, ) - @skipScriptTest() + @disableScriptTest() def test_mobilenet_v3(self): model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False) dummy_input = torch.randn(1, 3, 224, 224) @@ -1022,7 +990,7 @@ def test_mobilenet_v3(self): "Unstable loading pretrained quantized mobilenet v3: https://github.com/pytorch/vision/issues/5303" ) @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() + @disableScriptTest() def test_mobilenet_v3_quant(self): model = torchvision.models.quantization.mobilenet_v3_large( pretrained=True, quantize=True @@ -1064,15 +1032,15 @@ def forward(self, x): model = torch.jit.trace(TopPredictor(model), input_tensor) self.run_test(model, (input_tensor,)) - @skipScriptTest() + @disableScriptTest() def test_word_language_model_RNN_TANH(self): self.run_word_language_model("RNN_TANH") - @skipScriptTest() + @disableScriptTest() def test_word_language_model_RNN_RELU(self): self.run_word_language_model("RNN_RELU") - @skipScriptTest() # scripting prim::unchecked_cast prim::setattr + @disableScriptTest() # scripting prim::unchecked_cast prim::setattr def test_word_language_model_LSTM(self): self.run_word_language_model("LSTM") @@ -1147,7 +1115,7 @@ def forward(self, input): m1 = torch.randn(3, 4, 5, 6, 7) self.run_test(MyModel(), m1) - @skipScriptTest() + @disableScriptTest() def test_dict(self): class MyModel(torch.nn.Module): def forward(self, x_in): @@ -1160,7 +1128,7 @@ def forward(self, x_in): x = {torch.tensor(1.0): torch.randn(1, 2, 3)} self.run_test(MyModel(), (x, {})) - @skipScriptTest() + @disableScriptTest() def test_dict_str(self): class MyModel(torch.nn.Module): def forward(self, x_in): @@ -1171,12 +1139,12 @@ def forward(self, x_in): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x, {})) - @skipScriptTest() # User-defined class not supported + @disableScriptTest() # User-defined class not supported def test_dict_output(self): class DictModelOutput(OrderedDict): - tensor_out: Tensor - tuple_out: Optional[Tuple[Tensor]] = None - list_out: Optional[List[Tensor]] = None + tensor_out: torch.Tensor + tuple_out: Optional[Tuple[torch.Tensor]] = None + list_out: Optional[List[torch.Tensor]] = None class MyModel(torch.nn.Module): def forward(self, a, b, c, d): @@ -1216,7 +1184,7 @@ def forward(self, a, b, c, d): def test_tuple_input(self): class TupleModel(torch.nn.Module): - def forward(self, a: Tuple[Tensor, Tensor]): + def forward(self, a: Tuple[torch.Tensor, torch.Tensor]): return a x = (torch.randn(3, 4), torch.randn(4, 3)) @@ -1224,7 +1192,7 @@ def forward(self, a: Tuple[Tensor, Tensor]): def test_tuple_primitive_input(self): class TupleModel(torch.nn.Module): - def forward(self, a: Tuple[int, Tensor], b): + def forward(self, a: Tuple[int, torch.Tensor], b): return a[0], a[1] + b x = (3, torch.randn(4, 3)) @@ -1233,27 +1201,30 @@ def forward(self, a: Tuple[int, Tensor], b): def test_nested_tuple_input(self): class NestedTupleModel(torch.nn.Module): - def forward(self, a, b: Tuple[Tensor, Tuple[Tensor, Tensor]]): + def forward(self, a, b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): return a + b[0] + b[1][0] + b[1][1] x = torch.randn(4, 5) y = (torch.randn(4, 5), (torch.randn(1, 5), torch.randn(4, 1))) self.run_test(NestedTupleModel(), input=(x, y)) - def test_empty_kwargs(self): - class IdentityModel(torch.nn.Module): + @disableScriptTest() + def test_optional_inputs_with_no_optionals(self): + class NoOptionalModel(torch.nn.Module): def forward(self, input): return input - self.run_test(IdentityModel(), (torch.randn(2, 3), {})) + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,)) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) - @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 - @skipIfUnsupportedMinOpsetVersion(15) - def test_mixed_optional_default_none(self): - class Model(torch.nn.Module): - def forward( - self, x, y: Optional[Tensor] = None, z: Optional[Tensor] = None - ): + @disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs + def test_optional_inputs_with_mixed_optionals(self): + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): if y is not None: return x + y if z is not None: @@ -1263,49 +1234,45 @@ def forward( x = torch.randn(2, 3) y = torch.randn(2, 3) z = torch.randn(2, 3) - model = Model() - # Without kwargs dict. - self.run_test(model, (x, y, None)) - self.run_test(model, (x, None, z)) - # With kwargs dict. - self.run_test(model, (x, {"y": y, "z": None})) - self.run_test(model, (x, {"y": None, "z": z})) - self.run_test(model, (x, {"z": z})) - self.run_test(model, (x, {"y": y})) - - @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. - @skipIfUnsupportedMinOpsetVersion(15) - def test_mixed_optional_default_tensor(self): - class Model(torch.nn.Module): - def forward( - self, - x, - y: Optional[Tensor] = torch.ones(2, 3), - z: Optional[Tensor] = torch.zeros(2, 3), - ): + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None)) + self.run_test(MixedModel(), (x, None, z)) + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {"y": y, "z": None})) + self.run_test(MixedModel(), (x, {"y": None, "z": z})) + self.run_test(MixedModel(), (x, {"z": z})) + self.run_test(MixedModel(), (x, {"y": y})) + + @disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs + def test_optional_inputs_with_all_optionals(self): + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): if y is not None: - return x + y + return y if z is not None: - return x + z - return x + return z - x = torch.randn(2, 3) y = torch.randn(2, 3) - z = torch.randn(2, 3) - model = Model() + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None)) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {"y": y, "z": None}) - self.run_test(model, (x, y, None)) - self.run_test(model, (x, None, z)) + @disableScriptTest() + def test_input_names_with_optional_args(self): + class NoOptionalModel(torch.nn.Module): + def forward(self, input): + return input - @skipIfUnsupportedMinOpsetVersion(15) - def test_mixed_optional_default_tensor_script(self): - class Model(torch.nn.Module): - def forward( - self, - x, - y: Optional[Tensor] = torch.ones(2, 3), - z: Optional[Tensor] = torch.zeros(2, 3), - ): + # Without empty optional arguments dictionary + x = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (x,), input_names=["input_x"]) + # With empty optional arguments dictionary + y = torch.randn(2, 3) + self.run_test(NoOptionalModel(), (y, {})) + + class MixedModel(torch.nn.Module): + def forward(self, x, y=None, z=None): if y is not None: return x + y if z is not None: @@ -1315,128 +1282,54 @@ def forward( x = torch.randn(2, 3) y = torch.randn(2, 3) z = torch.randn(2, 3) - model = torch.jit.script(Model()) - - self.run_test(model, (x, y, z), input_names=("x", "y", "z")) - self.run_test(model, (x, {"y": y, "z": z}), input_names=("x", "y", "z")) - - # Requires input_names to be set so that we can feed the inputs properly into ORT. - # TODO: Export default values as ONNX initializers, then this should not raise. - # https://msdata.visualstudio.com/Vienna/_workitems/edit/969268 - # Default values are accessible via FunctionSchema. - with self.assertRaisesRegex( - ValueError, "Model requires 3 inputs. Input Feed contains 2" - ): - self.run_test(model, (x, {"y": y}), input_names=("x", "y")) + # Without optional arguments dictionary + self.run_test(MixedModel(), (x, y, None), input_names=["input_x", "input_y"]) + self.run_test(MixedModel(), (x, None, z), input_names=["input_x", "input_z"]) - for example_inputs in ( - (x, y, None), - (x, None, z), - (x, {"y": y, "z": None}), - (x, {"y": None, "z": z}), - ): - with self.assertRaisesRegex( - ValueError, "args contained 1 None's after flattening." - ): - self.run_test(model, example_inputs, input_names=("x", "y", "z")) + # With optional arguments dictionary + self.run_test(MixedModel(), (x, {"y": y, "z": None}), input_names=["input_x", "input_y"]) + self.run_test(MixedModel(), (x, {"y": None, "z": z}), input_names=["input_x", "input_z"]) - @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 - @skipIfUnsupportedMinOpsetVersion(15) - def test_all_optional_default_none(self): - class Model(torch.nn.Module): - def forward(self, x: Optional[Tensor] = None, y: Optional[Tensor] = None): - if x is not None: - return x + class AllOptionalModel(torch.nn.Module): + def forward(self, y=None, z=None): if y is not None: return y - else: - return torch.tensor(-1.0) - - x = torch.randn(2, 3) - model = Model() - self.run_test(model, (x, None)) - self.run_test( - model, - ({"x": x, "y": None},), - # y disappears in tracing. - input_names=("x",), - ) - - @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. - @skipIfUnsupportedMinOpsetVersion(15) - def test_all_optional_default_tensor(self): - class Model(torch.nn.Module): - def forward( - self, - x: Optional[Tensor] = torch.ones(2, 3), - y: Optional[Tensor] = torch.zeros(2, 3), - ): - if x is not None: - return x - elif y is not None: - return y - else: - return torch.tensor(-1.0) + if z is not None: + return z - x = torch.randn(2, 3) y = torch.randn(2, 3) - model = Model() - self.run_test(model, (x, None)) - self.run_test(model, (None, y)) - # tracing means y is never used so it's removed from the exported model inputs, - # and we fail when trying to run ORT. - with self.assertRaisesRegex(ValueError, "got too many positional inputs"): - self.run_test(model, (x, y)) - - @skipIfUnsupportedMinOpsetVersion(15) - def test_all_optional_default_tensor_script(self): + z = torch.randn(2, 3) + # Without optional arguments dictionary + self.run_test(AllOptionalModel(), (y, None), input_names=["input_y"]) + self.run_test(AllOptionalModel(), (None, z), input_names=["input_z"]) + # With optional arguments dictionary + self.run_test(AllOptionalModel(), {"y": y, "z": None}, input_names=["input_y"]) + self.run_test(AllOptionalModel(), {"y": None, "z": z}, input_names=["input_z"]) + + def test_input_as_output(self): class Model(torch.nn.Module): - def forward( - self, - x: Optional[Tensor] = torch.ones(2, 3), - y: Optional[Tensor] = torch.zeros(2, 3), - ): - if x is not None: - return x - elif y is not None: - return y - else: - return torch.tensor(-1.0) + def forward(self, x, y): + return x, y x = torch.randn(2, 3) - y = torch.randn(2, 3) - model = torch.jit.script(Model()) - - # TODO: Export default values as ONNX initializers, then this should not raise. - # https://msdata.visualstudio.com/Vienna/_workitems/edit/969268 - # Default values are accessible via FunctionSchema. - with self.assertRaisesRegex( - ValueError, "Model requires 2 inputs. Input Feed contains 1" - ): - self.run_test(model, (x,)) - self.run_test(model, ({"y": y},)) - self.run_test(model, (x, y)) - self.run_test(model, ({"x": x, "y": y},), input_names=("x", "y")) + y = torch.randn(3, 4) + self.run_test(Model(), (x, y), input_names=["x", "y"], output_names=["x_out", "y_out"]) - @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 - @skipIfUnsupportedMinOpsetVersion(15) - def test_mixed_optional(self): + @disableScriptTest() + def test_none_as_input(self): class Model(torch.nn.Module): - def forward(self, x, y: Optional[Tensor]): + def forward(self, x, y): if y is not None: return x + y return x x = torch.randn(2, 3) - model = Model() - self.run_test(model, (x, None)) - self.run_test(model, (x, x)) + self.run_test(Model(), (x, None)) - @skipScriptTest() # Needs https://github.com/pytorch/rfcs/pull/21 - @skipIfUnsupportedMinOpsetVersion(15) - def test_tuple_of_optional(self): + @disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs + def test_none_as_tuple_input(self): class Model(torch.nn.Module): - def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]): + def forward(self, x, y): if y[0] is not None: return x + y[0] if y[1] is not None: @@ -1444,67 +1337,28 @@ def forward(self, x, y: Tuple[Optional[Tensor], Optional[Tensor]]): return x x = torch.randn(2, 3) - y1 = torch.randn(2, 3) - self.run_test(Model(), (x, (None, y1))) - - @skipScriptTest() # tracing eliminates None inputs so it works differently. See _script version below. - @skipIfUnsupportedMinOpsetVersion(15) - def test_tuple_of_optional_default_tensor(self): - class Model(torch.nn.Module): - def forward( - self, - x, - y: Tuple[Optional[Tensor], Optional[Tensor]] = ( - torch.zeros(2, 3), - torch.zeros(2, 3), - ), - ): - y0, y1 = y - if y0 is not None: - return x + y0 - if y1 is not None: - return x + y1 - return x - - x = torch.randn(2, 3) - y1 = torch.randn(2, 3) - self.run_test(Model(), (x, (None, y1))) + y = torch.randn(2, 3) + self.run_test(Model(), (x, (None, y))) - @skipIfUnsupportedMinOpsetVersion(15) - def test_tuple_of_optional_default_tensor_script(self): + @disableScriptTest() # ScriptModule could not be exported without the Input Descriptor for optional inputs + def test_none_as_named_input(self): class Model(torch.nn.Module): - def forward( - self, - x, - y: Tuple[Optional[Tensor], Optional[Tensor]] = ( - torch.zeros(2, 3), - torch.zeros(2, 3), - ), - ): - y0, y1 = y - if y0 is not None: - return x + y0 - if y1 is not None: - return x + y1 + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z return x x = torch.randn(2, 3) - y0 = torch.randn(2, 3) - y1 = torch.randn(2, 3) - model = torch.jit.script(Model()) - with self.assertRaisesRegex( - ValueError, "args contained 1 None's after flattening." - ): - self.run_test(model, (x, (None, y1))) - self.run_test(model, (x, (y0, y1))) - # export succeeds, but running ORT through run_test would fail because the exported model - # has the inputs flattened into 3 inputs. - torch.onnx.export( - model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version - ) + z = torch.randn(2, 3) + self.run_test(Model(), (x, None, z)) def test_primitive_input_integer(self): class Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x: int, y): return x + y @@ -2135,7 +1989,7 @@ def forward(self, x): @skipIfUnsupportedMinOpsetVersion(12) def test_prim_min(self): @torch.jit.script - def list_append(boxes: List[Tensor]): + def list_append(boxes: List[torch.Tensor]): temp = [] for i, b in enumerate( boxes @@ -2240,15 +2094,12 @@ def forward(self, x: int, y: int): y = 2 self.run_test(ArithmeticModule(), (x, y)) - # In tracing, None outputs are removed. In scripting they're kept but - # we don't know Optional.elem_type, so we can't construct a valid Optional. - # Tests for Optional outputs (control flow with None in one branch, - # not-None in another) are in test_pytorch_onnx_no_runtime.py. - @skipScriptTest() + @disableScriptTest() def test_tuple_with_none_outputs(self): class TupleModel(torch.nn.Module): def forward(self, x): - return (x, (x, None, (x, None))) + l = (x, None, (x, None)) + return (x, l) x = torch.randn(3, 4) self.run_test(TupleModel(), (x,)) @@ -2464,7 +2315,7 @@ def forward(self, x, y): self.run_test(InputIndexSlice(), (x, y)) @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() # scripting tuple/list append + @disableScriptTest() # scripting tuple/list append def test_slice_dynamic(self): class DynamicSliceExportMod(torch.nn.Module): def forward(self, x): @@ -2507,7 +2358,7 @@ def forward(self, x): self.run_test(DynamicSliceModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() # scripting tuple/list append + @disableScriptTest() # scripting tuple/list append def test_slice_dynamic_to_end(self): class DynamicSliceExportMod(torch.nn.Module): def forward(self, x): @@ -2759,7 +2610,7 @@ def forward(self, input): self.run_test(SizeModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() # x.stride() not scriptable + @disableScriptTest() # x.stride() not scriptable def test_as_strided(self): class Model(torch.nn.Module): def forward(self, x): @@ -2774,7 +2625,7 @@ def forward(self, x): x = torch.randn(5, 8, 7) self.run_test(Model(), x) - @skipScriptTest() # Ellipses followed by tensor indexing not scriptable + @disableScriptTest() # Ellipses followed by tensor indexing not scriptable def test_tensor_index_advanced_indexing_ellipsis(self): class MyModel(torch.nn.Module): def forward(self, input): @@ -2966,7 +2817,7 @@ def forward(self, x, ind, update): self.run_test(IndexPutModel10(), (x, ind, update)) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() # Ellipses followed by tensor indexing not scriptable + @disableScriptTest() # Ellipses followed by tensor indexing not scriptable def test_index_put_ellipsis(self): class IndexPutModel(torch.nn.Module): def forward(self, x, update): @@ -3098,7 +2949,7 @@ def forward(self, x, mask): self.run_test(CopyModel5(), (x, mask)) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape) + @disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape) def test_copy_tracing(self): class CopyModel(torch.nn.Module): def forward(self, x, data): @@ -3364,7 +3215,7 @@ def test_interpolate_upsample(self): self._interpolate_tests(True) @skipIfUnsupportedMaxOpsetVersion(8) - @skipScriptTest() # Scripting supported for opsets > 8. See test_interpolate_upsample + @disableScriptTest() # Scripting supported for opsets > 8. See test_interpolate_upsample def test_interpolate_upsample_trace(self): self._interpolate_tests(True) @@ -3427,7 +3278,7 @@ def forward(self, x, y): ) self.run_test(MyModel(), (x, y), remained_onnx_input_idx=[0]) - @skipScriptTest() # scripting raises OnnxRuntimeError + @disableScriptTest() # scripting throws the ONNXRuntimeError def test_interpolate_adaptive_pooling_error(self): x = torch.randn(1, 2, 6, requires_grad=True) with self.assertRaises(RuntimeError) as cm: @@ -4103,7 +3954,7 @@ def forward(self, x, k): k = torch.tensor(3) self.run_test(MyModuleDynamic(), [x, k]) - @skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript. + @disableScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript. @skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11. def test_auto_grad(self): class MyClip(torch.autograd.Function): @@ -4367,7 +4218,7 @@ def forward(self, input, indices, values): self.run_test(ScatterModel(), input=(input, indices, values)) @torch.jit.script - def scatter_sum(src: Tensor, index: Tensor): + def scatter_sum(src: torch.Tensor, index: torch.Tensor): size = src.size() out = torch.zeros(size, dtype=src.dtype) return out.scatter_add_(1, index, src) @@ -4424,7 +4275,7 @@ def forward(self, input, indices): indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) self.run_test(GatherModel(), input=(input, indices)) - @skipScriptTest() # Scripting error: Cannot instantiate nn module + @disableScriptTest() # Scripting error: Cannot instantiate nn module def test_gather_constant_fold(self): class GatherModule(torch.nn.Module): def __init__(self): @@ -4470,16 +4321,10 @@ def forward(self, x): return x x = torch.randn(1, 3, 224, 224) - self.run_test( - GatherModule(), - (x,), - dynamic_axes={ - "input": {0: "batch", 2: "height", 3: "width"}, - "output": {0: "batch", 1: "class", 2: "height", 3: "width"}, - }, - input_names=["input"], - output_names=["output"], - ) + self.run_test(GatherModule(), (x,), + dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"}, + "output": {0: "batch", 1: "class", 2: "height", 3: "width"}}, + input_names=['input'], output_names=['output']) @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @@ -4848,7 +4693,7 @@ def __init__(self, input_size, hidden_size, num_layers, bidirectional): input_size, hidden_size, num_layers, bidirectional=bidirectional ) - def forward(self, input, initial_state: Tuple[Tensor, Tensor]): + def forward(self, input, initial_state: Tuple[torch.Tensor, torch.Tensor]): return self.lstm(input, initial_state) def get_LstmNet_model_and_inputs( @@ -4882,7 +4727,7 @@ def __init__(self, num_layers, bidirectional): bidirectional=bidirectional, ) - def forward(self, input, initial_state: Tuple[Tensor, Tensor]): + def forward(self, input, initial_state: Tuple[torch.Tensor, torch.Tensor]): return self.lstm(input, initial_state) def get_LstmNet_model_and_inputs(num_layers, bidirectional): @@ -4931,7 +4776,7 @@ def forward(self, input): }, ) - @skipScriptTest() + @disableScriptTest() def test_rnn_no_bias(self): def make_model(layers, packed_sequence): batch_first = True if packed_sequence == 2 else False @@ -5682,7 +5527,7 @@ def forward(self, input, weight, bias): z = torch.randn(1) self.run_test(LinearModel(), (x, y, z)) - @skipScriptTest() + @disableScriptTest() def test_weight_norm(self): # addmm for 3-d inputs converts to onnx::MatMul model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=1) @@ -5706,7 +5551,7 @@ def test_weight_norm(self): x = torch.randn(3, 3, 5, requires_grad=True) self.run_test(model, x) - @skipScriptTest() + @disableScriptTest() def test_weight_norm_nodim(self): # addmm for 3-d inputs converts to onnx::MatMul model = torch.nn.utils.weight_norm(torch.nn.Linear(5, 10), dim=None) @@ -5796,7 +5641,7 @@ def forward(self, x, y, i: int): i = 3 self.run_test(torch.jit.script(M()), (x, y, i)) - @skipScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable. + @disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable. @skipIfUnsupportedMinOpsetVersion(9) def test_nonzero(self): class NonzeroModel(torch.nn.Module): @@ -5875,7 +5720,7 @@ def forward(self, input): x = torch.randn(3, 4, 5) self.run_test(UnbindModel2(), x) - @skipScriptTest() # scripting tests run for opsets > 11. See: test_split_script + @disableScriptTest() # scripting tests run for opsets > 11. See: test_split_script def test_split(self): class SplitModel(torch.nn.Module): def forward(self, input): @@ -5922,12 +5767,12 @@ def forward(self, input): self.run_test(SplitModel3(), x) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_split_size_as_list(self): class SplitModel(torch.nn.Module): def forward(self, input, split_sizes: List[int]): out = [] - split_list: List[Tensor] = input.split(split_sizes) + split_list: List[torch.Tensor] = input.split(split_sizes) for ob in split_list: out.append(ob) @@ -5982,12 +5827,8 @@ def forward(self, x): x = torch.randn(4, 384, 2) input_names = ["logits"] - self.run_test( - Split(), - x, - input_names=input_names, - dynamic_axes={input_names[0]: {0: "batch"}}, - ) + self.run_test(Split(), x, input_names=input_names, + dynamic_axes={input_names[0]: {0: 'batch'}}) @skipIfUnsupportedMinOpsetVersion(11) def test_chunk(self): @@ -6541,7 +6382,7 @@ def forward(self, x): self.run_test(OnesModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable. + @disableScriptTest() # torch.zeros/torch.ones with size tensor of dim != 0 not scriptable. def test_zeros_ones_with_tensor_input(self): class ZeroAndOnes(torch.nn.Module): def forward(self, x): @@ -6838,7 +6679,7 @@ def forward(self, x): @skipIfUnsupportedMinOpsetVersion(14) # Need onnx::Identity of sequence in opset 14 def test_inplace_sequence_with_loop(self): class M(torch.nn.Module): - def process(self, beam_hyps: List[Tensor], done: Tensor, x): + def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x): batch_size = x.shape[0] for i in range(batch_size): if done[i]: @@ -6857,7 +6698,7 @@ def process(self, beam_hyps: List[Tensor], done: Tensor, x): return beam_hyps, done def forward(self, x): - beam_hyps: List[Tensor] = [] + beam_hyps: List[torch.Tensor] = [] batch_size = x.shape[0] cur_len = 0 max_len = x.shape[1] @@ -6872,7 +6713,8 @@ def forward(self, x): x = torch.randn(8, 4, 3) self.run_test(torch.jit.script(M()), (x)) - @skipScriptTest() # Sort with dynamic dim not supported in ONNX + + @disableScriptTest() # Sort with dynamic dim not supported in ONNX def test_sort(self): class SortModel(torch.nn.Module): def forward(self, x): @@ -6885,7 +6727,7 @@ def forward(self, x): self.run_test(SortModel(), x) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() # Sort with dynamic dim not supported in ONNX + @disableScriptTest() # Sort with dynamic dim not supported in ONNX def test_sort_ascending(self): class SortModel(torch.nn.Module): def forward(self, x): @@ -7067,10 +6909,9 @@ def forward(self, x): class CatModel(torch.nn.Module): def forward(self, fp16, fp32): return torch.cat([fp16, fp32]) - - fp16 = Tensor([0.5]) + fp16 = torch.Tensor([0.5]) fp16 = fp16.half() - fp32 = Tensor([1.5]) + fp32 = torch.Tensor([1.5]) self.run_test(CatModel(), (fp16, fp32)) @skipIfUnsupportedMinOpsetVersion(9) @@ -7198,7 +7039,7 @@ def forward(self, input, other): y = torch.randint(10, (5,)) self.run_test(MatmulModel(), (x, y)) - @skipScriptTest() # SpectralNorm not TorchScript compatible. + @disableScriptTest() # SpectralNorm not TorchScript compatible. def test_spectral_norm(self): m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4)) @@ -7239,13 +7080,9 @@ def forward(self, x): x = torch.randn(2, 3, 4) * 100.0 y = torch.randn(2, 4, 5) * 100.0 - self.run_test( - Relu6Model(), - x, - input_names=["x"], - dynamic_axes={"x": [1, 2]}, - test_with_inputs=[y], - ) + self.run_test(Relu6Model(), x, input_names=['x'], + dynamic_axes={'x': [1, 2]}, + test_with_inputs=[y]) def test_silu(self): class SiLUModel(torch.nn.Module): @@ -7479,7 +7316,7 @@ def forward(self, input): x = torch.tensor([False, True, True]) self.run_test(model, x) - @skipScriptTest() # error in propagate as assign input shape + @disableScriptTest() # error in propagate as assign input shape @skipIfUnsupportedMinOpsetVersion(10) def test_embedding_bag(self): model = torch.nn.EmbeddingBag(10, 5, mode="sum", scale_grad_by_freq=True) @@ -7540,7 +7377,7 @@ def forward(self, embedding_matrix, input, weights): test_with_inputs=[(embedding_matrix, x2, w2)], ) - @skipScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast + @disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast @skipIfUnsupportedMinOpsetVersion(11) @unittest.skip( "Due to ONNX Loop shape inference issue. " @@ -7837,7 +7674,7 @@ def forward(self, x, pad: List[int]): self.run_test(Pad(), (x, y)) @skipIfUnsupportedMaxOpsetVersion(10) - @skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script + @disableScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script def test_unsupported_pad(self): class Pad(torch.nn.Module): def forward(self, x, pad: List[int]): @@ -8337,7 +8174,7 @@ def forward(self, poses): return batch_boxes dummy_inputs = torch.rand(2, 2, 3) - self.run_test(M(), (dummy_inputs,), input_names=["x"], dynamic_axes={"x": [0]}) + self.run_test(M(), (dummy_inputs, ), input_names=['x'], dynamic_axes={"x": [0]}) @skipIfUnsupportedMinOpsetVersion(12) def test_outer(self): @@ -8742,17 +8579,11 @@ def test_nllloss_dynamic_ignore_index(self): def linear_combination(x, y, epsilon): return epsilon * x + (1 - epsilon) * y - def reduce_loss(loss, reduction="mean"): - return ( - loss.mean() - if reduction == "mean" - else loss.sum() - if reduction == "sum" - else loss - ) + def reduce_loss(loss, reduction='mean'): + return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss class LabelSmoothingCrossEntropy(torch.nn.Module): - def __init__(self, epsilon: float = 0.1, reduction="mean"): + def __init__(self, epsilon: float = 0.1, reduction='mean'): super().__init__() self.epsilon = epsilon self.reduction = reduction @@ -9251,7 +9082,7 @@ def forward(self, cond, input, other): self.run_test(Model(), (x, y, z)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script + @disableScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script def test_where_condition(self): class Model1(torch.nn.Module): def forward(self, input): @@ -9306,7 +9137,7 @@ def forward(self, input): @skipIfUnsupportedMinOpsetVersion(11) def test_derive_index_scripting(self): class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(len(x) - 1, -len(x), -2): y = x[idx] @@ -9317,7 +9148,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(-len(x), len(x) - 1, 2): y = x[idx] @@ -9328,7 +9159,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(len(x) - 1, -len(x), -3): y = x[idx] @@ -9338,7 +9169,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(-len(x), len(x) - 1, 3): y = x[idx] @@ -9347,10 +9178,10 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) - @skipScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting + @disableScriptTest() # Scripting fails for add lists for opsets < 11. Chek test_derive_index_scripting def test_derive_index(self): class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(len(x) - 1, -len(x), -2): y = x[idx] @@ -9361,7 +9192,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(-len(x), len(x) - 1, 2): y = x[idx] @@ -9372,7 +9203,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(len(x) - 1, -len(x), -3): y = x[idx] @@ -9382,7 +9213,7 @@ def forward(self, x: Tensor): self.run_test(MyModule(), x) class MyModule(torch.nn.Module): - def forward(self, x: Tensor): + def forward(self, x: torch.Tensor): j = [] for idx in range(-len(x), len(x) - 1, 3): y = x[idx] @@ -9465,7 +9296,16 @@ def check_proto(): self.assertRaises(RuntimeError, check_proto) - @skipScriptTest(min_opset_version=11) # dynamic split support addded in 11 + @skipIfUnsupportedMinOpsetVersion(11) + def test_split_tensor_scalar_scripting(self): + class SplitModel(torch.nn.Module): + def forward(self, x): + return torch.split(x, x.size(1)) + + x = torch.randn(1, 2, 3, requires_grad=True) + self.run_test(SplitModel(), x) + + @disableScriptTest() # Scripting fails to export dynamic split for opsets < 11 def test_split_tensor_scalar(self): class SplitModel(torch.nn.Module): def forward(self, x): @@ -9846,7 +9686,7 @@ def make_input(batch_size): other_input = make_input(RNN_BATCH_SIZE + 1) self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1) - @skipScriptTest() # TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1253950 + @disableScriptTest() # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported def test_transformer_encoder(self): from torch.nn import TransformerEncoder, TransformerEncoderLayer @@ -9911,7 +9751,7 @@ def forward(self, input): self.run_test(FakeQuantizePerChannelModel(), (x)) @skipIfUnsupportedMinOpsetVersion(13) - @skipScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear + @disableScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear def test_fake_quantize_activation(self): from torch import quantization @@ -10173,24 +10013,16 @@ def forward(self, x): x = torch.randn(10) model.train() - ort_sess = convert_to_onnx( - model, - input=(x,), - opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING, - ) - ort_outs = run_ort(ort_sess, (x,)) + ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) script_model = torch.jit.script(model) output = model(x) - ort_sess = convert_to_onnx( - script_model, - input=(x,), - opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING, - ) - ort_outs = run_ort(ort_sess, (x,)) + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) @skipIfUnsupportedMinOpsetVersion(12) @@ -10214,13 +10046,9 @@ def forward(self, x): nb_elements = torch.numel(input) model.train() - ort_sess = convert_to_onnx( - model, - input=(x,), - opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING, - ) - ort_outs = run_ort(ort_sess, (x,)) + ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) y = model(input) output = y.cpu().numpy() @@ -10235,13 +10063,9 @@ def forward(self, x): script_model = torch.jit.script(model) y = model(input) output = y.cpu().numpy() - ort_sess = convert_to_onnx( - script_model, - input=(x,), - opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING, - ) - ort_outs = run_ort(ort_sess, (x,)) + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) ort_mask = np.where(ort_outs[0] != 0, 1, 0) pyt_mask = np.where(output != 0, 1, 0) @@ -10339,7 +10163,7 @@ def __init__(self): super(MyModule, self).__init__() self.box_coder = BoxCoder(1.4) - def forward(self, box_regression: Tensor, proposals: List[Tensor]): + def forward(self, box_regression: torch.Tensor, proposals: List[torch.Tensor]): return self.box_coder.decode(box_regression, proposals) model = torch.jit.script(MyModule()) @@ -10473,7 +10297,6 @@ def forward(self, boxes, scores, idxs): self.run_test(Module(), (boxes, scores, idxs)) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() def test_clip_boxes_to_image(self): boxes = torch.randn(5, 4) * 500 boxes[:, 2:] += boxes[:, :2] @@ -10555,14 +10378,14 @@ def forward(self, images): ) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_transform_images(self): class TransformModule(torch.nn.Module): def __init__(self): super(TransformModule, self).__init__() self.transform = _init_test_generalized_rcnn_transform() - def forward(self, images: List[Tensor]): + def forward(self, images: List[torch.Tensor]): return self.transform(images)[0].tensors input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) @@ -10584,7 +10407,7 @@ def get_features(self, images): return features @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_rpn(self): set_rng_seed(0) @@ -10593,10 +10416,8 @@ def __init__(self): super(RPNModule, self).__init__() self.rpn = _init_test_rpn() - def forward(self, images, features: Dict[str, Tensor]): - images_m = ImageList( - images, [(i.shape[-1], i.shape[-2]) for i in images] - ) + def forward(self, images, features: Dict[str, torch.Tensor]): + images_m = ImageList(images, [(i.shape[-1], i.shape[-2]) for i in images]) return self.rpn(images_m, features) images = torch.rand(2, 3, 150, 150) @@ -10625,7 +10446,7 @@ def forward(self, images, features: Dict[str, Tensor]): @skipIfUnsupportedMaxOpsetVersion(15) # TODO: Opset 16 RoiAlign result mismatch @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_multi_scale_roi_align(self): class TransformModule(torch.nn.Module): def __init__(self): @@ -10667,7 +10488,7 @@ def forward(self, input: Dict[str, Tensor], boxes: List[Tensor]) -> Tensor: ) @skipIfUnsupportedMinOpsetVersion(11) - @skipScriptTest() + @disableScriptTest() def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): def __init__(self): @@ -10676,10 +10497,8 @@ def __init__(self): self.rpn = _init_test_rpn() self.roi_heads = _init_test_roi_heads_faster_rcnn() - def forward(self, images, features: Dict[str, Tensor]): - original_image_sizes = [ - (img.shape[-1], img.shape[-2]) for img in images - ] + def forward(self, images, features: Dict[str, torch.Tensor]): + original_image_sizes = [(img.shape[-1], img.shape[-2]) for img in images] images_m = ImageList( images, [(i.shape[-1], i.shape[-2]) for i in images] @@ -10729,14 +10548,9 @@ def forward(self, x, y): self.run_test(M(), (x, y), remained_onnx_input_idx=[1]) y2 = torch.randn(5, 2) - self.run_test( - M(), - (x, y), - remained_onnx_input_idx=[1], - input_names=["x", "y"], - dynamic_axes={"x": [0, 1], "y": [0, 1]}, - test_with_inputs=[(y, y2)], - ) + self.run_test(M(), (x, y), remained_onnx_input_idx=[1], input_names=['x', 'y'], + dynamic_axes={'x': [0, 1], 'y': [0, 1]}, + test_with_inputs=[(y, y2)]) @skipIfUnsupportedMinOpsetVersion(9) def test_set_attr_modules(self): @@ -10755,7 +10569,7 @@ def get_embedding(embedding_dim: int): ) return emb - def forward(self, input, incremental_state: Optional[Tensor] = None): + def forward(self, input, incremental_state: Optional[torch.Tensor] = None): bsz, seq_len = input.shape[0], input.shape[1] self.const = 3 if self.weights is None: @@ -10816,7 +10630,7 @@ def get_embedding(embedding_dim: int): ) return emb - def forward(self, input, incremental_state: Optional[Tensor] = None): + def forward(self, input, incremental_state: Optional[torch.Tensor] = None): bsz, seq_len = input.shape[0], input.shape[1] self.const = 1.5 self.weights = InnerModule.get_embedding(self.embedding_dim) @@ -10877,7 +10691,7 @@ def set_cell_anchors(self, anchors): self.conv.weight = torch.randn(3, 10) self.conv.bias = self.conv.weight[:] - def forward(self, anchors) -> Optional[Tensor]: + def forward(self, anchors) -> Optional[torch.Tensor]: self.set_cell_anchors(anchors) return self.conv.bias @@ -10901,7 +10715,7 @@ def set_cell_anchors(self, anchors, boxes): self.conv.weight = anchors + self.conv.weight boxes[:] = torch.zeros(2, 3) - def forward(self, anchors) -> Tuple[Tensor, Tensor]: + def forward(self, anchors) -> Tuple[torch.Tensor, torch.Tensor]: boxes = torch.ones(2, 2, 3) self.set_cell_anchors(anchors, boxes) if self.conv.bias is not None: @@ -10929,7 +10743,7 @@ def set_cell_anchors(self, anchors): else: self.conv.bias = torch.ones(3, 10, 3) - def forward(self, feature_maps, anchors) -> Tuple[Tensor, Tensor]: + def forward(self, feature_maps, anchors) -> Tuple[torch.Tensor, torch.Tensor]: self.set_cell_anchors(anchors) result = [] if self.conv.bias is not None: @@ -10992,7 +10806,7 @@ def set_cell_anchors(self, anchors, boxes): self.conv.weight = anchors * i boxes[j] += torch.ones(3, 3) - def forward(self, anchors) -> Tuple[Tensor, Tensor]: + def forward(self, anchors) -> Tuple[torch.Tensor, torch.Tensor]: boxes = torch.ones(10, 3, 3) self.set_cell_anchors(anchors, boxes) if self.conv.bias is not None: @@ -11011,9 +10825,7 @@ def __init__(self): self.conv = torch.nn.Conv1d(10, 3, 3) self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10)) self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3)) - self.boxes: List[Tensor] = [ - torch.ones(1) - ] # Workaround placeholder for TorchScript + self.boxes : List[torch.Tensor] = [torch.ones(1)] # Workaround placeholder for TorchScript def set_cell_anchors(self, anchors): self.conv.weight = torch.randn(3, 10) @@ -11485,7 +11297,7 @@ def forward(self, model_input_1, model_input_2, y): ) self.run_test(m, (x1, x2, y)) - @skipScriptTest() + @disableScriptTest() def test_unsafe_chunk(self): class ChunkModel(torch.nn.Module): def forward(self, x): @@ -11624,9 +11436,7 @@ def forward(self, start): return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.int64) x = torch.randn(2, 3, 4) - self.run_test( - ArangeModel(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]} - ) + self.run_test(ArangeModel(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) self.run_test(ArangeModel(), (x,), remained_onnx_input_idx=[]) class ArangeModel2(torch.nn.Module): @@ -11634,9 +11444,7 @@ def forward(self, start): return torch.arange(start.size(0), 8.5, 1.5, dtype=torch.double) x = torch.randn(2, 3, 4) - self.run_test( - ArangeModel2(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]} - ) + self.run_test(ArangeModel2(), (x,), input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) self.run_test(ArangeModel2(), (x,), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) @@ -11652,12 +11460,10 @@ def forward(self, x): return torch.nonzero(ones) x = torch.randn(2) - self.run_test(OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]}) + self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]}) self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) x = torch.randn(2, 3, 4) - self.run_test( - OneLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} - ) + self.run_test(OneLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) self.run_test(OneLikeModel(), x, remained_onnx_input_idx=[]) class ZeroLikeModel(torch.nn.Module): @@ -11671,12 +11477,10 @@ def forward(self, x): return torch.nonzero(zeros) x = torch.randn(2) - self.run_test(ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0]}) + self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0]}) self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) x = torch.randn(2, 3, 4) - self.run_test( - ZeroLikeModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2]} - ) + self.run_test(ZeroLikeModel(), x, input_names=['x'], dynamic_axes={"x": [0, 1, 2]}) self.run_test(ZeroLikeModel(), x, remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(9) @@ -11689,7 +11493,7 @@ def forward(self, x): self.run_test(ExpandModel(), (x,)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() # Test code not scriptable + @disableScriptTest() # Test code not scriptable def test_symbolic_shape_inference_expand_2(self): class M(torch.nn.Module): def forward(self, x): @@ -11707,7 +11511,7 @@ def forward(self, x): self.run_test(M(), (x,), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() # Test code not scriptable + @disableScriptTest() # Test code not scriptable def test_symbolic_shape_inference_slice(self): class M(torch.nn.Module): def forward(self, x, position_bias): @@ -11736,7 +11540,7 @@ def forward(self, position_bias): self.run_test(M(), (position_bias,)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() + @disableScriptTest() def test_symbolic_shape_inference_time(self): input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) @@ -11823,7 +11627,7 @@ def forward(self, x, window_length: int): self.run_test(module, (x, win_length)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() + @disableScriptTest() def test_hann_window_default_values(self): class HannWindowModule(torch.nn.Module): def __init__(self): @@ -11844,7 +11648,7 @@ def forward(self, x, window_length: int): self.run_test(module, (x, win_length)) @skipIfUnsupportedMinOpsetVersion(12) - @skipScriptTest() + @disableScriptTest() def test_tensordot_dim_count(self): class M(torch.nn.Module): def forward(self, x, y): @@ -11869,7 +11673,7 @@ def forward(self, x, y): self.run_test(M(), (x, y)) @skipIfUnsupportedMinOpsetVersion(12) - @skipScriptTest() + @disableScriptTest() def test_tensordot_dynamic_dim(self): class M(torch.nn.Module): def forward(self, x, y): @@ -11907,7 +11711,7 @@ def forward(self, x, y): self.run_test(M_ToDeviceDtype(), (x, y)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() + @disableScriptTest() def test_fill(self): class FillModule(torch.nn.Module): def forward(self, x, filled_value: int): @@ -12045,13 +11849,9 @@ def forward(self, x): ) index = torch.tensor([0, 2, 3, 1]) - self.run_test( - M(1, index, updates), - (x,), - test_with_inputs=[y], - input_names=["input_1"], - dynamic_axes={"input_1": [0, 1]}, - ) + self.run_test(M(1, index, updates), (x,), test_with_inputs=[y], + input_names=['input_1'], + dynamic_axes={'input_1': [0, 1]}) def test_roll(self): class M(torch.nn.Module): @@ -12075,7 +11875,7 @@ def forward(self, x): return torch.sum(x) x = torch.ones(12, 3) - self.run_test(M(), (x,), input_names=["x"], dynamic_axes={"x": [0]}) + self.run_test(M(), (x,), input_names=['x'], dynamic_axes={'x': [0]}) def test_sum_empty_tensor(self): class M(torch.nn.Module): @@ -12113,7 +11913,7 @@ def forward(self, x, y): self.run_test(M(), (x, y)) - @skipScriptTest() + @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_normal(self): class M(torch.nn.Module): @@ -12131,7 +11931,7 @@ def forward(self, x, y): ), ) - @skipScriptTest() + @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_normal_correctness(self): class M(torch.nn.Module): @@ -12150,7 +11950,7 @@ def forward(self, x, y): training=torch.onnx.TrainingMode.EVAL, ) - ort_out = run_ort(ort_sess, inputs=dummy_input) + ort_out = run_ort(ort_sess, input=dummy_input) actual_std = np.std(ort_out) actual_mean = np.mean(ort_out) @@ -12162,7 +11962,7 @@ def forward(self, x, y): abs(abs(actual_std) - expected_std) <= expected_std * 0.1 ), "the gap of variance between ort outputs and expected one is unacceptable." - @skipScriptTest() + @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_uniform(self): class M(torch.nn.Module): @@ -12175,7 +11975,7 @@ def forward(self, x, y): M(), (torch.tensor([1.0]), torch.tensor([[10.0], [7.0], [9.0], [20.0]])) ) - @skipScriptTest() + @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_uniform_correctness(self): class M(torch.nn.Module): @@ -12195,7 +11995,7 @@ def forward(self, x, y): training=torch.onnx.TrainingMode.EVAL, ) - ort_out = run_ort(ort_sess, inputs=dummy_input) + ort_out = run_ort(ort_sess, input=dummy_input) actual_min = np.min(ort_out) actual_max = np.max(ort_out) actual_mean = np.mean(ort_out) @@ -12391,7 +12191,7 @@ def forward(self, input): self.run_test(FlattenModel(), x) @skipIfUnsupportedMinOpsetVersion(10) - @skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: + @disableScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function: def test_quantized_arithmetic_qfunctional(self): x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) @@ -12704,7 +12504,7 @@ def make_test( # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382 # Operator aten::_pack_padded_sequence is not supported by exporter yet. # - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384 - @skipScriptTest() + @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(9) def f(self): self.is_script_test_enabled = ( @@ -12758,17 +12558,15 @@ def setup_rnn_tests(): # Need Add between list of tensors script_test_min_opset_version = 11 - if ( # compiling in script mode fails with errors like: - # torch.jit.frontend.UnsupportedNodeError: annotated assignments - # without assigned value aren't supported - # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 - base == "elman" - or - # compiling in script mode fails with errors like: - # RuntimeError: Arguments for call are not valid. - # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 - base == "lstm" - ): + if ( # compiling in script mode fails with errors like: + # torch.jit.frontend.UnsupportedNodeError: annotated assignments + # without assigned value aren't supported + # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 + base == 'elman' or + # compiling in script mode fails with errors like: + # RuntimeError: Arguments for call are not valid. + # https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723 + base == 'lstm'): script_test_min_opset_version = float("inf") make_test( name, diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py index c6279e11945cd..a3dc283541264 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py @@ -14,6 +14,10 @@ import torch from torch.cuda.amp import autocast +from test_pytorch_common import disableScriptTest, skipIfUnsupportedMinOpsetVersion +from test_pytorch_common import skipIfNoCuda, skipIfNoBFloat16Cuda + +from test_pytorch_onnx_onnxruntime import TestONNXRuntime class TestONNXRuntime_cuda(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version @@ -42,7 +46,7 @@ def forward(self, x): @skipIfUnsupportedMinOpsetVersion(9) @skipIfNoCuda - @skipScriptTest() + @disableScriptTest() def test_layer_norm_fp16(self): class LayerNormModel(torch.nn.Module): def __init__(self): @@ -66,7 +70,7 @@ def forward(self, x): @skipIfUnsupportedMinOpsetVersion(12) @skipIfNoCuda - @skipScriptTest() + @disableScriptTest() def test_softmaxCrossEntropy_fusion_fp16(self): class FusionModel(torch.nn.Module): def __init__(self): @@ -90,7 +94,7 @@ def forward(self, input, target): self.run_test(FusionModel(), (input, target)) @skipIfNoCuda - @skipScriptTest() + @disableScriptTest() def test_apex_o2(self): class LinearModel(torch.nn.Module): def __init__(self): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 1c39140645f25..f1b0bc1bb5b3f 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1447,7 +1447,7 @@ def forward(self, x): module = RenamedIntermediateModule() - g, p, o = utils._model_to_graph(module, torch.ones(1, 10), output_names=["y"]) + g, p, o = utils._model_to_graph(module, torch.ones(1, 10), output_names=['y']) renamed_intermediate = 0 for n in g.nodes(): for v in n.inputs(): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 80828056585e0..1b38139a68499 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -322,7 +322,7 @@ def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... -def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool, is_script: _bool) -> None: ... +def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ... def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Module) -> None: ... def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index afeb2ac7f3f50..66641057448d1 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -185,34 +185,6 @@ std::vector ConvertSequenceDependencies(Node* node, int opset_version) { return new_outputs; } -Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) { - TORCH_INTERNAL_ASSERT(opt_type); - TypePtr elem_type = opt_type->getElementType(); - Node* opt_node = g->create(::c10::onnx::Optional, 1); - opt_node->ty_(Symbol::attr("type"), elem_type); - opt_node->output()->setType(OptionalType::create(elem_type)); - return opt_node; -} - -// Replaces block output i with an onnx::Optional -// with `type` taken from opt_type. -// Needed when control flow has multiple branches, one of which -// is defined by `block` and returns a None and another branch -// returns not-None. The passed-in opt_type should be from the other branch. -void ReplaceBlockOutputWithOptional( - OptionalTypePtr opt_type, - Block* block, - size_t i) { - Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph()); - opt_node->insertBefore(block->return_node()); - Value* block_output = block->outputs().at(i); - block_output->replaceAllUsesWith(opt_node->output()); - if (!block_output->type()->cast()) { - opt_node->addInput(block_output); - opt_node->copyMetadata(block_output->node()); - } -} - // Resolving limitation from ONNX that the block output can not be // a value from outside the block. Inserting an Identity node inside // the block, linking with the value outside as workaround. @@ -220,17 +192,9 @@ void FixupONNXSubblockOutputs(Node* n) { for (Block* block : n->blocks()) { for (Value* output : block->outputs()) { if (output->node()->owningBlock() != block) { - Node* id_node = nullptr; - // Simplify graph by creating an empty optional rather than - // Identity(None). Also enables shape inference later on, since - // ONNX shape inference doesn't handle None. - if (output->type()->cast()) { - id_node = block->owningGraph()->create(onnx::Optional); - } else { - id_node = block->owningGraph()->create(onnx::Identity); - id_node->addInput(output); - } + Node* id_node = block->owningGraph()->create(onnx::Identity); id_node->insertBefore(block->return_node()); + id_node->addInput(output); id_node->output()->copyMetadata(output); id_node->copyMetadata(n); block->return_node()->replaceInputWith(output, id_node->output()); @@ -239,45 +203,7 @@ void FixupONNXSubblockOutputs(Node* n) { } } -// Infer type of optional inputs from outputs. -void FixupONNXLoopBlockInputs(Node* n) { - for (Block* block : n->blocks()) { - for (const auto i : c10::irange(1, block->inputs().size())) { - // input i corresponds to output i until we run FixupONNXLoopNodeInputs. - Value* input_i = block->inputs().at(i); - if (input_i->type()->cast() && - !block->outputs().at(i)->type()->cast()) { - TypePtr merged_type; - bool inferred = false; - std::tie(merged_type, inferred) = MergeInferredType( - input_i->type()->cast()->getElementType(), - block->outputs().at(i)->type()); - if (inferred) { - input_i->setType(OptionalType::create(merged_type)); - } - } - } - } -} - -// Replace None in outputs with Optional. -void FixupONNXLoopBlockOutputs(Node* n) { - for (Block* block : n->blocks()) { - // output 0 is continue_condition, never None. - for (const auto i : c10::irange(1, block->outputs().size())) { - if (block->outputs().at(i)->type()->cast()) { - ReplaceBlockOutputWithOptional( - // Output 0 is continue_condition. - // Inputs (0, 1) are (loop_counter, cond). So input i + 1 - // corresponds to output i. - block->inputs().at(i + 1)->type()->cast(), - block, - i); - } - } - } - FixupONNXSubblockOutputs(n); -} +} // anonymous namespace void FixupONNXLoopNodeInputs(Node* node) { if (node->kind() != ::c10::onnx::Loop) { @@ -309,53 +235,18 @@ void FixupONNXLoopNodeInputs(Node* node) { InsertCastForCond(next_cond_val, graph, sub_block->return_node()); cast_node->copyMetadata(node); } - - // Inputs (0, 1) are (max_trip_count, start_condition). Skip them - // since they're never None or Optional. - for (const auto i : c10::irange(2, node->inputs().size())) { - Value* input = node->inputs().at(i); - OptionalTypePtr sub_block_input_optional = - sub_block->inputs().at(i)->type()->cast(); - // If loop input is not optional but block input is, wrap loop input with - // Optional. Happens when the loop takes in None and outputs not-None, or - // vice-versa. - if (!input->type()->cast() && sub_block_input_optional) { - if (!input->type()->cast()) { - TypePtr merged_type; - bool inferred = false; - std::tie(merged_type, inferred) = MergeInferredType( - sub_block_input_optional->getElementType(), input->type()); - if (inferred) { - sub_block_input_optional = OptionalType::create(merged_type); - sub_block->inputs().at(i)->setType(sub_block_input_optional); - } - } - Node* opt_node = ONNXOptionalNode(sub_block_input_optional, graph); - if (!input->type()->cast()) { - opt_node->addInput(input); - } - opt_node->insertBefore(node); - node->replaceInputWith(input, opt_node->output()); - } - } } -} // anonymous namespace std::vector FixupONNXLoopNode(Node* node, int opset_version) { auto output_size = node->outputs().size(); - GRAPH_DEBUG("before FixupONNXLoopBlockInputs: ", *node->owningGraph()); - FixupONNXLoopBlockInputs(node); - GRAPH_DEBUG("after FixupONNXLoopBlockInputs: ", *node->owningGraph()); FixupONNXLoopNodeInputs(node); - GRAPH_DEBUG("after FixupONNXLoopNodeInputs: ", *node->owningGraph()); - FixupONNXLoopBlockOutputs(node); - GRAPH_DEBUG("after FixupONNXLoopBlockOutputs: ", *node->owningGraph()); + FixupONNXSubblockOutputs(node); // NOTE: the output order is deliberately changed to match expected order // since onnx loop requires scan outputs to be the last outputs. auto new_outputs = ConvertSequenceDependencies(node, opset_version); + // Copy type of block output to node output. FixupONNXControlflowNodeOutputs(node); - GRAPH_DEBUG("after FixupONNXControlflowNodeOutputs: ", *node->owningGraph()); TORCH_INTERNAL_ASSERT(output_size == new_outputs.size()); return new_outputs; } @@ -373,7 +264,8 @@ bool IsUninitializedNode(Node* n) { // Infer shape and type of the uninitialized_output from the corresponding // output of the other subblock. prim::Uninitialized node is proven to be -// unused. So replace this node with one of the inferred shape and type. +// unused. So replace this node with a Constant (TensorType) or +// Sequence (ListType) of the inferred shape and type. void InferShapeTypeForUninitializedOutput( Graph* graph, Block* block, @@ -404,24 +296,14 @@ void InferShapeTypeForUninitializedOutput( auto onnx_type = ATenTypeToOnnxType(scalar_type); const_node->i_(attr::dtype, onnx_type); const_node->output()->setType(other_output->type()); - } else if (elem->cast()) { - auto scalar_type = at::kLong; - auto onnx_type = ATenTypeToOnnxType(scalar_type); - const_node->i_(attr::dtype, onnx_type); - const_node->output()->setType(other_output->type()); } else { std::cerr << "Warning: UninitializedOutput - Invalid elem Type of ListTensor found." << std::endl; const_node->output()->setType(other_output->type()); } - } else if (auto output_type = other_output->type()->cast()) { - const_node = ONNXOptionalNode(output_type, graph); } - TORCH_CHECK( - const_node, - "Inferring type for prim::Uninitialized node from " + - other_output->type()->repr_str() + " not supported.") + const ParamMap empty_params_dict = {}; ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); const_node->insertBefore(block->return_node()); @@ -572,90 +454,23 @@ void ONNXMergeIfBlockOutputShapes(Node* node) { return nullptr; }; - auto mergeOptionalType = [&mergeTensorType, &mergeListType]( - OptionalTypePtr a, - OptionalTypePtr b) -> OptionalTypePtr { - if (a && b) { - if (a->getElementType()->cast()) { - auto a_tensor_type = a->getElementType()->cast(); - auto b_tensor_type = b->getElementType()->cast(); - auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type); - if (tensor_type) { - return a->withContained({tensor_type})->cast(); - } - // Both branches produce OptionalType without tensor shape. - return a; - } else if (a->getElementType()->cast()) { - auto a_list_type = a->getElementType()->cast(); - auto b_list_type = b->getElementType()->cast(); - auto list_type = mergeListType(a_list_type, b_list_type); - if (list_type) { - return a->withContained({list_type})->cast(); - } - // Both branches produce OptionalType without tensor shape. - return a; - } - } else if (a) { - return a; - } else if (b) { - return b; - } - return nullptr; - }; - for (const auto i : c10::irange(else_block->outputs().size())) { - Value* output_i = node->output(i); auto then_type = then_block->outputs().at(i)->type(); auto else_type = else_block->outputs().at(i)->type(); auto then_tensor_type = then_type->cast(); auto else_tensor_type = else_type->cast(); auto then_list_type = then_type->cast(); auto else_list_type = else_type->cast(); - auto then_optional_type = then_type->cast(); - auto else_optional_type = else_type->cast(); - auto then_none_type = then_type->cast(); - auto else_none_type = else_type->cast(); if (then_tensor_type || else_tensor_type) { - if (TypePtr merged_type = + if (auto tensor_type = mergeTensorType(then_tensor_type, else_tensor_type)) { - if (else_optional_type || else_none_type || then_optional_type || - then_none_type) { - merged_type = OptionalType::create(merged_type); - } - output_i->setType(merged_type); + node->output(i)->setType(tensor_type); } } else if (then_list_type || else_list_type) { - if (TypePtr merged_type = mergeListType(then_list_type, else_list_type)) { - if (else_optional_type || else_none_type || then_optional_type || - then_none_type) { - merged_type = OptionalType::create(merged_type); - } - output_i->setType(merged_type); - } - } - - if (then_optional_type || else_optional_type) { - if (auto optional_type = - mergeOptionalType(then_optional_type, else_optional_type)) { - output_i->setType(optional_type); - // Both branches output types must match. - if (!then_optional_type) { - ReplaceBlockOutputWithOptional(optional_type, then_block, i); - } else if (!else_optional_type) { - ReplaceBlockOutputWithOptional(optional_type, else_block, i); - } + if (auto list_type = mergeListType(then_list_type, else_list_type)) { + node->output(i)->setType(list_type); } } - - if (then_none_type && !else_optional_type) { - ReplaceBlockOutputWithOptional( - output_i->type()->cast(), then_block, i); - } - - if (else_none_type && !then_optional_type) { - ReplaceBlockOutputWithOptional( - output_i->type()->cast(), else_block, i); - } } } @@ -666,6 +481,7 @@ std::vector FixupONNXIfNode(Node* node, int opset_version) { GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph()); FixupONNXSubblockOutputs(node); ONNXFixupUninitializedOutput(node, opset_version); + // Copy type of block output to node output. ONNXMergeIfBlockOutputShapes(node); GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph()); @@ -688,28 +504,12 @@ std::vector FixupONNXControlflowNode(Node* n, int opset_version) { void FixupONNXControlflowNodeOutputs(Node* n) { switch (n->kind()) { case ::c10::onnx::Loop: { - Block* loop_block = n->blocks().at(0); - // inputs (0, 1) are (i, cond), remainder are carried outputs. - size_t loop_carried_output_size = loop_block->inputs().size() - 2; - + auto loop_carried_output_size = n->blocks().at(0)->inputs().size() - 2; for (auto i : c10::irange(n->outputs().size())) { + auto type = n->blocks().at(0)->outputs().at(i + 1)->type(); if (i < loop_carried_output_size) { - const TypePtr block_input_type = - loop_block->inputs().at(i + 2)->type(); - const TypePtr block_output_type = - loop_block->outputs().at(i + 1)->type(); - TypePtr type = block_output_type; - // Handle the case where a block input is Optional but the - // output is not (i.e. if the loop executes > 0 times, the - // output will not be None). - if (block_input_type->cast() && - !block_output_type->cast()) { - type = OptionalType::create(block_output_type); - } n->output(i)->setType(type); } else { - // scan output, should be a Tensor type - TypePtr type = loop_block->outputs().at(i + 1)->type(); if (auto t_type = type->cast()) { auto sizes = t_type->symbolic_sizes().sizes(); if (sizes.has_value()) { diff --git a/torch/csrc/jit/passes/onnx/peephole.h b/torch/csrc/jit/passes/onnx/peephole.h index 7d23267310ab8..2eb38b7e99334 100644 --- a/torch/csrc/jit/passes/onnx/peephole.h +++ b/torch/csrc/jit/passes/onnx/peephole.h @@ -10,5 +10,5 @@ void PeepholeOptimizeONNX( int opset_version, bool fixed_batch_size); -} // namespace jit +} } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index cba23985a3bb4..b3d5bc763967e 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -290,10 +290,9 @@ static void UpdateScalarTypeForInputs( static void UpdateScalarTypeForOutput( Node* n, const c10::ScalarType& scalar_type) { - if (auto output_tensor_type = n->output()->type()->cast()) { - n->output()->setType(CreateProfiledTensorTypeWithScalarType( - output_tensor_type, scalar_type)); - } + auto output_tensor_type = n->output()->type()->cast(); + n->output()->setType( + CreateProfiledTensorTypeWithScalarType(output_tensor_type, scalar_type)); } static void RecoverScalarTypeForOutput( diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index a8d88c79deafc..fb9ed28fdcf2d 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -17,15 +17,26 @@ #include #include #include -#include namespace torch { namespace jit { -inline bool PyNone_Check(PyObject* o) { - return o == Py_None; -} - +// Return a new TypePtr, merging ONNX inferred type with existing type. +// The inferred type will take higher precedence, since it is produced by ONNX +// shape inference, and is more compatible with ONNX. In cases where ONNX shape +// inference fails to produce an inferred type, or produces inferred type that +// is incomplete, refer to existing type and fill in the gap that is missing. +// Currently the following cases are supported. +// 1. existing type: Tensor[], inferred type: Tensor[] +// For list of tensors, existing type does not store datatype nor shape for +// inner tensor. Thus inferred type always contain more information, and is +// returned. +// 2. existing type: Tensor, inferred type: Tensor +// Fill in missing info (shape, data type) for inferred type from existing +// type. +// 3. existing type: Scalar[], inferred type: Tensor +// ONNX represents list of scalars by 1-d Tensor. Return inferred type since +// it is more compatible with ONNX. std::pair MergeInferredType( TypePtr existing_type, TypePtr inferred_type) { @@ -226,12 +237,6 @@ bool IsValidONNXNode(const Node* n) { } } - for (auto inp : n->inputs()) { - if (inp->type() == NoneType::get()) { - return false; - } - } - return true; } @@ -326,27 +331,27 @@ Node* CloneNodeToGraph( return clone_node; } -bool HasValidType(TypePtr type, std::string name) { - if (auto t_type = type->cast()) { - if (!t_type->scalarType().has_value()) { - GRAPH_UPDATE("Input ", name, " is missing tensor datatype."); - return false; - } - } else if (auto s_type = type->cast()) { - auto e_type = s_type->getElementType(); - return HasValidType(e_type, name); - } else if (auto o_type = type->cast()) { - auto e_type = o_type->getElementType(); - return HasValidType(e_type, name); - } - return true; -} - bool IsGraphValidForInference(std::shared_ptr graph) { - // Verify if every input has type (either Tensor, Sequence or Optional) and - // scalar type. This is a requirement for ONNX graph inputs. + // Verify if every input has type(either Tensor or Sequence) and scalar type. + // This is a requirement for ONNX graph inputs. for (auto in : graph->inputs()) { - return HasValidType(in->type(), in->debugName()); + if (auto t_type = in->type()->cast()) { + if (!t_type->scalarType().has_value()) { + GRAPH_UPDATE( + "Input ", in->debugName(), " is tensor type, but miss datatype."); + return false; + } + } else if (auto s_type = in->type()->cast()) { + auto e_type = s_type->getElementType(); + if (auto t_type = e_type->cast()) { + if (t_type->scalarType().has_value()) { + continue; + } + } + GRAPH_UPDATE( + "Input ", in->debugName(), " is sequence type, but miss datatype."); + return false; + } } return true; } @@ -2034,7 +2039,7 @@ void ONNXShapeTypeInference( } UpdateReliable(n); - // For the node type that does not have ComputeConstant logic, it may have + // For the node type that does nott have ComputeConstant logic, it may have // reliable shape but its shape is not in ConstantValueMap. So we need this // logic to update ConstantValueMap. for (auto node_output : n->outputs()) { @@ -2141,11 +2146,10 @@ size_t ONNXAssignOutputShape( std::shared_ptr& graph, size_t outputs_index, PyObject* output_obj, - bool onnx_shape_inference, - bool is_script) { + bool onnx_shape_inference) { auto index_check = [&]() { TORCH_INTERNAL_ASSERT( - outputs_index <= graph->outputs().size(), + outputs_index >= 0 && outputs_index <= graph->outputs().size(), "Incorrect number of elements provided as example outputs."); }; @@ -2163,8 +2167,7 @@ size_t ONNXAssignOutputShape( graph, outputs_index, PyTuple_GET_ITEM(output_obj, i), - onnx_shape_inference, - is_script); + onnx_shape_inference); } } else if (PyList_Check(output_obj)) { const auto list_len = PyList_GET_SIZE(output_obj); @@ -2211,8 +2214,7 @@ size_t ONNXAssignOutputShape( graph, outputs_index, PyList_GET_ITEM(output_obj, i), - onnx_shape_inference, - is_script); + onnx_shape_inference); } } } else if (PyDict_Check(output_obj)) { @@ -2227,26 +2229,15 @@ size_t ONNXAssignOutputShape( graph, outputs_index, PyList_GET_ITEM(unrolled_dict.ptr(), i), - onnx_shape_inference, - is_script); + onnx_shape_inference); } } else if (THPUtils_checkString(output_obj)) { // Ignore string, since they are not supported as output in ONNX. - } else if (PyNone_Check(output_obj)) { - // TODO: Currently there's no one thing to do here that works for - // both tracing and scripting. - // If we don't increment outputs_index here, then scripting fails - // for - // `python test/onnx/test_pytorch_onnx_no_runtime.py`. - // If we do increment it, then tracing fails for - // `python test/onnx/test_pytorch_onnx_onnxruntime.py - // TestONNXRuntime.test_tuple_with_none_outputs`. - // Cause: in tracing we flatten the outputs in ONNXTracedModule.forward - // in torch/jit/_trace.py while tracing. This means the output has None - // objects omitted. But then the outputs passed in here are un-flattened, - // which means they contain None objects. - // Ideally we'd remove this difference. - outputs_index += static_cast(is_script); + } else if (strcmp(THPUtils_typename(output_obj), "NoneType") == 0) { + // For cases with tracing, simply ignore NoneType outputs + // For cases with scripting, TODO: Add logic to handle NoneType outputs + // when such output types are supported. For now test cases with NoneType + // outputs have been disabled. } else { std::string msg = ("Model output has unsupported type. See " @@ -2264,14 +2255,13 @@ void ONNXAssignOutputShape( std::shared_ptr& graph, at::ArrayRef outputs, const python::IODescriptor& desc, - bool onnx_shape_inference, - bool is_script) { + bool onnx_shape_inference) { size_t outputs_index = 0; PyObject* py_obj = unflatten(outputs, desc); TORCH_INTERNAL_ASSERT(PyTuple_Check(py_obj)); - outputs_index = ONNXAssignOutputShape( - graph, outputs_index, py_obj, onnx_shape_inference, is_script); + outputs_index = + ONNXAssignOutputShape(graph, outputs_index, py_obj, onnx_shape_inference); TORCH_INTERNAL_ASSERT( outputs_index == graph->outputs().size(), diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index 34248b351e86a..73cfd8e150cd0 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -4,33 +4,9 @@ #include #include -#include - namespace torch { namespace jit { -// Merges existing_type and inferred_type. -// Returns {merged type, whether or not inferred_type was used}. -// -// The inferred type will take higher precedence, since it is produced by ONNX -// shape inference, and is more compatible with ONNX. In cases where ONNX shape -// inference fails to produce an inferred type, or produces an inferred type -// that is incomplete, refer to existing type and fill in the gap that is -// missing. Currently the following cases are supported. -// 1. existing type: Tensor[], inferred type: Tensor[] -// For list of tensors, existing type does not store datatype nor shape for -// inner tensor. Thus inferred type always contain more information, and is -// returned. -// 2. existing type: Tensor, inferred type: Tensor -// Fill in missing info (shape, data type) for inferred type from existing -// type. -// 3. existing type: Scalar[], inferred type: Tensor -// ONNX represents list of scalars by 1-d Tensor. Return inferred type since -// it is more compatible with ONNX. -std::pair MergeInferredType( - TypePtr existing_type, - TypePtr inferred_type); - void MergeInferredTypeAndSetMap( Value* dest_v, TypePtr existing_type, @@ -56,8 +32,7 @@ TORCH_API void ONNXAssignOutputShape( std::shared_ptr& graph, at::ArrayRef outputs, const python::IODescriptor& desc, - bool onnx_shape_inference, - bool is_script); + bool onnx_shape_inference); // Utilize ONNX Shape Inference for node. // The node must have ONNX namespace, and is valid ONNX node according to spec. diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index 68cb44777828e..c755446afc49d 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -30,10 +30,6 @@ static constexpr char NoneType = 'n'; namespace { -inline bool PyNone_Check(PyObject* o) { - return o == Py_None; -} - template py::object cast_handle_sequence(std::vector objs) { auto num_objs = objs.size(); @@ -72,7 +68,7 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) { args.vars.push_back(var); args.desc.metadata.emplace_back(var); args.desc.structure.push_back(D::Variable); - } else if (PyNone_Check(obj)) { + } else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) { args.desc.structure.push_back(D::NoneType); } else if (PyBool_Check(obj)) { // Wrap bools in Bool tensors at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj))); diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 4e9f787a9020b..723b307f240d2 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -742,16 +742,6 @@ void initPythonIRBindings(PyObject* module_) { .def( "z", [](Node& n, const char* name) { return n.t(Symbol::attr(name)); }) - .def( - "ty_", - [](Node& n, const char* name, const TypePtr& type) { - return n.ty_(Symbol::attr(name), type); - }) - .def( - "tys_", - [](Node& n, const char* name, const std::vector& types) { - return n.tys_(Symbol::attr(name), types); - }) .def( "zs_", [](Node& n, const char* name, TensorsAttr::ValueType v) { diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 77c0604e15109..f1a59945f6755 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -447,29 +447,34 @@ static TypePtr inferShapeAndTypeForInput( TypePtr input_type, Stack::const_iterator& s_iter, const Stack::const_iterator& s_iter_end, + bool complete); + +static TupleTypePtr getTupleTensorType( + Stack::const_iterator& s_iter, + const Stack::const_iterator& s_iter_end, + const TypePtr& tupleType, bool complete) { - if (auto tuple_type = input_type->cast()) { - std::vector types; - for (const auto& sub_type : tuple_type->containedTypes()) { - TORCH_INTERNAL_ASSERT(s_iter != s_iter_end); - types.emplace_back( - inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete)); - } - return TupleType::create(types); - } else if (auto list_type = input_type->cast()) { - const TypePtr& sub_type = list_type->getElementType(); - auto elem_type = - inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete); - return ListType::create(elem_type); - } else if (auto tensor_type = input_type->cast()) { + TORCH_INTERNAL_ASSERT(tupleType->kind() == TupleType::Kind); + std::vector types; + for (const auto& subType : tupleType->containedTypes()) { + TORCH_INTERNAL_ASSERT(s_iter != s_iter_end); + types.emplace_back( + inferShapeAndTypeForInput(subType, s_iter, s_iter_end, complete)); + } + return TupleType::create(types); +} + +static TypePtr inferShapeAndTypeForInput( + TypePtr input_type, + Stack::const_iterator& s_iter, + const Stack::const_iterator& s_iter_end, + bool complete) { + if (input_type->kind() == TupleType::Kind) { + return getTupleTensorType(s_iter, s_iter_end, input_type, complete); + } else if (input_type->kind() == TensorType::Kind) { auto type = getTensorType(s_iter->toTensor(), complete); s_iter++; return type; - } else if (auto optional_type = input_type->cast()) { - const TypePtr& sub_type = optional_type->getElementType(); - auto elem_type = - inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete); - return OptionalType::create(elem_type); } else { // Primitive type, keep as is. s_iter++; @@ -489,6 +494,7 @@ static void setInputTensorTypes( TORCH_INTERNAL_ASSERT(input_values.size() == param_count_list.size()); } for (auto v : input_values) { + AT_ASSERT(s_iter != stack.end()); // Leave packed param types alone. This is needed for downstream passes // (like alias analysis) to work properly. This will be unpacked later // in unpackQuantizedWeights. @@ -496,12 +502,8 @@ static void setInputTensorTypes( if (auto qualname = named_type->name()) { if (getCustomClass(qualname->qualifiedName())) { if (param_count_list.empty()) { - AT_ASSERT(s_iter != stack.end()); s_iter++; } else { - if (param_count_list[list_idx] > 0) { - AT_ASSERT(s_iter != stack.end()); - } s_iter += param_count_list[list_idx]; } list_idx++; diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 471bcc23595c4..420d73014ae33 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -297,11 +297,6 @@ class GraphEncoder { bool use_external_data_format = false, const std::string& onnx_file_path = std::string()); - void EncodeTypeProto( - onnx::TypeProto* type_proto, - const TypePtr& node_type, - const std::string& name); - void EncodeLocalFunctionOpsetImport( onnx::FunctionProto* func_proto, const Node* n, @@ -359,16 +354,6 @@ class GraphEncoder { void AddAttribute(onnx::FunctionProto* func_proto, const std::string& name); - void TensorTypeToONNXType( - const TensorTypePtr& tensor_type, - const std::string& dim_name_prefix, - const std::string& name, - const std::unordered_map< - std::string, - std::unordered_map>& dynamic_axes, - onnx::TypeProto_Tensor* onnx_tensor_type, - bool assign_dim_param = true); - SymbolDimMap symbol_dim_map_; onnx::ModelProto model_proto_; size_t num_blocks_; @@ -446,10 +431,6 @@ onnx::AttributeProto_AttributeType ATenAttributeKindToOnnxAttributeType( return onnx::AttributeProto_AttributeType_TENSOR; case AttributeKind::ts: return onnx::AttributeProto_AttributeType_TENSORS; - case AttributeKind::ty: - return onnx::AttributeProto_AttributeType_TYPE_PROTO; - case AttributeKind::tys: - return onnx::AttributeProto_AttributeType_TYPE_PROTOS; case AttributeKind::g: return onnx::AttributeProto_AttributeType_GRAPH; case AttributeKind::gs: @@ -560,43 +541,6 @@ GraphEncoder::GraphEncoder( } } -void GraphEncoder::TensorTypeToONNXType( - const TensorTypePtr& tensor_type, - const std::string& dim_name_prefix, - const std::string& name, - const std::unordered_map< - std::string, - std::unordered_map>& dynamic_axes, - onnx::TypeProto_Tensor* onnx_tensor_type, - bool assign_dim_param) { - if (tensor_type->dim()) { - onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape(); - auto sizes = tensor_type->symbolic_sizes().sizes().value(); - for (const auto i : c10::irange(sizes.size())) { - shape->add_dim(); - if ((dynamic_axes.find(name) != dynamic_axes.end()) && - (dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) { - shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i)); - if (!sizes[i].is_static()) { - symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i); - } - } else if (sizes[i].is_static()) { - shape->mutable_dim(i)->set_dim_value(sizes[i].static_size()); - } else if (assign_dim_param) { - if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) { - symbol_dim_map_[sizes[i]] = - dim_name_prefix + name + "_dim_" + std::to_string(i); - } - shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]); - } - } - } - if (tensor_type->scalarType()) { - onnx_tensor_type->set_elem_type( - ATenTypeToOnnxType(tensor_type->scalarType().value())); - } -} - void GraphEncoder::EncodeValueInfoType( onnx::TypeProto* onnx_type, const TypePtr node_type, @@ -604,10 +548,44 @@ void GraphEncoder::EncodeValueInfoType( const std::unordered_map< std::string, std::unordered_map>& dynamic_axes) { - std::string dim_name_prefix; - if (n->node()->kind() != prim::Param) { - dim_name_prefix = n->node()->kind().toUnqualString(); - } + auto tensorTypeToONNXType = [&dynamic_axes, n, this]( + const TensorTypePtr& t, + onnx::TypeProto_Tensor* onnx_tensor_type, + bool assign_dim_param) { + std::string name = n->debugName(); + if (t->dim()) { + onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape(); + auto sizes = t->symbolic_sizes().sizes().value(); + for (const auto i : c10::irange(sizes.size())) { + shape->add_dim(); + if ((dynamic_axes.find(name) != dynamic_axes.end()) && + (dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) { + shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i)); + if (!sizes[i].is_static()) { + symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i); + } + } else if (sizes[i].is_static()) { + shape->mutable_dim(i)->set_dim_value(sizes[i].static_size()); + } else if (assign_dim_param) { + if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) { + if (n->node()->kind() == prim::Param) { + symbol_dim_map_[sizes[i]] = name + "_dim_" + std::to_string(i); + } else { + std::string op_type = n->node()->kind().toUnqualString(); + symbol_dim_map_[sizes[i]] = + op_type + name + "_dim_" + std::to_string(i); + } + } + shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]); + } + } + } + if (t->scalarType()) { + onnx_tensor_type->set_elem_type( + ATenTypeToOnnxType(t->scalarType().value())); + } + }; + if (TensorTypePtr tensor_type = node_type->cast()) { if (tensor_type->dim() || tensor_type->scalarType()) { // Encode type if either shape or dtype exists. @@ -619,13 +597,7 @@ void GraphEncoder::EncodeValueInfoType( // to denote an unknown dimension. // Create and assign dim_param for normal tensor type. auto is_sequence_tensor = static_cast(n->type()->cast()); - TensorTypeToONNXType( - tensor_type, - dim_name_prefix, - n->debugName(), - dynamic_axes, - onnx_tensor_type, - !is_sequence_tensor); + tensorTypeToONNXType(tensor_type, onnx_tensor_type, !is_sequence_tensor); } } else if (BoolTypePtr bool_type = node_type->cast()) { onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type(); @@ -642,37 +614,6 @@ void GraphEncoder::EncodeValueInfoType( onnx_type->mutable_sequence_type(); onnx::TypeProto* onnx_tensor_type = sequence_type->mutable_elem_type(); EncodeValueInfoType(onnx_tensor_type, list_elem_type, n, dynamic_axes); - } else if (OptionalTypePtr optional_type = node_type->cast()) { - auto elem_type = optional_type->getElementType(); - if (TensorTypePtr tensor_type = elem_type->cast()) { - onnx::TypeProto_Optional* onnx_optional_type = - onnx_type->mutable_optional_type(); - onnx::TypeProto_Tensor* onnx_tensor_type = - onnx_optional_type->mutable_elem_type()->mutable_tensor_type(); - TensorTypeToONNXType( - tensor_type, - dim_name_prefix, - n->debugName(), - dynamic_axes, - onnx_tensor_type); - } else if (ListTypePtr inner_node_type = elem_type->cast()) { - auto list_elem_type = inner_node_type->getElementType(); - if (TensorTypePtr tensor_type = list_elem_type->cast()) { - onnx::TypeProto_Optional* onnx_optional_type = - onnx_type->mutable_optional_type(); - onnx::TypeProto_Sequence* onnx_optional_sequence_type = - onnx_optional_type->mutable_elem_type()->mutable_sequence_type(); - onnx::TypeProto_Tensor* onnx_tensor_type = - onnx_optional_sequence_type->mutable_elem_type() - ->mutable_tensor_type(); - TensorTypeToONNXType( - tensor_type, - dim_name_prefix, - n->debugName(), - dynamic_axes, - onnx_tensor_type); - } - } } } @@ -1038,26 +979,6 @@ void GraphEncoder::AddAttribute( EncodeTensor(t, v, {}, use_external_data_format, onnx_file_path); } break; - case AttributeKind::ty: { - attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTO); - auto tp = attr->mutable_tp(); - const TypePtr& node_type = node->ty(name); - EncodeTypeProto( - tp, node_type, node_proto->op_type() + "_" + name.toDisplayString()); - } break; - case AttributeKind::tys: { - attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTOS); - size_t index = 0; - for (auto& v : node->tys(name)) { - auto tp = attr->add_type_protos(); - EncodeTypeProto( - tp, - v, - node_proto->op_type() + "_" + name.toDisplayString() + "_" + - std::to_string(index)); - index++; - } - } break; case AttributeKind::g: { auto g = attr->mutable_g(); EncodeGraph( @@ -1180,21 +1101,6 @@ void GraphEncoder::EncodeLocalFunction( } } -void GraphEncoder::EncodeTypeProto( - onnx::TypeProto* type_proto, - const TypePtr& node_type, - const std::string& name) { - if (TensorTypePtr tensor_type = node_type->cast()) { - onnx::TypeProto_Tensor* onnx_tensor_type = - type_proto->mutable_tensor_type(); - TensorTypeToONNXType(tensor_type, "", name, {}, onnx_tensor_type); - } else if (ListTypePtr list_type = node_type->cast()) { - onnx::TypeProto_Sequence* seq_type = type_proto->mutable_sequence_type(); - auto elem_type = list_type->getElementType(); - EncodeTypeProto(seq_type->mutable_elem_type(), elem_type, name); - } -} - void GraphEncoder::EncodeTensor( onnx::TensorProto* tensor_proto, const at::Tensor& tensor, diff --git a/torch/csrc/jit/serialization/onnx.cpp b/torch/csrc/jit/serialization/onnx.cpp index aaf91a5c71adc..9e45d78c0ea33 100644 --- a/torch/csrc/jit/serialization/onnx.cpp +++ b/torch/csrc/jit/serialization/onnx.cpp @@ -61,16 +61,6 @@ void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) { void dump(const onnx::TypeProto& type, std::ostream& stream); -void dump(const onnx::TypeProto_Optional& optional_type, std::ostream& stream) { - stream << "Optional<"; - if (optional_type.has_elem_type()) { - dump(optional_type.elem_type(), stream); - } else { - stream << "None"; - } - stream << ">"; -} - void dump(const onnx::TypeProto_Sequence& sequence_type, std::ostream& stream) { stream << "Sequence<"; if (sequence_type.has_elem_type()) { @@ -86,8 +76,6 @@ void dump(const onnx::TypeProto& type, std::ostream& stream) { dump(type.tensor_type(), stream); } else if (type.has_sequence_type()) { dump(type.sequence_type(), stream); - } else if (type.has_optional_type()) { - dump(type.optional_type(), stream); } else { stream << "None"; } diff --git a/torch/csrc/onnx/init.cpp b/torch/csrc/onnx/init.cpp index ad163094865a1..2e400c2a97631 100644 --- a/torch/csrc/onnx/init.cpp +++ b/torch/csrc/onnx/init.cpp @@ -41,10 +41,8 @@ void initONNXBindings(PyObject* module) { [](std::shared_ptr& graph, const std::vector& tensors, const python::IODescriptor& desc, - bool onnx_shape_inference, - bool is_script) { - ONNXAssignOutputShape( - graph, tensors, desc, onnx_shape_inference, is_script); + bool onnx_shape_inference = false) { + ONNXAssignOutputShape(graph, tensors, desc, onnx_shape_inference); }) .def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution) .def( @@ -196,9 +194,7 @@ void initONNXBindings(PyObject* module) { m.def( "_check_onnx_proto", - [](const std::string& proto_string, bool full_check) { - check_onnx_proto(proto_string, full_check); - }, + [](const std::string& proto_string, bool full_check) { check_onnx_proto(proto_string, full_check); }, py::arg("proto_string"), py::arg("full_check") = false); diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index 32222673fdc65..f02afcb97737c 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -23,38 +23,3 @@ # Shape https://github.com/onnx/onnx/pull/3580 # Backwards compatible # TODO: optional start/end attribute. - - -import torch -from torch._C import OptionalType -from torch.onnx.symbolic_helper import _is_none -from torch.onnx.symbolic_opset9 import eq, wrap_logical_op_with_negation - - -def __is_(g, self, other): - if _is_none(other): - if isinstance(self.type(), OptionalType): - none = g.op("OptionalHasElement", self) - return g.op("Not", none) - else: - return g.op("Constant", value_t=torch.BoolTensor([0])) - return eq(g, self, other) - - -@wrap_logical_op_with_negation -def __isnot_(g, self, other): - return __is_(g, self, other) - - -class Prim: - domain = "prim" - - @staticmethod - def unchecked_cast(g, self): - # exists to refine the type of the Value - # if x is Optional[Tensor], unchecked_cast will cast - # x to Tensor, so the rest of the graph knows that x is a Tensor. - if isinstance(self.type(), OptionalType): - return g.op("OptionalGetElement", self) - - return self diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index d6a2ad6cbc215..c41107a2c313d 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -4839,14 +4839,7 @@ def Loop(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs): for i, b_in in enumerate(b.inputs()): if i == 0 and i < len(inputs): b_in.setType(inputs[i].type()) - # For optional block inputs, they may switch between None not-None inside - # the loop body, so if the loop input is not optional, the block input may - # still need to be optional. - if ( - i > 0 - and (i + 1) < len(inputs) - and not isinstance(b_in.type(), OptionalType) - ): + if i > 0 and (i + 1) < len(inputs): b_in.setType(inputs[i + 1].type()) torch._C._jit_pass_onnx_block( b, new_block, operator_export_type, env, False # type:ignore[arg-type] diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f25d56d6a17f2..a0b9113e8e7a0 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -290,6 +290,7 @@ def _optimize_graph( # onnx only supports tensors, so we turn all out number types into tensors torch._C._jit_pass_erase_number_types(graph) + if _onnx_shape_inference: input_names = [] if input_names is None else input_names dynamic_axes = {} if dynamic_axes is None else dynamic_axes @@ -439,51 +440,37 @@ def _decide_constant_folding(do_constant_folding, operator_export_type, training return do_constant_folding -def _signature(model) -> inspect.Signature: - should_be_callable = getattr(model, "forward", model) - if callable(should_be_callable): - return inspect.signature(should_be_callable) - raise ValueError("model has no forward method and is not callable") - - def _decide_input_format(model, args): try: - sig = _signature(model) - except ValueError as e: - warnings.warn("%s, skipping _decide_input_format" % e) - return args - try: + sig = inspect.signature(model.forward) ordered_list_keys = list(sig.parameters.keys()) - if ordered_list_keys[0] == "self": - ordered_list_keys = ordered_list_keys[1:] - args_dict: Dict = {} - if isinstance(args, list): - args_list = args - elif isinstance(args, tuple): - args_list = list(args) - else: - args_list = [args] - if isinstance(args_list[-1], dict): - args_dict = args_list[-1] - args_list = args_list[:-1] - n_nonkeyword = len(args_list) - for optional_arg in ordered_list_keys[n_nonkeyword:]: - if optional_arg in args_dict: - args_list.append(args_dict[optional_arg]) - # Check if this arg has a default value - else: - param = sig.parameters[optional_arg] - if param.default != param.empty: - args_list.append(param.default) - args = args_list if isinstance(args, list) else tuple(args_list) + if isinstance(args[-1], dict): + args_dict = args[-1] + args = list(args)[:-1] + n_nonkeyword = len(args) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default is param.empty: + args.append(None) + else: + args.append(param.default) + args = tuple(args) + return args + # Cases of models without forward functions and dict inputs + except (AttributeError, ValueError): + warnings.warn("Model has no forward function") + return args # Cases of models with no input args except IndexError: - warnings.warn("No input args, skipping _decide_input_format") + warnings.warn("No input args") + return args except Exception as e: warnings.warn("Skipping _decide_input_format\n {}".format(e.args[0])) - - return args - + return args def _trace(func, args, operator_export_type, return_outs=False): # Special case for common case of passing a single Tensor @@ -527,66 +514,37 @@ def _get_param_count_list(method_graph, args_params): in_vars, _ = torch.jit._flatten(arg_params_) param_count_list.append(len(in_vars)) else: - param_count_list.append(arg_params_ is not None) - + param_count_list.append(1) return param_count_list -def _check_flatten_did_not_remove(original, jit_flattened): - """torch.jit._flatten removes None. Check if it did so in this case.""" - - def flatten(x): - if isinstance(x, (list, tuple)): - for inner in x: - for y in flatten(inner): - yield y - elif isinstance(x, dict): - for inner in x.values(): - for y in flatten(inner): - yield y - else: - yield x - - flattened_with_none = list(flatten(original)) - num_none = len(flattened_with_none) - len(jit_flattened) - assert num_none >= 0 - if num_none: - raise ValueError( - f"args contained {num_none} None's after flattening. " - "When exporting a ScriptModule or ScriptFunction, no args may " - "be None because that breaks type propagation." - ) - - def _create_jit_graph(model, args): torch_out = None params: Union[List, Tuple] - if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): - flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) - _check_flatten_did_not_remove(args, flattened_args) if isinstance(model, torch.jit.ScriptModule): try: graph = model.forward.graph + torch._C._jit_pass_onnx_function_substitution(graph) + freezed_m = torch._C._freeze_module(model._c, preserveParameters=True) + module, params = torch._C._jit_onnx_list_model_parameters(freezed_m) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = torch._C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) except AttributeError as e: raise RuntimeError("'forward' method must be a script method") from e - torch._C._jit_pass_onnx_function_substitution(graph) - freezed_m = torch._C._freeze_module(model._c, preserveParameters=True) - module, params = torch._C._jit_onnx_list_model_parameters(freezed_m) - method_graph = module._get_method("forward").graph - args_params = tuple(args) + tuple(params) - param_count_list = _get_param_count_list(method_graph, args_params) - in_vars, _ = torch.jit._flatten(args_params) - graph = torch._C._propagate_and_assign_input_shapes( - method_graph, tuple(in_vars), param_count_list, False, False - ) return graph, params, torch_out, module elif isinstance(model, torch.jit.ScriptFunction): params = () + in_vars, in_desc = torch.jit._flatten(tuple(args)) graph = model.graph torch._C._jit_pass_onnx_function_substitution(graph) param_count_list = _get_param_count_list(graph, args) graph = torch._C._propagate_and_assign_input_shapes( - graph, flattened_args, param_count_list, False, False + graph, tuple(in_vars), param_count_list, False, False ) return graph, params, torch_out, None else: @@ -619,11 +577,11 @@ def _get_example_outputs(model, args): input_args = input_args[:-1] example_outputs = model(*input_args, **input_kwargs) - if isinstance(example_outputs, list): - example_outputs = [example_outputs] - elif not isinstance(example_outputs, tuple): + if isinstance(example_outputs, (torch.Tensor, int, float, bool)): example_outputs = (example_outputs,) + if isinstance(example_outputs, list): + example_outputs = [example_outputs] return example_outputs @@ -718,6 +676,7 @@ def _model_to_graph( model = _pre_trace_quant_model(model, args) graph, params, torch_out, module = _create_jit_graph(model, args) + params_dict = _get_named_param_dict(graph, params) try: @@ -736,21 +695,26 @@ def _model_to_graph( raise from torch.onnx.symbolic_helper import _onnx_shape_inference - is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) - if is_script: + if isinstance(model, torch.jit.ScriptModule) or isinstance( + model, torch.jit.ScriptFunction + ): example_outputs = _get_example_outputs(model, args) example_outputs_final = () for example_output in example_outputs: example_outputs_final += unpack_quantized_tensor(example_output) out_vars, desc = torch.jit._flatten(example_outputs_final) torch._C._jit_pass_onnx_assign_output_shape( - graph, out_vars, desc, _onnx_shape_inference, is_script + graph, out_vars, desc, _onnx_shape_inference ) + else: + flatten_args, _ = torch._C._jit_flatten(args) + # make sure that the param dict and the graph match each other + assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. - else: - if not isinstance(torch_out, (list, tuple)): + if torch_out is not None: + if not (isinstance(torch_out, list) or isinstance(torch_out, tuple)): output_wrapped = [torch_out] else: output_wrapped = torch_out # type: ignore[assignment] @@ -761,7 +725,7 @@ def _model_to_graph( # single value in PyTorch. if not any(getattr(out, "is_quantized", False) for out in output_tensors): torch._C._jit_pass_onnx_assign_output_shape( - graph, output_tensors, out_desc, _onnx_shape_inference, is_script + graph, output_tensors, out_desc, _onnx_shape_inference ) _set_input_and_output_names(graph, input_names, output_names) @@ -1256,7 +1220,7 @@ def set_names(node_list, name_list, descriptor): set_names(list(graph.outputs()), output_names, "output") -_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$") +_attr_pattern = re.compile("^(.+)_([ifstgz])$") def _run_symbolic_method(g, op_name, symbolic_fn, args):