From b502aad21dfec06c2effb107bfe1a45e7dfed5d3 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 2 Nov 2020 18:31:31 +0800 Subject: [PATCH] Add test case for oneDNN RNN --- tests/python/mkl/test_mkldnn.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index cf2ca13c161a..2fafc7821b5e 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -31,6 +31,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../unittest/')) from common import with_seed +import itertools def test_mkldnn_model(): @@ -724,6 +725,36 @@ def check_elemwise_add_training(stype): for stype in stypes: check_elemwise_add_training(stype) + +@with_seed() +def test_rnn(): + SEQ_LENGTH = [2**10, 2**5] + STATE_SIZE = [1, 2] + BATCH_SIZE = [4] + INPUT_SIZE = [4] + def batch_check(seq_length, state_size, batch_size, input_size): + modes_params = [('rnn_relu', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)), + ('rnn_tanh', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)), + ('gru', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size*3),)) + ] + for m, p in modes_params: + data = mx.np.random.normal(0, 1, (seq_length, batch_size, input_size)) + state = mx.np.random.normal(0, 1, (1, batch_size, state_size)) + data.attach_grad() + state.attach_grad() + + with mx.autograd.record(): + y = mx.npx.rnn(data=data, parameters=p, mode=m, \ + state=state, state_size=state_size, num_layers=1) + assert y.shape == (seq_length, batch_size, state_size) + assert type(y[0]).__name__ == 'ndarray' + y.backward() + assert state.shape == (1, batch_size, state_size) + assert type(state[0]).__name__ == 'ndarray' + + for sl, ss, bs, in_s in itertools.product(SEQ_LENGTH, STATE_SIZE, BATCH_SIZE, INPUT_SIZE): + batch_check(sl, ss, bs, in_s) + if __name__ == '__main__': import nose nose.runmodule()