diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 18063aa761e1..8282c93a6f6d 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -23,10 +23,8 @@ import threading import copy import warnings -import weakref -from collections import OrderedDict, defaultdict - import re +from collections import OrderedDict, defaultdict import numpy as np from ..base import mx_real_t, MXNetError @@ -48,7 +46,7 @@ class _BlockScope(object): _current = threading.local() def __init__(self, block): - self._block = weakref.ref(block) if block is not None else None + self._block = block self._counter = {} self._old_scope = None self._name_scope = None @@ -57,8 +55,7 @@ def __init__(self, block): def create(prefix, params, hint): """Creates prefix and params for new `Block`.""" current = getattr(_BlockScope._current, "value", None) - block = current._block() if current is not None else None - if current is None or block is None: + if current is None: if prefix is None: if not hasattr(_name.NameManager._current, "value"): _name.NameManager._current.value = _name.NameManager() @@ -74,25 +71,23 @@ def create(prefix, params, hint): prefix = '%s%d_'%(hint, count) current._counter[hint] = count + 1 if params is None: - parent = block.params + parent = current._block.params params = ParameterDict(parent.prefix+prefix, parent._shared) else: params = ParameterDict(params.prefix, params) - return block.prefix + prefix, params + return current._block.prefix+prefix, params def __enter__(self): - block = self._block() - if block is None or block._empty_prefix: + if self._block._empty_prefix: return self self._old_scope = getattr(_BlockScope._current, "value", None) _BlockScope._current.value = self - self._name_scope = _name.Prefix(block.prefix) + self._name_scope = _name.Prefix(self._block.prefix) self._name_scope.__enter__() return self def __exit__(self, ptype, value, trace): - block = self._block() - if block is None or block._empty_prefix: + if self._block._empty_prefix: return self._name_scope.__exit__(ptype, value, trace) self._name_scope = None diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 98b606d58394..42252d52be2b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -17,7 +17,6 @@ import os import tempfile -import gc import mxnet as mx from mxnet import gluon @@ -3230,44 +3229,6 @@ def hybrid_forward(self, F, x): mx.test_utils.assert_almost_equal(grad1, grad2) -def test_no_memory_leak_in_gluon(): - # Collect all other garbage prior to this test. Otherwise the test may fail - # due to unrelated memory leaks. - gc.collect() - - gc_flags = gc.get_debug() - gc.set_debug(gc.DEBUG_SAVEALL) - net = mx.gluon.nn.Dense(10, in_units=10) - net.initialize() - del net - gc.collect() - gc.set_debug(gc_flags) # reset gc flags - - # Check for leaked NDArrays - seen = set() - def has_array(element): - try: - if element in seen: - return False - seen.add(element) - except TypeError: # unhashable - pass - - if isinstance(element, mx.nd._internal.NDArrayBase): - return True - elif hasattr(element, '__dict__'): - return any(has_array(x) for x in vars(element)) - elif isinstance(element, dict): - return any(has_array(x) for x in element.items()) - else: - try: - return any(has_array(x) for x in element) - except (TypeError, KeyError): - return False - - assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' - del gc.garbage[:] - if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 50ecb064b04e..f0e3c660181b 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -124,9 +124,8 @@ def __init__(self, prefix): status = [False] event = threading.Event() def f(): - net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block - with block._BlockScope(net): - x = NameManager.current.get(None, "hello") + with block._BlockScope(dummy_block("spawned_")): + x= NameManager.current.get(None, "hello") event.wait() if x == "spawned_hello0": status[0] = True