Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

expose group2ctx to module #8539

Merged
merged 6 commits into from
Nov 10, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ class BucketingModule(BaseModule):
state_names : list of str
States are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
fixed_param_names=None, state_names=None, group2ctxs=None):
super(BucketingModule, self).__init__(logger=logger)

assert default_bucket_key is not None
Expand All @@ -77,6 +79,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
self._state_names = state_names
self._context = context
self._work_load_list = work_load_list
self._group2ctxs = group2ctxs

self._buckets = {}
self._curr_module = None
Expand Down Expand Up @@ -319,7 +322,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
Expand Down Expand Up @@ -349,7 +352,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
state_names=self._state_names, group2ctxs=self._group2ctxs)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
Expand Down
12 changes: 10 additions & 2 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,23 @@ class DataParallelExecutorGroup(object):
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names,
for_training, inputs_need_grad, shared_group=None, logger=logging,
fixed_param_names=None, grad_req='write', state_names=None):
fixed_param_names=None, grad_req='write', state_names=None, group2ctxs=None):
self.param_names = param_names
self.arg_names = symbol.list_arguments()
self.aux_names = symbol.list_auxiliary_states()

self.symbol = symbol
self.contexts = contexts
self.workload = workload
if group2ctxs is None:
group2ctxs = [None] * len(self.contexts)
assert len(group2ctxs) == len(self.contexts)
self.group2ctxs = group2ctxs

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
Expand Down Expand Up @@ -597,9 +603,11 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})

group2ctx = self.group2ctxs[i]

executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
shared_exec=shared_exec,
shared_exec=shared_exec, group2ctx=group2ctx,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor
Expand Down
11 changes: 9 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ class Module(BaseModule):
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
group2ctxs : list of dict of str to context
Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
fixed_param_names=None, state_names=None, group2ctxs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Is Module::bind a better place to put the group2ctxs arg instead of the constructor, since it's only used during binding?

Copy link
Member Author

@ZiyueHuang ZiyueHuang Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better for group2ctxs to be in the same place with contextsand symbol, i.e, in the constructor. Otherwise if an user use mod.fit, he should pass group2ctxs to fit, I think this way seems kind of weird.

super(Module, self).__init__(logger=logger)

if isinstance(context, ctx.Context):
Expand All @@ -73,6 +75,11 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list

if group2ctxs is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems redundant since you already checked None in DataExecutorGroup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some user may use DataExecutorGroup directly. And there are such cases in test_shared_exec_group in test_module.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the same check in module not necessary then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, removed.

group2ctxs = [None] * len(self._context)
assert len(group2ctxs) == len(self._context)
self._group2ctxs = group2ctxs

self._symbol = symbol

data_names = list(data_names) if data_names is not None else []
Expand Down Expand Up @@ -413,7 +420,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
grad_req=grad_req,
grad_req=grad_req, group2ctxs=self._group2ctxs,
state_names=self._state_names)
self._total_exec_bytes = self._exec_group._total_exec_bytes
if shared_module is not None:
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,63 @@ def test_module_input_grads():
assert np.all(c_grad == 3), c_grad


def test_module_ctx_group():
with mx.AttrScope(ctx_group='dev1'):
a = mx.symbol.Variable('a')
a = a * 2
with mx.AttrScope(ctx_group='dev2'):
b = mx.symbol.Variable('b')
c = a + b
shape = (2, 5)
mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None,
group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}])
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod1.init_params()
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod1.backward([mx.nd.ones(shape)])
mod1_input_grads = mod1.get_input_grads()

mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None)
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True)
mod2.init_params()
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True)
mod2.backward([mx.nd.ones(shape)])
mod2_input_grads = mod2.get_input_grads()

assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy())
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())


def test_bucket_module_ctx_group():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test against? Anything to assert?

Copy link
Member Author

@ZiyueHuang ZiyueHuang Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to verify that the change in bucket module is OK, i.e. the symbol can successfully be binded to group2ctxs. Should I remove this test or assert self.binded?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah just add assert mod.binded

num_hidden = 10
batch_size = 5
def sym_gen(seq_len):
with mx.AttrScope(ctx_group='dev1'):
data = mx.symbol.Variable('data')
weight = mx.symbol.Variable('dev1_weight')
bias = mx.symbol.Variable('dev1_bias')
fc = data
for i in range(seq_len):
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
name='dev1_fc_%d' % i, num_hidden=num_hidden)
with mx.AttrScope(ctx_group='dev2'):
label = mx.symbol.Variable('label')
weight = mx.symbol.Variable('dev2_weight')
bias = mx.symbol.Variable('dev2_bias')
for i in range(seq_len):
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias,
name='dev2_fc_%d' % i, num_hidden=num_hidden)
sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax')

return sym, ('data',), ('label',)

mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)],
group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}])
mod.bind(data_shapes=[['data', (batch_size, num_hidden)]],
label_shapes=[['label', (batch_size,)]],
for_training=True, inputs_need_grad=True)


def test_module_layout():
sym = mx.sym.Variable('data')
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC')
Expand Down