Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Jan 14, 2020
1 parent 9e9bb24 commit 57363b8
Showing 1 changed file with 14 additions and 42 deletions.
56 changes: 14 additions & 42 deletions datumaro/tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,11 @@ def test_can_have_project_source(self):

def test_can_batch_launch_custom_model(self):
class TestExtractor(Extractor):
def __init__(self, url, n=0):
super().__init__(length=n)
self.n = n

def __iter__(self):
for i in range(self.n):
for i in range(5):
yield DatasetItem(id=i, subset='train', image=i)

class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass

def launch(self, inputs):
for i, inp in enumerate(inputs):
yield [ LabelObject(attributes={'idx': i, 'data': inp}) ]
Expand All @@ -157,7 +150,7 @@ def launch(self, inputs):
project.env.launchers.register(launcher_name, TestLauncher)
project.add_model(model_name, { 'launcher': launcher_name })
model = project.make_executable_model(model_name)
extractor = TestExtractor('', n=5)
extractor = TestExtractor()

batch_size = 3
executor = InferenceWrapper(extractor, model, batch_size=batch_size)
Expand All @@ -171,27 +164,20 @@ def launch(self, inputs):

def test_can_do_transform_with_custom_model(self):
class TestExtractorSrc(Extractor):
def __init__(self, url, n=2):
super().__init__(length=n)
self.n = n

def __iter__(self):
for i in range(self.n):
for i in range(2):
yield DatasetItem(id=i, subset='train', image=i,
annotations=[ LabelObject(i) ])

class TestLauncher(Launcher):
def __init__(self, **kwargs):
pass

def launch(self, inputs):
for inp in inputs:
yield [ LabelObject(inp) ]

class TestConverter(Converter):
def __call__(self, extractor, save_dir):
for item in extractor:
with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f:
with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f:
f.write(str(item.subset) + '\n')
f.write(str(item.annotations[0].label) + '\n')

Expand All @@ -204,8 +190,8 @@ def __iter__(self):
for path in self.items:
with open(path, 'r') as f:
index = osp.splitext(osp.basename(path))[0]
subset = f.readline()[:-1]
label = int(f.readline()[:-1])
subset = f.readline().strip()
label = int(f.readline().strip())
assert subset == 'train'
yield DatasetItem(id=index, subset=subset,
annotations=[ LabelObject(label) ])
Expand Down Expand Up @@ -261,12 +247,8 @@ def __iter__(self):

def test_project_filter_can_be_applied(self):
class TestExtractor(Extractor):
def __init__(self, url, n=10):
super().__init__(length=n)
self.n = n

def __iter__(self):
for i in range(self.n):
for i in range(10):
yield DatasetItem(id=i, subset='train')

e_type = 'type'
Expand Down Expand Up @@ -331,30 +313,23 @@ def test_project_compound_child_can_be_modified_recursively(self):
self.assertEqual(1, len(dataset.sources['child2']))

def test_project_can_merge_item_annotations(self):
class TestExtractor(Extractor):
def __init__(self, url, v=None):
super().__init__()
self.v = v

class TestExtractor1(Extractor):
def __iter__(self):
v1_item = DatasetItem(id=1, subset='train', annotations=[
yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(2, id=3),
LabelObject(3, attributes={ 'x': 1 }),
])

v2_item = DatasetItem(id=1, subset='train', annotations=[
class TestExtractor2(Extractor):
def __iter__(self):
yield DatasetItem(id=1, subset='train', annotations=[
LabelObject(3, attributes={ 'x': 1 }),
LabelObject(4, id=4),
])

if self.v == 1:
yield v1_item
else:
yield v2_item

project = Project()
project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1))
project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2))
project.env.extractors.register('t1', TestExtractor1)
project.env.extractors.register('t2', TestExtractor2)
project.add_source('source1', { 'format': 't1' })
project.add_source('source2', { 'format': 't2' })

Expand Down Expand Up @@ -494,9 +469,6 @@ def test_can_produce_multilayer_config_from_dict(self):
class ExtractorTest(TestCase):
def test_custom_extractor_can_be_created(self):
class CustomExtractor(Extractor):
def __init__(self, url):
super().__init__()

def __iter__(self):
return iter([
DatasetItem(id=0, subset='train'),
Expand Down

0 comments on commit 57363b8

Please sign in to comment.