diff --git a/example/sparse/linear_regression.py b/example/sparse/linear_classification.py similarity index 84% rename from example/sparse/linear_regression.py rename to example/sparse/linear_classification.py index e7040d4b03e5..567568c6eb80 100644 --- a/example/sparse/linear_regression.py +++ b/example/sparse/linear_classification.py @@ -22,14 +22,14 @@ import argparse import os -parser = argparse.ArgumentParser(description="Run sparse linear regression " \ +parser = argparse.ArgumentParser(description="Run sparse linear classification " \ "with distributed kvstore", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--profiler', type=int, default=0, help='whether to use profiler') parser.add_argument('--num-epoch', type=int, default=1, help='number of epochs to train') -parser.add_argument('--batch-size', type=int, default=512, +parser.add_argument('--batch-size', type=int, default=8192, help='number of examples per batch') parser.add_argument('--num-batch', type=int, default=99999999, help='number of batches per epoch') @@ -37,7 +37,7 @@ help='whether to use dummy iterator to exclude io cost') parser.add_argument('--kvstore', type=str, default='dist_sync', help='what kvstore to use [local, dist_sync, etc]') -parser.add_argument('--log-level', type=str, default='debug', +parser.add_argument('--log-level', type=str, default='DEBUG', help='logging level [debug, info, error]') parser.add_argument('--dataset', type=str, default='avazu', help='what test dataset to use') @@ -78,18 +78,18 @@ def next(self): datasets = { 'kdda' : kdda, 'avazu' : avazu } -def regression_model(feature_dim): - initializer = mx.initializer.Normal() +def linear_model(feature_dim): x = mx.symbol.Variable("data", stype='csr') norm_init = mx.initializer.Normal(sigma=0.01) - v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse') - embed = mx.symbol.dot(x, v) + weight = mx.symbol.Variable("weight", shape=(feature_dim, 1), init=norm_init, stype='row_sparse') + bias = mx.symbol.Variable("bias", shape=(1,), init=norm_init) + dot = mx.symbol.dot(x, weight) + pred = mx.symbol.broadcast_add(dot, bias) y = mx.symbol.Variable("softmax_label") - model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out") + model = mx.symbol.SoftmaxOutput(data=pred, label=y, name="out") return model if __name__ == '__main__': - # arg parser args = parser.parse_args() num_epoch = args.num_epoch @@ -138,7 +138,7 @@ def regression_model(feature_dim): train_data = DummyIter(train_data) # model - model = regression_model(feature_dim) + model = linear_model(feature_dim) # module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label']) @@ -148,11 +148,10 @@ def regression_model(feature_dim): learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker) mod.init_optimizer(optimizer=sgd, kvstore=kv) # use accuracy as the metric - metric = mx.metric.create('MSE') + metric = mx.metric.create('Accuracy') # start profiler if profiler: - import random name = 'profile_output_' + str(num_worker) + '.json' mx.profiler.profiler_set_config(mode='all', filename=name) mx.profiler.profiler_set_state('run') @@ -162,31 +161,22 @@ def regression_model(feature_dim): data_iter = iter(train_data) for epoch in range(num_epoch): nbatch = 0 - end_of_batch = False data_iter.reset() metric.reset() - next_batch = next(data_iter) - while not end_of_batch: + for batch in data_iter: nbatch += 1 - batch = next_batch - # TODO(haibin) remove extra copy after Jun's change - row_ids = batch.data[0].indices.copyto(mx.cpu()) + row_ids = batch.data[0].indices # pull sparse weight - index = mod._exec_group.param_names.index('v') - kv.row_sparse_pull('v', mod._exec_group.param_arrays[index], + index = mod._exec_group.param_names.index('weight') + kv.row_sparse_pull('weight', mod._exec_group.param_arrays[index], priority=-index, row_ids=[row_ids]) mod.forward_backward(batch) # update parameters mod.update() - try: - # pre fetch next batch - next_batch = next(data_iter) - if nbatch == num_batch: - raise StopIteration - except StopIteration: - end_of_batch = True # accumulate prediction accuracy mod.update_metric(metric, batch.label) + if nbatch == num_batch: + break logging.info('epoch %d, %s' % (epoch, metric.get())) if profiler: mx.profiler.profiler_set_state('stop')