From 2487f2c2544648136fe3f30050e425d9a1be28de Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Tue, 1 May 2018 23:51:47 -0700 Subject: [PATCH] Use numpy.arange in RandomSampler (#10768) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Significant speedup for large datasets: In [2]: %timeit current_sample(1529*8192) 12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit np_sample(1529*8192) 641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) --- python/mxnet/gluon/data/sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py index 66d6cfb29797..2f827c83bc4d 100644 --- a/python/mxnet/gluon/data/sampler.py +++ b/python/mxnet/gluon/data/sampler.py @@ -20,7 +20,7 @@ """Dataset sampler.""" __all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler'] -import random +import numpy as np class Sampler(object): """Base class for samplers. @@ -65,8 +65,8 @@ def __init__(self, length): self._length = length def __iter__(self): - indices = list(range(self._length)) - random.shuffle(indices) + indices = np.arange(self._length) + np.random.shuffle(indices) return iter(indices) def __len__(self):