Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

expose group2ctx to module #8539

Merged
merged 6 commits into from
Nov 10, 2017
Merged

expose group2ctx to module #8539

merged 6 commits into from
Nov 10, 2017

Conversation

ZiyueHuang
Copy link
Member

Description

Expose group2ctx to module. Will be helpful for model parallel, e.g., some sparse operators on cpu and other dense operators on gpu.

As a feature requested in #8168.

cc @eric-haibin-lin for review. Also thanks for your valuable suggestions!

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • add unittest

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Intersting edge cases to note here

@eric-haibin-lin eric-haibin-lin self-assigned this Nov 8, 2017
assert np.all(mod1_input_grads[1].asnumpy() == mod2_input_grads[1].asnumpy())


def test_bucket_module_ctx_group():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test against? Anything to assert?

Copy link
Member Author

@ZiyueHuang ZiyueHuang Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to verify that the change in bucket module is OK, i.e. the symbol can successfully be binded to group2ctxs. Should I remove this test or assert self.binded?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah just add assert mod.binded

@@ -73,6 +75,11 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list

if group2ctxs is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems redundant since you already checked None in DataExecutorGroup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some user may use DataExecutorGroup directly. And there are such cases in test_shared_exec_group in test_module.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the same check in module not necessary then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, removed.

"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
fixed_param_names=None, state_names=None, group2ctxs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Is Module::bind a better place to put the group2ctxs arg instead of the constructor, since it's only used during binding?

Copy link
Member Author

@ZiyueHuang ZiyueHuang Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better for group2ctxs to be in the same place with contextsand symbol, i.e, in the constructor. Otherwise if an user use mod.fit, he should pass group2ctxs to fit, I think this way seems kind of weird.

@ZiyueHuang
Copy link
Member Author

@eric-haibin-lin CI passed. Anything remaining to address?

@piiswrong piiswrong merged commit c9bde1b into apache:master Nov 10, 2017
@eric-haibin-lin
Copy link
Member

I think accepted arguments should be in the form of:

  1. {'dev1':mx.cpu(), 'dev2':mx.gpu(0)}
  2. {'dev1':mx.cpu(), 'dev2':[mx.gpu(0)]}
  3. {'dev1':mx.cpu(), 'dev2':[mx.gpu(0), mx.gpu(1)]},
  4. {'dev1':[mx.cpu(), mx.cpu()], 'dev2':[mx.gpu(0), mx.gpu(1)]},

1 and 2 are equivalent. 3 and 4 are equivalent. For the case of 3, we transform g2x['dev1'] to [mx.cpu()] * 2 internally.

@ZiyueHuang
Copy link
Member Author

@eric-haibin-lin Got it. Working on it now.

eric-haibin-lin pushed a commit to eric-haibin-lin/mxnet that referenced this pull request Dec 3, 2017
* expose group2ctx to module

* Update test_module.py

* address comments

* update
@ZiyueHuang ZiyueHuang deleted the g2c branch January 30, 2018 11:33
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* expose group2ctx to module

* Update test_module.py

* address comments

* update
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants