-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy()) | ||
|
||
|
||
def test_bucket_module_ctx_group(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this test against? Anything to assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to verify that the change in bucket module is OK, i.e. the symbol can successfully be binded to group2ctxs
. Should I remove this test or assert self.binded
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah just add assert mod.binded
python/mxnet/module/module.py
Outdated
@@ -73,6 +75,11 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), | |||
assert len(work_load_list) == len(self._context) | |||
self._work_load_list = work_load_list | |||
|
|||
if group2ctxs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems redundant since you already checked None
in DataExecutorGroup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some user may use DataExecutorGroup
directly. And there are such cases in test_shared_exec_group
in test_module.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the same check in module not necessary then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, removed.
""" | ||
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), | ||
logger=logging, context=ctx.cpu(), work_load_list=None, | ||
fixed_param_names=None, state_names=None): | ||
fixed_param_names=None, state_names=None, group2ctxs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm. Is Module::bind a better place to put the group2ctxs
arg instead of the constructor, since it's only used during binding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better for group2ctxs
to be in the same place with contexts
and symbol
, i.e, in the constructor. Otherwise if an user use mod.fit
, he should pass group2ctxs
to fit
, I think this way seems kind of weird.
@eric-haibin-lin CI passed. Anything remaining to address? |
I think accepted arguments should be in the form of:
1 and 2 are equivalent. 3 and 4 are equivalent. For the case of 3, we transform |
@eric-haibin-lin Got it. Working on it now. |
* expose group2ctx to module * Update test_module.py * address comments * update
* expose group2ctx to module * Update test_module.py * address comments * update
Description
Expose group2ctx to module. Will be helpful for model parallel, e.g., some sparse operators on cpu and other dense operators on gpu.
As a feature requested in #8168.
cc @eric-haibin-lin for review. Also thanks for your valuable suggestions!
Checklist
Essentials
make lint
)Changes
Comments