Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

a user friendly way to use g2c in module and an example of g2c #8632

Merged
merged 19 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions example/model-parallel/matrix_factorization/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import mxnet as mx


def get_movielens_data(prefix):
if not os.path.exists("%s.zip" % prefix):
print("Dataset MovieLens 10M not present. Downloading now ...")
os.system("wget http://files.grouplens.org/datasets/movielens/%s.zip" % prefix)
os.system("unzip %s.zip" % prefix)
os.system("cd ml-10M100K; sh split_ratings.sh; cd -;")

def get_movielens_iter(filename, batch_size):
"""Not particularly fast code to parse the text file and load into NDArrays.
return two data iters, one for train, the other for validation.
"""
print("Preparing data iterators for " + filename + " ... ")
user = []
item = []
score = []
with open(filename, 'r') as f:
num_samples = 0
for line in f:
tks = line.strip().split('::')
if len(tks) != 4:
continue
num_samples += 1
user.append((tks[0]))
item.append((tks[1]))
score.append((tks[2]))
# convert to ndarrays
user = mx.nd.array(user, dtype='int32')
item = mx.nd.array(item)
score = mx.nd.array(score)
# prepare data iters
data_train = {'user':user, 'item':item}
label_train = {'score':score}
iter_train = mx.io.NDArrayIter(data=data_train,label=label_train,
batch_size=batch_size, shuffle=True)
return mx.io.PrefetchingIter(iter_train)
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx

def matrix_fact_model_parallel_net(factor_size, num_hidden, max_user, max_item):
# set ctx_group attribute to 'dev1' for the symbols created in this scope,
# the symbols will be bound to the context that 'dev1' map to in group2ctxs
with mx.AttrScope(ctx_group='dev1'):
# input
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
# user feature lookup
user_weight = mx.symbol.Variable('user_weight')
user = mx.symbol.Embedding(data=user, weight=user_weight,
input_dim=max_user, output_dim=factor_size)
# item feature lookup
item_weight = mx.symbol.Variable('item_weight')
item = mx.symbol.Embedding(data=item, weight=item_weight,
input_dim=max_item, output_dim=factor_size)
# set ctx_group attribute to 'dev2' for the symbols created in this scope,
# the symbols will be bound to the context that 'dev2' map to in group2ctxs
with mx.AttrScope(ctx_group='dev2'):
# non-linear transformation of user features
user = mx.symbol.Activation(data=user, act_type='relu')
fc_user_weight = mx.symbol.Variable('fc_user_weight')
fc_user_bias = mx.symbol.Variable('fc_user_bias')
user = mx.symbol.FullyConnected(data=user, weight=fc_user_weight, bias=fc_user_bias, num_hidden=num_hidden)
# non-linear transformation of user features
item = mx.symbol.Activation(data=item, act_type='relu')
fc_item_weight = mx.symbol.Variable('fc_item_weight')
fc_item_bias = mx.symbol.Variable('fc_item_bias')
item = mx.symbol.FullyConnected(data=item, weight=fc_item_weight, bias=fc_item_bias, num_hidden=num_hidden)
# predict by the inner product, which is elementwise product and then sum
pred = user * item
pred = mx.symbol.sum(data=pred, axis=1)
pred = mx.symbol.Flatten(data=pred)
# label
score = mx.symbol.Variable('score')
# loss layer
pred = mx.symbol.LinearRegressionOutput(data=pred, label=score)
return pred
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import logging
import time
import mxnet as mx
import numpy as np
from get_data import get_movielens_iter, get_movielens_data
from matrix_fact_parallel_model import matrix_fact_model_parallel_net


logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser(description="Run model parallel version of matrix factorization",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-epoch', type=int, default=3,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=256,
help='number of examples per batch')
parser.add_argument('--print-every', type=int, default=100,
help='logging interval')
parser.add_argument('--factor-size', type=int, default=128,
help="the factor size of the embedding operation")
parser.add_argument('--num-gpus', type=int, default=2,
help="number of gpus to use")

MOVIELENS = {
'dataset': 'ml-10m',
'train': './ml-10M100K/r1.train',
'val': './ml-10M100K/r1.test',
'max_user': 71569,
'max_movie': 65135,
}

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)

# arg parser
args = parser.parse_args()
logging.info(args)
num_epoch = args.num_epoch
batch_size = args.batch_size
optimizer = 'sgd'
factor_size = args.factor_size
print_every = args.print_every
num_gpus = args.num_gpus

momentum = 0.9
learning_rate = 0.1

# prepare dataset and iterators
max_user = MOVIELENS['max_user']
max_movies = MOVIELENS['max_movie']
get_movielens_data(MOVIELENS['dataset'])
train_iter = get_movielens_iter(MOVIELENS['train'], batch_size)
val_iter = get_movielens_iter(MOVIELENS['val'], batch_size)

# construct the model
net = matrix_fact_model_parallel_net(factor_size, factor_size, max_user, max_movies)

# construct the module
# map the ctx_group attribute to the context assignment
group2ctxs={'dev1':mx.cpu(), 'dev2':[mx.gpu(i) for i in range(num_gpus)]}
mod = mx.module.Module(symbol=net, context=[mx.cpu()]*num_gpus, data_names=['user', 'item'],
label_names=['score'], group2ctxs=group2ctxs)

# the initializer uesd to initialize the parameters
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)

# the parameters for the optimizer constructor
optimizer_params = {
'learning_rate': learning_rate,
'wd': 1e-4,
'momentum': momentum,
'rescale_grad': 1.0/batch_size}

# use MSE as the metric
metric = mx.metric.create(['MSE'])

speedometer = mx.callback.Speedometer(batch_size, print_every)

# start training
mod.fit(train_iter,
val_iter,
eval_metric = metric,
num_epoch = num_epoch,
optimizer = optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
batch_end_callback = speedometer)
6 changes: 6 additions & 0 deletions example/model-parallel/matrix_factorization/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Model Parallel Matrix Factorization
==============

The example demonstrates the basic usage of `group2ctxs` in `Module`, which allows one part of the model trained on cpu and the other on gpu.

- `python matrix_factorization_model_parallel.py --num-gpus 2`
3 changes: 2 additions & 1 deletion python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class BucketingModule(BaseModule):
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
compression_params : dict
Specifies type of gradient compression and additional arguments depending
Expand Down
37 changes: 32 additions & 5 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,35 @@ def _merge_multi_context(outputs, major_axis):
rets.append(tensors[0])
return rets

def _prepare_group2ctxs(group2ctxs, ctx_len):
"""Prepare the group2contexts, will duplicate the context
if some ctx_group map to only one context.
"""
if group2ctxs is None:
return [None] * ctx_len
elif isinstance(group2ctxs, list):
assert(len(group2ctxs) == ctx_len), "length of group2ctxs\
should be %d" % ctx_len
return group2ctxs
elif isinstance(group2ctxs, dict):
ret = [{}] * ctx_len
for k, v in group2ctxs.items():
ctxs = None
if isinstance(v, ctx.Context):
ctxs = [v] * ctx_len
else:
if len(v) == 1:
ctxs = v * ctx_len
else:
assert(len(v) == ctx_len), "length of group2ctxs[%s]\
should be %d or 1" % (k, ctx_len)
ctxs = v
for i in range(ctx_len):
ret[i][k] = ctxs[i]
return ret
else:
assert(False), "group2ctxs should be list of dict of str to context,\
or dict of str to context or list of context"

class DataParallelExecutorGroup(object):
"""A group of executors that lives on a group of devices.
Expand Down Expand Up @@ -139,7 +168,8 @@ class DataParallelExecutorGroup(object):
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
Expand All @@ -152,10 +182,7 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_
self.symbol = symbol
self.contexts = contexts
self.workload = workload
if group2ctxs is None:
group2ctxs = [None] * len(self.contexts)
assert len(group2ctxs) == len(self.contexts)
self.group2ctxs = group2ctxs
self.group2ctxs = _prepare_group2ctxs(group2ctxs, len(contexts))

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class Module(BaseModule):
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
group2ctxs : list of dict of str to context
group2ctxs : dict of str to context or list of context,
or list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
compression_params : dict
Specifies type of gradient compression and additional arguments depending
Expand Down
61 changes: 35 additions & 26 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,40 @@ def test_module_input_grads():


def test_module_ctx_group():
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())

def check_module_ctx_group(ctxs, group2ctxs):
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=ctxs, data_names=['a', 'b'], label_names=None,
group2ctxs=group2ctxs)
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, context=ctxs, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())

check_module_ctx_group([mx.cpu(0)], {'dev1': mx.cpu(1), 'dev2': mx.cpu(2)})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)],
[{'dev1': mx.cpu(2), 'dev2': mx.cpu(3)}, {'dev1': mx.cpu(4), 'dev2': mx.cpu(5)}])
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1': mx.cpu(2), 'dev2': mx.cpu(3)})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1': mx.cpu(2), 'dev2': [mx.cpu(3)]})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)], {'dev1':mx.cpu(2), 'dev2':[mx.cpu(3), mx.cpu(3)]})
check_module_ctx_group([mx.cpu(0), mx.cpu(1)],
{'dev1':[mx.cpu(2), mx.cpu(2)], 'dev2':[mx.cpu(3), mx.cpu(3)]})

def test_bucket_module_ctx_group():
num_hidden = 10
Expand All @@ -121,7 +130,7 @@ def sym_gen(seq_len):
return sym, ('data',), ('label',)

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
group2ctxs=[{'dev1': mx.cpu(1), 'dev2': mx.cpu(2)}])
mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
label_shapes=[['label', (batch_size,)]],
for_training=True, inputs_need_grad=True)
Expand Down