diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index f3c7ecbddc05..dd6cafb277f0 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -52,10 +52,12 @@ class BucketingModule(BaseModule): state_names : list of str States are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states() + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(BucketingModule, self).__init__(logger=logger) assert default_bucket_key is not None @@ -77,6 +79,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, self._state_names = state_names self._context = context self._work_load_list = work_load_list + self._group2ctxs = group2ctxs self._buckets = {} self._curr_module = None @@ -319,7 +322,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=grad_req) self._curr_module = module @@ -349,7 +352,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 0f3c079f8fcb..ea7651b65d93 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -139,10 +139,12 @@ class DataParallelExecutorGroup(object): Requirement for gradient accumulation. Can be 'write', 'add', or 'null' (default to 'write'). Can be specified globally (str) or for each argument (list, dict). + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group=None, logger=logging, - fixed_param_names=None, grad_req='write', state_names=None): + fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None): self.param_names = param_names self.arg_names = symbol.list_arguments() self.aux_names = symbol.list_auxiliary_states() @@ -150,6 +152,10 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_ self.symbol = symbol self.contexts = contexts self.workload = workload + if group2ctxs is None: + group2ctxs = [None] * len(self.contexts) + assert len(group2ctxs) == len(self.contexts) + self.group2ctxs = group2ctxs self.for_training = for_training self.inputs_need_grad = inputs_need_grad @@ -597,9 +603,11 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): if label_shapes is not None: input_types.update({x.name: x.dtype for x in label_shapes}) + group2ctx = self.group2ctxs[i] + executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req, type_dict=input_types, shared_arg_names=self.param_names, - shared_exec=shared_exec, + shared_exec=shared_exec, group2ctx=group2ctx, shared_buffer=shared_data_arrays, **input_shapes) self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1]) return executor diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 4c20a6fed542..8301330313ae 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -59,10 +59,12 @@ class Module(BaseModule): state_names : list of str states are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by `set_states()`. + group2ctxs : list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. """ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(Module, self).__init__(logger=logger) if isinstance(context, ctx.Context): @@ -73,6 +75,8 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), assert len(work_load_list) == len(self._context) self._work_load_list = work_load_list + self._group2ctxs = group2ctxs + self._symbol = symbol data_names = list(data_names) if data_names is not None else [] @@ -413,7 +417,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, for_training, inputs_need_grad, shared_group, logger=self.logger, fixed_param_names=self._fixed_param_names, - grad_req=grad_req, + grad_req=grad_req, group2ctxs=self._group2ctxs, state_names=self._state_names) self._total_exec_bytes = self._exec_group._total_exec_bytes if shared_module is not None: diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 180d2ee05242..722ba9885c81 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -70,6 +70,63 @@ def test_module_input_grads(): assert np.all(c_grad == 3), c_grad +def test_module_ctx_group(): + with mx.AttrScope(ctx_group='dev1'): + a = mx.symbol.Variable('a') + a = a * 2 + with mx.AttrScope(ctx_group='dev2'): + b = mx.symbol.Variable('b') + c = a + b + shape = (2, 5) + mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None, + group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}]) + mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) + mod1.init_params() + mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) + mod1.backward([mx.nd.ones(shape)]) + mod1_input_grads = mod1.get_input_grads() + + mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None) + mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) + mod2.init_params() + mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) + mod2.backward([mx.nd.ones(shape)]) + mod2_input_grads = mod2.get_input_grads() + + assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy()) + assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy()) + + +def test_bucket_module_ctx_group(): + num_hidden = 10 + batch_size = 5 + def sym_gen(seq_len): + with mx.AttrScope(ctx_group='dev1'): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('dev1_weight') + bias = mx.symbol.Variable('dev1_bias') + fc = data + for i in range(seq_len): + fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, + name='dev1_fc_%d' % i, num_hidden=num_hidden) + with mx.AttrScope(ctx_group='dev2'): + label = mx.symbol.Variable('label') + weight = mx.symbol.Variable('dev2_weight') + bias = mx.symbol.Variable('dev2_bias') + for i in range(seq_len): + fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, + name='dev2_fc_%d' % i, num_hidden=num_hidden) + sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax') + + return sym, ('data',), ('label',) + + mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)], + group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}]) + mod.bind(data_shapes=[['data', (batch_size, num_hidden)]], + label_shapes=[['label', (batch_size,)]], + for_training=True, inputs_need_grad=True) + assert(mod.binded) + def test_module_layout(): sym = mx.sym.Variable('data') sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')