Skip to content

Commit

Permalink
make AspectRatioGrouped dataset re-iterable
Browse files Browse the repository at this point in the history
Summary:
* Re-iterate AspectRatioGrouped dataset may cause memory leak. This PR fixes the issue. This will close facebookresearch#3847.

Pull Request resolved: facebookresearch#3849

Reviewed By: wat3rBro

Differential Revision: D33392659

Pulled By: zhanghang1989

fbshipit-source-id: 3f81438ea45e2f35f5818a319044543f66e0e8db
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jan 7, 2022
1 parent bb96d0b commit 424cfae
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
5 changes: 4 additions & 1 deletion detectron2/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,8 @@ def __iter__(self):
bucket = self._buckets[bucket_id]
bucket.append(d)
if len(bucket) == self.batch_size:
yield bucket[:]
data = bucket[:]
# Clear bucket first, because code after yield is not
# guaranteed to execute
del bucket[:]
yield data
18 changes: 18 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.data.common import AspectRatioGroupedDataset
from detectron2.data.samplers import InferenceSampler, TrainingSampler


Expand Down Expand Up @@ -73,6 +74,23 @@ def test_pickleability(self):
self.assertEqual(ds[0], 2)


class TestAspectRatioGrouping(unittest.TestCase):
def test_reiter_leak(self):
data = [(1, 0), (0, 1), (1, 0), (0, 1)]
data = [{"width": a, "height": b} for (a, b) in data]
batchsize = 2
dataset = AspectRatioGroupedDataset(data, batchsize)

for _ in range(5):
for idx, __ in enumerate(dataset):
if idx == 1:
# manually break, so the iterator does not stop by itself
break
# check that bucket sizes are valid
for bucket in dataset._buckets:
self.assertLess(len(bucket), batchsize)


class TestDataLoader(unittest.TestCase):
def _get_kwargs(self):
# get kwargs of build_detection_train_loader
Expand Down

0 comments on commit 424cfae

Please sign in to comment.