Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-614] Adding Synchronized Batch Normalization #11502

Merged
merged 38 commits into from
Jul 14, 2018

Conversation

zhanghang1989
Copy link
Contributor

@zhanghang1989 zhanghang1989 commented Jun 29, 2018

Description

Adding Synchronized Batch Normalization
Thanks @eric-haibin-lin for great help!

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@zhanghang1989 zhanghang1989 requested a review from szha as a code owner June 29, 2018 23:43
@zhanghang1989 zhanghang1989 changed the title [MXNET-614] Adding Synchronized Batch Normalization [MXNET-614] [WIP] Adding Synchronized Batch Normalization Jun 30, 2018
@zhanghang1989
Copy link
Contributor Author

Help Wanted for passing the CI Test!!

@zhanghang1989 zhanghang1989 changed the title [MXNET-614] [WIP] Adding Synchronized Batch Normalization [MXNET-614] [Help Wanted for CI Test] Adding Synchronized Batch Normalization Jun 30, 2018
@zhanghang1989 zhanghang1989 changed the title [MXNET-614] [Help Wanted for CI Test] Adding Synchronized Batch Normalization [MXNET-614] Adding Synchronized Batch Normalization Jul 2, 2018
'ndev': num_devices, 'key': self.prefix}

def _get_num_devices(self):
# Caution: if not using all the GPUs, please mannually set num_devices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the warning to docstring rather than showing a comment here

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
# include <condition_variable>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space between # and include?

template<class T>
class SharedND {
private:
int nDev;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convention for variables is xxx_ for private members

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and camel for functions, which is correct right now

std::lock_guard<std::mutex> lock(mutex_);
auto it = registry_.find(key);
if (it != registry_.end()) return it->second;
T *newT = new T(ndev);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory is not released pointed by these raw pointers

Copy link
Contributor Author

@zhanghang1989 zhanghang1989 Jul 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhanghang1989
Copy link
Contributor Author

Thanks @RogerChern ! The comments in deconstruction function is really helpful.

@zhanghang1989
Copy link
Contributor Author

Finally pass the CI Test. Please take a look and let me know if you have further comments. @zhreshold @eric-haibin-lin @zhreshold @piiswrong . Thanks!

Docs are deployed here http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-11502/31/api/python/gluon/contrib.html?highlight=syncbatchnorm#mxnet.gluon.contrib.nn.SyncBatchNorm.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor suggestions

_assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]),
_find_bn(bn2).running_var.data(ctx_list[0]))
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
#print('input1.grad', input1.grad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Will do. Thx

_assert_tensor_close(input1.grad, input2grad)

def test_sync_batchnorm():
def get_num_devices():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's test_utils.list_gpus()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is slightly different. list_gpus() doesn’t consider CUDA_VISIBLE_DEVICES

@@ -1909,6 +1909,91 @@ def test_context_num_gpus():
# Test that num_gpus reports at least one GPU, as the test is run on a GPU host.
assert mx.context.num_gpus() > 0

def _check_batchnorm_result(input, num_devices=1, cuda=False):
from mxnet.gluon.utils import split_and_load
def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will assert_almost_equal do?

}

~SharedND() {
mshadow::FreeSpace(&mean_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for data_inited_ before freeing memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I Agree. Will make the changes. Thx

}
}

T* Retrieve(mshadow::Shape<1> shape, int index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need doc for these member functions

~GlobalShared() {
for (auto it = registry_.begin(); it != registry_.end(); it++) {
T *ptr = it->second;
delete ptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, you have to guarantee deleting valid pointer, since you didn't init them in the constructor, but in a public function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not inited, the map should be empty

}
~GlobalSharedRank() {
for (auto it = registry_.begin(); it != registry_.end(); it++) {
T *ptr = it->second;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not inited, the hash map should be empty

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, should be fine

mshadow::Shape2(5, mean.shape_[0]), s);
Tensor<xpu, 1> gmean = workspace[0];
Tensor<xpu, 1> gvar = workspace[1];
// Tensor<xpu, 1> tmp = workspace[2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused

@zhreshold
Copy link
Member

Comments added. The rest LGTM now.

@eric-haibin-lin eric-haibin-lin merged commit 3ae4331 into apache:master Jul 14, 2018
@eric-haibin-lin
Copy link
Member

@indhub FYI

@miteshyh
Copy link

SyncBatchNorm class doesn't seem to be available from mxnet-cu91 nightly. Its visible for regular mxnet nightly. Are these changes merged fully?

@eric-haibin-lin
Copy link
Member

@miteshyh mxnet-cu91 is for stable release. SyncBatchNorm will only appear in nightly distribution via --pre

@szha
Copy link
Member

szha commented Jul 17, 2018

@miteshyh would you be able to update and use cu92? I heard from @bhavinthaker that nvidia discontinued support for cu91 so we intend to do the same.

@miteshyh
Copy link

Thanks @szha , I down graded to cu90 as cu92 doesn't have clean support on my hardware yet, and it works.

However while I train ADE20K with GluonCV I get "socket.error: [Errno 111] Connection refused" after a few (@551) iterations, I have raised a separate issue for the same. And this happens with/without SyncBatchNorm.

dmlc/gluon-cv#215

XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* sync batch norm

* global rank and barrier

* lint

* cpplint

* pylint

* doc

* add ref

* customized barrier

* cpplint

* get rid of pthread

* address comments

* warning

* pylint

* gpu unitest

* gpu 0

* mv to cpu test

* Revert "mv to cpu test"

This reverts commit 24543c9.

* ndev = 2

* debuging

* sum prod

* lint

* contrib, ngpu

* code style

* code style

* forward backward

* test

* cpu test

* fix deconstruction

* doc indent

* doc

* doc

* address comments

* typo

* asnumpy
@jianchao-li
Copy link

jianchao-li commented Sep 24, 2018

Set Rank and Barrier in forward and backward as separate variables won't resolve the deadlock issue. I suggest instead we postfix their key parameter with "forward" and "backward".

Hello, @RogerChern. I also met a deadlock issue while training PSPNet on gluon-cv. For the "key parameter" you mentioned above, do you mean the one in this line? Could you please share more details about the fix? Thank you.

@zhanghang1989
Copy link
Contributor Author

Please set the ndev to the number of gpus used. In gluoncv, please pass the parameter --ngpus 4 if you are using 4 gpus.

@jianchao-li
Copy link

jianchao-li commented Sep 24, 2018

Hello, @zhanghang1989. Thank you for your reply. I will try it tomorrow morning and update the result with you.

Update

Hello, @zhanghang1989. I am not quite sure about whether you suggested me to explicitly set --ngpus 4. Actually I have only 4 GPUs on the machine and the default value of ngpus is len(mx.test_utils.list_gpus()), which actually returned 4 in my case. The logs of print(args) also convinced me about this.

@pengwangucla
Copy link

HI Hang, I used your sync_bn implementation for mxnet symbol. However, it reduced the performance of my network. I wonder whether you have ever tried with symbol API with your sync_bn other than gluon. Thanks

@zhanghang1989
Copy link
Contributor Author

Asked here #8458 (comment)

@ngunauj
Copy link

ngunauj commented Jul 26, 2019

How to use it?

@zhanghang1989
Copy link
Contributor Author

How to use it?

#11502 (comment)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants