From b73edd93dc9617666f557c07f77e3979c0a90d26 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Sat, 4 Nov 2017 16:30:25 +0800 Subject: [PATCH 1/4] expose group2ctx to module --- python/mxnet/module/bucketing_module.py | 9 ++-- python/mxnet/module/executor_group.py | 12 +++++- python/mxnet/module/module.py | 11 ++++- tests/python/unittest/test_module.py | 57 +++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 7 deletions(-) diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index f3c7ecbddc05..dd6cafb277f0 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -52,10 +52,12 @@ class BucketingModule(BaseModule): state_names : list of str States are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states() + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(BucketingModule, self).__init__(logger=logger) assert default_bucket_key is not None @@ -77,6 +79,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, self._state_names = state_names self._context = context self._work_load_list = work_load_list + self._group2ctxs = group2ctxs self._buckets = {} self._curr_module = None @@ -319,7 +322,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=grad_req) self._curr_module = module @@ -349,7 +352,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 0f3c079f8fcb..ea7651b65d93 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -139,10 +139,12 @@ class DataParallelExecutorGroup(object): Requirement for gradient accumulation. Can be 'write', 'add', or 'null' (default to 'write'). Can be specified globally (str) or for each argument (list, dict). + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group=None, logger=logging, - fixed_param_names=None, grad_req='write', state_names=None): + fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None): self.param_names = param_names self.arg_names = symbol.list_arguments() self.aux_names = symbol.list_auxiliary_states() @@ -150,6 +152,10 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_ self.symbol = symbol self.contexts = contexts self.workload = workload + if group2ctxs is None: + group2ctxs = [None] * len(self.contexts) + assert len(group2ctxs) == len(self.contexts) + self.group2ctxs = group2ctxs self.for_training = for_training self.inputs_need_grad = inputs_need_grad @@ -597,9 +603,11 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): if label_shapes is not None: input_types.update({x.name: x.dtype for x in label_shapes}) + group2ctx = self.group2ctxs[i] + executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req, type_dict=input_types, shared_arg_names=self.param_names, - shared_exec=shared_exec, + shared_exec=shared_exec, group2ctx=group2ctx, shared_buffer=shared_data_arrays, **input_shapes) self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1]) return executor diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 4c20a6fed542..b90892530058 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -59,10 +59,12 @@ class Module(BaseModule): state_names : list of str states are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by `set_states()`. + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(Module, self).__init__(logger=logger) if isinstance(context, ctx.Context): @@ -73,6 +75,11 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), assert len(work_load_list) == len(self._context) self._work_load_list = work_load_list + if group2ctxs is None: + group2ctxs = [None] * len(self._context) + assert len(group2ctxs) == len(self._context) + self._group2ctxs = group2ctxs + self._symbol = symbol data_names = list(data_names) if data_names is not None else [] @@ -413,7 +420,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, for_training, inputs_need_grad, shared_group, logger=self.logger, fixed_param_names=self._fixed_param_names, - grad_req=grad_req, + grad_req=grad_req, group2ctxs=self._group2ctxs, state_names=self._state_names) self._total_exec_bytes = self._exec_group._total_exec_bytes if shared_module is not None: diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index d79657a85af5..6d32773f12cd 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -70,6 +70,63 @@ def test_module_input_grads(): assert np.all(c_grad == 3), c_grad +def test_module_ctx_group(): + with mx.AttrScope(ctx_group='dev1'): + a = mx.symbol.Variable('a') + a = a * 2 + with mx.AttrScope(ctx_group='dev2'): + b = mx.symbol.Variable('b') + c = a + b + shape = (2, 5) + mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None, + group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}]) + mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) + mod1.init_params() + mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) + mod1.backward([mx.nd.ones(shape)]) + mod1_input_grads = mod1.get_input_grads() + + mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None) + mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) + mod2.init_params() + mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) + mod2.backward([mx.nd.ones(shape)]) + mod2_input_grads = mod2.get_input_grads() + + assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy()) + assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy()) + + +def test_bucket_module_ctx_group(): + num_hidden = 10 + batch_size = 5 + def sym_gen(seq_len): + with mx.AttrScope(ctx_group='dev1'): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('dev1_weight') + bias = mx.symbol.Variable('dev1_bias') + for i in range(seq_len): + fc = mx.symbol.FullyConnected(data=data, weight=weight, bias=bias, + name='dev1_fc_%d' % i, num_hidden=num_hidden) + + with mx.AttrScope(ctx_group='dev2'): + label = mx.symbol.Variable('label') + weight = mx.symbol.Variable('dev2_weight') + bias = mx.symbol.Variable('dev2_bias') + for i in range(seq_len): + fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, + name='dev2_fc_%d' % i, num_hidden=num_hidden) + sym = mx.sym.SoftmaxOutput(fc, label, name='softmax') + + return sym, ('data',), ('label',) + + mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)], + group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}]) + mod.bind(data_shapes=[['data', (batch_size, num_hidden)]], + label_shapes=[['label', (batch_size,)]], + for_training=True, inputs_need_grad=True) + + def test_module_layout(): sym = mx.sym.Variable('data') sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') From 96e208f85b64e9b3440ccc78f0cd9978d7b59cc6 Mon Sep 17 00:00:00 2001 From: Ziyue Huang Date: Sat, 4 Nov 2017 04:28:08 -0500 Subject: [PATCH 2/4] Update test_module.py --- tests/python/unittest/test_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 6d32773f12cd..1a9b25e7a9c6 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -105,10 +105,10 @@ def sym_gen(seq_len): data = mx.symbol.Variable('data') weight = mx.symbol.Variable('dev1_weight') bias = mx.symbol.Variable('dev1_bias') + fc = data for i in range(seq_len): - fc = mx.symbol.FullyConnected(data=data, weight=weight, bias=bias, + fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, name='dev1_fc_%d' % i, num_hidden=num_hidden) - with mx.AttrScope(ctx_group='dev2'): label = mx.symbol.Variable('label') weight = mx.symbol.Variable('dev2_weight') @@ -116,7 +116,7 @@ def sym_gen(seq_len): for i in range(seq_len): fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, name='dev2_fc_%d' % i, num_hidden=num_hidden) - sym = mx.sym.SoftmaxOutput(fc, label, name='softmax') + sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax') return sym, ('data',), ('label',) From 258948292ef91221fc17eaeddd3b2e78e4a38683 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Thu, 9 Nov 2017 22:34:59 +0800 Subject: [PATCH 3/4] address comments --- python/mxnet/module/module.py | 4 +--- tests/python/unittest/test_module.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index b90892530058..2ea0eacd21c9 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -75,9 +75,7 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), assert len(work_load_list) == len(self._context) self._work_load_list = work_load_list - if group2ctxs is None: - group2ctxs = [None] * len(self._context) - assert len(group2ctxs) == len(self._context) + # length of group2ctxs should be equal with length of context self._group2ctxs = group2ctxs self._symbol = symbol diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index ea0ecd136cd1..8f63d8c9dbcc 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -125,7 +125,7 @@ def sym_gen(seq_len): mod.bind(data_shapes=[['data', (batch_size, num_hidden)]], label_shapes=[['label', (batch_size,)]], for_training=True, inputs_need_grad=True) - + assert(mod.binded) def test_module_layout(): sym = mx.sym.Variable('data') From c059e7240ff73aefcf4da04fb5c66a85ef13e469 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Thu, 9 Nov 2017 22:50:21 +0800 Subject: [PATCH 4/4] update --- python/mxnet/module/module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 2ea0eacd21c9..8301330313ae 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -75,7 +75,6 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), assert len(work_load_list) == len(self._context) self._work_load_list = work_load_list - # length of group2ctxs should be equal with length of context self._group2ctxs = group2ctxs self._symbol = symbol