diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 8d9a1909f07ca..6a0f6e05a37d5 100755 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -25,6 +25,7 @@ Sampler, SequenceSampler, Subset, + SubsetRandomSampler, TensorDataset, WeightedRandomSampler, get_worker_info, @@ -48,4 +49,5 @@ 'WeightedRandomSampler', 'random_split', 'Subset', + 'SubsetRandomSampler', ] diff --git a/python/paddle/io/dataloader/__init__.py b/python/paddle/io/dataloader/__init__.py index bb65463f70afc..aff32fd70de49 100644 --- a/python/paddle/io/dataloader/__init__.py +++ b/python/paddle/io/dataloader/__init__.py @@ -29,3 +29,4 @@ from .sampler import SequenceSampler from .sampler import RandomSampler from .sampler import WeightedRandomSampler +from .sampler import SubsetRandomSampler diff --git a/python/paddle/io/dataloader/sampler.py b/python/paddle/io/dataloader/sampler.py index 44bc545f777cd..f6bb2e41b4b8f 100644 --- a/python/paddle/io/dataloader/sampler.py +++ b/python/paddle/io/dataloader/sampler.py @@ -15,6 +15,7 @@ import numpy as np from ...framework import core +from ...tensor import randperm class Sampler: @@ -340,3 +341,45 @@ def __iter__(self): def __len__(self): mul = np.prod(self.weights.shape) // self.weights.shape[-1] return self.num_samples * mul + + +class SubsetRandomSampler(Sampler): + r""" + Randomly sample elements from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + + Examples: + + .. code-block:: python + + >>> import paddle + >>> from paddle.io import SubsetRandomSampler + + >>> paddle.seed(2023) + >>> sampler = SubsetRandomSampler(indices=[1, 3, 5, 7, 9]) + + >>> for index in sampler: + ... print(index) + 9 + 3 + 7 + 5 + 1 + + """ + + def __init__(self, indices): + if len(indices) == 0: + raise ValueError( + "The length of `indices` in SubsetRandomSampler should be greater than 0." + ) + self.indices = indices + + def __iter__(self): + for i in randperm(len(self.indices)): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) diff --git a/test/legacy_test/test_batch_sampler.py b/test/legacy_test/test_batch_sampler.py index 72ea1577beb53..750a916b3b29a 100644 --- a/test/legacy_test/test_batch_sampler.py +++ b/test/legacy_test/test_batch_sampler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import unittest import numpy as np @@ -22,6 +23,7 @@ RandomSampler, Sampler, SequenceSampler, + SubsetRandomSampler, WeightedRandomSampler, ) @@ -110,6 +112,28 @@ def test_with_generator_num_samples(self): assert tuple(sorted(rets)) == tuple(range(0, 50)) +class TestSubsetRandomSampler(unittest.TestCase): + def test_main(self): + indices = list(range(100)) + random.shuffle(indices) + indices = indices[:30] + sampler = SubsetRandomSampler(indices) + assert len(sampler) == len(indices) + + hints = {i: 0 for i in indices} + for index in iter(sampler): + hints[index] += 1 + for h in hints.values(): + assert h == 1 + + def test_raise(self): + try: + sampler = SubsetRandomSampler([]) + self.assertTrue(False) + except ValueError: + self.assertTrue(True) + + class TestBatchSampler(unittest.TestCase): def setUp(self): self.num_samples = 1000