Skip to content

Commit

Permalink
a user friendly way to use g2c in module and an example of g2c (apach…
Browse files Browse the repository at this point in the history
…e#8632)

* a user friendly way to use g2c in module

* also support g2c to be list

* update

* update test

* g2c example

* Update matrix_factorization_model_parallel.py

* address comments

* update

* update

* remove fc

* debug g2c

* Revert "debug g2c"

This reverts commit caabdc5.

* update

* move g2c example to another folder

* update

* readme
  • Loading branch information
ZiyueHuang authored and zhreshold committed Dec 14, 2017
1 parent af748b9 commit dfaa43b
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 33 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
56 changes: 56 additions & 0 deletions example/model-parallel/matrix_factorization/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import mxnet as mx


def get_movielens_data(prefix):
if not os.path.exists("%s.zip" % prefix):
print("Dataset MovieLens 10M not present. Downloading now ...")
os.system("wget http://files.grouplens.org/datasets/movielens/%s.zip" % prefix)
os.system("unzip %s.zip" % prefix)
os.system("cd ml-10M100K; sh split_ratings.sh; cd -;")

def get_movielens_iter(filename, batch_size):
"""Not particularly fast code to parse the text file and load into NDArrays.
return two data iters, one for train, the other for validation.
"""
print("Preparing data iterators for " + filename + " ... ")
user = []
item = []
score = []
with open(filename, 'r') as f:
num_samples = 0
for line in f:
tks = line.strip().split('::')
if len(tks) != 4:
continue
num_samples += 1
user.append((tks[0]))
item.append((tks[1]))
score.append((tks[2]))
# convert to ndarrays
user = mx.nd.array(user, dtype='int32')
item = mx.nd.array(item)
score = mx.nd.array(score)
# prepare data iters
data_train = {'user':user, 'item':item}
label_train = {'score':score}
iter_train = mx.io.NDArrayIter(data=data_train,label=label_train,
batch_size=batch_size, shuffle=True)
return mx.io.PrefetchingIter(iter_train)
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx

def matrix_fact_model_parallel_net(factor_size, num_hidden, max_user, max_item):
# set ctx_group attribute to 'dev1' for the symbols created in this scope,
# the symbols will be bound to the context that 'dev1' map to in group2ctxs
with mx.AttrScope(ctx_group='dev1'):
# input
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
# user feature lookup
user_weight = mx.symbol.Variable('user_weight')
user = mx.symbol.Embedding(data=user, weight=user_weight,
input_dim=max_user, output_dim=factor_size)
# item feature lookup
item_weight = mx.symbol.Variable('item_weight')
item = mx.symbol.Embedding(data=item, weight=item_weight,
input_dim=max_item, output_dim=factor_size)
# set ctx_group attribute to 'dev2' for the symbols created in this scope,
# the symbols will be bound to the context that 'dev2' map to in group2ctxs
with mx.AttrScope(ctx_group='dev2'):
# non-linear transformation of user features
user = mx.symbol.Activation(data=user, act_type='relu')
fc_user_weight = mx.symbol.Variable('fc_user_weight')
fc_user_bias = mx.symbol.Variable('fc_user_bias')
user = mx.symbol.FullyConnected(data=user, weight=fc_user_weight, bias=fc_user_bias, num_hidden=num_hidden)
# non-linear transformation of user features
item = mx.symbol.Activation(data=item, act_type='relu')
fc_item_weight = mx.symbol.Variable('fc_item_weight')
fc_item_bias = mx.symbol.Variable('fc_item_bias')
item = mx.symbol.FullyConnected(data=item, weight=fc_item_weight, bias=fc_item_bias, num_hidden=num_hidden)
# predict by the inner product, which is elementwise product and then sum
pred = user * item
pred = mx.symbol.sum(data=pred, axis=1)
pred = mx.symbol.Flatten(data=pred)
# label
score = mx.symbol.Variable('score')
# loss layer
pred = mx.symbol.LinearRegressionOutput(data=pred, label=score)
return pred
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import logging
import time
import mxnet as mx
import numpy as np
from get_data import get_movielens_iter, get_movielens_data
from matrix_fact_parallel_model import matrix_fact_model_parallel_net


logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser(description="Run model parallel version of matrix factorization",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-epoch', type=int, default=3,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=256,
help='number of examples per batch')
parser.add_argument('--print-every', type=int, default=100,
help='logging interval')
parser.add_argument('--factor-size', type=int, default=128,
help="the factor size of the embedding operation")
parser.add_argument('--num-gpus', type=int, default=2,
help="number of gpus to use")

MOVIELENS = {
'dataset': 'ml-10m',
'train': './ml-10M100K/r1.train',
'val': './ml-10M100K/r1.test',
'max_user': 71569,
'max_movie': 65135,
}

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)

# arg parser
args = parser.parse_args()
logging.info(args)
num_epoch = args.num_epoch
batch_size = args.batch_size
optimizer = 'sgd'
factor_size = args.factor_size
print_every = args.print_every
num_gpus = args.num_gpus

momentum = 0.9
learning_rate = 0.1

# prepare dataset and iterators
max_user = MOVIELENS['max_user']
max_movies = MOVIELENS['max_movie']
get_movielens_data(MOVIELENS['dataset'])
train_iter = get_movielens_iter(MOVIELENS['train'], batch_size)
val_iter = get_movielens_iter(MOVIELENS['val'], batch_size)

# construct the model
net = matrix_fact_model_parallel_net(factor_size, factor_size, max_user, max_movies)

# construct the module
# map the ctx_group attribute to the context assignment
group2ctxs={'dev1':mx.cpu(), 'dev2':[mx.gpu(i) for i in range(num_gpus)]}
mod = mx.module.Module(symbol=net, context=[mx.cpu()]*num_gpus, data_names=['user', 'item'],
label_names=['score'], group2ctxs=group2ctxs)

# the initializer uesd to initialize the parameters
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)

# the parameters for the optimizer constructor
optimizer_params = {
'learning_rate': learning_rate,
'wd': 1e-4,
'momentum': momentum,
'rescale_grad': 1.0/batch_size}

# use MSE as the metric
metric = mx.metric.create(['MSE'])

speedometer = mx.callback.Speedometer(batch_size, print_every)

# start training
mod.fit(train_iter,
val_iter,
eval_metric = metric,
num_epoch = num_epoch,
optimizer = optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
batch_end_callback = speedometer)
6 changes: 6 additions & 0 deletions example/model-parallel/matrix_factorization/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Model Parallel Matrix Factorization
==============

The example demonstrates the basic usage of `group2ctxs` in `Module`, which allows one part of the model trained on cpu and the other on gpu.

- `python matrix_factorization_model_parallel.py --num-gpus 2`
3 changes: 2 additions & 1 deletion python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class BucketingModule(BaseModule):
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
compression_params : dict
Specifies type of gradient compression and additional arguments depending
Expand Down
37 changes: 32 additions & 5 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,35 @@ def _merge_multi_context(outputs, major_axis):
rets.append(tensors[0])
return rets

def _prepare_group2ctxs(group2ctxs, ctx_len):
"""Prepare the group2contexts, will duplicate the context
if some ctx_group map to only one context.
"""
if group2ctxs is None:
return [None] * ctx_len
elif isinstance(group2ctxs, list):
assert(len(group2ctxs) == ctx_len), "length of group2ctxs\
should be %d" % ctx_len
return group2ctxs
elif isinstance(group2ctxs, dict):
ret = [{}] * ctx_len
for k, v in group2ctxs.items():
ctxs = None
if isinstance(v, ctx.Context):
ctxs = [v] * ctx_len
else:
if len(v) == 1:
ctxs = v * ctx_len
else:
assert(len(v) == ctx_len), "length of group2ctxs[%s]\
should be %d or 1" % (k, ctx_len)
ctxs = v
for i in range(ctx_len):
ret[i][k] = ctxs[i]
return ret
else:
assert(False), "group2ctxs should be list of dict of str to context,\
or dict of str to context or list of context"

class DataParallelExecutorGroup(object):
"""A group of executors that lives on a group of devices.
Expand Down Expand Up @@ -139,7 +168,8 @@ class DataParallelExecutorGroup(object):
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
Expand All @@ -152,10 +182,7 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_
self.symbol = symbol
self.contexts = contexts
self.workload = workload
if group2ctxs is None:
group2ctxs = [None] * len(self.contexts)
assert len(group2ctxs) == len(self.contexts)
self.group2ctxs = group2ctxs
self.group2ctxs = _prepare_group2ctxs(group2ctxs, len(contexts))

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class Module(BaseModule):
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
compression_params : dict
Specifies type of gradient compression and additional arguments depending
Expand Down
61 changes: 35 additions & 26 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,40 @@ def test_module_input_grads():


def test_module_ctx_group():
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())

def check_module_ctx_group(ctxs, group2ctxs):
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=ctxs, data_names=['a', 'b'], label_names=None,
group2ctxs=group2ctxs)
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, context=ctxs, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())

check_module_ctx_group([mx.cpu(0)], {'dev1': mx.cpu(1), 'dev2': mx.cpu(2)})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)],
[{'dev1': mx.cpu(2), 'dev2': mx.cpu(3)}, {'dev1': mx.cpu(4), 'dev2': mx.cpu(5)}])
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1': mx.cpu(2), 'dev2': mx.cpu(3)})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1': mx.cpu(2), 'dev2': [mx.cpu(3)]})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1':mx.cpu(2), 'dev2':[mx.cpu(3), mx.cpu(3)]})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)],
{'dev1':[mx.cpu(2), mx.cpu(2)], 'dev2':[mx.cpu(3), mx.cpu(3)]})

def test_bucket_module_ctx_group():
num_hidden = 10
Expand All @@ -121,7 +130,7 @@ def sym_gen(seq_len):
return sym, ('data',), ('label',)

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
group2ctxs=[{'dev1': mx.cpu(1), 'dev2': mx.cpu(2)}])
mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
label_shapes=[['label', (batch_size,)]],
for_training=True, inputs_need_grad=True)
Expand Down

0 comments on commit dfaa43b

Please sign in to comment.