Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Added WeightedPrevalenceSampler for uniform sampling of all classes
Browse files Browse the repository at this point in the history
  • Loading branch information
mibaumgartner committed Mar 14, 2019
1 parent b663d98 commit 433cccb
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 8 deletions.
3 changes: 2 additions & 1 deletion delira/data_loading/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
PrevalenceSequentialSampler, StoppingPrevalenceSequentialSampler
from .random_sampler import RandomSampler, PrevalenceRandomSampler, \
StoppingPrevalenceRandomSampler
from .weighted_sampler import WeightedRandomSampler
from .weighted_sampler import WeightedRandomSampler, \
WeightedPrevalenceRandomSampler
from .lambda_sampler import LambdaSampler

__all__ = [
Expand Down
28 changes: 28 additions & 0 deletions delira/data_loading/sampler/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ def __init__(self, indices, shuffle_batch=True):

@classmethod
def from_dataset(cls, dataset: AbstractDataset, **kwargs):
"""
Classmethod to initialize the sampler from a given dataset
Parameters
----------
dataset : AbstractDataset
the given dataset
Returns
-------
AbstractSampler
The initialized sampler
"""
indices = range(len(dataset))
labels = [dataset[idx]['label'] for idx in indices]
return cls(labels, **kwargs)
Expand Down Expand Up @@ -225,6 +239,20 @@ def __init__(self, indices, shuffle_batch=True):

@classmethod
def from_dataset(cls, dataset: AbstractDataset, **kwargs):
"""
Classmethod to initialize the sampler from a given dataset
Parameters
----------
dataset : AbstractDataset
the given dataset
Returns
-------
AbstractSampler
The initialized sampler
"""
indices = range(len(dataset))
labels = [dataset[idx]['label'] for idx in indices]
return cls(labels, **kwargs)
Expand Down
34 changes: 28 additions & 6 deletions delira/data_loading/sampler/weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .abstract_sampler import AbstractSampler

from numpy.random import choice
import numpy as np


class WeightedRandomSampler(AbstractSampler):
Expand All @@ -25,7 +26,7 @@ def __init__(self, indices, weights=None):
super().__init__()

self._indices = list(range(len(indices)))
self._weights = weight
self._weights = weights
self._global_index = 0

@classmethod
Expand All @@ -45,9 +46,7 @@ def from_dataset(cls, dataset: AbstractDataset, **kwargs):
The initialzed sampler
"""

indices = list(range(len(dataset)))
labels = [d['label'] for d in dataset.data]
labels = [d['label'] for d in dataset]
return cls(labels, **kwargs)

def _get_indices(self, n_indices):
Expand All @@ -68,8 +67,6 @@ def _get_indices(self, n_indices):
------
StopIteration
If maximal number of samples is reached
TypeError
if weights and cum_weights are specified at the same time
ValueError
if weights or cum_weights don't match the population
Expand All @@ -94,3 +91,28 @@ def _get_indices(self, n_indices):
def __len__(self):
return len(self._indices)


class WeightedPrevalenceRandomSampler(WeightedRandomSampler):
def __int__(self, indices):
"""
Implements random Per-Class Sampling and ensures uniform sampling
of all classes
Parameters
----------
indices : array-like
list of classes each sample belongs to. List index corresponds to
data index and the value at a certain index indicates the
corresponding class
"""
weights = np.array(indices)
classes, classes_count = np.unique(indices, return_counts=True)

# compute probabilities
classes_count = classes_count / weights.shape[0]

# generate weight matrix
for i, c in enumerate(classes):
weights[weights == c] = classes_count[i]

super().__init__(indices, weights=weights)
56 changes: 55 additions & 1 deletion tests/data_loading/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
SequentialSampler, \
StoppingPrevalenceRandomSampler, \
StoppingPrevalenceSequentialSampler, \
WeightedRandomSampler
WeightedRandomSampler, \
WeightedPrevalenceRandomSampler

import numpy as np
from . import DummyDataset


def test_lambda_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -26,6 +28,7 @@ def sampling_fn_b(index_list, n_indices):
assert sampler_a(15) == list(range(15))
assert sampler_b(15) == list(range(len(dset) - 15, len(dset)))


def test_prevalence_random_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -49,6 +52,7 @@ def test_prevalence_random_sampler():

assert len(sampler(5)) == 5


def test_prevalence_sequential_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -59,6 +63,7 @@ def test_prevalence_sequential_sampler():

assert len(sampler(5)) == 5


def test_random_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -70,6 +75,7 @@ def test_random_sampler():
# checks if labels are all the same (should not happen if random sampled)
assert len(set([dset[_idx]["label"] for _idx in sampler(301)])) > 1


def test_sequential_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -84,6 +90,7 @@ def test_sequential_sampler():
# labels
assert len(set([dset[_idx]["label"] for _idx in sampler(101)])) == 2


def test_stopping_prevalence_random_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -100,6 +107,7 @@ def test_stopping_prevalence_random_sampler():
except StopIteration:
assert True


def test_stopping_prevalence_sequential_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
Expand All @@ -117,6 +125,52 @@ def test_stopping_prevalence_sequential_sampler():
assert True


def test_weighted_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])

sampler = WeightedRandomSampler.from_dataset(dset)

for batch_len in [1, 2, 3]:

equal_batch = sampler(batch_len)

seen_labels = []
for idx in equal_batch:
curr_label = dset[idx]["label"]

if curr_label not in seen_labels:
seen_labels.append(curr_label)
else:
assert False, "Label already seen and labels must be unique. \
Batch length: %d" % batch_len

assert len(sampler(5)) == 5


def test_weighted_prevalence_sampler():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])

sampler = WeightedPrevalenceRandomSampler.from_dataset(dset)

for batch_len in [1, 2, 3]:

equal_batch = sampler(batch_len)

seen_labels = []
for idx in equal_batch:
curr_label = dset[idx]["label"]

if curr_label not in seen_labels:
seen_labels.append(curr_label)
else:
assert False, "Label already seen and labels must be unique. \
Batch length: %d" % batch_len

assert len(sampler(5)) == 5


if __name__ == '__main__':
test_lambda_sampler()
test_prevalence_random_sampler()
Expand Down

0 comments on commit 433cccb

Please sign in to comment.