Skip to content

Commit

Permalink
cpu sparse embedding op (apache#8460)
Browse files Browse the repository at this point in the history
* cpu embedding draft

* clean up

* fix omp thread call

* add sparse embedding example

* check bound with signel thread

* add note

* add comments

* add operator note

* support rsp weight sharing for bucketing

* improve workload balance in take add grad rsp kernel

* use MSHADOW_CINLINE for cpu kernel

* review comments. add unit test for shared rsp weight

* remove indexing op-inl.h

* Trigger

* Trigger
  • Loading branch information
eric-haibin-lin authored and Olivier committed Nov 9, 2017
1 parent 7d91d1d commit 8892e6d
Show file tree
Hide file tree
Showing 15 changed files with 926 additions and 114 deletions.
60 changes: 59 additions & 1 deletion example/sparse/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,28 @@
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import os, gzip
import sys
import mxnet as mx

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size

for batch in real_iter:
self.the_batch = batch
break

def __iter__(self):
return self

def next(self):
return self.the_batch

def get_libsvm_data(data_dir, data_name, url):
if not os.path.isdir(data_dir):
Expand All @@ -31,3 +50,42 @@ def get_libsvm_data(data_dir, data_name, url):
os.system("bzip2 -d %r" % data_name + ".bz2")
print("Dataset " + data_name + " is now present.")
os.chdir("..")

def get_movielens_data(prefix):
if not os.path.exists("%s.zip" % prefix):
print("Dataset MovieLens 10M not present. Downloading now ...")
os.system("wget http://files.grouplens.org/datasets/movielens/%s.zip" % prefix)
os.system("unzip %s.zip" % prefix)
os.system("cd ml-10M100K; sh split_ratings.sh; cd -;")

def get_movielens_iter(filename, batch_size, dummy_iter):
"""Not particularly fast code to parse the text file and load into NDArrays.
return two data iters, one for train, the other for validation.
"""
print("Preparing data iterators for " + filename + " ... ")
user = []
item = []
score = []
with file(filename) as f:
num_samples = 0
for line in f:
tks = line.strip().split('::')
if len(tks) != 4:
continue
num_samples += 1
user.append((tks[0]))
item.append((tks[1]))
score.append((tks[2]))
if dummy_iter and num_samples > batch_size * 10:
break
# convert to ndarrays
user = mx.nd.array(user, dtype='int32')
item = mx.nd.array(item)
score = mx.nd.array(score)
# prepare data iters
data_train = {'user':user, 'item':item}
label_train = {'score':score}
iter_train = mx.io.NDArrayIter(data=data_train,label=label_train,
batch_size=batch_size, shuffle=True)
iter_train = DummyIter(iter_train) if dummy_iter else iter_train
return mx.io.PrefetchingIter(iter_train)
51 changes: 51 additions & 0 deletions example/sparse/matrix_fact_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx

def matrix_fact_net(factor_size, num_hidden, max_user, max_item, sparse_embed=True):
# input
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')
if sparse_embed:
# user feature lookup
user_weight = mx.symbol.Variable('user_weight', stype='row_sparse')
user = mx.symbol.contrib.SparseEmbedding(data=user, weight=user_weight,
input_dim=max_user, output_dim=factor_size)
# item feature lookup
item_weight = mx.symbol.Variable('item_weight', stype='row_sparse')
item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight,
input_dim=max_item, output_dim=factor_size)
else:
# user feature lookup
user = mx.symbol.Embedding(data=user, input_dim=max_user, output_dim=factor_size)
# item feature lookup
item = mx.symbol.Embedding(data=item, input_dim=max_item, output_dim=factor_size)
# non-linear transformation of user features
user = mx.symbol.Activation(data=user, act_type='relu')
user = mx.symbol.FullyConnected(data=user, num_hidden=num_hidden)
# non-linear transformation of item features
item = mx.symbol.Activation(data=item, act_type='relu')
item = mx.symbol.FullyConnected(data=item, num_hidden=num_hidden)
# predict by the inner product, which is elementwise product and then sum
pred = user * item
pred = mx.symbol.sum(data=pred, axis = 1)
pred = mx.symbol.Flatten(data=pred)
# loss layer
pred = mx.symbol.LinearRegressionOutput(data=pred, label=score)
return pred
111 changes: 111 additions & 0 deletions example/sparse/matrix_factorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import logging
import time
import mxnet as mx
import numpy as np
from get_data import get_movielens_iter, get_movielens_data
from matrix_fact_model import matrix_fact_net
logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser(description="Run matrix factorization with sparse embedding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-epoch', type=int, default=3,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=128,
help='number of examples per batch')
parser.add_argument('--print-every', type=int, default=100,
help='logging frequency')
parser.add_argument('--factor-size', type=int, default=128,
help="the factor size of the embedding operation")
parser.add_argument('--use-dense', action='store_true',
help="use the dense embedding operator")
parser.add_argument('--dummy-iter', action='store_true',
help="use the dummy data iterator for speed test")

MOVIELENS = {
'dataset': 'ml-10m',
'train': './ml-10M100K/r1.train',
'val': './ml-10M100K/r1.test',
'max_user': 71569,
'max_movie': 65135,
}

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)

# arg parser
args = parser.parse_args()
logging.info(args)
num_epoch = args.num_epoch
batch_size = args.batch_size
optimizer = 'sgd'
use_sparse = not args.use_dense
factor_size = args.factor_size
dummy_iter = args.dummy_iter
print_every = args.print_every

momentum = 0.9
ctx = mx.cpu(0)
learning_rate = 0.1

# prepare dataset and iterators
max_user = MOVIELENS['max_user']
max_movies = MOVIELENS['max_movie']
get_movielens_data(MOVIELENS['dataset'])
train_iter = get_movielens_iter(MOVIELENS['train'], batch_size, dummy_iter)
val_iter = get_movielens_iter(MOVIELENS['val'], batch_size, dummy_iter)

# construct the model
net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse)
a = time.time()

# initialize the module
mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'],
label_names=['score'])
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
optim = mx.optimizer.create(optimizer, learning_rate=learning_rate, momentum=momentum,
wd=1e-4, rescale_grad=1.0/batch_size)
mod.init_optimizer(optimizer=optim)
# use MSE as the metric
metric = mx.metric.create(['MSE'])
speedometer = mx.callback.Speedometer(batch_size, print_every)
logging.info('Training started ...')
for epoch in range(num_epoch):
nbatch = 0
metric.reset()
for batch in train_iter:
nbatch += 1
mod.forward_backward(batch)
# update all parameters
mod.update()
# update training metric
mod.update_metric(metric, batch.label)
speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=metric, locals=locals())
speedometer(speedometer_param)
# evaluate metric on validation dataset
score = mod.score(val_iter, ['MSE'])
logging.info('epoch %d, eval MSE = %s ' % (epoch, score[0][1]))
# reset the iterator for next pass of data
train_iter.reset()
val_iter.reset()
logging.info('Training completed.')
11 changes: 9 additions & 2 deletions example/sparse/readme.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Example
===========
This folder contains examples using the sparse feature in MXNet.
This folder contains examples using the sparse feature in MXNet. They are for demonstration purpose only.

## Linear Classification
## Linear Classification Using Sparse Matrix Multiplication

The example demonstrates the basic usage of the sparse feature in MXNet to speedup computation. It utilizes the sparse data loader, sparse operators and a sparse gradient updater to train a linear model on the [Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset.

Expand All @@ -12,3 +12,10 @@ Notes on Distributed Training:

- For distributed training, please use the `../../tools/launch.py` script to launch a cluster.
- For example, to run two workers and two servers with one machine, run `../../tools/launch.py -n 2 --launcher=local python linear_classification.py --kvstore=dist_async`

## Matrix Factorization Using Sparse Embedding

The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples.

- `python matrix_factorization.py`
- To compare the train speed with (dense) Embedding, run `python matrix_factorization.py --use-dense`
Loading

0 comments on commit 8892e6d

Please sign in to comment.