diff --git a/delira/data_loading/sampler/__init__.py b/delira/data_loading/sampler/__init__.py index 2e4a48d7..f049f68c 100644 --- a/delira/data_loading/sampler/__init__.py +++ b/delira/data_loading/sampler/__init__.py @@ -3,7 +3,8 @@ PrevalenceSequentialSampler, StoppingPrevalenceSequentialSampler from .random_sampler import RandomSampler, PrevalenceRandomSampler, \ StoppingPrevalenceRandomSampler -from .weighted_sampler import WeightedRandomSampler +from .weighted_sampler import WeightedRandomSampler, \ + WeightedPrevalenceRandomSampler from .lambda_sampler import LambdaSampler __all__ = [ diff --git a/delira/data_loading/sampler/random_sampler.py b/delira/data_loading/sampler/random_sampler.py index 38e4ea5b..9f439410 100644 --- a/delira/data_loading/sampler/random_sampler.py +++ b/delira/data_loading/sampler/random_sampler.py @@ -116,6 +116,20 @@ def __init__(self, indices, shuffle_batch=True): @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): + """ + Classmethod to initialize the sampler from a given dataset + + Parameters + ---------- + dataset : AbstractDataset + the given dataset + + Returns + ------- + AbstractSampler + The initialized sampler + + """ indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels, **kwargs) @@ -225,6 +239,20 @@ def __init__(self, indices, shuffle_batch=True): @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): + """ + Classmethod to initialize the sampler from a given dataset + + Parameters + ---------- + dataset : AbstractDataset + the given dataset + + Returns + ------- + AbstractSampler + The initialized sampler + + """ indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels, **kwargs) diff --git a/delira/data_loading/sampler/weighted_sampler.py b/delira/data_loading/sampler/weighted_sampler.py index 2c25511b..3ed2aaab 100644 --- a/delira/data_loading/sampler/weighted_sampler.py +++ b/delira/data_loading/sampler/weighted_sampler.py @@ -2,6 +2,7 @@ from .abstract_sampler import AbstractSampler from numpy.random import choice +import numpy as np class WeightedRandomSampler(AbstractSampler): @@ -25,7 +26,7 @@ def __init__(self, indices, weights=None): super().__init__() self._indices = list(range(len(indices))) - self._weights = weight + self._weights = weights self._global_index = 0 @classmethod @@ -45,9 +46,7 @@ def from_dataset(cls, dataset: AbstractDataset, **kwargs): The initialzed sampler """ - - indices = list(range(len(dataset))) - labels = [d['label'] for d in dataset.data] + labels = [d['label'] for d in dataset] return cls(labels, **kwargs) def _get_indices(self, n_indices): @@ -68,8 +67,6 @@ def _get_indices(self, n_indices): ------ StopIteration If maximal number of samples is reached - TypeError - if weights and cum_weights are specified at the same time ValueError if weights or cum_weights don't match the population @@ -94,3 +91,28 @@ def _get_indices(self, n_indices): def __len__(self): return len(self._indices) + +class WeightedPrevalenceRandomSampler(WeightedRandomSampler): + def __int__(self, indices): + """ + Implements random Per-Class Sampling and ensures uniform sampling + of all classes + + Parameters + ---------- + indices : array-like + list of classes each sample belongs to. List index corresponds to + data index and the value at a certain index indicates the + corresponding class + """ + weights = np.array(indices) + classes, classes_count = np.unique(indices, return_counts=True) + + # compute probabilities + classes_count = classes_count / weights.shape[0] + + # generate weight matrix + for i, c in enumerate(classes): + weights[weights == c] = classes_count[i] + + super().__init__(indices, weights=weights) diff --git a/tests/data_loading/test_sampler.py b/tests/data_loading/test_sampler.py index 9fe21cb0..2ca3f222 100644 --- a/tests/data_loading/test_sampler.py +++ b/tests/data_loading/test_sampler.py @@ -5,11 +5,13 @@ SequentialSampler, \ StoppingPrevalenceRandomSampler, \ StoppingPrevalenceSequentialSampler, \ - WeightedRandomSampler + WeightedRandomSampler, \ + WeightedPrevalenceRandomSampler import numpy as np from . import DummyDataset + def test_lambda_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -26,6 +28,7 @@ def sampling_fn_b(index_list, n_indices): assert sampler_a(15) == list(range(15)) assert sampler_b(15) == list(range(len(dset) - 15, len(dset))) + def test_prevalence_random_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -49,6 +52,7 @@ def test_prevalence_random_sampler(): assert len(sampler(5)) == 5 + def test_prevalence_sequential_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -59,6 +63,7 @@ def test_prevalence_sequential_sampler(): assert len(sampler(5)) == 5 + def test_random_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -70,6 +75,7 @@ def test_random_sampler(): # checks if labels are all the same (should not happen if random sampled) assert len(set([dset[_idx]["label"] for _idx in sampler(301)])) > 1 + def test_sequential_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -84,6 +90,7 @@ def test_sequential_sampler(): # labels assert len(set([dset[_idx]["label"] for _idx in sampler(101)])) == 2 + def test_stopping_prevalence_random_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -100,6 +107,7 @@ def test_stopping_prevalence_random_sampler(): except StopIteration: assert True + def test_stopping_prevalence_sequential_sampler(): np.random.seed(1) dset = DummyDataset(600, [0.5, 0.3, 0.2]) @@ -117,6 +125,52 @@ def test_stopping_prevalence_sequential_sampler(): assert True +def test_weighted_sampler(): + np.random.seed(1) + dset = DummyDataset(600, [0.5, 0.3, 0.2]) + + sampler = WeightedRandomSampler.from_dataset(dset) + + for batch_len in [1, 2, 3]: + + equal_batch = sampler(batch_len) + + seen_labels = [] + for idx in equal_batch: + curr_label = dset[idx]["label"] + + if curr_label not in seen_labels: + seen_labels.append(curr_label) + else: + assert False, "Label already seen and labels must be unique. \ + Batch length: %d" % batch_len + + assert len(sampler(5)) == 5 + + +def test_weighted_prevalence_sampler(): + np.random.seed(1) + dset = DummyDataset(600, [0.5, 0.3, 0.2]) + + sampler = WeightedPrevalenceRandomSampler.from_dataset(dset) + + for batch_len in [1, 2, 3]: + + equal_batch = sampler(batch_len) + + seen_labels = [] + for idx in equal_batch: + curr_label = dset[idx]["label"] + + if curr_label not in seen_labels: + seen_labels.append(curr_label) + else: + assert False, "Label already seen and labels must be unique. \ + Batch length: %d" % batch_len + + assert len(sampler(5)) == 5 + + if __name__ == '__main__': test_lambda_sampler() test_prevalence_random_sampler()