Skip to content

Commit

Permalink
change sparse example from regression to classification (apache#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored Aug 15, 2017
1 parent 54f698b commit 6fa078e
Showing 1 changed file with 17 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@
import argparse
import os

parser = argparse.ArgumentParser(description="Run sparse linear regression " \
parser = argparse.ArgumentParser(description="Run sparse linear classification " \
"with distributed kvstore",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--profiler', type=int, default=0,
help='whether to use profiler')
parser.add_argument('--num-epoch', type=int, default=1,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=512,
parser.add_argument('--batch-size', type=int, default=8192,
help='number of examples per batch')
parser.add_argument('--num-batch', type=int, default=99999999,
help='number of batches per epoch')
parser.add_argument('--dummy-iter', type=int, default=0,
help='whether to use dummy iterator to exclude io cost')
parser.add_argument('--kvstore', type=str, default='dist_sync',
help='what kvstore to use [local, dist_sync, etc]')
parser.add_argument('--log-level', type=str, default='debug',
parser.add_argument('--log-level', type=str, default='DEBUG',
help='logging level [debug, info, error]')
parser.add_argument('--dataset', type=str, default='avazu',
help='what test dataset to use')
Expand Down Expand Up @@ -78,18 +78,18 @@ def next(self):

datasets = { 'kdda' : kdda, 'avazu' : avazu }

def regression_model(feature_dim):
initializer = mx.initializer.Normal()
def linear_model(feature_dim):
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
embed = mx.symbol.dot(x, v)
weight = mx.symbol.Variable("weight", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
bias = mx.symbol.Variable("bias", shape=(1,), init=norm_init)
dot = mx.symbol.dot(x, weight)
pred = mx.symbol.broadcast_add(dot, bias)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out")
model = mx.symbol.SoftmaxOutput(data=pred, label=y, name="out")
return model

if __name__ == '__main__':

# arg parser
args = parser.parse_args()
num_epoch = args.num_epoch
Expand Down Expand Up @@ -138,7 +138,7 @@ def regression_model(feature_dim):
train_data = DummyIter(train_data)

# model
model = regression_model(feature_dim)
model = linear_model(feature_dim)

# module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
Expand All @@ -148,11 +148,10 @@ def regression_model(feature_dim):
learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
mod.init_optimizer(optimizer=sgd, kvstore=kv)
# use accuracy as the metric
metric = mx.metric.create('MSE')
metric = mx.metric.create('Accuracy')

# start profiler
if profiler:
import random
name = 'profile_output_' + str(num_worker) + '.json'
mx.profiler.profiler_set_config(mode='all', filename=name)
mx.profiler.profiler_set_state('run')
Expand All @@ -162,31 +161,22 @@ def regression_model(feature_dim):
data_iter = iter(train_data)
for epoch in range(num_epoch):
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
while not end_of_batch:
for batch in data_iter:
nbatch += 1
batch = next_batch
# TODO(haibin) remove extra copy after Jun's change
row_ids = batch.data[0].indices.copyto(mx.cpu())
row_ids = batch.data[0].indices
# pull sparse weight
index = mod._exec_group.param_names.index('v')
kv.row_sparse_pull('v', mod._exec_group.param_arrays[index],
index = mod._exec_group.param_names.index('weight')
kv.row_sparse_pull('weight', mod._exec_group.param_arrays[index],
priority=-index, row_ids=[row_ids])
mod.forward_backward(batch)
# update parameters
mod.update()
try:
# pre fetch next batch
next_batch = next(data_iter)
if nbatch == num_batch:
raise StopIteration
except StopIteration:
end_of_batch = True
# accumulate prediction accuracy
mod.update_metric(metric, batch.label)
if nbatch == num_batch:
break
logging.info('epoch %d, %s' % (epoch, metric.get()))
if profiler:
mx.profiler.profiler_set_state('stop')
Expand Down

0 comments on commit 6fa078e

Please sign in to comment.