Skip to content

Commit

Permalink
update fm test (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored Jun 1, 2017
1 parent 965dfd7 commit cbfa792
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def mean_abs(x):
break
assert(mon_result_counts == [2, 2, 1, 6, 6, 4])

def test_fm_module():
def test_module_fm():
mx.random.seed(11)
rnd.seed(11)
def fm_model(k, feature_dim, storage_type='default'):
initializer = mx.initializer.Normal(sigma=0.01)
x = mx.symbol.Variable("data", storage_type=storage_type)
Expand Down Expand Up @@ -287,8 +289,9 @@ def fm_model(k, feature_dim, storage_type='default'):

num_batches = 8
batch_size = 25
scipy_data = scipy.sparse.rand(num_batches * batch_size, feature_dim,
density=0.5, format='csr')
import scipy.sparse as sp
scipy_data = sp.rand(num_batches * batch_size, feature_dim,
density=0.5, format='csr')
dns_label = mx.nd.ones((num_batches * batch_size,1))
csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices,
(num_batches * batch_size, feature_dim))
Expand All @@ -309,8 +312,7 @@ def fm_model(k, feature_dim, storage_type='default'):
# use accuracy as the metric
metric = mx.metric.create('MSE')
# train 5 epoch, i.e. going over the data iter one pass
# TODO(haibin) test with row_sparse instead
storage_type_dict = {'v' : 'default'}
storage_type_dict = {'v' : 'row_sparse'}

for epoch in range(10):
train_iter.reset()
Expand All @@ -320,7 +322,8 @@ def fm_model(k, feature_dim, storage_type='default'):
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update(storage_type_dict) # update parameters
print('Epoch %d, Training %s' % (epoch, metric.get()))
# print('Epoch %d, Training %s' % (epoch, metric.get()))
assert(metric.get()[1] < 0.2)

if __name__ == '__main__':
test_module_dtype()
Expand All @@ -331,4 +334,4 @@ def fm_model(k, feature_dim, storage_type='default'):
test_module_layout()
test_module_switch_bucket()
test_monitor()
test_fm_module()
test_module_fm()

0 comments on commit cbfa792

Please sign in to comment.