diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index bc7c3ce19e09..3c48a74435b1 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -220,11 +220,15 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, center=True, scale=True, use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones', running_mean_initializer='zeros', running_variance_initializer='ones', **kwargs): - fuse_relu = False - super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats, - fuse_relu, beta_initializer, gamma_initializer, - running_mean_initializer, running_variance_initializer, - in_channels, **kwargs) + super(SyncBatchNorm, self).__init__( + axis=1, momentum=momentum, epsilon=epsilon, + center=center, scale=scale, + use_global_stats=use_global_stats, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + running_mean_initializer=running_mean_initializer, + running_variance_initializer=running_variance_initializer, + in_channels=in_channels, **kwargs) num_devices = self._get_num_devices() if num_devices is None else num_devices self._kwargs = {'eps': epsilon, 'momentum': momentum, 'fix_gamma': not scale, 'use_global_stats': use_global_stats, diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 70b0a71841f1..72230fe25795 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -410,7 +410,6 @@ class BatchNorm(_BatchNorm): If True, use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. If False, use local batch-norm. - fuse_relu: False beta_initializer: str or `Initializer`, default 'zeros' Initializer for the beta weight. gamma_initializer: str or `Initializer`, default 'ones' @@ -432,17 +431,20 @@ class BatchNorm(_BatchNorm): - **out**: output tensor with the same shape as `data`. """ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, - use_global_stats=False, fuse_relu=False, + use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones', running_mean_initializer='zeros', running_variance_initializer='ones', in_channels=0, **kwargs): - assert not fuse_relu, "Please use BatchNormReLU with Relu fusion" super(BatchNorm, self).__init__( - axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, - use_global_stats=False, fuse_relu=False, - beta_initializer='zeros', gamma_initializer='ones', - running_mean_initializer='zeros', running_variance_initializer='ones', - in_channels=0, **kwargs) + axis=axis, momentum=momentum, epsilon=epsilon, center=center, + scale=scale, + use_global_stats=use_global_stats, fuse_relu=False, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + running_mean_initializer=running_mean_initializer, + running_variance_initializer=running_variance_initializer, + in_channels=in_channels, **kwargs) + class BatchNormReLU(_BatchNorm): """Batch normalization layer (Ioffe and Szegedy, 2014). @@ -472,7 +474,6 @@ class BatchNormReLU(_BatchNorm): If True, use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. If False, use local batch-norm. - fuse_relu: True beta_initializer: str or `Initializer`, default 'zeros' Initializer for the beta weight. gamma_initializer: str or `Initializer`, default 'ones' @@ -494,17 +495,20 @@ class BatchNormReLU(_BatchNorm): - **out**: output tensor with the same shape as `data`. """ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, - use_global_stats=False, fuse_relu=True, + use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones', running_mean_initializer='zeros', running_variance_initializer='ones', in_channels=0, **kwargs): - assert fuse_relu, "Please use BatchNorm w/o Relu fusion" super(BatchNormReLU, self).__init__( - axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, - use_global_stats=False, fuse_relu=True, - beta_initializer='zeros', gamma_initializer='ones', - running_mean_initializer='zeros', running_variance_initializer='ones', - in_channels=0, **kwargs) + axis=axis, momentum=momentum, epsilon=epsilon, + center=center, scale=scale, + use_global_stats=use_global_stats, fuse_relu=True, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + running_mean_initializer=running_mean_initializer, + running_variance_initializer=running_variance_initializer, + in_channels=in_channels, **kwargs) + class Embedding(HybridBlock): r"""Turns non-negative integers (indexes/tokens) into dense vectors