Skip to content

Commit

Permalink
Revert "[ONNX] Support optional type (pytorch#68793) (pytorch#73284)"
Browse files Browse the repository at this point in the history
This reverts commit 679fc90.
  • Loading branch information
atalman committed May 12, 2022
1 parent e06400e commit 66b427a
Show file tree
Hide file tree
Showing 23 changed files with 533 additions and 1,167 deletions.
3 changes: 0 additions & 3 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,6 @@ namespace c10 {
_(onnx, Range) \
_(onnx, Tile) \
_(onnx, Where) \
_(onnx, Optional) \
_(onnx, OptionalGetElement) \
_(onnx, OptionalHasElement) \
FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \
_(attr, ReverseSubgraph) \
Expand Down
3 changes: 0 additions & 3 deletions caffe2/python/onnx/tests/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@
'|test_optional_.*'
'|test_shape_end_.*'
'|test_shape_start_.*'
'|test_identity_opt_*'
'|test_loop16_seq_none_*'
'|test_if_opt_*'
')')

# Unsupported ops in opset 16
Expand Down
18 changes: 9 additions & 9 deletions test/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
run_tests,
skipIfNoLapack,
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
disableScriptTest,
)
from torchvision.models import shufflenet_v2_x1_0
from torchvision.models.alexnet import alexnet
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_prelu(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(PReluNet(), x)

@skipScriptTest()
@disableScriptTest()
def test_concat(self):
input_a = Variable(torch.randn(BATCH_SIZE, 3))
input_b = Variable(torch.randn(BATCH_SIZE, 3))
Expand All @@ -93,12 +93,12 @@ def test_permute(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
self.exportTest(PermuteNet(), x)

@skipScriptTest()
@disableScriptTest()
def test_embedding_sequential_1(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork1(), x)

@skipScriptTest()
@disableScriptTest()
def test_embedding_sequential_2(self):
x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
self.exportTest(EmbeddingNetwork2(), x)
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_resnet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)

@skipScriptTest(min_opset_version=15) # None type in outputs
@disableScriptTest() # None type in outputs
def test_inception(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
self.exportTest(toC(inception_v3()), toC(x))
Expand All @@ -175,14 +175,14 @@ def test_densenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)

@skipScriptTest()
@disableScriptTest()
def test_dcgan_netD(self):
netD = _netD(1)
netD.apply(weights_init)
input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
self.exportTest(toC(netD), toC(input))

@skipScriptTest()
@disableScriptTest()
def test_dcgan_netG(self):
netG = _netG(1)
netG.apply(weights_init)
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_qat_resnet_per_channel(self):

self.exportTest(toC(qat_resnet50), toC(x))

@skipScriptTest(min_opset_version=15) # None type in outputs
@disableScriptTest() # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
Expand All @@ -252,7 +252,7 @@ def test_mobilenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)

@skipScriptTest() # prim_data
@disableScriptTest() # prim_data
def test_shufflenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
Expand Down
15 changes: 12 additions & 3 deletions test/onnx/test_pytorch_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,21 @@ def wrapper(self):

return skip_dec

# Enables tests for scripting, instead of only tracing the model.
def enableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = True
return func(self)
return wrapper
return script_dec


# skips tests for scripting.
def skipScriptTest(min_opset_version=float("inf")):
# Disable tests for scripting.
def disableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = self.opset_version >= min_opset_version
self.is_script_test_enabled = False
return func(self)

return wrapper
Expand Down
13 changes: 5 additions & 8 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,14 +2573,11 @@ def forward(self, lstm_in):
bias=has_bias,
num_layers=num_layers,
)
lstm_in = (
[
torch.from_numpy(inputs),
torch.from_numpy(hx),
torch.from_numpy(hx),
]
+ [param.detach() for param in torch_lstm._flat_weights],
)
lstm_in = [
torch.from_numpy(inputs),
torch.from_numpy(hx),
torch.from_numpy(hx),
] + [param.detach() for param in torch_lstm._flat_weights]

self.run_model_test(
MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False
Expand Down
Loading

0 comments on commit 66b427a

Please sign in to comment.