Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add test case for oneDNN RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Nov 6, 2020
1 parent 87b66a9 commit b502aad
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../unittest/'))
from common import with_seed
import itertools


def test_mkldnn_model():
Expand Down Expand Up @@ -724,6 +725,36 @@ def check_elemwise_add_training(stype):
for stype in stypes:
check_elemwise_add_training(stype)


@with_seed()
def test_rnn():
SEQ_LENGTH = [2**10, 2**5]
STATE_SIZE = [1, 2]
BATCH_SIZE = [4]
INPUT_SIZE = [4]
def batch_check(seq_length, state_size, batch_size, input_size):
modes_params = [('rnn_relu', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
('rnn_tanh', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size),)),
('gru', mx.np.random.normal(0, 1, ((input_size + state_size + 2)*state_size*3),))
]
for m, p in modes_params:
data = mx.np.random.normal(0, 1, (seq_length, batch_size, input_size))
state = mx.np.random.normal(0, 1, (1, batch_size, state_size))
data.attach_grad()
state.attach_grad()

with mx.autograd.record():
y = mx.npx.rnn(data=data, parameters=p, mode=m, \
state=state, state_size=state_size, num_layers=1)
assert y.shape == (seq_length, batch_size, state_size)
assert type(y[0]).__name__ == 'ndarray'
y.backward()
assert state.shape == (1, batch_size, state_size)
assert type(state[0]).__name__ == 'ndarray'

for sl, ss, bs, in_s in itertools.product(SEQ_LENGTH, STATE_SIZE, BATCH_SIZE, INPUT_SIZE):
batch_check(sl, ss, bs, in_s)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit b502aad

Please sign in to comment.