From 85442a0f1dc545730666b1e0eac1e7500481ea4b Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Thu, 28 Jun 2018 11:44:12 +0800 Subject: [PATCH 1/2] fix test case bug --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e07a602b8c18..4523bcce5c55 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -62,7 +62,7 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): mod1.backward(out_grads=[dy]) mod2.backward(out_grads=[dy]) if grad_req != 'null': - assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-2) else: assert(mod1.get_input_grads()[0] == None) assert(mod2.get_input_grads()[0] == None) From 6b895a186ad543df47818a6de08b6118879821c0 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Thu, 28 Jun 2018 14:25:56 +0800 Subject: [PATCH 2/2] adjust tolerance only for relu --- tests/python/unittest/test_operator.py | 56 +++++++++++++------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 4523bcce5c55..c23e45b60786 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -28,7 +28,7 @@ from common import setup_module, with_seed, teardown import unittest -def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): +def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol, atol): dshape = (N, T, I) data = mx.sym.Variable('data') @@ -51,18 +51,18 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): # check inference mod1.forward(batch, is_train=False) mod2.forward(batch, is_train=False) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) # check training mod1.forward(batch, is_train=True) mod2.forward(batch, is_train=True) - assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4) + assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=rtol, atol=atol) dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape) mod1.backward(out_grads=[dy]) mod2.backward(out_grads=[dy]) if grad_req != 'null': - assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-2) + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=rtol, atol=atol) else: assert(mod1.get_input_grads()[0] == None) assert(mod2.get_input_grads()[0] == None) @@ -78,9 +78,9 @@ def test_lstm_sym(): stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_lstm_bidirectional(): @@ -98,9 +98,9 @@ def test_lstm_bidirectional(): mx.rnn.LSTMCell(H, prefix='r1_'), output_prefix='bi_lstm_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_gru_sym(): @@ -111,9 +111,9 @@ def test_gru_sym(): stack.add(mx.rnn.GRUCell(H, prefix='l1_')) stack.add(mx.rnn.GRUCell(H, prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_gru_bidirectional(): @@ -133,9 +133,9 @@ def test_gru_bidirectional(): mx.rnn.GRUCell(H, prefix='r1_'), output_prefix='bi_gru_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_rnntanh_sym(): @@ -147,9 +147,9 @@ def test_rnntanh_sym(): stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_rnntanh_bidirectional(): @@ -168,9 +168,9 @@ def test_rnntanh_bidirectional(): mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), output_prefix='bi_rnntanh_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-4) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-4) @with_seed() def test_rnnrelu_sym(): @@ -182,9 +182,9 @@ def test_rnnrelu_sym(): stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-2) @with_seed() def test_rnnrelu_bidirectional(): @@ -203,9 +203,9 @@ def test_rnnrelu_bidirectional(): mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), output_prefix='bi_rnnrelu_1_')) - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + check_rnn_consistency(fused, stack, T, N, I, H, 'write', 1e-2, 1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', 1e-2, 1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', 1e-2, 1e-2) @with_seed() def test_lstm_dropout():