Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
switch pytest to unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Mar 21, 2019
1 parent b663d98 commit a709809
Show file tree
Hide file tree
Showing 17 changed files with 896 additions and 850 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ before_install:
# pip uninstall -y tensorflow-gpu;
# pip install tensorflow==1.12.0;
# fi
- pip install coverage
- pip install codecov
- pip install -r docs/requirements.txt

Expand All @@ -31,7 +32,7 @@ install:

# command to run tests
script:
- pytest
- coverage run -m unittest
- codecov
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.7" ]]; then
cd ./docs;
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ scikit-image>=0.14.0
scikit-learn>=0.20.0
jupyter>=1.0.0
flake8
pytest-cov
autopep8
ipython
joblib
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def find_version(file):
long_description_content_type='text/markdown',
license=license,
install_requires=requirements,
tests_require=["pytest-cov"],
tests_require=["coverage],
python_requires=">=3.5",
extras_require={
"full": requirements_extra_full,
Expand Down
36 changes: 23 additions & 13 deletions tests/data_loading/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from delira.data_loading import BaseDataLoader, SequentialSampler

import numpy as np
import numpy as np
from . import DummyDataset
import unittest

def test_data_loader():
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
sampler = SequentialSampler.from_dataset(dset)
loader = BaseDataLoader(dset, batch_size=16, sampler=sampler)

assert isinstance(loader.generate_train_batch(), dict)
for key, val in loader.generate_train_batch().items():
assert len(val) == 16
assert "label" in loader.generate_train_batch()
assert "data" in loader.generate_train_batch()
assert len(set([_tmp for _tmp in loader.generate_train_batch()["label"]])) == 1
class DataLoaderTest(unittest.TestCase):

def test_data_loader(self):
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
sampler = SequentialSampler.from_dataset(dset)
loader = BaseDataLoader(dset, batch_size=16, sampler=sampler)

self.assertIsInstance(loader.generate_train_batch(), dict)

for key, val in loader.generate_train_batch().items():
self.assertEqual(len(val), 16)

self.assertIn("label", loader.generate_train_batch())
self.assertIn("data", loader.generate_train_batch())

self.assertEqual(
len(set([_tmp for _tmp in loader.generate_train_batch()["label"]])),
1)


if __name__ == '__main__':
test_data_loader()
unittest.main()
42 changes: 24 additions & 18 deletions tests/data_loading/test_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,36 @@
from delira.data_loading import BaseDataLoader, SequentialSampler
from batchgenerators.dataloading import MultiThreadedAugmenter

def test_base_datamanager():
import unittest

batch_size = 16
class DataManagerTest(unittest.TestCase):

np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])
def test_base_datamanager(self):

manager = BaseDataManager(dset, batch_size, n_process_augmentation=1,
transforms=None)
batch_size = 16

assert isinstance(manager.get_batchgen(), MultiThreadedAugmenter)
np.random.seed(1)
dset = DummyDataset(600, [0.5, 0.3, 0.2])

# create batch manually
data, labels = [], []
for i in range(batch_size):
data.append(dset[i]["data"])
labels.append(dset[i]["label"])
manager = BaseDataManager(dset, batch_size, n_process_augmentation=1,
transforms=None)

batch_dict = {"data": np.asarray(data), "label": np.asarray(labels)}
self.assertIsInstance(manager.get_batchgen(), MultiThreadedAugmenter)

for key, val in next(manager.get_batchgen()).items():
assert (val == batch_dict[key]).all()
for key, val in next(manager.get_batchgen()).items():
assert len(val) == batch_size
# create batch manually
data, labels = [], []
for i in range(batch_size):
data.append(dset[i]["data"])
labels.append(dset[i]["label"])

batch_dict = {"data": np.asarray(data), "label": np.asarray(labels)}

for key, val in next(manager.get_batchgen()).items():
self.assertTrue((val == batch_dict[key]).all())

for key, val in next(manager.get_batchgen()).items():
self.assertEqual(len(val), batch_size)


if __name__ == '__main__':
test_base_datamanager()
unittest.main()
113 changes: 61 additions & 52 deletions tests/data_loading/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,78 @@

from delira.data_loading import ConcatDataset, BaseCacheDataset

def test_data_subset_concat():
def load_dummy_sample(path, label_load_fct):
"""
Returns dummy data, independent of path or label_load_fct
Parameters
----------
path
label_load_fct
Returns
-------
: dict
dict with data and label
"""
return {'data': np.random.rand(1, 256, 256),
'label': np.random.randint(2)}

class DummyCacheDataset(BaseCacheDataset):
def __init__(self, num: int, label_load_fct, *args, **kwargs):
import unittest


class DataSubsetConcatTest(unittest.TestCase):

def test_data_subset_concat(self):

def load_dummy_sample(path, label_load_fct):
"""
Generates random samples with _make_dataset
Returns dummy data, independent of path or label_load_fct
Parameters
----------
num : int
number of random samples
args :
passed to BaseCacheDataset
kwargs :
passed to BaseCacheDataset
path
label_load_fct
Returns
-------
: dict
dict with data and label
"""
self.label_load_fct = label_load_fct
super().__init__(data_path=num, *args, **kwargs)
return {'data': np.random.rand(1, 256, 256),
'label': np.random.randint(2)}

class DummyCacheDataset(BaseCacheDataset):
def __init__(self, num: int, label_load_fct, *args, **kwargs):
"""
Generates random samples with _make_dataset
Parameters
----------
num : int
number of random samples
args :
passed to BaseCacheDataset
kwargs :
passed to BaseCacheDataset
"""
self.label_load_fct = label_load_fct
super().__init__(data_path=num, *args, **kwargs)

def _make_dataset(self, path):
data = []
for i in range(path):
data.append(self._load_fn(i, self.label_load_fct))
return data

dset_a = DummyCacheDataset(500, None, load_fn=load_dummy_sample,
img_extensions=[], gt_extensions=[])
dset_b = DummyCacheDataset(700, None, load_fn=load_dummy_sample,
img_extensions=[], gt_extensions=[])

# test concatenating
concat_dataset = ConcatDataset(dset_a, dset_b)

def _make_dataset(self, path):
data = []
for i in range(path):
data.append(self._load_fn(i, self.label_load_fct))
return data
self.assertEqual(len(concat_dataset), len(dset_a) + len(dset_b))

dset_a = DummyCacheDataset(500, None, load_fn=load_dummy_sample,
img_extensions=[], gt_extensions=[])
dset_b = DummyCacheDataset(700, None, load_fn=load_dummy_sample,
img_extensions=[], gt_extensions=[])
self.assertTrue(concat_dataset[0])

# test concatenating
concat_dataset = ConcatDataset(dset_a, dset_b)
assert len(concat_dataset) == (len(dset_a) + len(dset_b))
# test slicing:
half_len_a = len(dset_a) // 2
half_len_b = len(dset_b) // 2

assert concat_dataset[0]
self.assertEqual(len(dset_a.get_subset(range(half_len_a))), half_len_a)
self.assertEqual(len(dset_b.get_subset(range(half_len_b))), half_len_b)

# test slicing:
half_len_a = len(dset_a) // 2
half_len_b = len(dset_b) // 2
assert len(dset_a.get_subset(range(half_len_a))) == half_len_a
assert len(dset_b.get_subset(range(half_len_b))) == half_len_b
sliced_concat_set = concat_dataset.get_subset(
range(half_len_a + half_len_b))

sliced_concat_set = concat_dataset.get_subset(
range(half_len_a + half_len_b))
assert len(sliced_concat_set) == (half_len_a + half_len_b)
self.assertEqual(len(sliced_concat_set), half_len_a + half_len_b)

# check if entries are valid
assert sliced_concat_set[0]
# check if entries are valid
self.assertTrue(sliced_concat_set[0])


if __name__ == "__main__":
test_data_subset_concat()
unittest.main()
Loading

0 comments on commit a709809

Please sign in to comment.