Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support optional type (#68793) #73284

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
625 changes: 315 additions & 310 deletions aten/src/ATen/core/interned_strings.h

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions caffe2/python/onnx/tests/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@
'|test_optional_.*'
'|test_shape_end_.*'
'|test_shape_start_.*'
'|test_identity_opt_*'
'|test_loop16_seq_none_*'
'|test_if_opt_*'
')')

# Unsupported ops in opset 16
Expand Down
18 changes: 9 additions & 9 deletions test/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from model_defs.op_test import DummyNet, ConcatNet, PermuteNet, PReluNet, FakeQuantNet
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2

from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, disableScriptTest
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, skipScriptTest

import torch
import torch.onnx
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_prelu(self):
)
self.exportTest(PReluNet(), x)

@disableScriptTest()
@skipScriptTest()
def test_concat(self):
input_a = Variable(torch.randn(BATCH_SIZE, 3))
input_b = Variable(torch.randn(BATCH_SIZE, 3))
Expand All @@ -79,12 +79,12 @@ def test_permute(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
self.exportTest(PermuteNet(), x)

@disableScriptTest()
@skipScriptTest()
def test_embedding_sequential_1(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork1(), x)

@disableScriptTest()
@skipScriptTest()
def test_embedding_sequential_2(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork2(), x)
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_resnet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)

@disableScriptTest() # None type in outputs
@skipScriptTest(min_opset_version=15) # None type in outputs
def test_inception(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
self.exportTest(toC(inception_v3()), toC(x))
Expand All @@ -163,14 +163,14 @@ def test_densenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)

@disableScriptTest()
@skipScriptTest()
def test_dcgan_netD(self):
netD = _netD(1)
netD.apply(weights_init)
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
self.exportTest(toC(netD), toC(input))

@disableScriptTest()
@skipScriptTest()
def test_dcgan_netG(self):
netG = _netG(1)
netG.apply(weights_init)
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_qat_resnet_per_channel(self):

self.exportTest(toC(qat_resnet50), toC(x))

@disableScriptTest() # None type in outputs
@skipScriptTest(min_opset_version=15) # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
Expand All @@ -237,7 +237,7 @@ def test_mobilenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest() # prim_data
@skipScriptTest() # prim_data
def test_shufflenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
Expand Down
18 changes: 4 additions & 14 deletions test/onnx/test_pytorch_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,11 @@ def wrapper(self):
return wrapper
return skip_dec

# Enables tests for scripting, instead of only tracing the model.
def enableScriptTest():
# skips tests for scripting.
def skipScriptTest(min_opset_version=float("inf")):
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = True
return func(self)
return wrapper
return script_dec


# Disable tests for scripting.
def disableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = False
self.is_script_test_enabled = self.opset_version >= min_opset_version
return func(self)
return wrapper
return script_dec
Expand All @@ -111,7 +101,7 @@ def skipIfONNXShapeInference(onnx_shape_inference):
def skip_dec(func):
def wrapper(self):
if self.onnx_shape_inference is onnx_shape_inference:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
raise unittest.SkipTest("Skip test due to onnx_shape_inference")
return func(self)
return wrapper
return skip_dec
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,11 +1998,11 @@ def forward(self, lstm_in):
bias=has_bias,
num_layers=num_layers,
)
lstm_in = [
lstm_in = ([
torch.from_numpy(inputs),
torch.from_numpy(hx),
torch.from_numpy(hx),
] + [param.detach() for param in torch_lstm._flat_weights]
] + [param.detach() for param in torch_lstm._flat_weights],)

self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)

Expand Down
100 changes: 100 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Owner(s): ["module: onnx"]
garymm marked this conversation as resolved.
Show resolved Hide resolved

"""Tests for onnx export that don't run the exported model."""

import io
import unittest

import onnx
import torch
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torch import Tensor
from torch.onnx import symbolic_helper

from typing import Optional, Type


class TestOptionalOutput(unittest.TestCase):
# TODO: Move these tests to test_pytorch_onnx_onnxruntime once
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
# ONNX Runtime 1.11 is released and supports opset 16.

class IfNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
if x.size(0) > 1:
y = x
return y

class IfNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
if x.size(0) > 1:
y = None
return y


class LoopNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
for _ in range(x.size(0)):
y = x
return y

class LoopNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
for _ in range(x.size(0)):
y = None
return y


@parametrize(
"module_class",
(IfNoneInput, IfNoneOutput, LoopNoneInput, LoopNoneOutput),
name_fn=lambda module_class: module_class.__name__)
@parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
# Need scripting to preserve control flow for this test to be meaningful.
model = torch.jit.script(module_class())
f = io.BytesIO()
x = torch.ones(x_size)
dynamic_axis_name = "condition"
torch.onnx.export(
model, (x,), f, opset_version=15,
# Ensure condition is not constant
dynamic_axes={"x": {0: dynamic_axis_name}}, input_names=["x"])
exported = onnx.load_from_string(f.getvalue())
expected_elem_type = symbolic_helper.scalar_type_to_onnx[
symbolic_helper.scalar_type_to_pytorch_type.index(x.dtype)].value
expected_output_type = onnx.helper.make_optional_type_proto(
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)))
self.assertEqual(expected_output_type, exported.graph.output[0].type)
for node in exported.graph.node:
# Both branches output types should match.
if node.op_type == "If":
for attr in node.attribute:
if attr.name in ("then_branch", "else_branch"):
self.assertEqual(expected_output_type, attr.g.output[0].type)

def test_uninitialized_optional(self):
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
class Module(torch.nn.Module):
def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
if y is not None:
if y.shape[1] < 5:
if y.size(0) == 1:
y = y + 4
else:
return y
return y

y = torch.ones((3, 4), dtype=torch.int)
torch.onnx.export(
torch.jit.script(Module()), y, io.BytesIO(), opset_version=15,
dynamic_axes={"y": {0: "y0", 1: "y1"}}, input_names=["y"])


instantiate_parametrized_tests(TestOptionalOutput)
BowenBao marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main()
Loading