Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix reference cycle in gluon.Trainer #18363

Merged
merged 3 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""

import logging
import gc
import os
import random

Expand Down Expand Up @@ -229,3 +230,50 @@ def doctest(doctest_namespace):
logging.warning('Unable to import numpy/mxnet. Skipping conftest.')
import doctest
doctest.ELLIPSIS_MARKER = '-etc-'


@pytest.fixture(scope='session')
def mxnet_module():
import mxnet
return mxnet


@pytest.fixture()
# @pytest.fixture(autouse=True) # Fix all the bugs and mark this autouse=True
def check_leak_ndarray(mxnet_module):
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
# Collect garbage prior to running the next test
gc.collect()
# Enable gc debug mode to check if the test leaks any arrays
gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)

# Run the test
yield

# Check for leaked NDArrays
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mxnet_module.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
26 changes: 18 additions & 8 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from collections import OrderedDict, defaultdict
import warnings
import weakref
import numpy as np

from ..base import mx_real_t, MXNetError
Expand Down Expand Up @@ -201,12 +202,15 @@ def shape(self, new_shape):
def _set_trainer(self, trainer):
""" Set the trainer this parameter is associated with. """
# trainer cannot be replaced for sparse params
if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer:
if self._stype != 'default' and self._trainer and trainer and self._trainer() is not trainer:
raise RuntimeError(
"Failed to set the trainer for Parameter '%s' because it was already set. " \
"More than one trainers for a %s Parameter is not supported." \
%(self.name, self._stype))
self._trainer = trainer
if trainer is not None:
self._trainer = weakref.ref(trainer)
else:
self._trainer = trainer

def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
Expand Down Expand Up @@ -245,13 +249,14 @@ def _get_row_sparse(self, arr_list, ctx, row_id):
# get row sparse params based on row ids
if not isinstance(row_id, ndarray.NDArray):
raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id)))
if not self._trainer:
trainer = self._trainer() if self._trainer else None
if not trainer:
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
results = self._check_and_get(arr_list, ctx)

# fetch row sparse params from the trainer
self._trainer._row_sparse_pull(self, results, row_id)
trainer._row_sparse_pull(self, results, row_id)
return results

def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'):
Expand Down Expand Up @@ -397,7 +402,11 @@ def _reduce(self):
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
trainer = self._trainer() if self._trainer else None
if not trainer:
raise RuntimeError("Cannot reduce row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down Expand Up @@ -503,9 +512,10 @@ def set_data(self, data):
return

# if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore:
if self not in self._trainer._params_to_init:
self._trainer._reset_kvstore()
trainer = self._trainer() if self._trainer else None
if trainer and trainer._kv_initialized and trainer._update_on_kvstore:
if self not in trainer._params_to_init:
trainer._reset_kvstore()

for arr in self._check_and_get(self._data, list):
arr[:] = data
Expand Down
37 changes: 3 additions & 34 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_parameter_invalid_access():
assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)

@with_seed()
@pytest.mark.usefixtures("check_leak_ndarray")
def test_parameter_dict():
ctx = mx.cpu(1)
params0 = gluon.ParameterDict('net_')
Expand Down Expand Up @@ -3226,40 +3227,8 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
@pytest.mark.usefixtures("check_leak_ndarray")
def test_no_memory_leak_in_gluon():
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
2 changes: 1 addition & 1 deletion tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_multi_trainer():
x.initialize()
# test set trainer
trainer0 = gluon.Trainer([x], 'sgd')
assert(x._trainer is trainer0)
assert(x._trainer() is trainer0)
# test unset trainer
x._set_trainer(None)
assert(x._trainer is None)
Expand Down