Skip to content

Commit

Permalink
【Hackathon 5th No.24】Add SubsetRandomSampler -part (PaddlePaddle#57726)
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll authored Nov 8, 2023
1 parent ed28804 commit 9e2fae9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sampler,
SequenceSampler,
Subset,
SubsetRandomSampler,
TensorDataset,
WeightedRandomSampler,
get_worker_info,
Expand All @@ -48,4 +49,5 @@
'WeightedRandomSampler',
'random_split',
'Subset',
'SubsetRandomSampler',
]
1 change: 1 addition & 0 deletions python/paddle/io/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .sampler import SequenceSampler
from .sampler import RandomSampler
from .sampler import WeightedRandomSampler
from .sampler import SubsetRandomSampler
43 changes: 43 additions & 0 deletions python/paddle/io/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from ...framework import core
from ...tensor import randperm


class Sampler:
Expand Down Expand Up @@ -340,3 +341,45 @@ def __iter__(self):
def __len__(self):
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
return self.num_samples * mul


class SubsetRandomSampler(Sampler):
r"""
Randomly sample elements from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
Examples:
.. code-block:: python
>>> import paddle
>>> from paddle.io import SubsetRandomSampler
>>> paddle.seed(2023)
>>> sampler = SubsetRandomSampler(indices=[1, 3, 5, 7, 9])
>>> for index in sampler:
... print(index)
9
3
7
5
1
"""

def __init__(self, indices):
if len(indices) == 0:
raise ValueError(
"The length of `indices` in SubsetRandomSampler should be greater than 0."
)
self.indices = indices

def __iter__(self):
for i in randperm(len(self.indices)):
yield self.indices[i]

def __len__(self) -> int:
return len(self.indices)
24 changes: 24 additions & 0 deletions test/legacy_test/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest

import numpy as np
Expand All @@ -22,6 +23,7 @@
RandomSampler,
Sampler,
SequenceSampler,
SubsetRandomSampler,
WeightedRandomSampler,
)

Expand Down Expand Up @@ -110,6 +112,28 @@ def test_with_generator_num_samples(self):
assert tuple(sorted(rets)) == tuple(range(0, 50))


class TestSubsetRandomSampler(unittest.TestCase):
def test_main(self):
indices = list(range(100))
random.shuffle(indices)
indices = indices[:30]
sampler = SubsetRandomSampler(indices)
assert len(sampler) == len(indices)

hints = {i: 0 for i in indices}
for index in iter(sampler):
hints[index] += 1
for h in hints.values():
assert h == 1

def test_raise(self):
try:
sampler = SubsetRandomSampler([])
self.assertTrue(False)
except ValueError:
self.assertTrue(True)


class TestBatchSampler(unittest.TestCase):
def setUp(self):
self.num_samples = 1000
Expand Down

0 comments on commit 9e2fae9

Please sign in to comment.