diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 54eb0fd94a22..c395199e8ea4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -30,7 +30,7 @@ from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled, assertRaises import unittest -def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): +def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4): dshape = (N, T, I) data = mx.sym.Variable('data') @@ -53,18 +53,18 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): # check inference mod1.forward(batch, is_train=False) mod2.forward(batch, is_train=False) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) # check training mod1.forward(batch, is_train=True) mod2.forward(batch, is_train=True) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape) mod1.backward(out_grads=[dy]) mod2.backward(out_grads=[dy]) if grad_req != 'null': - assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=rtol, atol=atol) else: assert(mod1.get_input_grads()[0] == None) assert(mod2.get_input_grads()[0] == None) @@ -195,9 +195,8 @@ def test_rnnrelu_sym(): check_rnn_consistency(fused, stack, T, N, I, H, 'add') check_rnn_consistency(fused, stack, T, N, I, H, 'null') - -@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11410") @with_seed() +@assert_raises_cudnn_disabled() def test_rnnrelu_bidirectional(): T, N, I, H = 5, 20, 200, 200 @@ -214,9 +213,9 @@ def test_rnnrelu_bidirectional(): mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), output_prefix='bi_rnnrelu_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) @with_seed() def test_lstm_dropout():