diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 958dbf28155a..c30a570cb59c 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -134,18 +134,11 @@ def test_can_have_project_source(self): def test_can_batch_launch_custom_model(self): class TestExtractor(Extractor): - def __init__(self, url, n=0): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(5): yield DatasetItem(id=i, subset='train', image=i) class TestLauncher(Launcher): - def __init__(self, **kwargs): - pass - def launch(self, inputs): for i, inp in enumerate(inputs): yield [ LabelObject(attributes={'idx': i, 'data': inp}) ] @@ -157,7 +150,7 @@ def launch(self, inputs): project.env.launchers.register(launcher_name, TestLauncher) project.add_model(model_name, { 'launcher': launcher_name }) model = project.make_executable_model(model_name) - extractor = TestExtractor('', n=5) + extractor = TestExtractor() batch_size = 3 executor = InferenceWrapper(extractor, model, batch_size=batch_size) @@ -171,19 +164,12 @@ def launch(self, inputs): def test_can_do_transform_with_custom_model(self): class TestExtractorSrc(Extractor): - def __init__(self, url, n=2): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(2): yield DatasetItem(id=i, subset='train', image=i, annotations=[ LabelObject(i) ]) class TestLauncher(Launcher): - def __init__(self, **kwargs): - pass - def launch(self, inputs): for inp in inputs: yield [ LabelObject(inp) ] @@ -191,7 +177,7 @@ def launch(self, inputs): class TestConverter(Converter): def __call__(self, extractor, save_dir): for item in extractor: - with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f: + with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f: f.write(str(item.subset) + '\n') f.write(str(item.annotations[0].label) + '\n') @@ -204,8 +190,8 @@ def __iter__(self): for path in self.items: with open(path, 'r') as f: index = osp.splitext(osp.basename(path))[0] - subset = f.readline()[:-1] - label = int(f.readline()[:-1]) + subset = f.readline().strip() + label = int(f.readline().strip()) assert subset == 'train' yield DatasetItem(id=index, subset=subset, annotations=[ LabelObject(label) ]) @@ -261,12 +247,8 @@ def __iter__(self): def test_project_filter_can_be_applied(self): class TestExtractor(Extractor): - def __init__(self, url, n=10): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(10): yield DatasetItem(id=i, subset='train') e_type = 'type' @@ -331,30 +313,23 @@ def test_project_compound_child_can_be_modified_recursively(self): self.assertEqual(1, len(dataset.sources['child2'])) def test_project_can_merge_item_annotations(self): - class TestExtractor(Extractor): - def __init__(self, url, v=None): - super().__init__() - self.v = v - + class TestExtractor1(Extractor): def __iter__(self): - v1_item = DatasetItem(id=1, subset='train', annotations=[ + yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(2, id=3), LabelObject(3, attributes={ 'x': 1 }), ]) - v2_item = DatasetItem(id=1, subset='train', annotations=[ + class TestExtractor2(Extractor): + def __iter__(self): + yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(3, attributes={ 'x': 1 }), LabelObject(4, id=4), ]) - if self.v == 1: - yield v1_item - else: - yield v2_item - project = Project() - project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1)) - project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2)) + project.env.extractors.register('t1', TestExtractor1) + project.env.extractors.register('t2', TestExtractor2) project.add_source('source1', { 'format': 't1' }) project.add_source('source2', { 'format': 't2' }) @@ -494,9 +469,6 @@ def test_can_produce_multilayer_config_from_dict(self): class ExtractorTest(TestCase): def test_custom_extractor_can_be_created(self): class CustomExtractor(Extractor): - def __init__(self, url): - super().__init__() - def __iter__(self): return iter([ DatasetItem(id=0, subset='train'),