Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use numpy.arange in RandomSampler (#10768)
Browse files Browse the repository at this point in the history
Significant speedup for large datasets:

In [2]: %timeit current_sample(1529*8192)
12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [3]: %timeit np_sample(1529*8192)
641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Loading branch information
leezu authored and piiswrong committed May 2, 2018
1 parent 1822dac commit 23934cf
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/gluon/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""Dataset sampler."""
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler']

import random
import numpy as np

class Sampler(object):
"""Base class for samplers.
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(self, length):
self._length = length

def __iter__(self):
indices = list(range(self._length))
random.shuffle(indices)
indices = np.arange(self._length)
np.random.shuffle(indices)
return iter(indices)

def __len__(self):
Expand Down

0 comments on commit 23934cf

Please sign in to comment.