forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_verify.py
129 lines (103 loc) · 3.93 KB
/
test_verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.autograd import Function
from torch.nn import Module, Parameter
import caffe2.python.onnx.backend as backend
from verify import verify
from test_pytorch_common import TestCase, run_tests
import unittest
class TestVerify(TestCase):
maxDiff = None
def assertVerifyExpectFail(self, *args, **kwargs):
try:
verify(*args, **kwargs)
except AssertionError as e:
if str(e):
# substring a small piece of string because the exact message
# depends on system's formatting settings
# self.assertExpected(str(e)[:60])
# NB: why we comment out the above check? because numpy keeps
# changing the error format, and we have to keep updating the
# expect files let's relax this constraint
return
else:
raise
# Don't put this in the try block; the AssertionError will catch it
self.assertTrue(False, msg="verify() did not fail when expected to")
def test_result_different(self):
class BrokenAdd(Function):
@staticmethod
def symbolic(g, a, b):
return g.op("Add", a, b)
@staticmethod
def forward(ctx, a, b):
return a.sub(b) # yahaha! you found me!
class MyModel(Module):
def forward(self, x, y):
return BrokenAdd().apply(x, y)
x = torch.tensor([1, 2])
y = torch.tensor([3, 4])
self.assertVerifyExpectFail(MyModel(), (x, y), backend)
def test_jumbled_params(self):
class MyModel(Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
y = x * x
self.param = Parameter(torch.tensor([2.0]))
return y
x = torch.tensor([1, 2])
with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
verify(MyModel(), x, backend)
def test_modifying_params(self):
class MyModel(Module):
def __init__(self):
super(MyModel, self).__init__()
self.param = Parameter(torch.tensor([2.0]))
def forward(self, x):
y = x * x
self.param.data.add_(1.0)
return y
x = torch.tensor([1, 2])
self.assertVerifyExpectFail(MyModel(), x, backend)
def test_dynamic_model_structure(self):
class MyModel(Module):
def __init__(self):
super(MyModel, self).__init__()
self.iters = 0
def forward(self, x):
if self.iters % 2 == 0:
r = x * x
else:
r = x + x
self.iters += 1
return r
x = torch.tensor([1, 2])
self.assertVerifyExpectFail(MyModel(), x, backend)
@unittest.skip("Indexing is broken by #3725")
def test_embedded_constant_difference(self):
class MyModel(Module):
def __init__(self):
super(MyModel, self).__init__()
self.iters = 0
def forward(self, x):
r = x[self.iters % 2]
self.iters += 1
return r
x = torch.tensor([[1, 2], [3, 4]])
self.assertVerifyExpectFail(MyModel(), x, backend)
def test_explicit_test_args(self):
class MyModel(Module):
def forward(self, x):
if x.data.sum() == 1.0:
return x + x
else:
return x * x
x = torch.tensor([[6, 2]])
y = torch.tensor([[2, -1]])
self.assertVerifyExpectFail(MyModel(), x, backend, test_args=[(y,)])
if __name__ == '__main__':
run_tests()