From 77e8fc1992d29a1a129fc8134b53fa120501954b Mon Sep 17 00:00:00 2001 From: Hao Li Date: Mon, 13 Aug 2018 07:21:30 +0800 Subject: [PATCH] Fix precision issue of test case test_rnnrelu_bidirectional (#12099) * adjust tolerance only for relu for fixing test case bug * only adjust torence for test_rnnrelu_bidirectional and adjust back on test_rnnrelu_sym --- tests/python/unittest/test_operator.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 54eb0fd94a22..c395199e8ea4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -30,7 +30,7 @@ from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled, assertRaises import unittest -def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): +def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4): dshape = (N, T, I) data = mx.sym.Variable('data') @@ -53,18 +53,18 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): # check inference mod1.forward(batch, is_train=False) mod2.forward(batch, is_train=False) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) # check training mod1.forward(batch, is_train=True) mod2.forward(batch, is_train=True) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape) mod1.backward(out_grads=[dy]) mod2.backward(out_grads=[dy]) if grad_req != 'null': - assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=rtol, atol=atol) else: assert(mod1.get_input_grads()[0] == None) assert(mod2.get_input_grads()[0] == None) @@ -195,9 +195,8 @@ def test_rnnrelu_sym(): check_rnn_consistency(fused, stack, T, N, I, H, 'add') check_rnn_consistency(fused, stack, T, N, I, H, 'null') - -@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11410") @with_seed() +@assert_raises_cudnn_disabled() def test_rnnrelu_bidirectional(): T, N, I, H = 5, 20, 200, 200 @@ -214,9 +213,9 @@ def test_rnnrelu_bidirectional(): mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), output_prefix='bi_rnnrelu_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) @with_seed() def test_lstm_dropout():