Skip to content

Commit

Permalink
fix batchnorm (apache#18377)
Browse files Browse the repository at this point in the history
Update basic_layers.py

fix

fix

Update basic_layers.py

fix bug
  • Loading branch information
sxjscience authored and yijunc committed Jun 9, 2020
1 parent aaaf192 commit f2869fa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
14 changes: 9 additions & 5 deletions python/mxnet/gluon/contrib/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,15 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True, use_global_stats=False, beta_initializer='zeros',
gamma_initializer='ones', running_mean_initializer='zeros',
running_variance_initializer='ones', **kwargs):
fuse_relu = False
super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats,
fuse_relu, beta_initializer, gamma_initializer,
running_mean_initializer, running_variance_initializer,
in_channels, **kwargs)
super(SyncBatchNorm, self).__init__(
axis=1, momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
use_global_stats=use_global_stats,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
running_variance_initializer=running_variance_initializer,
in_channels=in_channels, **kwargs)
num_devices = self._get_num_devices() if num_devices is None else num_devices
self._kwargs = {'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats,
Expand Down
36 changes: 20 additions & 16 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ class BatchNorm(_BatchNorm):
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
fuse_relu: False
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Expand All @@ -446,17 +445,20 @@ class BatchNorm(_BatchNorm):
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, fuse_relu=False,
use_global_stats=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
assert not fuse_relu, "Please use BatchNormReLU with Relu fusion"
super(BatchNorm, self).__init__(
axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, fuse_relu=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs)
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale,
use_global_stats=use_global_stats, fuse_relu=False,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
running_variance_initializer=running_variance_initializer,
in_channels=in_channels, **kwargs)


class BatchNormReLU(_BatchNorm):
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Expand Down Expand Up @@ -486,7 +488,6 @@ class BatchNormReLU(_BatchNorm):
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
fuse_relu: True
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Expand All @@ -508,17 +509,20 @@ class BatchNormReLU(_BatchNorm):
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, fuse_relu=True,
use_global_stats=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
assert fuse_relu, "Please use BatchNorm w/o Relu fusion"
super(BatchNormReLU, self).__init__(
axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, fuse_relu=True,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs)
axis=axis, momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
use_global_stats=use_global_stats, fuse_relu=True,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
running_variance_initializer=running_variance_initializer,
in_channels=in_channels, **kwargs)


class Embedding(HybridBlock):
r"""Turns non-negative integers (indexes/tokens) into dense vectors
Expand Down

0 comments on commit f2869fa

Please sign in to comment.