-
Notifications
You must be signed in to change notification settings - Fork 6.8k
expose group2ctx to module #8539
Changes from 2 commits
b73edd9
96e208f
c5ce66b
2589482
1e78402
c059e72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,10 +59,12 @@ class Module(BaseModule): | |
state_names : list of str | ||
states are similar to data and label, but not provided by data iterator. | ||
Instead they are initialized to 0 and can be set by `set_states()`. | ||
group2ctxs : list of dict of str to context | ||
Default is `None`. Mapping the `ctx_group` attribute to the context assignment. | ||
""" | ||
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), | ||
logger=logging, context=ctx.cpu(), work_load_list=None, | ||
fixed_param_names=None, state_names=None): | ||
fixed_param_names=None, state_names=None, group2ctxs=None): | ||
super(Module, self).__init__(logger=logger) | ||
|
||
if isinstance(context, ctx.Context): | ||
|
@@ -73,6 +75,11 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), | |
assert len(work_load_list) == len(self._context) | ||
self._work_load_list = work_load_list | ||
|
||
if group2ctxs is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems redundant since you already checked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some user may use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the same check in module not necessary then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, removed. |
||
group2ctxs = [None] * len(self._context) | ||
assert len(group2ctxs) == len(self._context) | ||
self._group2ctxs = group2ctxs | ||
|
||
self._symbol = symbol | ||
|
||
data_names = list(data_names) if data_names is not None else [] | ||
|
@@ -413,7 +420,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, | |
for_training, inputs_need_grad, | ||
shared_group, logger=self.logger, | ||
fixed_param_names=self._fixed_param_names, | ||
grad_req=grad_req, | ||
grad_req=grad_req, group2ctxs=self._group2ctxs, | ||
state_names=self._state_names) | ||
self._total_exec_bytes = self._exec_group._total_exec_bytes | ||
if shared_module is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,63 @@ def test_module_input_grads(): | |
assert np.all(c_grad == 3), c_grad | ||
|
||
|
||
def test_module_ctx_group(): | ||
with mx.AttrScope(ctx_group='dev1'): | ||
a = mx.symbol.Variable('a') | ||
a = a * 2 | ||
with mx.AttrScope(ctx_group='dev2'): | ||
b = mx.symbol.Variable('b') | ||
c = a + b | ||
shape = (2, 5) | ||
mod1 = mx.mod.Module(c, context=[mx.cpu(0)], data_names=['a', 'b'], label_names=None, | ||
group2ctxs=[{'dev1':mx.cpu(1),'dev2':mx.cpu(2)}]) | ||
mod1.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) | ||
mod1.init_params() | ||
mod1.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) | ||
mod1.backward([mx.nd.ones(shape)]) | ||
mod1_input_grads = mod1.get_input_grads() | ||
|
||
mod2 = mx.mod.Module(c, data_names=['a', 'b'], label_names=None) | ||
mod2.bind(data_shapes=[['a', shape], ['b', shape]], inputs_need_grad=True) | ||
mod2.init_params() | ||
mod2.forward(data_batch=mx.io.DataBatch(data=[mx.nd.ones(shape), mx.nd.ones(shape)]), is_train=True) | ||
mod2.backward([mx.nd.ones(shape)]) | ||
mod2_input_grads = mod2.get_input_grads() | ||
|
||
assert np.all(mod1_input_grads[0].asnumpy() == mod2_input_grads[0].asnumpy()) | ||
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy()) | ||
|
||
|
||
def test_bucket_module_ctx_group(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this test against? Anything to assert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to verify that the change in bucket module is OK, i.e. the symbol can successfully be binded to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah just add |
||
num_hidden = 10 | ||
batch_size = 5 | ||
def sym_gen(seq_len): | ||
with mx.AttrScope(ctx_group='dev1'): | ||
data = mx.symbol.Variable('data') | ||
weight = mx.symbol.Variable('dev1_weight') | ||
bias = mx.symbol.Variable('dev1_bias') | ||
fc = data | ||
for i in range(seq_len): | ||
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, | ||
name='dev1_fc_%d' % i, num_hidden=num_hidden) | ||
with mx.AttrScope(ctx_group='dev2'): | ||
label = mx.symbol.Variable('label') | ||
weight = mx.symbol.Variable('dev2_weight') | ||
bias = mx.symbol.Variable('dev2_bias') | ||
for i in range(seq_len): | ||
fc = mx.symbol.FullyConnected(data=fc, weight=weight, bias=bias, | ||
name='dev2_fc_%d' % i, num_hidden=num_hidden) | ||
sym = mx.symbol.SoftmaxOutput(fc, label, name='softmax') | ||
|
||
return sym, ('data',), ('label',) | ||
|
||
mod = mx.mod.BucketingModule(sym_gen=sym_gen, default_bucket_key=10, context=[mx.cpu(0)], | ||
group2ctxs=[{'dev1':mx.cpu(1), 'dev2':mx.cpu(2)}]) | ||
mod.bind(data_shapes=[['data', (batch_size, num_hidden)]], | ||
label_shapes=[['label', (batch_size,)]], | ||
for_training=True, inputs_need_grad=True) | ||
|
||
|
||
def test_module_layout(): | ||
sym = mx.sym.Variable('data') | ||
sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm. Is Module::bind a better place to put the
group2ctxs
arg instead of the constructor, since it's only used during binding?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better for
group2ctxs
to be in the same place withcontexts
andsymbol
, i.e, in the constructor. Otherwise if an user usemod.fit
, he should passgroup2ctxs
tofit
, I think this way seems kind of weird.