Skip to content

Commit

Permalink
Fix gelu in PyTorch frontend, tighten numerical checks (#5763)
Browse files Browse the repository at this point in the history
Previously, the PyTorch frontend approximated gelu with fastgelu.
To provide a more faithful conversion, we implement gelu instead.

We also tighten the numerical comparisons between PyTorch and
TVM-from-PyTorch to 1e-5. The object detection models need an
increased tolerance of 1e-4 to pass.

I had to throw in a few fixes for missing conversions
(probably due to working with very new PyTorch).

I must admit the GoogLeNet/NasNet test didn't run on my machine,
probably due to problems at my end.
  • Loading branch information
t-vi authored Jun 11, 2020
1 parent e2fb503 commit 672e6cd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
20 changes: 12 additions & 8 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ def _impl(inputs, input_types):
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)

dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
if inputs[2] is not None: # dtype given
dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
else:
dtype = data.type_annotation.dtype

return _op.full(_expr.const(fill_value), shape, dtype=dtype)
return _impl
Expand Down Expand Up @@ -567,14 +570,13 @@ def _impl(inputs, input_types):

def _gelu():
def _impl(inputs, input_types):
import math
data = inputs[0]

def _pow3(x):
return x * x * x
return _expr.const(0.5) * data * (_expr.const(1.0) +
_op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) *
(data + _expr.const(0.044715) * _pow3(data))))
# gelu is data * normcdf(data)
# normcdf expressed as erf because we don't currently have that intrinsic
# note that there is also a fastgelu variant approximating normcdf
# with tanh and third order polynomials, but this is "true" gelu
return data * (_expr.const(0.5) +
_op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5))
return _impl

def _selu():
Expand Down Expand Up @@ -1839,6 +1841,7 @@ def _get_convert_map(prelude):
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
"prim::ImplicitTensorToNum" : _tensortonum(),
"aten::ScalarImplicit" : _tensortonum(),
"aten::constant_pad_nd" : _pad("constant"),
"aten::reflection_pad1d" : _pad("reflect"),
"aten::reflection_pad2d" : _pad("reflect"),
Expand Down Expand Up @@ -1877,6 +1880,7 @@ def _get_convert_map(prelude):
"aten::floor" : _unary("floor"),
"aten::round" : _unary("round"),
"aten::isfinite" : _unary("isfinite"),
"aten::isinf" : _unary("isinf"),
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::detach" : _identity(),
Expand Down
27 changes: 14 additions & 13 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):

def verify_model(model_name, input_data=[],
custom_convert_map={},
ctx_list=ctx_list()):
ctx_list=ctx_list(),
rtol=1e-5, atol=1e-5):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if isinstance(model_name, str):
Expand Down Expand Up @@ -190,7 +191,7 @@ def verify_model(model_name, input_data=[],

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output,
rtol=1e-3, atol=1e-3)
rtol=rtol, atol=atol)

del model_name
del baseline_model
Expand Down Expand Up @@ -1216,35 +1217,35 @@ def test_conv3d_transpose():
# Model tests
def test_resnet18():
torch.set_grad_enabled(False)
verify_model("resnet18")
verify_model("resnet18", atol=1e-4, rtol=1e-4)

def test_squeezenet1_0():
torch.set_grad_enabled(False)
verify_model("squeezenet1_0")
verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4)

def test_squeezenet1_1():
torch.set_grad_enabled(False)
verify_model("squeezenet1_1")
verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4)

def test_densenet121():
torch.set_grad_enabled(False)
verify_model("densenet121")
verify_model("densenet121", atol=1e-4, rtol=1e-4)

def test_inception_v3():
torch.set_grad_enabled(False)
verify_model("inception_v3")
verify_model("inception_v3", atol=1e-4, rtol=1e-4)

def test_googlenet():
torch.set_grad_enabled(False)
verify_model("googlenet")
verify_model("googlenet", atol=1e-4, rtol=1e-4)

def test_mnasnet0_5():
torch.set_grad_enabled(False)
verify_model("mnasnet0_5")
verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4)

def test_mobilenet_v2():
torch.set_grad_enabled(False)
verify_model("mobilenet_v2")
verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4)

"""
#TODO: Fix VGG and AlexNet issues (probably due to pooling)
Expand Down Expand Up @@ -1305,19 +1306,19 @@ def forward(self, inp):

inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]

verify_model(SegmentationModelWrapper(fcn.eval()), inp)
verify_model(SegmentationModelWrapper(fcn.eval()), inp, atol=1e-4, rtol=1e-4)

# depthwise + dilated covolution not supported on x86
# see https://github.com/apache/incubator-tvm/issues/4962
cuda_ctx = ("cuda", tvm.gpu(0))
if cuda_ctx[1].exist:
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx], atol=1e-4, rtol=1e-4)


def test_3d_models():
input_shape = (1, 3, 4, 56, 56)
resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval()
verify_model(resnet3d, [torch.rand(input_shape)])
verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4)


def verify_script_model(pt_model, ishapes):
Expand Down

0 comments on commit 672e6cd

Please sign in to comment.