Skip to content

Commit

Permalink
Fix memory leaks in Gluon (apache#18328) (apache#18359)
Browse files Browse the repository at this point in the history
Fix leak of ndarray objects in the frontend due to reference cycle.

Backport of 3e676fc
  • Loading branch information
leezu authored and ChaiBapchya committed Aug 15, 2020
1 parent 7c9c9fc commit ee8c3c8
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
21 changes: 13 additions & 8 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import threading
import copy
import warnings
import re
import weakref
from collections import OrderedDict, defaultdict

import re
import numpy as np

from ..base import mx_real_t, MXNetError
Expand All @@ -46,7 +48,7 @@ class _BlockScope(object):
_current = threading.local()

def __init__(self, block):
self._block = block
self._block = weakref.ref(block) if block is not None else None
self._counter = {}
self._old_scope = None
self._name_scope = None
Expand All @@ -55,7 +57,8 @@ def __init__(self, block):
def create(prefix, params, hint):
"""Creates prefix and params for new `Block`."""
current = getattr(_BlockScope._current, "value", None)
if current is None:
block = current._block() if current is not None else None
if current is None or block is None:
if prefix is None:
if not hasattr(_name.NameManager._current, "value"):
_name.NameManager._current.value = _name.NameManager()
Expand All @@ -71,23 +74,25 @@ def create(prefix, params, hint):
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = current._block.params
parent = block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
return current._block.prefix+prefix, params
return block.prefix + prefix, params

def __enter__(self):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope = _name.Prefix(block.prefix)
self._name_scope.__enter__()
return self

def __exit__(self, ptype, value, trace):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import tempfile
import gc

import mxnet as mx
from mxnet import gluon
Expand Down Expand Up @@ -3212,6 +3213,44 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]

if __name__ == '__main__':
import nose
nose.runmodule()
5 changes: 3 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def __init__(self, prefix):
status = [False]
event = threading.Event()
def f():
with block._BlockScope(dummy_block("spawned_")):
x= NameManager.current.get(None, "hello")
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
with block._BlockScope(net):
x = NameManager.current.get(None, "hello")
event.wait()
if x == "spawned_hello0":
status[0] = True
Expand Down

0 comments on commit ee8c3c8

Please sign in to comment.