Skip to content

Commit

Permalink
Use numpy.arange in RandomSampler (apache#10768)
Browse files Browse the repository at this point in the history
Significant speedup for large datasets:

In [2]: %timeit current_sample(1529*8192)
12.3 s ± 721 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [3]: %timeit np_sample(1529*8192)
641 ms ± 6.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Loading branch information
leezu authored and zheng-da committed Jun 28, 2018
1 parent b00a445 commit 2487f2c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/gluon/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""Dataset sampler."""
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler']

import random
import numpy as np

class Sampler(object):
"""Base class for samplers.
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(self, length):
self._length = length

def __iter__(self):
indices = list(range(self._length))
random.shuffle(indices)
indices = np.arange(self._length)
np.random.shuffle(indices)
return iter(indices)

def __len__(self):
Expand Down

0 comments on commit 2487f2c

Please sign in to comment.