Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUG] The wrong gradient of Batch Norm when grad_req = add #18499

Closed
wkcn opened this issue Jun 5, 2020 · 4 comments · Fixed by #18500
Closed

[BUG] The wrong gradient of Batch Norm when grad_req = add #18499

wkcn opened this issue Jun 5, 2020 · 4 comments · Fixed by #18500

Comments

@wkcn
Copy link
Member

wkcn commented Jun 5, 2020

Description

Hi there, we found that the current implementation of batch norm layer does support grad_req = add. If grad_req is set to add, the gradient of input data is not accumulated. Besides the gradient of gamma and beta are not assigned to any value by mistake.

To Reproduce

import mxnet as mx
from mxnet.gluon import nn

N = 1
C = 3
H = W = 2
block = nn.BatchNorm() 
block.collect_params().initialize()
block.collect_params().setattr('grad_req', 'add')

x = mx.nd.arange(N*C*H*W).reshape((N, C, H, W))
x.attach_grad()
for i in range(3):
    with mx.autograd.record():
        y = block(x)
        loss = (y * y).sum() 
    loss.backward()
print(x.grad, block.gamma.grad(), block.beta.grad())

It outputs the following message:
mxnet-2.0.0b20200421 installed by pip

[[[[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]

  [[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]

  [[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]]]
<NDArray 1x3x2x2 @cpu(0)> 
[7.999936 7.999936 7.999936]
<NDArray 3 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)>

MXNet 1.6 installed by pip --pre

[[[[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]

  [[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]

  [[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]]]
<NDArray 1x3x2x2 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)>

The correct result should be

[[[[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]],

   [[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]],

   [[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]]]]

[23.9998, 23.9998, 23.9998]

[0., 0., 0.]

The several values are the gradients of the input data, gamma, beta individually. The gradients are wrong.

Environment

mxnet-2.0.0b20200421 installed by pip
I could not run the latest version(mxnet-2.0.0b20200516) of MXNet 2.0 on my laptop since libopenblas.so.0 is not found : (

----------Python Info----------
Version      : 3.8.3
Compiler     : GCC 10.1.0
Build        : ('default', 'May 17 2020 18:15:42')
Arch         : ('64bit', 'ELF')
------------Pip Info-----------
Version      : 20.0.2
Directory    : /usr/lib/python3.8/site-packages/pip
----------MXNet Info-----------
Version      : 2.0.0
Directory    : /usr/lib/python3.8/site-packages/mxnet
Hashtag not found. Not installed from pre-built package.
----------System Info----------
Platform     : Linux-5.6.15-arch1-1-x86_64-with-glibc2.2.5
system       : Linux
node         : MiraiT
release      : 5.6.15-arch1-1
version      : #1 SMP PREEMPT Wed, 27 May 2020 23:42:26 +0000
----------Hardware Info----------
machine      : x86_64
processor    : 
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
@wkcn wkcn changed the title [BUG] The wrong gradient of Batch Norm when grad_req = True [BUG] The wrong gradient of Batch Norm when grad_req = add Jun 5, 2020
@sxjscience sxjscience added the v2.0 label Jun 5, 2020
@sxjscience
Copy link
Member

I'm tagging it as v2.0. @ciyongch How about also tagging it as v1.7? BN is a very basic layer and it's not acceptable to have bug in BN.

@wkcn
Copy link
Member Author

wkcn commented Jun 6, 2020

I'm tryining to fix it.
https://github.com/wkcn/incubator-mxnet/tree/fix_bn_when_grad_addto

@ciyongch
Copy link
Contributor

ciyongch commented Jun 6, 2020

Hi @sxjscience , I got the same result as mxnet-2.0.0b20200421 with 1.7. As it's a functionality bug exposed before the final release, I do suggest to include this fix for 1.7 as well.
@wkcn would you mind to help backport the fix to 1.7.x and 1.x when it's merged to master? Thanks!

@wkcn
Copy link
Member Author

wkcn commented Jun 6, 2020

@ciyongch I'm glad to do it : )

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants