Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Added dataset iterator and corresponding functionality to unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
mibaumgartner committed Apr 18, 2019
1 parent 8bf25f2 commit 70d6232
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
38 changes: 38 additions & 0 deletions delira/data_loading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def __len__(self):
"""
return len(self.data)

def __iter__(self):
"""
Return an iterator for the dataset
Returns
-------
object
a single sample
"""
return _DatasetIter(self)

def get_sample_from_index(self, index):
"""
Returns the data sample for a given index
Expand Down Expand Up @@ -173,6 +184,33 @@ def train_test_split(self, *args, **kwargs):
return self.get_subset(train_idxs), self.get_subset(test_idxs)


class _DatasetIter(object):
"""
Iterator for dataset
"""
def __init__(self, dset):
"""
Parameters
----------
dset: :class: `AbstractDataset`
the dataset which should be iterated
"""
self._dset = dset
self._curr_index = 0

def __iter__(self):
return self

def __next__(self):
if self._curr_index > len(self._dset):
raise StopIteration

sample = self._dset[self._curr_index]
self._curr_index += 1
return sample


class BlankDataset(AbstractDataset):
"""
Blank Dataset loading the data, which has been passed
Expand Down
30 changes: 30 additions & 0 deletions tests/data_loading/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def load_mul_sample(path):
except:
raise AssertionError('Dataset access failed.')

try:
j = 0
for i in dataset:
assert 'data' in i
assert 'label' in i
j += 1
assert j == len(dataset)
except:
raise AssertionError('Dataset iteration failed.')

# test extend cache dataset
dataset = BaseExtendCacheDataset(paths, load_mul_sample)
assert len(dataset) == 40
Expand All @@ -108,6 +118,16 @@ def load_mul_sample(path):
except:
raise AssertionError('Dataset access failed.')

try:
j = 0
for i in dataset:
assert 'data' in i
assert 'label' in i
j += 1
assert j == len(dataset)
except:
raise AssertionError('Dataset iteration failed.')


def test_lazy_dataset():
# test lazy dataset
Expand All @@ -121,6 +141,16 @@ def test_lazy_dataset():
except:
raise AssertionError('Dataset access failed.')

try:
j = 0
for i in dataset:
assert 'data' in i
assert 'label' in i
j += 1
assert j == len(dataset)
except:
raise AssertionError('Dataset iteration failed.')


def test_load_sample():
def load_dummy_label(path):
Expand Down

0 comments on commit 70d6232

Please sign in to comment.