Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix reference leak in Block._children
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed May 21, 2020
1 parent 0210ce2 commit bb1deb1
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 99 deletions.
47 changes: 0 additions & 47 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,50 +230,3 @@ def doctest(doctest_namespace):
logging.warning('Unable to import numpy/mxnet. Skipping conftest.')
import doctest
doctest.ELLIPSIS_MARKER = '-etc-'


@pytest.fixture(scope='session')
def mxnet_module():
import mxnet
return mxnet


@pytest.fixture()
# @pytest.fixture(autouse=True) # Fix all the bugs and mark this autouse=True
def check_leak_ndarray(mxnet_module):
# Collect garbage prior to running the next test
gc.collect()
# Enable gc debug mode to check if the test leaks any arrays
gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)

# Run the test
yield

# Check for leaked NDArrays
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mxnet_module.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
20 changes: 10 additions & 10 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _find_unregistered_block_in_container(data):
return True
return False
elif isinstance(data, Block):
return not data in children
return not data in (c() for c in children)
else:
return False
for k, v in self.__dict__.items():
Expand Down Expand Up @@ -425,15 +425,15 @@ def collect_params(self, select=None):
pattern = re.compile(select)
ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
for cld in self._children.values():
ret.update(cld.collect_params(select=select))
ret.update(cld().collect_params(select=select))
return ret

def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
ret = {prefix + key : val for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
ret.update(child()._collect_params_with_prefix(prefix + name))
return ret

def save_parameters(self, filename, deduplicate=False):
Expand Down Expand Up @@ -601,7 +601,7 @@ def register_child(self, block, name=None):
attributes will be registered automatically."""
if name is None:
name = str(len(self._children))
self._children[name] = block
self._children[name] = weakref.ref(block)

def register_forward_pre_hook(self, hook):
r"""Registers a forward pre-hook on the block.
Expand Down Expand Up @@ -654,7 +654,7 @@ def apply(self, fn):
this block
"""
for cld in self._children.values():
cld.apply(fn)
cld().apply(fn)
fn(self)
return self

Expand All @@ -681,7 +681,7 @@ def hybridize(self, active=True, **kwargs):
""" Please refer description of HybridBlock hybridize().
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
cld().hybridize(active, **kwargs)

def cast(self, dtype):
"""Cast this Block to use another data type.
Expand All @@ -692,7 +692,7 @@ def cast(self, dtype):
The new data type.
"""
for child in self._children.values():
child.cast(dtype)
child().cast(dtype)
for _, param in self.params.items():
param.cast(dtype)

Expand Down Expand Up @@ -736,7 +736,7 @@ def register_op_hook(self, callback, monitor_all=False):
If True, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld.register_op_hook(callback, monitor_all)
cld().register_op_hook(callback, monitor_all)

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
Expand Down Expand Up @@ -1314,8 +1314,8 @@ def c_callback(name, op_name, array):
self._callback = c_callback
self._monitor_all = monitor_all
for cld in self._children.values():
cld._callback = c_callback
cld._monitor_all = monitor_all
cld()._callback = c_callback
cld()._monitor_all = monitor_all

def __call__(self, x, *args):
if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/contrib/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, axis=-1, prefix=None, params=None):
def forward(self, x):
out = []
for block in self._children.values():
out.append(block(x))
out.append(block()(x))
out = nd.concat(*out, dim=self.axis)
return out

Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, axis=-1, prefix=None, params=None):
def hybrid_forward(self, F, x):
out = []
for block in self._children.values():
out.append(block(x))
out.append(block()(x))
out = F.concat(*out, dim=self.axis)
return out

Expand Down
28 changes: 15 additions & 13 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,17 @@ class Sequential(Block):
"""
def __init__(self, prefix=None, params=None):
super(Sequential, self).__init__(prefix=prefix, params=params)
self._layers = []

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

def forward(self, x, *args):
for block in self._children.values():
x = block(x, *args)
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
Expand All @@ -64,20 +66,19 @@ def forward(self, x, *args):
def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block.__repr__(), 2))
block=_indent(block().__repr__(), 2))
for key, block in self._children.items()])
return s.format(name=self.__class__.__name__,
modstr=modstr)
return s.format(name=self.__class__.__name__, modstr=modstr)

def __getitem__(self, key):
layers = list(self._children.values())[key]
if isinstance(layers, list):
net = type(self)(prefix=self._prefix)
with net.name_scope():
net.add(*layers)
net.add(*(l() for l in layers))
return net
else:
return layers
return layers()

def __len__(self):
return len(self._children)
Expand All @@ -93,7 +94,7 @@ def hybridize(self, active=True, **kwargs):
**kwargs : string
Additional flags for hybridized operator.
"""
if self._children and all(isinstance(c, HybridBlock) for c in self._children.values()):
if self._children and all(isinstance(c(), HybridBlock) for c in self._children.values()):
warnings.warn(
"All children of this Sequential layer '%s' are HybridBlocks. Consider "
"using HybridSequential for the best performance."%self.prefix, stacklevel=2)
Expand All @@ -114,15 +115,17 @@ class HybridSequential(HybridBlock):
"""
def __init__(self, prefix=None, params=None):
super(HybridSequential, self).__init__(prefix=prefix, params=params)
self._layers = []

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

def hybrid_forward(self, F, x, *args):
for block in self._children.values():
x = block(x, *args)
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
Expand All @@ -134,20 +137,19 @@ def hybrid_forward(self, F, x, *args):
def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block.__repr__(), 2))
block=_indent(block().__repr__(), 2))
for key, block in self._children.items()])
return s.format(name=self.__class__.__name__,
modstr=modstr)
return s.format(name=self.__class__.__name__, modstr=modstr)

def __getitem__(self, key):
layers = list(self._children.values())[key]
if isinstance(layers, list):
net = type(self)(prefix=self._prefix)
with net.name_scope():
net.add(*layers)
net.add(*(l() for l in layers))
return net
else:
return layers
return layers()

def __len__(self):
return len(self._children)
Expand Down
Loading

0 comments on commit bb1deb1

Please sign in to comment.