From 3e676fc2c88bec75e4463c8fa9b5532664d518c2 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 15 May 2020 10:00:56 -0700 Subject: [PATCH] Fix memory leaks in Gluon (#18328) Fix leak of ndarray objects in the frontend due to reference cycle. --- python/mxnet/gluon/block.py | 25 ++++++++------ tests/python/unittest/test_gluon.py | 38 ++++++++++++++++++++++ tests/python/unittest/test_thread_local.py | 5 +-- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 27428e3191b8..6d9ea9acb314 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -23,8 +23,10 @@ import threading import copy import warnings -import re +import weakref from collections import OrderedDict, defaultdict + +import re import numpy as np from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str @@ -48,7 +50,7 @@ class _BlockScope(object): _current = threading.local() def __init__(self, block): - self._block = block + self._block = weakref.ref(block) if block is not None else None self._counter = {} self._old_scope = None self._name_scope = None @@ -60,7 +62,8 @@ def create(prefix, params, hint): The profiler scope is to support the GPU memory profiler. """ current = getattr(_BlockScope._current, "value", None) - if current is None: + block = current._block() if current is not None else None + if current is None or block is None: if prefix is None: if not hasattr(_name.NameManager._current, "value"): _name.NameManager._current.value = _name.NameManager() @@ -79,29 +82,31 @@ def create(prefix, params, hint): prefix = '%s%d_'%(hint, count) current._counter[hint] = count + 1 if params is None: - parent = current._block.params + parent = block.params params = ParameterDict(parent.prefix+prefix, parent._shared) else: params = ParameterDict(params.prefix, params) # replace the trailing underscore with colon profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \ else prefix) + ":" - return current._block.prefix + prefix, params, \ - current._block._profiler_scope_name + profiler_scope_name + return block.prefix + prefix, params, \ + block._profiler_scope_name + profiler_scope_name def __enter__(self): - if self._block._empty_prefix: + block = self._block() + if block is None or block._empty_prefix: return self self._old_scope = getattr(_BlockScope._current, "value", None) _BlockScope._current.value = self - self._name_scope = _name.Prefix(self._block.prefix) + self._name_scope = _name.Prefix(block.prefix) self._name_scope.__enter__() - self._profiler_scope = _profiler.Scope(self._block._profiler_scope_name) + self._profiler_scope = _profiler.Scope(block._profiler_scope_name) self._profiler_scope.__enter__() return self def __exit__(self, ptype, value, trace): - if self._block._empty_prefix: + block = self._block() + if block is None or block._empty_prefix: return self._name_scope.__exit__(ptype, value, trace) self._name_scope = None diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 3bfe4f30ef4b..587be268deff 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -16,6 +16,7 @@ # under the License. import os +import gc import mxnet as mx from mxnet import gluon @@ -3229,3 +3230,40 @@ def hybrid_forward(self, F, x): mx.test_utils.assert_almost_equal(grad1, grad2) +def test_no_memory_leak_in_gluon(): + # Collect all other garbage prior to this test. Otherwise the test may fail + # due to unrelated memory leaks. + gc.collect() + + gc_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + net = mx.gluon.nn.Dense(10, in_units=10) + net.initialize() + del net + gc.collect() + gc.set_debug(gc_flags) # reset gc flags + + # Check for leaked NDArrays + seen = set() + def has_array(element): + try: + if element in seen: + return False + seen.add(element) + except TypeError: # unhashable + pass + + if isinstance(element, mx.nd._internal.NDArrayBase): + return True + elif hasattr(element, '__dict__'): + return any(has_array(x) for x in vars(element)) + elif isinstance(element, dict): + return any(has_array(x) for x in element.items()) + else: + try: + return any(has_array(x) for x in element) + except (TypeError, KeyError): + return False + + assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' + del gc.garbage[:] diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 7e875c8fb835..5423249b0ee6 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -125,8 +125,9 @@ def __init__(self, prefix): status = [False] event = threading.Event() def f(): - with block._BlockScope(dummy_block("spawned_")): - x= NameManager.current.get(None, "hello") + net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block + with block._BlockScope(net): + x = NameManager.current.get(None, "hello") event.wait() if x == "spawned_hello0": status[0] = True