diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index eafe73651dbc..5910bf91578e 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -268,6 +268,11 @@ def _init_default(self, name, _): '"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \ 'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name) + def __eq__(self, other): + if not isinstance(other, Initializer): + return NotImplemented + # pylint: disable=unidiomatic-typecheck + return type(self) is type(other) and self._kwargs == other._kwargs # pylint: disable=invalid-name _register = registry.get_register_func(Initializer, 'initializer') diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index da8dba7ce476..5d15b27fa7ea 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -3119,6 +3119,21 @@ def forward(self, x): shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1) block(mx.nd.ones(shape)) +def test_shared_parameters_with_non_default_initializer(): + class MyBlock(gluon.HybridBlock): + def __init__(self, **kwargs): + super(MyBlock, self).__init__(**kwargs) + + with self.name_scope(): + self.param = self.params.get("param", shape=(1, ), init=mx.init.Constant(-10.0)) + + bl = MyBlock() + bl2 = MyBlock(params=bl.collect_params()) + assert bl.param is bl2.param + bl3 = MyBlock() + assert bl.param is not bl3.param + assert bl.param.init == bl3.param.init + @with_seed() def test_reqs_switching_training_inference(): class Foo(gluon.HybridBlock):