Skip to content

Commit

Permalink
[ONNX] Support optional type (pytorch#68793) (pytorch#73284)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#73284

Some important ops won't support optional type until opset 16,
so we can't fully test things end-to-end, but I believe this should
be all that's needed. Once ONNX Runtime supports opset 16,
we can do more testing and fix any remaining bugs.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34625646

Pulled By: malfet

fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b

Co-authored-by: BowenBao <[email protected]>
Co-authored-by: neginraoof <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
(cherry picked from commit 822e79f)
  • Loading branch information
BowenBao authored and pytorchmergebot committed May 4, 2022
1 parent b8776e1 commit 679fc90
Show file tree
Hide file tree
Showing 24 changed files with 1,098 additions and 487 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ namespace c10 {
_(onnx, Range) \
_(onnx, Tile) \
_(onnx, Where) \
_(onnx, Optional) \
_(onnx, OptionalGetElement) \
_(onnx, OptionalHasElement) \
FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \
_(attr, ReverseSubgraph) \
Expand Down
3 changes: 3 additions & 0 deletions caffe2/python/onnx/tests/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@
'|test_optional_.*'
'|test_shape_end_.*'
'|test_shape_start_.*'
'|test_identity_opt_*'
'|test_loop16_seq_none_*'
'|test_if_opt_*'
')')

# Unsupported ops in opset 16
Expand Down
18 changes: 9 additions & 9 deletions test/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from model_defs.op_test import DummyNet, ConcatNet, PermuteNet, PReluNet, FakeQuantNet
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2

from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, disableScriptTest
from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, skipIfUnsupportedMinOpsetVersion, skipScriptTest

import torch
import torch.onnx
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_prelu(self):
)
self.exportTest(PReluNet(), x)

@disableScriptTest()
@skipScriptTest()
def test_concat(self):
input_a = Variable(torch.randn(BATCH_SIZE, 3))
input_b = Variable(torch.randn(BATCH_SIZE, 3))
Expand All @@ -79,12 +79,12 @@ def test_permute(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
self.exportTest(PermuteNet(), x)

@disableScriptTest()
@skipScriptTest()
def test_embedding_sequential_1(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork1(), x)

@disableScriptTest()
@skipScriptTest()
def test_embedding_sequential_2(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork2(), x)
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_resnet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)

@disableScriptTest() # None type in outputs
@skipScriptTest(min_opset_version=15) # None type in outputs
def test_inception(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
self.exportTest(toC(inception_v3()), toC(x))
Expand All @@ -163,14 +163,14 @@ def test_densenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)

@disableScriptTest()
@skipScriptTest()
def test_dcgan_netD(self):
netD = _netD(1)
netD.apply(weights_init)
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
self.exportTest(toC(netD), toC(input))

@disableScriptTest()
@skipScriptTest()
def test_dcgan_netG(self):
netG = _netG(1)
netG.apply(weights_init)
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_qat_resnet_per_channel(self):

self.exportTest(toC(qat_resnet50), toC(x))

@disableScriptTest() # None type in outputs
@skipScriptTest(min_opset_version=15) # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
Expand All @@ -237,7 +237,7 @@ def test_mobilenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest() # prim_data
@skipScriptTest() # prim_data
def test_shufflenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
Expand Down
16 changes: 3 additions & 13 deletions test/onnx/test_pytorch_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,11 @@ def wrapper(self):
return wrapper
return skip_dec

# Enables tests for scripting, instead of only tracing the model.
def enableScriptTest():
# skips tests for scripting.
def skipScriptTest(min_opset_version=float("inf")):
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = True
return func(self)
return wrapper
return script_dec


# Disable tests for scripting.
def disableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = False
self.is_script_test_enabled = self.opset_version >= min_opset_version
return func(self)
return wrapper
return script_dec
Expand Down
4 changes: 2 additions & 2 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,11 +1997,11 @@ def forward(self, lstm_in):
bias=has_bias,
num_layers=num_layers,
)
lstm_in = [
lstm_in = ([
torch.from_numpy(inputs),
torch.from_numpy(hx),
torch.from_numpy(hx),
] + [param.detach() for param in torch_lstm._flat_weights]
] + [param.detach() for param in torch_lstm._flat_weights],)

self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)

Expand Down
100 changes: 100 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Owner(s): ["module: onnx"]

"""Tests for onnx export that don't run the exported model."""

import io
import unittest

import onnx
import torch
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torch import Tensor
from torch.onnx import symbolic_helper

from typing import Optional, Type


class TestOptionalOutput(unittest.TestCase):
# TODO: Move these tests to test_pytorch_onnx_onnxruntime once
# ONNX Runtime 1.11 is released and supports opset 16.

class IfNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
if x.size(0) > 1:
y = x
return y

class IfNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
if x.size(0) > 1:
y = None
return y


class LoopNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
for _ in range(x.size(0)):
y = x
return y

class LoopNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
for _ in range(x.size(0)):
y = None
return y


@parametrize(
"module_class",
(IfNoneInput, IfNoneOutput, LoopNoneInput, LoopNoneOutput),
name_fn=lambda module_class: module_class.__name__)
@parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
# Need scripting to preserve control flow for this test to be meaningful.
model = torch.jit.script(module_class())
f = io.BytesIO()
x = torch.ones(x_size)
dynamic_axis_name = "condition"
torch.onnx.export(
model, (x,), f, opset_version=15,
# Ensure condition is not constant
dynamic_axes={"x": {0: dynamic_axis_name}}, input_names=["x"])
exported = onnx.load_from_string(f.getvalue())
expected_elem_type = symbolic_helper.scalar_type_to_onnx[
symbolic_helper.scalar_type_to_pytorch_type.index(x.dtype)].value
expected_output_type = onnx.helper.make_optional_type_proto(
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)))
self.assertEqual(expected_output_type, exported.graph.output[0].type)
for node in exported.graph.node:
# Both branches output types should match.
if node.op_type == "If":
for attr in node.attribute:
if attr.name in ("then_branch", "else_branch"):
self.assertEqual(expected_output_type, attr.g.output[0].type)

def test_uninitialized_optional(self):
class Module(torch.nn.Module):
def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
if y is not None:
if y.shape[1] < 5:
if y.size(0) == 1:
y = y + 4
else:
return y
return y

y = torch.ones((3, 4), dtype=torch.int)
torch.onnx.export(
torch.jit.script(Module()), y, io.BytesIO(), opset_version=15,
dynamic_axes={"y": {0: "y0", 1: "y1"}}, input_names=["y"])


instantiate_parametrized_tests(TestOptionalOutput)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 679fc90

Please sign in to comment.