Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add int64 to arange. add checkpointing example
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Oct 20, 2017
1 parent b8d7bd5 commit b5a8fa1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
7 changes: 7 additions & 0 deletions example/sparse/linear_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
# get the sparse weight parameter
weight_index = mod._exec_group.param_names.index('weight')
weight_param = mod._exec_group.param_arrays[weight_index]
all_row_ids = mx.nd.arange(0, num_features, dtype='int64')
speedometer = mx.callback.Speedometer(batch_size, 100)

logging.info('Training started ...')
Expand All @@ -118,9 +119,15 @@
speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=metric, locals=locals())
speedometer(speedometer_param)
# pull all rows before making a checkpoint
if kv:
kv.row_sparse_pull('weight', weight_param, row_ids=[all_row_ids],
priority=-weight_index)
# evaluate metric on validation dataset
score = mod.score(eval_data, ['nll_loss'])
logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
save_optimizer_states = 'dist' not in kv.type
mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=False)
# reset the iterator for next pass of data
data_iter.reset()
logging.info('Training completed.')
1 change: 1 addition & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.describe("Target data type.");
}
};
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,8 @@ def test_output():
assert_almost_equal(out.asnumpy(), zeros.asnumpy())
mx.nd.full(shape, 2, out=out)
assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
arange_out = mx.nd.arange(0, 20, dtype='int64')
assert_almost_equal(arange_out.asnumpy(), np.arange(0, 20))

def test_ndarray_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
Expand Down

0 comments on commit b5a8fa1

Please sign in to comment.