diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 2923f7dc4e1e..09f0a698e999 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -156,35 +156,31 @@ def convert_attrs(label, cvat_attrs): for tag_obj in cvat_anno.tags: anno_group = tag_obj.group - if isinstance(anno_group, int): - anno_group = anno_group anno_label = map_label(tag_obj.label) anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes) - anno = datumaro.LabelObject(label=anno_label, + anno = datumaro.Label(label=anno_label, attributes=anno_attr, group=anno_group) item_anno.append(anno) for shape_obj in cvat_anno.labeled_shapes: anno_group = shape_obj.group - if isinstance(anno_group, int): - anno_group = anno_group anno_label = map_label(shape_obj.label) anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes) anno_points = shape_obj.points if shape_obj.type == ShapeType.POINTS: - anno = datumaro.PointsObject(anno_points, + anno = datumaro.Points(anno_points, label=anno_label, attributes=anno_attr, group=anno_group) elif shape_obj.type == ShapeType.POLYLINE: - anno = datumaro.PolyLineObject(anno_points, + anno = datumaro.PolyLine(anno_points, label=anno_label, attributes=anno_attr, group=anno_group) elif shape_obj.type == ShapeType.POLYGON: - anno = datumaro.PolygonObject(anno_points, + anno = datumaro.Polygon(anno_points, label=anno_label, attributes=anno_attr, group=anno_group) elif shape_obj.type == ShapeType.RECTANGLE: x0, y0, x1, y1 = anno_points - anno = datumaro.BboxObject(x0, y0, x1 - x0, y1 - y0, + anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0, label=anno_label, attributes=anno_attr, group=anno_group) else: raise Exception("Unknown shape type '%s'" % (shape_obj.type)) diff --git a/cvat/apps/dataset_manager/export_templates/extractors/cvat_rest_api_task_images.py b/cvat/apps/dataset_manager/export_templates/plugins/cvat_rest_api_task_images.py similarity index 96% rename from cvat/apps/dataset_manager/export_templates/extractors/cvat_rest_api_task_images.py rename to cvat/apps/dataset_manager/export_templates/plugins/cvat_rest_api_task_images.py index f6d5da6bfcc3..c7cf8fbbd220 100644 --- a/cvat/apps/dataset_manager/export_templates/extractors/cvat_rest_api_task_images.py +++ b/cvat/apps/dataset_manager/export_templates/plugins/cvat_rest_api_task_images.py @@ -28,7 +28,7 @@ 'server_port': 80 }, schema=CONFIG_SCHEMA, mutable=False) -class cvat_rest_api_task_images(datumaro.Extractor): +class cvat_rest_api_task_images(datumaro.SourceExtractor): def _image_local_path(self, item_id): task_id = self._config.task_id return osp.join(self._cache_dir, @@ -53,7 +53,7 @@ def _connect(self): session = None try: - print("Enter credentials for '%s:%s':" % \ + print("Enter credentials for '%s:%s' to read task data:" % \ (self._config.server_host, self._config.server_port)) username = input('User: ') password = getpass.getpass() diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index 7c361a6133f6..db2b37c0a3bb 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -15,13 +15,13 @@ import django_rq from cvat.apps.engine.log import slogger -from cvat.apps.engine.models import Task, ShapeType +from cvat.apps.engine.models import Task from .util import current_function_name, make_zip_archive _CVAT_ROOT_DIR = __file__[:__file__.rfind('cvat/')] _DATUMARO_REPO_PATH = osp.join(_CVAT_ROOT_DIR, 'datumaro') sys.path.append(_DATUMARO_REPO_PATH) -from datumaro.components.project import Project +from datumaro.components.project import Project, Environment import datumaro.components.extractor as datumaro from .bindings import CvatImagesDirExtractor, CvatTaskExtractor @@ -132,83 +132,7 @@ def _generate_categories(self): return categories def put_annotations(self, annotations): - patch = {} - - categories = self._dataset.categories() - label_cat = categories[datumaro.AnnotationType.label] - - label_map = {} - attr_map = {} - db_labels = self._db_task.label_set.all() - for db_label in db_labels: - label_map[db_label.id] = label_cat.find(db_label.name) - - db_attributes = db_label.attributespec_set.all() - for db_attr in db_attributes: - attr_map[(db_label.id, db_attr.id)] = db_attr.name - map_label = lambda label_db_id: label_map[label_db_id] - map_attr = lambda label_db_id, attr_db_id: \ - attr_map[(label_db_id, attr_db_id)] - - for tag_obj in annotations['tags']: - item_id = str(tag_obj['frame']) - item_anno = patch.get(item_id, []) - - anno_group = tag_obj['group'] - if isinstance(anno_group, int): - anno_group = [anno_group] - anno_label = map_label(tag_obj['label_id']) - anno_attr = {} - for attr in tag_obj['attributes']: - attr_name = map_attr(tag_obj['label_id'], attr['id']) - anno_attr[attr_name] = attr['value'] - - anno = datumaro.LabelObject(label=anno_label, - attributes=anno_attr, group=anno_group) - item_anno.append(anno) - - patch[item_id] = item_anno - - for shape_obj in annotations['shapes']: - item_id = str(shape_obj['frame']) - item_anno = patch.get(item_id, []) - - anno_group = shape_obj['group'] - if isinstance(anno_group, int): - anno_group = [anno_group] - anno_label = map_label(shape_obj['label_id']) - anno_attr = {} - for attr in shape_obj['attributes']: - attr_name = map_attr(shape_obj['label_id'], attr['id']) - anno_attr[attr_name] = attr['value'] - - anno_points = shape_obj['points'] - if shape_obj['type'] == ShapeType.POINTS: - anno = datumaro.PointsObject(anno_points, - label=anno_label, attributes=anno_attr, group=anno_group) - elif shape_obj['type'] == ShapeType.POLYLINE: - anno = datumaro.PolyLineObject(anno_points, - label=anno_label, attributes=anno_attr, group=anno_group) - elif shape_obj['type'] == ShapeType.POLYGON: - anno = datumaro.PolygonObject(anno_points, - label=anno_label, attributes=anno_attr, group=anno_group) - elif shape_obj['type'] == ShapeType.RECTANGLE: - x0, y0, x1, y1 = anno_points - anno = datumaro.BboxObject(x0, y0, x1 - x0, y1 - y0, - label=anno_label, attributes=anno_attr, group=anno_group) - else: - raise Exception("Unknown shape type '%s'" % (shape_obj['type'])) - - item_anno.append(anno) - - patch[item_id] = item_anno - - # TODO: support track annotations - - patch = [datumaro.DatasetItem(id=id_, annotations=anno) \ - for id_, ann in patch.items()] - - self._dataset.update(patch) + raise NotImplementedError() def save(self, save_dir=None, save_images=False): if self._dataset is not None: @@ -296,10 +220,10 @@ def _remote_export(self, save_dir, server_url=None): osp.join(templates_dir, 'README.md'), osp.join(target_dir, 'README.md')) - templates_dir = osp.join(templates_dir, 'extractors') + templates_dir = osp.join(templates_dir, 'plugins') target_dir = osp.join(target_dir, exported_project.config.env_dir, - exported_project.env.config.extractors_dir) + exported_project.config.plugins_dir) os.makedirs(target_dir, exist_ok=True) shutil.copyfile( osp.join(templates_dir, _TASK_IMAGES_REMOTE_EXTRACTOR + '.py'), @@ -409,9 +333,9 @@ def clear_export_cache(task_id, file_path, file_ctime): ] def get_export_formats(): - from datumaro.components import converters + converters = Environment().converters - available_formats = set(name for name, _ in converters.items) + available_formats = set(converters.items) available_formats.add(EXPORT_FORMAT_DATUMARO_PROJECT) public_formats = [] diff --git a/datumaro/datumaro/cli/__main__.py b/datumaro/datumaro/cli/__main__.py index 0ed611d050dd..b68de2267409 100644 --- a/datumaro/datumaro/cli/__main__.py +++ b/datumaro/datumaro/cli/__main__.py @@ -5,6 +5,8 @@ import argparse import logging as log +import logging.handlers +import os import sys from . import contexts, commands @@ -81,15 +83,70 @@ def make_parser(): return parser -def set_up_logger(args): - log.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', - level=args.loglevel) +class _LogManager: + _LOGLEVEL_ENV_NAME = '_DATUMARO_INIT_LOGLEVEL' + _BUFFER_SIZE = 1000 + _root = None + _init_handler = None + _default_handler = None + + @classmethod + def init_basic_logger(cls): + base_loglevel = os.getenv(cls._LOGLEVEL_ENV_NAME, 'info') + base_loglevel = loglevel(base_loglevel) + root = log.getLogger() + root.setLevel(base_loglevel) + + # NOTE: defer use of this handler until the logger + # is properly initialized, but keep logging enabled before this. + # Store messages obtained during initialization and print them after + # if necessary. + default_handler = log.StreamHandler() + default_handler.setFormatter( + log.Formatter('%(asctime)s %(levelname)s: %(message)s')) + + init_handler = logging.handlers.MemoryHandler(cls._BUFFER_SIZE, + target=default_handler) + root.addHandler(init_handler) + + cls._root = root + cls._init_handler = init_handler + cls._default_handler = default_handler + + @classmethod + def set_up_logger(cls, level): + log.getLogger().setLevel(level) + + if cls._init_handler: + # NOTE: Handlers are not capable of filtering with loglevel + # despite a level can be set for a handler. The level is checked + # by Logger. However, handler filters are checked at handler level. + class LevelFilter: + def __init__(self, level): + super().__init__() + self.level = level + + def filter(self, record): + return record.levelno >= self.level + filt = LevelFilter(level) + cls._default_handler.addFilter(filt) + + cls._root.removeHandler(cls._init_handler) + cls._init_handler.close() + del cls._init_handler + cls._init_handler = None + + cls._default_handler.removeFilter(filt) + + cls._root.addHandler(cls._default_handler) def main(args=None): + _LogManager.init_basic_logger() + parser = make_parser() args = parser.parse_args(args) - set_up_logger(args) + _LogManager.set_up_logger(args.loglevel) if 'command' not in args: parser.print_help() diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index 0ba03461a2c9..4b2ede1c639b 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -10,10 +10,11 @@ import os.path as osp import shutil -from datumaro.components.project import Project +from datumaro.components.project import Project, Environment from datumaro.components.comparator import Comparator from datumaro.components.dataset_filter import DatasetItemEncoder from datumaro.components.extractor import AnnotationType +from datumaro.components.cli_plugin import CliPlugin from .diff import DiffVisualizer from ...util import add_subparser, CliException, MultilineFormatter from ...util.project import make_project_path, load_project, \ @@ -75,8 +76,7 @@ def create_command(args): return 0 def build_import_parser(parser_ctor=argparse.ArgumentParser): - import datumaro.components.importers as importers_module - builtin_importers = [name for name, cls in importers_module.items] + builtins = sorted(Environment().importers.items) parser = parser_ctor(help="Create project from existing dataset", description=""" @@ -104,7 +104,7 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): /.datumaro/extractors and /.datumaro/importers.|n |n - List of supported dataset formats: %s|n + List of builtin dataset formats: %s|n |n Examples:|n - Create a project from VOC dataset in the current directory:|n @@ -112,7 +112,7 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): |n - Create a project from COCO dataset in other directory:|n |s|simport -f coco -i path/to/coco -o path/I/like/ - """ % ', '.join(builtin_importers), + """ % ', '.join(builtins), formatter_class=MultilineFormatter) parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', @@ -129,8 +129,8 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): help="Path to import project from") parser.add_argument('-f', '--format', required=True, help="Source project format") - # parser.add_argument('extra_args', nargs=argparse.REMAINDER, - # help="Additional arguments for importer (pass '-- -h' for help)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, + help="Additional arguments for importer (pass '-- -h' for help)") parser.set_defaults(command=import_command) return parser @@ -155,11 +155,21 @@ def import_command(args): if project_name is None: project_name = osp.basename(project_dir) + extra_args = {} + try: + env = Environment() + importer = env.importers.get(args.format) + if hasattr(importer, 'from_cmdline'): + extra_args = importer.from_cmdline(args.extra_args) + except KeyError: + raise CliException("Importer for format '%s' is not found" % \ + args.format) + log.info("Importing project from '%s' as '%s'" % \ (args.source, args.format)) source = osp.abspath(args.source) - project = Project.import_from(source, args.format) + project = importer(source, **extra_args) project.config.project_name = project_name project.config.project_dir = project_dir @@ -217,8 +227,7 @@ def list_options(cls): return [m.name.replace('_', '+') for m in cls] def build_export_parser(parser_ctor=argparse.ArgumentParser): - import datumaro.components.converters as converters_module - builtin_converters = [name for name, cls in converters_module.items] + builtins = sorted(Environment().converters.items) parser = parser_ctor(help="Export project", description=""" @@ -237,7 +246,7 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): To do this, you need to put a Converter definition script to /.datumaro/converters.|n |n - List of supported dataset formats: %s|n + List of builtin dataset formats: %s|n |n Examples:|n - Export project as a VOC-like dataset, include images:|n @@ -245,7 +254,7 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): |n - Export project as a COCO-like dataset in other directory:|n |s|sexport -f coco -o path/I/like/ - """ % ', '.join(builtin_converters), + """ % ', '.join(builtins), formatter_class=MultilineFormatter) parser.add_argument('-e', '--filter', default=None, @@ -282,8 +291,10 @@ def export_command(args): dst_dir = osp.abspath(dst_dir) try: - converter = project.env.make_converter(args.format, - cmdline_args=args.extra_args) + converter = project.env.converters.get(args.format) + if hasattr(converter, 'from_cmdline'): + extra_args = converter.from_cmdline(args.extra_args) + converter = converter(**extra_args) except KeyError: raise CliException("Converter for format '%s' is not found" % \ args.format) @@ -494,8 +505,7 @@ def diff_command(args): second_project.config.project_name) ) dst_dir = osp.abspath(dst_dir) - if dst_dir: - log.info("Saving diff to '%s'" % dst_dir) + log.info("Saving diff to '%s'" % dst_dir) visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator, output_format=args.format) @@ -506,13 +516,19 @@ def diff_command(args): return 0 def build_transform_parser(parser_ctor=argparse.ArgumentParser): + builtins = sorted(Environment().transforms.items) + parser = parser_ctor(help="Transform project", description=""" Applies some operation to dataset items in the project - and produces a new project. - - [NOT IMPLEMENTED YET] - """, + and produces a new project.|n + |n + Builtin transforms: %s|n + |n + Examples:|n + - Convert instance polygons to masks:|n + |s|stransform -n polygons_to_masks + """ % ', '.join(builtins), formatter_class=MultilineFormatter) parser.add_argument('-t', '--transform', required=True, @@ -523,30 +539,46 @@ def build_transform_parser(parser_ctor=argparse.ArgumentParser): help="Overwrite existing files in the save directory") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, + help="Additional arguments for transformation (pass '-- -h' for help)") parser.set_defaults(command=transform_command) return parser def transform_command(args): - raise NotImplementedError("Not implemented yet.") + project = load_project(args.project_dir) - # project = load_project(args.project_dir) + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to force creation)" % dst_dir) + else: + dst_dir = generate_next_dir_name('%s-transform' % \ + project.config.project_name) + dst_dir = osp.abspath(dst_dir) - # dst_dir = args.dst_dir - # if dst_dir: - # if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): - # raise CliException("Directory '%s' already exists " - # "(pass --overwrite to force creation)" % dst_dir) - # dst_dir = osp.abspath(args.dst_dir) + extra_args = {} + try: + transform = project.env.transforms.get(args.transform) + if hasattr(transform, 'from_cmdline'): + extra_args = transform.from_cmdline(args.extra_args) + except KeyError: + raise CliException("Transform '%s' is not found" % args.transform) + + log.info("Loading the project...") + dataset = project.make_dataset() - # project.make_dataset().transform_project( - # method=args.transform, - # save_dir=dst_dir - # ) + log.info("Transforming the project...") + dataset.transform_project( + method=transform, + save_dir=dst_dir, + **extra_args + ) - # log.info("Transform results saved to '%s'" % dst_dir) + log.info("Transform results have been saved to '%s'" % dst_dir) - # return 0 + return 0 def build_info_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Get project info", diff --git a/datumaro/datumaro/cli/contexts/source/__init__.py b/datumaro/datumaro/cli/contexts/source/__init__.py index b20be3de3f86..94734265e7d1 100644 --- a/datumaro/datumaro/cli/contexts/source/__init__.py +++ b/datumaro/datumaro/cli/contexts/source/__init__.py @@ -9,13 +9,13 @@ import os.path as osp import shutil +from datumaro.components.project import Environment from ...util import add_subparser, CliException, MultilineFormatter from ...util.project import load_project def build_add_parser(parser_ctor=argparse.ArgumentParser): - import datumaro.components.extractors as extractors_module - extractors_list = [name for name, cls in extractors_module.items] + builtins = sorted(Environment().extractors.items) base_parser = argparse.ArgumentParser(add_help=False) base_parser.add_argument('-n', '--name', default=None, @@ -53,7 +53,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): To do this, you need to put an Extractor definition script to /.datumaro/extractors.|n |n - List of supported source formats: %s|n + List of builtin source formats: %s|n |n Examples:|n - Add a local directory with VOC-like dataset:|n @@ -61,7 +61,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): - Add a local file with CVAT annotations, call it 'mysource'|n |s|s|s|sto the project somewhere else:|n |s|sadd path path/to/cvat.xml -f cvat -n mysource -p somewhere/else/ - """ % ('%(prog)s SOURCE_TYPE --help', ', '.join(extractors_list)), + """ % ('%(prog)s SOURCE_TYPE --help', ', '.join(builtins)), formatter_class=MultilineFormatter, add_help=False) parser.set_defaults(command=add_command) diff --git a/datumaro/datumaro/components/cli_plugin.py b/datumaro/datumaro/components/cli_plugin.py new file mode 100644 index 000000000000..08a7f3834cc9 --- /dev/null +++ b/datumaro/datumaro/components/cli_plugin.py @@ -0,0 +1,56 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from datumaro.cli.util import MultilineFormatter + + +class CliPlugin: + @staticmethod + def _get_name(cls): + return getattr(cls, 'NAME', + remove_plugin_type(to_snake_case(cls.__name__))) + + @staticmethod + def _get_doc(cls): + return getattr(cls, '__doc__', "") + + @classmethod + def build_cmdline_parser(cls, **kwargs): + args = { + 'prog': cls._get_name(cls), + 'description': cls._get_doc(cls), + 'formatter_class': MultilineFormatter, + } + args.update(kwargs) + + return argparse.ArgumentParser(**args) + + @classmethod + def from_cmdline(cls, args=None): + if args and args[0] == '--': + args = args[1:] + parser = cls.build_cmdline_parser() + args = parser.parse_args(args) + return vars(args) + +def remove_plugin_type(s): + for t in {'transform', 'extractor', 'converter', 'launcher', 'importer'}: + s = s.replace('_' + t, '') + return s + +def to_snake_case(s): + if not s: + return '' + + name = [s[0].lower()] + for char in s[1:]: + if char.isalpha() and char.isupper(): + name.append('_') + name.append(char.lower()) + else: + name.append(char) + return ''.join(name) \ No newline at end of file diff --git a/datumaro/datumaro/components/config_model.py b/datumaro/datumaro/components/config_model.py index d21d3393b0da..9bce725ebd7b 100644 --- a/datumaro/datumaro/components/config_model.py +++ b/datumaro/datumaro/components/config_model.py @@ -30,40 +30,22 @@ def __init__(self, config=None): super().__init__(config, schema=MODEL_SCHEMA) -ENV_SCHEMA = _SchemaBuilder() \ - .add('models_dir', str) \ - .add('importers_dir', str) \ - .add('launchers_dir', str) \ - .add('converters_dir', str) \ - .add('extractors_dir', str) \ - \ - .add('models', lambda: _DefaultConfig( - lambda v=None: Model(v))) \ - .build() - -ENV_DEFAULT_CONFIG = Config({ - 'models_dir': 'models', - 'importers_dir': 'importers', - 'launchers_dir': 'launchers', - 'converters_dir': 'converters', - 'extractors_dir': 'extractors', -}, mutable=False, schema=ENV_SCHEMA) - - PROJECT_SCHEMA = _SchemaBuilder() \ .add('project_name', str) \ .add('format_version', int) \ \ - .add('sources_dir', str) \ - .add('dataset_dir', str) \ - .add('build_dir', str) \ .add('subsets', list) \ .add('sources', lambda: _DefaultConfig( lambda v=None: Source(v))) \ + .add('models', lambda: _DefaultConfig( + lambda v=None: Model(v))) \ \ + .add('models_dir', str, internal=True) \ + .add('plugins_dir', str, internal=True) \ + .add('sources_dir', str, internal=True) \ + .add('dataset_dir', str, internal=True) \ .add('project_filename', str, internal=True) \ .add('project_dir', str, internal=True) \ - .add('env_filename', str, internal=True) \ .add('env_dir', str, internal=True) \ .build() @@ -73,10 +55,10 @@ def __init__(self, config=None): 'sources_dir': 'sources', 'dataset_dir': 'dataset', - 'build_dir': 'build', + 'models_dir': 'models', + 'plugins_dir': 'plugins', 'project_filename': 'config.yaml', 'project_dir': '', - 'env_filename': 'datumaro.yaml', 'env_dir': '.datumaro', }, mutable=False, schema=PROJECT_SCHEMA) diff --git a/datumaro/datumaro/components/converters/__init__.py b/datumaro/datumaro/components/converters/__init__.py deleted file mode 100644 index 0991ed29543a..000000000000 --- a/datumaro/datumaro/components/converters/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from datumaro.components.converters.datumaro import DatumaroConverter - -from datumaro.components.converters.coco import ( - CocoConverter, - CocoImageInfoConverter, - CocoCaptionsConverter, - CocoInstancesConverter, - CocoPersonKeypointsConverter, - CocoLabelsConverter, -) - -from datumaro.components.converters.voc import ( - VocConverter, - VocClassificationConverter, - VocDetectionConverter, - VocLayoutConverter, - VocActionConverter, - VocSegmentationConverter, -) - -from datumaro.components.converters.yolo import YoloConverter -from datumaro.components.converters.tfrecord import DetectionApiConverter -from datumaro.components.converters.cvat import CvatConverter - - -items = [ - ('datumaro', DatumaroConverter), - - ('coco', CocoConverter), - ('coco_images', CocoImageInfoConverter), - ('coco_captions', CocoCaptionsConverter), - ('coco_instances', CocoInstancesConverter), - ('coco_person_kp', CocoPersonKeypointsConverter), - ('coco_labels', CocoLabelsConverter), - - ('voc', VocConverter), - ('voc_cls', VocClassificationConverter), - ('voc_det', VocDetectionConverter), - ('voc_segm', VocSegmentationConverter), - ('voc_action', VocActionConverter), - ('voc_layout', VocLayoutConverter), - - ('yolo', YoloConverter), - - ('tf_detection_api', DetectionApiConverter), - - ('cvat', CvatConverter), -] diff --git a/datumaro/datumaro/components/dataset_filter.py b/datumaro/datumaro/components/dataset_filter.py index 73c7ce812ad5..5037331f07ae 100644 --- a/datumaro/datumaro/components/dataset_filter.py +++ b/datumaro/datumaro/components/dataset_filter.py @@ -6,8 +6,7 @@ from lxml import etree as ET # NOTE: lxml has proper XPath implementation from datumaro.components.extractor import (DatasetItem, Extractor, Annotation, AnnotationType, - LabelObject, MaskObject, PointsObject, PolygonObject, - PolyLineObject, BboxObject, CaptionObject, + Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, ) @@ -100,7 +99,7 @@ def encode_bbox_object(cls, obj, categories): ET.SubElement(ann_elem, 'y').text = str(obj.y) ET.SubElement(ann_elem, 'w').text = str(obj.w) ET.SubElement(ann_elem, 'h').text = str(obj.h) - ET.SubElement(ann_elem, 'area').text = str(obj.area()) + ET.SubElement(ann_elem, 'area').text = str(obj.get_area()) return ann_elem @@ -191,19 +190,19 @@ def encode_caption_object(cls, obj): @classmethod def encode_annotation(cls, o, categories=None): - if isinstance(o, LabelObject): + if isinstance(o, Label): return cls.encode_label_object(o, categories) - if isinstance(o, MaskObject): + if isinstance(o, Mask): return cls.encode_mask_object(o, categories) - if isinstance(o, BboxObject): + if isinstance(o, Bbox): return cls.encode_bbox_object(o, categories) - if isinstance(o, PointsObject): + if isinstance(o, Points): return cls.encode_points_object(o, categories) - if isinstance(o, PolyLineObject): + if isinstance(o, PolyLine): return cls.encode_polyline_object(o, categories) - if isinstance(o, PolygonObject): + if isinstance(o, Polygon): return cls.encode_polygon_object(o, categories) - if isinstance(o, CaptionObject): + if isinstance(o, Caption): return cls.encode_caption_object(o) raise NotImplementedError("Unexpected annotation object passed: %s" % o) diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index afc221ac0c34..b6d7be0cb2c1 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -35,7 +35,9 @@ def __init__(self, id=None, type=None, attributes=None, group=None): attributes = dict(attributes) self.attributes = attributes - if group is not None: + if group is None: + group = 0 + else: group = int(group) self.group = group # pylint: enable=redefined-builtin @@ -95,6 +97,8 @@ def add(self, name, parent=None, attributes=None): attributes = set(attributes) for attr in attributes: assert isinstance(attr, str) + if parent is None: + parent = '' index = len(self.items) self.items.append(self.Category(name, parent, attributes)) @@ -112,12 +116,15 @@ def __eq__(self, other): return \ (self.items == other.items) -class LabelObject(Annotation): +class Label(Annotation): # pylint: disable=redefined-builtin def __init__(self, label=None, id=None, attributes=None, group=None): super().__init__(id=id, type=AnnotationType.label, attributes=attributes, group=group) + + if label is not None: + label = int(label) self.label = label # pylint: enable=redefined-builtin @@ -157,54 +164,60 @@ def __eq__(self, other): return False return True -class MaskObject(Annotation): +class Mask(Annotation): # pylint: disable=redefined-builtin def __init__(self, image=None, label=None, z_order=None, id=None, attributes=None, group=None): - super().__init__(id=id, type=AnnotationType.mask, - attributes=attributes, group=group) + super().__init__(type=AnnotationType.mask, + id=id, attributes=attributes, group=group) + self._image = image + + if label is not None: + label = int(label) self._label = label if z_order is None: z_order = 0 + else: + z_order = int(z_order) self._z_order = z_order # pylint: enable=redefined-builtin - @property - def label(self): - return self._label - @property def image(self): if callable(self._image): return self._image() return self._image - def painted_data(self, colormap): - raise NotImplementedError() + @property + def label(self): + return self._label - def area(self): - if self._label is None: - raise NotImplementedError() - return np.count_nonzero(self.image) + @property + def z_order(self): + return self._z_order - def extract(self, class_id): - raise NotImplementedError() + def as_class_mask(self, label_id=None): + from datumaro.util.mask_tools import make_index_mask + if label_id is None: + label_id = self.label + return make_index_mask(self.image, label_id) + + def as_instance_mask(self, instance_id): + from datumaro.util.mask_tools import make_index_mask + return make_index_mask(self.image, instance_id) + + def get_area(self): + return np.count_nonzero(self.image) def get_bbox(self): - if self._label is None: - raise NotImplementedError() - image = self.image - cols = np.any(image, axis=0) - rows = np.any(image, axis=1) - x0, x1 = np.where(cols)[0][[0, -1]] - y0, y1 = np.where(rows)[0][[0, -1]] - return [x0, y0, x1 - x0, y1 - y0] + from datumaro.util.mask_tools import find_mask_bbox + return find_mask_bbox(self.image) - @property - def z_order(self): - return self._z_order + def paint(self, colormap): + from datumaro.util.mask_tools import paint_mask + return paint_mask(self.as_class_mask(), colormap) def __eq__(self, other): if not super().__eq__(other): @@ -215,7 +228,7 @@ def __eq__(self, other): (self.image is not None and other.image is not None and \ np.array_equal(self.image, other.image)) -class RleMask(MaskObject): +class RleMask(Mask): # pylint: disable=redefined-builtin def __init__(self, rle=None, label=None, z_order=None, id=None, attributes=None, group=None): @@ -231,11 +244,11 @@ def _lazy_decode(rle): from pycocotools import mask as mask_utils return lambda: mask_utils.decode(rle).astype(np.bool) - def area(self): + def get_area(self): from pycocotools import mask as mask_utils return mask_utils.area(self._rle) - def bbox(self): + def get_bbox(self): from pycocotools import mask as mask_utils return mask_utils.toBbox(self._rle) @@ -248,6 +261,73 @@ def __eq__(self, other): return super().__eq__(other) return self._rle == other._rle +class CompiledMask: + @staticmethod + def from_instance_masks(instance_masks, + instance_ids=None, instance_labels=None): + from datumaro.util.mask_tools import merge_masks + + if instance_ids is not None: + assert len(instance_ids) == len(instance_masks) + else: + instance_ids = [1 + i for i in range(len(instance_masks))] + + if instance_labels is not None: + assert len(instance_labels) == len(instance_masks) + else: + instance_labels = [None] * len(instance_masks) + + instance_masks = sorted(instance_masks, key=lambda m: m.z_order) + + instance_mask = [m.as_instance_mask(id) for m, id in + zip(instance_masks, instance_ids)] + instance_mask = merge_masks(instance_mask) + + cls_mask = [m.as_class_mask(c) for m, c in + zip(instance_masks, instance_labels)] + cls_mask = merge_masks(cls_mask) + return __class__(class_mask=cls_mask, instance_mask=instance_mask) + + def __init__(self, class_mask=None, instance_mask=None): + self._class_mask = class_mask + self._instance_mask = instance_mask + + @staticmethod + def _get_image(image): + if callable(image): + return image() + return image + + @property + def class_mask(self): + return self._get_image(self._class_mask) + + @property + def instance_mask(self): + return self._get_image(self._instance_mask) + + @property + def instance_count(self): + return int(self.instance_mask.max()) + + def get_instance_labels(self, class_count=None): + if class_count is None: + class_count = np.max(self.class_mask) + 1 + + m = self.class_mask * class_count + self.instance_mask + m = m.astype(int) + keys = np.unique(m) + instance_labels = {k % class_count: k // class_count + for k in keys if k % class_count != 0 + } + return instance_labels + + def extract(self, instance_id): + return self.instance_mask == instance_id + + def lazy_extract(self, instance_id): + return lambda: self.extract(instance_id) + def compute_iou(bbox_a, bbox_b): aX, aY, aW, aH = bbox_a bX, bY, bW, bH = bbox_b @@ -266,29 +346,43 @@ def compute_iou(bbox_a, bbox_b): return intersection / max(1.0, union) -class ShapeObject(Annotation): +class _Shape(Annotation): # pylint: disable=redefined-builtin def __init__(self, type, points=None, label=None, z_order=None, id=None, attributes=None, group=None): super().__init__(id=id, type=type, attributes=attributes, group=group) - self.points = points - self.label = label + self._points = points + + if label is not None: + label = int(label) + self._label = label if z_order is None: z_order = 0 + else: + z_order = int(z_order) self._z_order = z_order # pylint: enable=redefined-builtin - def area(self): - raise NotImplementedError() + @property + def points(self): + return self._points + + @property + def label(self): + return self._label - def get_polygon(self): + @property + def z_order(self): + return self._z_order + + def get_area(self): raise NotImplementedError() def get_bbox(self): - points = self.get_points() - if not self.points: + points = self.points + if not points: return None xs = [p for p in points[0::2]] @@ -299,22 +393,15 @@ def get_bbox(self): y1 = max(ys) return [x0, y0, x1 - x0, y1 - y0] - def get_points(self): - return self.points - - @property - def z_order(self): - return self._z_order - def __eq__(self, other): if not super().__eq__(other): return False return \ - (self.points == other.points) and \ + (np.array_equal(self.points, other.points)) and \ (self.z_order == other.z_order) and \ (self.label == other.label) -class PolyLineObject(ShapeObject): +class PolyLine(_Shape): # pylint: disable=redefined-builtin def __init__(self, points=None, label=None, z_order=None, id=None, attributes=None, group=None): @@ -323,35 +410,34 @@ def __init__(self, points=None, label=None, z_order=None, id=id, attributes=attributes, group=group) # pylint: enable=redefined-builtin - def get_polygon(self): - return self.get_points() + def as_polygon(self): + return self.points[:] - def area(self): + def get_area(self): return 0 -class PolygonObject(ShapeObject): +class Polygon(_Shape): # pylint: disable=redefined-builtin def __init__(self, points=None, z_order=None, label=None, id=None, attributes=None, group=None): if points is not None: + # keep the message on the single line to produce + # informative output assert len(points) % 2 == 0 and 3 <= len(points) // 2, "Wrong polygon points: %s" % points super().__init__(type=AnnotationType.polygon, points=points, label=label, z_order=z_order, id=id, attributes=attributes, group=group) # pylint: enable=redefined-builtin - def get_polygon(self): - return self.get_points() - - def area(self): + def get_area(self): import pycocotools.mask as mask_utils _, _, w, h = self.get_bbox() - rle = mask_utils.frPyObjects([self.get_points()], h, w) + rle = mask_utils.frPyObjects([self.points], h, w) area = mask_utils.area(rle)[0] return area -class BboxObject(ShapeObject): +class Bbox(_Shape): # pylint: disable=redefined-builtin def __init__(self, x=0, y=0, w=0, h=0, label=None, z_order=None, id=None, attributes=None, group=None): @@ -376,13 +462,13 @@ def w(self): def h(self): return self.points[3] - self.points[1] - def area(self): + def get_area(self): return self.w * self.h def get_bbox(self): return [self.x, self.y, self.w, self.h] - def get_polygon(self): + def as_polygon(self): x, y, w, h = self.get_bbox() return [ x, y, @@ -417,7 +503,7 @@ def __eq__(self, other): return \ (self.items == other.items) -class PointsObject(ShapeObject): +class Points(_Shape): Visibility = Enum('Visibility', [ ('absent', 0), ('hidden', 1), @@ -447,7 +533,7 @@ def __init__(self, points=None, visibility=None, label=None, z_order=None, self.visibility = visibility # pylint: enable=redefined-builtin - def area(self): + def get_area(self): return 0 def get_bbox(self): @@ -467,7 +553,7 @@ def __eq__(self, other): return \ (self.visibility == other.visibility) -class CaptionObject(Annotation): +class Caption(Annotation): # pylint: disable=redefined-builtin def __init__(self, caption=None, id=None, attributes=None, group=None): @@ -476,6 +562,8 @@ def __init__(self, caption=None, if caption is None: caption = '' + else: + caption = str(caption) self.caption = caption # pylint: enable=redefined-builtin @@ -642,5 +730,39 @@ def select(self, pred): return DatasetIteratorWrapper( _DatasetFilter(self, pred), self.categories(), self.subsets()) +DEFAULT_SUBSET_NAME = 'default' + + +class SourceExtractor(Extractor): + pass + +class Importer: + def __call__(self, path, **extra_params): + raise NotImplementedError() + +class Transform(Extractor): + @classmethod + def wrap_item(cls, item, **kwargs): + expected_args = {'id', 'annotations', 'subset', 'path', 'image'} + for k in expected_args: + if k not in kwargs: + if k == 'image' and item.has_image: + kwargs[k] = lambda: item.image + else: + kwargs[k] = getattr(item, k) + return DatasetItem(**kwargs) + + def __init__(self, extractor): + super().__init__() + + self._extractor = extractor + + def __iter__(self): + for item in self._extractor: + yield self.transform_item(item) + + def categories(self): + return self._extractor.categories() -DEFAULT_SUBSET_NAME = 'default' \ No newline at end of file + def transform_item(self, item): + raise NotImplementedError() \ No newline at end of file diff --git a/datumaro/datumaro/components/extractors/__init__.py b/datumaro/datumaro/components/extractors/__init__.py deleted file mode 100644 index 0b7a19475a60..000000000000 --- a/datumaro/datumaro/components/extractors/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from datumaro.components.extractors.datumaro import DatumaroExtractor - -from datumaro.components.extractors.coco import ( - CocoImageInfoExtractor, - CocoCaptionsExtractor, - CocoInstancesExtractor, - CocoLabelsExtractor, - CocoPersonKeypointsExtractor, -) - -from datumaro.components.extractors.voc import ( - VocClassificationExtractor, - VocDetectionExtractor, - VocSegmentationExtractor, - VocLayoutExtractor, - VocActionExtractor, - VocComp_1_2_Extractor, - VocComp_3_4_Extractor, - VocComp_5_6_Extractor, - VocComp_7_8_Extractor, - VocComp_9_10_Extractor, -) - -from datumaro.components.extractors.yolo import YoloExtractor -from datumaro.components.extractors.tfrecord import DetectionApiExtractor -from datumaro.components.extractors.cvat import CvatExtractor -from datumaro.components.extractors.image_dir import ImageDirExtractor - -items = [ - ('datumaro', DatumaroExtractor), - - ('coco_images', CocoImageInfoExtractor), - ('coco_captions', CocoCaptionsExtractor), - ('coco_instances', CocoInstancesExtractor), - ('coco_person_kp', CocoPersonKeypointsExtractor), - ('coco_labels', CocoLabelsExtractor), - - ('voc_cls', VocClassificationExtractor), - ('voc_det', VocDetectionExtractor), - ('voc_segm', VocSegmentationExtractor), - ('voc_layout', VocLayoutExtractor), - ('voc_action', VocActionExtractor), - - ('voc_comp_1_2', VocComp_1_2_Extractor), - ('voc_comp_3_4', VocComp_3_4_Extractor), - ('voc_comp_5_6', VocComp_5_6_Extractor), - ('voc_comp_7_8', VocComp_7_8_Extractor), - ('voc_comp_9_10', VocComp_9_10_Extractor), - - ('yolo', YoloExtractor), - - ('tf_detection_api', DetectionApiExtractor), - - ('cvat', CvatExtractor), - - ('image_dir', ImageDirExtractor), -] \ No newline at end of file diff --git a/datumaro/datumaro/components/formats/__init__.py b/datumaro/datumaro/components/formats/__init__.py deleted file mode 100644 index a9773073830c..000000000000 --- a/datumaro/datumaro/components/formats/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - diff --git a/datumaro/datumaro/components/importers/__init__.py b/datumaro/datumaro/components/importers/__init__.py deleted file mode 100644 index cc009dbf47d7..000000000000 --- a/datumaro/datumaro/components/importers/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from datumaro.components.importers.datumaro import DatumaroImporter -from datumaro.components.importers.coco import CocoImporter -from datumaro.components.importers.voc import VocImporter, VocResultsImporter -from datumaro.components.importers.tfrecord import DetectionApiImporter -from datumaro.components.importers.yolo import YoloImporter -from datumaro.components.importers.cvat import CvatImporter -from datumaro.components.importers.image_dir import ImageDirImporter - - -items = [ - ('datumaro', DatumaroImporter), - - ('coco', CocoImporter), - - ('voc', VocImporter), - ('voc_results', VocResultsImporter), - - ('yolo', YoloImporter), - - ('tf_detection_api', DetectionApiImporter), - - ('cvat', CvatImporter), - - ('image_dir', ImageDirImporter), -] \ No newline at end of file diff --git a/datumaro/datumaro/components/importers/image_dir.py b/datumaro/datumaro/components/importers/image_dir.py deleted file mode 100644 index ef2cdd43e5a3..000000000000 --- a/datumaro/datumaro/components/importers/image_dir.py +++ /dev/null @@ -1,26 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -import os.path as osp - - -class ImageDirImporter: - EXTRACTOR_NAME = 'image_dir' - - def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - - if not osp.isdir(path): - raise Exception("Can't find a directory at '%s'" % path) - - source_name = osp.basename(osp.normpath(path)) - project.add_source(source_name, { - 'url': source_name, - 'format': self.EXTRACTOR_NAME, - 'options': dict(extra_params), - }) - - return project diff --git a/datumaro/datumaro/components/launchers/__init__.py b/datumaro/datumaro/components/launchers/__init__.py deleted file mode 100644 index 8d613a2ac53a..000000000000 --- a/datumaro/datumaro/components/launchers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -items = [ -] - -try: - from datumaro.components.launchers.openvino import OpenVinoLauncher - items.append(('openvino', OpenVinoLauncher)) -except ImportError: - pass diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index 34acf41f5243..e328903305d4 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -4,9 +4,11 @@ # SPDX-License-Identifier: MIT from collections import OrderedDict, defaultdict +from functools import reduce import git +from glob import glob import importlib -from functools import reduce +import inspect import logging as log import os import os.path as osp @@ -20,16 +22,16 @@ XPathDatasetFilter, XPathAnnotationsFilter -def import_foreign_module(name, path): +def import_foreign_module(name, path, package=None): module = None default_path = sys.path.copy() try: sys.path = [ osp.abspath(path), ] + default_path sys.modules.pop(name, None) # remove from cache - module = importlib.import_module(name) + module = importlib.import_module(name, package=package) sys.modules.pop(name) # remove from cache - except ImportError as e: - log.warn("Failed to import module '%s': %s" % (name, e)) + except Exception: + raise finally: sys.path = default_path return module @@ -81,19 +83,21 @@ def load(self, config): for name, source in config.sources.items(): self.register(name, source) - -class ModuleRegistry(Registry): +class PluginRegistry(Registry): def __init__(self, config=None, builtin=None, local=None): super().__init__(config) + from datumaro.components.cli_plugin import CliPlugin + if builtin is not None: - for k, v in builtin: + for v in builtin: + k = CliPlugin._get_name(v) self.register(k, v) if local is not None: - for k, v in local: + for v in local: + k = CliPlugin._get_name(v) self.register(k, v) - class GitWrapper: def __init__(self, config=None): self.repo = None @@ -135,103 +139,134 @@ def load_project_as_dataset(url): raise NotImplementedError() class Environment: + _builtin_plugins = None PROJECT_EXTRACTOR_NAME = 'project' def __init__(self, config=None): config = Config(config, fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - env_dir = osp.join(config.project_dir, config.env_dir) - env_config_path = osp.join(env_dir, config.env_filename) - env_config = Config(fallback=ENV_DEFAULT_CONFIG, schema=ENV_SCHEMA) - if osp.isfile(env_config_path): - env_config.update(Config.parse(env_config_path)) - - self.config = env_config - - self.models = ModelRegistry(env_config) + self.models = ModelRegistry(config) self.sources = SourceRegistry(config) - import datumaro.components.importers as builtin_importers - builtin_importers = builtin_importers.items - custom_importers = self._get_custom_module_items( - env_dir, env_config.importers_dir) - self.importers = ModuleRegistry(config, - builtin=builtin_importers, local=custom_importers) - - import datumaro.components.extractors as builtin_extractors - builtin_extractors = builtin_extractors.items - custom_extractors = self._get_custom_module_items( - env_dir, env_config.extractors_dir) - self.extractors = ModuleRegistry(config, - builtin=builtin_extractors, local=custom_extractors) + self.git = GitWrapper(config) + + env_dir = osp.join(config.project_dir, config.env_dir) + builtin = self._load_builtin_plugins() + custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir)) + select = lambda seq, t: [e for e in seq if issubclass(e, t)] + from datumaro.components.extractor import Transform + from datumaro.components.extractor import SourceExtractor + from datumaro.components.extractor import Importer + from datumaro.components.converter import Converter + from datumaro.components.launcher import Launcher + self.extractors = PluginRegistry( + builtin=select(builtin, SourceExtractor), + local=select(custom, SourceExtractor) + ) self.extractors.register(self.PROJECT_EXTRACTOR_NAME, load_project_as_dataset) - import datumaro.components.launchers as builtin_launchers - builtin_launchers = builtin_launchers.items - custom_launchers = self._get_custom_module_items( - env_dir, env_config.launchers_dir) - self.launchers = ModuleRegistry(config, - builtin=builtin_launchers, local=custom_launchers) - - import datumaro.components.converters as builtin_converters - builtin_converters = builtin_converters.items - custom_converters = self._get_custom_module_items( - env_dir, env_config.converters_dir) - if custom_converters is not None: - custom_converters = custom_converters.items - self.converters = ModuleRegistry(config, - builtin=builtin_converters, local=custom_converters) - - self.statistics = ModuleRegistry(config) - self.visualizers = ModuleRegistry(config) - self.git = GitWrapper(config) + self.importers = PluginRegistry( + builtin=select(builtin, Importer), + local=select(custom, Importer) + ) + self.launchers = PluginRegistry( + builtin=select(builtin, Launcher), + local=select(custom, Launcher) + ) + self.converters = PluginRegistry( + builtin=select(builtin, Converter), + local=select(custom, Converter) + ) + self.transforms = PluginRegistry( + builtin=select(builtin, Transform), + local=select(custom, Transform) + ) - def _get_custom_module_items(self, module_dir, module_name): - items = None + @staticmethod + def _find_plugins(plugins_dir): + plugins = [] + if not osp.exists(plugins_dir): + return plugins + + for plugin_name in os.listdir(plugins_dir): + p = osp.join(plugins_dir, plugin_name) + if osp.isfile(p) and p.endswith('.py'): + plugins.append((plugins_dir, plugin_name, None)) + elif osp.isdir(p): + plugins += [(plugins_dir, + osp.splitext(plugin_name)[0] + '.' + osp.basename(p), + osp.splitext(plugin_name)[0] + ) + for p in glob(osp.join(p, '*.py'))] + return plugins - module = None - if osp.exists(osp.join(module_dir, module_name)): - module = import_foreign_module(module_name, module_dir) - if module is not None: - if hasattr(module, 'items'): - items = module.items - else: - items = self._find_custom_module_items( - osp.join(module_dir, module_name)) + @classmethod + def _import_module(cls, module_dir, module_name, types, package=None): + module = import_foreign_module(osp.splitext(module_name)[0], module_dir, + package=package) - return items + exports = [] + if hasattr(module, 'exports'): + exports = module.exports + else: + for symbol in dir(module): + if symbol.startswith('_'): + continue + exports.append(getattr(module, symbol)) - @staticmethod - def _find_custom_module_items(module_dir): - files = [p for p in os.listdir(module_dir) - if p.endswith('.py') and p != '__init__.py'] - - all_items = [] - for f in files: - name = osp.splitext(f)[0] - module = import_foreign_module(name, module_dir) - - items = [] - if hasattr(module, 'items'): - items = module.items - else: - if hasattr(module, name): - items = [ (name, getattr(module, name)) ] - else: - log.warn("Failed to import custom module '%s'." - " Custom module is expected to provide 'items' " - "list or have an item matching its file name." - " Skipping this module." % \ - (module_dir + '.' + name)) + exports = [s for s in exports + if inspect.isclass(s) and issubclass(s, types) and not s in types] - all_items.extend(items) + return exports + + @classmethod + def _load_plugins(cls, plugins_dir, types): + types = tuple(types) + + plugins = cls._find_plugins(plugins_dir) + + all_exports = [] + for module_dir, module_name, package in plugins: + try: + exports = cls._import_module(module_dir, module_name, types, + package) + except ImportError as e: + log.debug("Failed to import module '%s': %s" % (module_name, e)) + continue + + log.debug("Imported the following symbols from %s: %s" % \ + ( + module_name, + ', '.join(s.__name__ for s in exports) + ) + ) + all_exports.extend(exports) + + return all_exports - return all_items + @classmethod + def _load_builtin_plugins(cls): + if not cls._builtin_plugins: + plugins_dir = osp.join( + __file__[: __file__.rfind(osp.join('datumaro', 'components'))], + osp.join('datumaro', 'plugins') + ) + assert osp.isdir(plugins_dir), plugins_dir + cls._builtin_plugins = cls._load_plugins2(plugins_dir) + return cls._builtin_plugins + + @classmethod + def _load_plugins2(cls, plugins_dir): + from datumaro.components.extractor import Transform + from datumaro.components.extractor import SourceExtractor + from datumaro.components.extractor import Importer + from datumaro.components.converter import Converter + from datumaro.components.launcher import Launcher + types = [SourceExtractor, Converter, Importer, Launcher, Transform] - def save(self, path): - self.config.dump(path) + return cls._load_plugins(plugins_dir, types) def make_extractor(self, name, *args, **kwargs): return self.extractors.get(name)(*args, **kwargs) @@ -246,11 +281,9 @@ def make_converter(self, name, *args, **kwargs): return self.converters.get(name)(*args, **kwargs) def register_model(self, name, model): - self.config.models[name] = model self.models.register(name, model) def unregister_model(self, name): - self.config.models.remove(name) self.models.unregister(name) @@ -648,15 +681,21 @@ def _save_branch_project(self, extractor, save_dir=None): dst_dataset.save(save_dir=save_dir, merge=True) - def transform_project(self, method, *args, save_dir=None, **kwargs): + def transform_project(self, method, save_dir=None, **method_kwargs): # NOTE: probably this function should be in the ViewModel layer - transformed = self.transform(method, *args, **kwargs) + if isinstance(method, str): + method = self.env.make_transform(method) + + transformed = self.transform(method, **method_kwargs) self._save_branch_project(transformed, save_dir=save_dir) - def apply_model(self, model_name, save_dir=None): + def apply_model(self, model, save_dir=None, batch_size=1): # NOTE: probably this function should be in the ViewModel layer - launcher = self._project.make_executable_model(model_name) - self.transform_project(InferenceWrapper, launcher, save_dir=save_dir) + if isinstance(model, str): + launcher = self._project.make_executable_model(model) + + self.transform_project(InferenceWrapper, launcher=launcher, + save_dir=save_dir, batch_size=batch_size) def export_project(self, save_dir, converter, filter_expr=None, filter_annotations=False, remove_empty=False): @@ -698,12 +737,8 @@ def save(self, save_dir=None): if save_dir is None: assert config.project_dir save_dir = osp.abspath(config.project_dir) + os.makedirs(save_dir, exist_ok=True) config_path = osp.join(save_dir, config.project_filename) - - env_dir = osp.join(save_dir, config.env_dir) - os.makedirs(env_dir, exist_ok=True) - self.env.save(osp.join(env_dir, config.env_filename)) - config.dump(config_path) @staticmethod @@ -757,6 +792,7 @@ def add_model(self, name, value=Model()): if isinstance(value, (dict, Config)): value = Model(value) self.env.register_model(name, value) + self.config.models[name] = value def get_model(self, name): try: @@ -765,6 +801,7 @@ def get_model(self, name): raise KeyError("Model '%s' is not found" % name) def remove_model(self, name): + self.config.models.remove(name) self.env.unregister_model(name) def make_executable_model(self, name): @@ -785,7 +822,7 @@ def make_source_project(self, name): def local_model_dir(self, model_name): return osp.join( - self.config.env_dir, self.env.config.models_dir, model_name) + self.config.env_dir, self.config.models_dir, model_name) def local_source_dir(self, source_name): return osp.join(self.config.sources_dir, source_name) diff --git a/datumaro/datumaro/components/converters/coco.py b/datumaro/datumaro/plugins/coco_format/converter.py similarity index 90% rename from datumaro/datumaro/components/converters/coco.py rename to datumaro/datumaro/plugins/coco_format/converter.py index f2017a19dd74..39605aa4612f 100644 --- a/datumaro/datumaro/components/converters/coco.py +++ b/datumaro/datumaro/plugins/coco_format/converter.py @@ -13,14 +13,16 @@ import pycocotools.mask as mask_utils from datumaro.components.converter import Converter -from datumaro.components.extractor import ( - DEFAULT_SUBSET_NAME, AnnotationType, PointsObject, MaskObject +from datumaro.components.extractor import (DEFAULT_SUBSET_NAME, + AnnotationType, Points, Mask ) -from datumaro.components.formats.coco import CocoTask, CocoPath +from datumaro.components.cli_plugin import CliPlugin from datumaro.util import find from datumaro.util.image import save_image import datumaro.util.mask_tools as mask_tools +from .format import CocoTask, CocoPath + def _cast(value, type_conv, default=None): if value is None: @@ -202,7 +204,7 @@ def find_instance_parts(self, group, img_width, img_height): leader = self.find_group_leader(anns) bbox = self.compute_bbox(anns) mask = None - polygons = [p.get_polygon() for p in polygons] + polygons = [p.points for p in polygons] if self._context._segmentation_mode == SegmentationMode.guess: use_masks = True == leader.attributes.get('is_crowd', @@ -237,7 +239,7 @@ def find_instance_parts(self, group, img_width, img_height): @staticmethod def find_group_leader(group): - return max(group, key=lambda x: x.area()) + return max(group, key=lambda x: x.get_area()) @staticmethod def merge_masks(masks): @@ -245,7 +247,7 @@ def merge_masks(masks): return None def get_mask(m): - if isinstance(m, MaskObject): + if isinstance(m, Mask): return m.image else: return m @@ -278,7 +280,7 @@ def find_instances(cls, annotations): ann_groups = [] for g_id, group in groupby(instance_anns, lambda a: a.group): - if g_id is None: + if not g_id: ann_groups.extend(([a] for a in group)) else: ann_groups.append(list(group)) @@ -395,7 +397,7 @@ def find_solitary_points(cls, annotations): solitary_points = [] for g_id, group in groupby(annotations, lambda a: a.group): - if g_id is not None and not cls.find_instance_anns(group): + if g_id and not cls.find_instance_anns(group): group = [a for a in group if a.type == AnnotationType.points] solitary_points.extend(group) @@ -404,7 +406,7 @@ def find_solitary_points(cls, annotations): @staticmethod def convert_points_object(ann): keypoints = [] - points = ann.get_points() + points = ann.points visibility = ann.visibility for index in range(0, len(points), 2): kp = points[index : index + 2] @@ -412,7 +414,7 @@ def convert_points_object(ann): keypoints.extend([*kp, state]) num_annotated = len([v for v in visibility \ - if v != PointsObject.Visibility.absent]) + if v != Points.Visibility.absent]) return { 'keypoints': keypoints, @@ -543,8 +545,11 @@ def convert(self): filename = '' if item.has_image: filename = str(item.id) + CocoPath.IMAGE_EXT - if self._save_images: + if self._save_images: + if item.has_image: self.save_image(item, filename) + else: + log.debug("Item '%s' has no image" % item.id) for task_conv in task_converters.values(): task_conv.save_image_info(item, filename) task_conv.save_annotations(item) @@ -554,46 +559,29 @@ def convert(self): task_conv.write(osp.join(self._ann_dir, '%s_%s.json' % (task.name, subset_name))) -class CocoConverter(Converter): - def __init__(self, - tasks=None, save_images=False, segmentation_mode=None, - crop_covered=False, - cmdline_args=None): - super().__init__() - - self._options = { - 'tasks': tasks, - 'save_images': save_images, - 'segmentation_mode': segmentation_mode, - 'crop_covered': crop_covered, - } - - if cmdline_args is not None: - self._options.update(self._parse_cmdline(cmdline_args)) - +class CocoConverter(Converter, CliPlugin): @staticmethod def _split_tasks_string(s): return [CocoTask[i.strip()] for i in s.split(',')] @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='coco') - + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--save-images', action='store_true', help="Save images (default: %(default)s)") parser.add_argument('--segmentation-mode', choices=[m.name for m in SegmentationMode], default=SegmentationMode.guess.name, - help="Save mode for instance segmentation: " - "- '{sm.guess.name}': guess the mode for each instance, " - "use 'is_crowd' attribute as hint; " - "- '{sm.polygons.name}': save polygons, " - "merge and convert masks, prefer polygons; " - "- '{sm.mask.name}': save masks, " - "merge and convert polygons, prefer masks; " - "(default: %(default)s)".format(sm=SegmentationMode)) + help=""" + Save mode for instance segmentation:|n + - '{sm.guess.name}': guess the mode for each instance,|n + |s|suse 'is_crowd' attribute as hint|n + - '{sm.polygons.name}': save polygons,|n + |s|smerge and convert masks, prefer polygons|n + - '{sm.mask.name}': save masks,|n + |s|smerge and convert polygons, prefer masks|n + Default: %(default)s. + """.format(sm=SegmentationMode)) parser.add_argument('--crop-covered', action='store_true', help="Crop covered segments so that background objects' " "segmentation was more accurate (default: %(default)s)") @@ -601,24 +589,40 @@ def build_cmdline_parser(cls, parser=None): default=None, help="COCO task filter, comma-separated list of {%s} " "(default: all)" % ', '.join([t.name for t in CocoTask])) - return parser + def __init__(self, + tasks=None, save_images=False, segmentation_mode=None, + crop_covered=False): + super().__init__() + + self._options = { + 'tasks': tasks, + 'save_images': save_images, + 'segmentation_mode': segmentation_mode, + 'crop_covered': crop_covered, + } + def __call__(self, extractor, save_dir): converter = _Converter(extractor, save_dir, **self._options) converter.convert() -def CocoInstancesConverter(**kwargs): - return CocoConverter(CocoTask.instances, **kwargs) +class CocoInstancesConverter(CocoConverter): + def __init__(self, **kwargs): + super().__init__(CocoTask.instances, **kwargs) -def CocoImageInfoConverter(**kwargs): - return CocoConverter(CocoTask.image_info, **kwargs) +class CocoImageInfoConverter(CocoConverter): + def __init__(self, **kwargs): + super().__init__(CocoTask.image_info, **kwargs) -def CocoPersonKeypointsConverter(**kwargs): - return CocoConverter(CocoTask.person_keypoints, **kwargs) +class CocoPersonKeypointsConverter(CocoConverter): + def __init__(self, **kwargs): + super().__init__(CocoTask.person_keypoints, **kwargs) -def CocoCaptionsConverter(**kwargs): - return CocoConverter(CocoTask.captions, **kwargs) +class CocoCaptionsConverter(CocoConverter): + def __init__(self, **kwargs): + super().__init__(CocoTask.captions, **kwargs) -def CocoLabelsConverter(**kwargs): - return CocoConverter(CocoTask.labels, **kwargs) \ No newline at end of file +class CocoLabelsConverter(CocoConverter): + def __init__(self, **kwargs): + super().__init__(CocoTask.labels, **kwargs) diff --git a/datumaro/datumaro/components/extractors/coco.py b/datumaro/datumaro/plugins/coco_format/extractor.py similarity index 77% rename from datumaro/datumaro/components/extractors/coco.py rename to datumaro/datumaro/plugins/coco_format/extractor.py index 05404f21e6b6..dca7d8532a5b 100644 --- a/datumaro/datumaro/components/extractors/coco.py +++ b/datumaro/datumaro/plugins/coco_format/extractor.py @@ -4,22 +4,23 @@ # SPDX-License-Identifier: MIT from collections import OrderedDict +import logging as log import os.path as osp from pycocotools.coco import COCO import pycocotools.mask as mask_utils -from datumaro.components.extractor import (Extractor, DatasetItem, - DEFAULT_SUBSET_NAME, AnnotationType, - LabelObject, RleMask, PointsObject, PolygonObject, - BboxObject, CaptionObject, +from datumaro.components.extractor import (SourceExtractor, + DEFAULT_SUBSET_NAME, DatasetItem, + AnnotationType, Label, RleMask, Points, Polygon, Bbox, Caption, LabelCategories, PointsCategories ) -from datumaro.components.formats.coco import CocoTask, CocoPath from datumaro.util.image import lazy_image +from .format import CocoTask, CocoPath -class CocoExtractor(Extractor): + +class _CocoExtractor(SourceExtractor): def __init__(self, path, task, merge_instance_polygons=False): super().__init__() @@ -156,7 +157,7 @@ def _load_annotations(self, ann, image_info=None): points = [p for i, p in enumerate(keypoints) if i % 3 != 2] visibility = keypoints[2::3] parsed_annotations.append( - PointsObject(points, visibility, label=label_id, + Points(points, visibility, label=label_id, id=ann_id, attributes=attributes, group=group) ) @@ -165,14 +166,14 @@ def _load_annotations(self, ann, image_info=None): rle = None if isinstance(segmentation, list): - # polygon - a single object can consist of multiple parts - for polygon_points in segmentation: - parsed_annotations.append(PolygonObject( - points=polygon_points, label=label_id, - id=ann_id, attributes=attributes, group=group - )) - - if self._merge_instance_polygons: + if not self._merge_instance_polygons: + # polygon - a single object can consist of multiple parts + for polygon_points in segmentation: + parsed_annotations.append(Polygon( + points=polygon_points, label=label_id, + id=ann_id, attributes=attributes, group=group + )) + else: # merge all parts into a single mask RLE img_h = image_info['height'] img_w = image_info['width'] @@ -180,8 +181,19 @@ def _load_annotations(self, ann, image_info=None): rle = mask_utils.merge(rles) elif isinstance(segmentation['counts'], list): # uncompressed RLE - img_h, img_w = segmentation['size'] - rle = mask_utils.frPyObjects([segmentation], img_h, img_w)[0] + img_h = image_info['height'] + img_w = image_info['width'] + mask_h, mask_w = segmentation['size'] + if img_h == mask_h and img_w == mask_w: + rle = mask_utils.frPyObjects( + [segmentation], mask_h, mask_w)[0] + else: + log.warning("item #%s: mask #%s " + "does not match image size: %s vs. %s. " + "Skipping this annotation.", + image_info['id'], ann_id, + (mask_h, mask_w), (img_h, img_w) + ) else: # compressed RLE rle = segmentation @@ -190,21 +202,21 @@ def _load_annotations(self, ann, image_info=None): parsed_annotations.append(RleMask(rle=rle, label=label_id, id=ann_id, attributes=attributes, group=group )) - - parsed_annotations.append( - BboxObject(x, y, w, h, label=label_id, - id=ann_id, attributes=attributes, group=group) - ) + else: + parsed_annotations.append( + Bbox(x, y, w, h, label=label_id, + id=ann_id, attributes=attributes, group=group) + ) elif self._task is CocoTask.labels: label_id = self._get_label_id(ann) parsed_annotations.append( - LabelObject(label=label_id, + Label(label=label_id, id=ann_id, attributes=attributes, group=group) ) elif self._task is CocoTask.captions: caption = ann['caption'] parsed_annotations.append( - CaptionObject(caption, + Caption(caption, id=ann_id, attributes=attributes, group=group) ) else: @@ -222,23 +234,22 @@ def _find_image(self, file_name): if osp.exists(image_path): return lazy_image(image_path) -class CocoImageInfoExtractor(CocoExtractor): +class CocoImageInfoExtractor(_CocoExtractor): def __init__(self, path, **kwargs): super().__init__(path, task=CocoTask.image_info, **kwargs) -class CocoCaptionsExtractor(CocoExtractor): +class CocoCaptionsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): super().__init__(path, task=CocoTask.captions, **kwargs) -class CocoInstancesExtractor(CocoExtractor): +class CocoInstancesExtractor(_CocoExtractor): def __init__(self, path, **kwargs): super().__init__(path, task=CocoTask.instances, **kwargs) -class CocoPersonKeypointsExtractor(CocoExtractor): +class CocoPersonKeypointsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): - super().__init__(path, task=CocoTask.person_keypoints, - **kwargs) + super().__init__(path, task=CocoTask.person_keypoints, **kwargs) -class CocoLabelsExtractor(CocoExtractor): +class CocoLabelsExtractor(_CocoExtractor): def __init__(self, path, **kwargs): super().__init__(path, task=CocoTask.labels, **kwargs) \ No newline at end of file diff --git a/datumaro/datumaro/components/formats/coco.py b/datumaro/datumaro/plugins/coco_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/coco.py rename to datumaro/datumaro/plugins/coco_format/format.py diff --git a/datumaro/datumaro/components/importers/coco.py b/datumaro/datumaro/plugins/coco_format/importer.py similarity index 90% rename from datumaro/datumaro/components/importers/coco.py rename to datumaro/datumaro/plugins/coco_format/importer.py index 9e3d38e611e0..bb129d7aed5a 100644 --- a/datumaro/datumaro/components/importers/coco.py +++ b/datumaro/datumaro/plugins/coco_format/importer.py @@ -8,16 +8,18 @@ import logging as log import os.path as osp -from datumaro.components.formats.coco import CocoTask, CocoPath +from datumaro.components.extractor import Importer +from .format import CocoTask, CocoPath -class CocoImporter: + +class CocoImporter(Importer): _COCO_EXTRACTORS = { CocoTask.instances: 'coco_instances', - CocoTask.person_keypoints: 'coco_person_kp', + CocoTask.person_keypoints: 'coco_person_keypoints', CocoTask.captions: 'coco_captions', CocoTask.labels: 'coco_labels', - CocoTask.image_info: 'coco_images', + CocoTask.image_info: 'coco_image_info', } def __call__(self, path, **extra_params): diff --git a/datumaro/datumaro/components/converters/cvat.py b/datumaro/datumaro/plugins/cvat_format/converter.py similarity index 92% rename from datumaro/datumaro/components/converters/cvat.py rename to datumaro/datumaro/plugins/cvat_format/converter.py index 475bc0b997ed..948916fee658 100644 --- a/datumaro/datumaro/components/converters/cvat.py +++ b/datumaro/datumaro/plugins/cvat_format/converter.py @@ -4,15 +4,18 @@ # SPDX-License-Identifier: MIT from collections import OrderedDict +import logging as log import os import os.path as osp from xml.sax.saxutils import XMLGenerator +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.converter import Converter from datumaro.components.extractor import DEFAULT_SUBSET_NAME, AnnotationType -from datumaro.components.formats.cvat import CvatPath from datumaro.util.image import save_image +from .format import CvatPath + def _cast(value, type_conv, default=None): if value is None: @@ -156,7 +159,10 @@ def write(self): for item in self._extractor: if self._context._save_images: - self._save_image(item) + if item.has_image: + self._save_image(item) + else: + log.debug("Item '%s' has no image" % item.id) self._write_item(item) self._writer.close_root() @@ -235,13 +241,12 @@ def _write_shape(self, shape): ("occluded", str(int(shape.attributes.get('occluded', False)))), ]) - points = shape.get_points() if shape.type == AnnotationType.bbox: shape_data.update(OrderedDict([ - ("xtl", "{:.2f}".format(points[0])), - ("ytl", "{:.2f}".format(points[1])), - ("xbr", "{:.2f}".format(points[2])), - ("ybr", "{:.2f}".format(points[3])) + ("xtl", "{:.2f}".format(shape.points[0])), + ("ytl", "{:.2f}".format(shape.points[1])), + ("xbr", "{:.2f}".format(shape.points[2])), + ("ybr", "{:.2f}".format(shape.points[3])) ])) else: shape_data.update(OrderedDict([ @@ -249,12 +254,12 @@ def _write_shape(self, shape): ','.join(( "{:.2f}".format(x), "{:.2f}".format(y) - )) for x, y in pairwise(points)) + )) for x, y in pairwise(shape.points)) )), ])) shape_data['z_order'] = str(int(shape.attributes.get('z_order', 0))) - if shape.group is not None: + if shape.group: shape_data['group_id'] = str(shape.group) if shape.type == AnnotationType.bbox: @@ -320,28 +325,21 @@ def convert(self): writer = _SubsetWriter(f, subset_name, subset, self) writer.write() -class CvatConverter(Converter): - def __init__(self, save_images=False, cmdline_args=None): +class CvatConverter(Converter, CliPlugin): + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().__init__(**kwargs) + parser.add_argument('--save-images', action='store_true', + help="Save images (default: %(default)s)") + return parser + + def __init__(self, save_images=False): super().__init__() self._options = { 'save_images': save_images, } - if cmdline_args is not None: - self._options.update(self._parse_cmdline(cmdline_args)) - - @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='cvat') - - parser.add_argument('--save-images', action='store_true', - help="Save images (default: %(default)s)") - - return parser - def __call__(self, extractor, save_dir): converter = _Converter(extractor, save_dir, **self._options) converter.convert() \ No newline at end of file diff --git a/datumaro/datumaro/components/extractors/cvat.py b/datumaro/datumaro/plugins/cvat_format/extractor.py similarity index 95% rename from datumaro/datumaro/components/extractors/cvat.py rename to datumaro/datumaro/plugins/cvat_format/extractor.py index e3c869c46b1a..00b9acb6a425 100644 --- a/datumaro/datumaro/components/extractors/cvat.py +++ b/datumaro/datumaro/plugins/cvat_format/extractor.py @@ -7,16 +7,17 @@ import os.path as osp import xml.etree as ET -from datumaro.components.extractor import (Extractor, DatasetItem, - DEFAULT_SUBSET_NAME, AnnotationType, - PointsObject, PolygonObject, PolyLineObject, BboxObject, +from datumaro.components.extractor import (SourceExtractor, + DEFAULT_SUBSET_NAME, DatasetItem, + AnnotationType, Points, Polygon, PolyLine, Bbox, LabelCategories ) -from datumaro.components.formats.cvat import CvatPath from datumaro.util.image import lazy_image +from .format import CvatPath -class CvatExtractor(Extractor): + +class CvatExtractor(SourceExtractor): _SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points') def __init__(self, path): @@ -242,8 +243,6 @@ def _parse_ann(cls, ann, categories): attributes['keyframe'] = ann.get('keyframe', False) group = ann.get('group') - if group == 0: - group = None label = ann.get('label') label_id = categories[AnnotationType.label].find(label)[0] @@ -251,21 +250,21 @@ def _parse_ann(cls, ann, categories): points = ann.get('points', []) if ann_type == 'polyline': - return PolyLineObject(points, label=label_id, + return PolyLine(points, label=label_id, id=ann_id, attributes=attributes, group=group) elif ann_type == 'polygon': - return PolygonObject(points, label=label_id, + return Polygon(points, label=label_id, id=ann_id, attributes=attributes, group=group) elif ann_type == 'points': - return PointsObject(points, label=label_id, + return Points(points, label=label_id, id=ann_id, attributes=attributes, group=group) elif ann_type == 'box': x, y = points[0], points[1] w, h = points[2] - x, points[3] - y - return BboxObject(x, y, w, h, label=label_id, + return Bbox(x, y, w, h, label=label_id, id=ann_id, attributes=attributes, group=group) else: diff --git a/datumaro/datumaro/components/formats/cvat.py b/datumaro/datumaro/plugins/cvat_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/cvat.py rename to datumaro/datumaro/plugins/cvat_format/format.py diff --git a/datumaro/datumaro/components/importers/cvat.py b/datumaro/datumaro/plugins/cvat_format/importer.py similarity index 91% rename from datumaro/datumaro/components/importers/cvat.py rename to datumaro/datumaro/plugins/cvat_format/importer.py index 6f831a7b90fa..a81b5cb38cec 100644 --- a/datumaro/datumaro/components/importers/cvat.py +++ b/datumaro/datumaro/plugins/cvat_format/importer.py @@ -7,10 +7,12 @@ import logging as log import os.path as osp -from datumaro.components.formats.cvat import CvatPath +from datumaro.components.extractor import Importer +from .format import CvatPath -class CvatImporter: + +class CvatImporter(Importer): EXTRACTOR_NAME = 'cvat' def __call__(self, path, **extra_params): diff --git a/datumaro/datumaro/components/converters/datumaro.py b/datumaro/datumaro/plugins/datumaro_format/converter.py similarity index 88% rename from datumaro/datumaro/components/converters/datumaro.py rename to datumaro/datumaro/plugins/datumaro_format/converter.py index 635817d43eb1..32f31e4acd3b 100644 --- a/datumaro/datumaro/components/converters/datumaro.py +++ b/datumaro/datumaro/plugins/datumaro_format/converter.py @@ -12,12 +12,13 @@ from datumaro.components.converter import Converter from datumaro.components.extractor import ( DEFAULT_SUBSET_NAME, Annotation, - LabelObject, MaskObject, PointsObject, PolygonObject, - PolyLineObject, BboxObject, CaptionObject, + Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, LabelCategories, MaskCategories, PointsCategories ) -from datumaro.components.formats.datumaro import DatumaroPath from datumaro.util.image import save_image +from datumaro.components.cli_plugin import CliPlugin + +from .format import DatumaroPath def _cast(value, type_conv, default=None): @@ -60,19 +61,19 @@ def write_item(self, item): self.items.append(item_desc) for ann in item.annotations: - if isinstance(ann, LabelObject): + if isinstance(ann, Label): converted_ann = self._convert_label_object(ann) - elif isinstance(ann, MaskObject): + elif isinstance(ann, Mask): converted_ann = self._convert_mask_object(ann) - elif isinstance(ann, PointsObject): + elif isinstance(ann, Points): converted_ann = self._convert_points_object(ann) - elif isinstance(ann, PolyLineObject): + elif isinstance(ann, PolyLine): converted_ann = self._convert_polyline_object(ann) - elif isinstance(ann, PolygonObject): + elif isinstance(ann, Polygon): converted_ann = self._convert_polygon_object(ann) - elif isinstance(ann, BboxObject): + elif isinstance(ann, Bbox): converted_ann = self._convert_bbox_object(ann) - elif isinstance(ann, CaptionObject): + elif isinstance(ann, Caption): converted_ann = self._convert_caption_object(ann) else: raise NotImplementedError() @@ -101,7 +102,7 @@ def _convert_annotation(self, obj): 'id': _cast(obj.id, int), 'type': _cast(obj.type.name, str), 'attributes': obj.attributes, - 'group': _cast(obj.group, int, None), + 'group': _cast(obj.group, int, 0), } return ann_json @@ -148,7 +149,7 @@ def _convert_polyline_object(self, obj): converted.update({ 'label_id': _cast(obj.label, int), - 'points': [float(p) for p in obj.get_points()], + 'points': [float(p) for p in obj.points], }) return converted @@ -157,7 +158,7 @@ def _convert_polygon_object(self, obj): converted.update({ 'label_id': _cast(obj.label, int), - 'points': [float(p) for p in obj.get_points()], + 'points': [float(p) for p in obj.points], }) return converted @@ -272,28 +273,20 @@ def _save_image(self, item): str(item.id) + DatumaroPath.IMAGE_EXT) save_image(image_path, image) -class DatumaroConverter(Converter): - def __init__(self, save_images=False, cmdline_args=None): +class DatumaroConverter(Converter, CliPlugin): + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser.add_argument('--save-images', action='store_true', + help="Save images (default: %(default)s)") + return parser + + def __init__(self, save_images=False): super().__init__() self._options = { 'save_images': save_images, } - if cmdline_args is not None: - self._options.update(self._parse_cmdline(cmdline_args)) - - @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='datumaro') - - parser.add_argument('--save-images', action='store_true', - help="Save images (default: %(default)s)") - - return parser - def __call__(self, extractor, save_dir): converter = _Converter(extractor, save_dir, **self._options) - converter.convert() \ No newline at end of file + converter.convert() diff --git a/datumaro/datumaro/components/extractors/datumaro.py b/datumaro/datumaro/plugins/datumaro_format/extractor.py similarity index 87% rename from datumaro/datumaro/components/extractors/datumaro.py rename to datumaro/datumaro/plugins/datumaro_format/extractor.py index 8917b8b99e4f..04b9af0b9482 100644 --- a/datumaro/datumaro/components/extractors/datumaro.py +++ b/datumaro/datumaro/plugins/datumaro_format/extractor.py @@ -7,18 +7,18 @@ import logging as log import os.path as osp -from datumaro.components.extractor import (Extractor, DatasetItem, - DEFAULT_SUBSET_NAME, AnnotationType, - LabelObject, MaskObject, PointsObject, PolygonObject, - PolyLineObject, BboxObject, CaptionObject, +from datumaro.components.extractor import (SourceExtractor, + DEFAULT_SUBSET_NAME, DatasetItem, + AnnotationType, Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, LabelCategories, MaskCategories, PointsCategories ) -from datumaro.components.formats.datumaro import DatumaroPath from datumaro.util.image import lazy_image from datumaro.util.mask_tools import lazy_mask +from .format import DatumaroPath -class DatumaroExtractor(Extractor): + +class DatumaroExtractor(SourceExtractor): def __init__(self, path): super().__init__() @@ -120,7 +120,7 @@ def _load_annotations(self, item): if ann_type == AnnotationType.label: label_id = ann.get('label_id') - loaded.append(LabelObject(label=label_id, + loaded.append(Label(label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.mask: @@ -137,36 +137,36 @@ def _load_annotations(self, item): log.warn("Not found mask image file '%s', skipped." % \ mask_path) - loaded.append(MaskObject(label=label_id, image=mask, + loaded.append(Mask(label=label_id, image=mask, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.polyline: label_id = ann.get('label_id') points = ann.get('points') - loaded.append(PolyLineObject(points, label=label_id, + loaded.append(PolyLine(points, label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.polygon: label_id = ann.get('label_id') points = ann.get('points') - loaded.append(PolygonObject(points, label=label_id, + loaded.append(Polygon(points, label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.bbox: label_id = ann.get('label_id') x, y, w, h = ann.get('bbox') - loaded.append(BboxObject(x, y, w, h, label=label_id, + loaded.append(Bbox(x, y, w, h, label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.points: label_id = ann.get('label_id') points = ann.get('points') - loaded.append(PointsObject(points, label=label_id, + loaded.append(Points(points, label=label_id, id=ann_id, attributes=attributes, group=group)) elif ann_type == AnnotationType.caption: caption = ann.get('caption') - loaded.append(CaptionObject(caption, + loaded.append(Caption(caption, id=ann_id, attributes=attributes, group=group)) else: diff --git a/datumaro/datumaro/components/formats/datumaro.py b/datumaro/datumaro/plugins/datumaro_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/datumaro.py rename to datumaro/datumaro/plugins/datumaro_format/format.py diff --git a/datumaro/datumaro/components/importers/datumaro.py b/datumaro/datumaro/plugins/datumaro_format/importer.py similarity index 91% rename from datumaro/datumaro/components/importers/datumaro.py rename to datumaro/datumaro/plugins/datumaro_format/importer.py index 828208d8d204..0184ef9040f4 100644 --- a/datumaro/datumaro/components/importers/datumaro.py +++ b/datumaro/datumaro/plugins/datumaro_format/importer.py @@ -7,10 +7,12 @@ import logging as log import os.path as osp -from datumaro.components.formats.datumaro import DatumaroPath +from datumaro.components.extractor import Importer +from .format import DatumaroPath -class DatumaroImporter: + +class DatumaroImporter(Importer): EXTRACTOR_NAME = 'datumaro' def __call__(self, path, **extra_params): diff --git a/datumaro/datumaro/components/extractors/image_dir.py b/datumaro/datumaro/plugins/image_dir.py similarity index 63% rename from datumaro/datumaro/components/extractors/image_dir.py rename to datumaro/datumaro/plugins/image_dir.py index 561fa9d8a80d..c2e0e687bce7 100644 --- a/datumaro/datumaro/components/extractors/image_dir.py +++ b/datumaro/datumaro/plugins/image_dir.py @@ -1,5 +1,5 @@ -# Copyright (C) 2018 Intel Corporation +# Copyright (C) 2019 Intel Corporation # # SPDX-License-Identifier: MIT @@ -7,11 +7,31 @@ import os import os.path as osp -from datumaro.components.extractor import DatasetItem, Extractor +from datumaro.components.extractor import DatasetItem, SourceExtractor, Importer from datumaro.util.image import lazy_image -class ImageDirExtractor(Extractor): +class ImageDirImporter(Importer): + EXTRACTOR_NAME = 'image_dir' + + def __call__(self, path, **extra_params): + from datumaro.components.project import Project # cyclic import + project = Project() + + if not osp.isdir(path): + raise Exception("Can't find a directory at '%s'" % path) + + source_name = osp.basename(osp.normpath(path)) + project.add_source(source_name, { + 'url': source_name, + 'format': self.EXTRACTOR_NAME, + 'options': dict(extra_params), + }) + + return project + + +class ImageDirExtractor(SourceExtractor): _SUPPORTED_FORMATS = ['.png', '.jpg'] def __init__(self, url): diff --git a/datumaro/datumaro/components/launchers/openvino.py b/datumaro/datumaro/plugins/openvino_launcher.py similarity index 100% rename from datumaro/datumaro/components/launchers/openvino.py rename to datumaro/datumaro/plugins/openvino_launcher.py diff --git a/datumaro/datumaro/components/converters/tfrecord.py b/datumaro/datumaro/plugins/tf_detection_api_format/converter.py similarity index 85% rename from datumaro/datumaro/components/converters/tfrecord.py rename to datumaro/datumaro/plugins/tf_detection_api_format/converter.py index 72b9c95cc22a..7c240626a50c 100644 --- a/datumaro/datumaro/components/converters/tfrecord.py +++ b/datumaro/datumaro/plugins/tf_detection_api_format/converter.py @@ -5,16 +5,19 @@ import codecs from collections import OrderedDict +import logging as log import os import os.path as osp import string from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME from datumaro.components.converter import Converter -from datumaro.components.formats.tfrecord import DetectionApiPath +from datumaro.components.cli_plugin import CliPlugin from datumaro.util.image import encode_image from datumaro.util.tf_util import import_tf as _import_tf +from .format import DetectionApiPath + # we need it to filter out non-ASCII characters, otherwise training will crash _printable = set(string.printable) @@ -56,14 +59,17 @@ def float_list_feature(value): 'image/width': int64_feature(width), }) - if save_images and item.has_image: - fmt = DetectionApiPath.IMAGE_FORMAT - buffer = encode_image(item.image, DetectionApiPath.IMAGE_EXT) + if save_images: + if item.has_image: + fmt = DetectionApiPath.IMAGE_FORMAT + buffer = encode_image(item.image, DetectionApiPath.IMAGE_EXT) - features.update({ - 'image/encoded': bytes_feature(buffer), - 'image/format': bytes_feature(fmt.encode('utf-8')), - }) + features.update({ + 'image/encoded': bytes_feature(buffer), + 'image/format': bytes_feature(fmt.encode('utf-8')), + }) + else: + log.debug("Item '%s' has no image" % item.id) xmins = [] # List of normalized left x coordinates in bounding box (1 per box) xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box) @@ -98,29 +104,19 @@ def float_list_feature(value): return tf_example -class DetectionApiConverter(Converter): - def __init__(self, save_images=False, cmdline_args=None): - super().__init__() - - self._save_images = save_images - - if cmdline_args is not None: - options = self._parse_cmdline(cmdline_args) - for k, v in options.items(): - if hasattr(self, '_' + str(k)): - setattr(self, '_' + str(k), v) - +class TfDetectionApiConverter(Converter, CliPlugin): @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='tf_detection_api') - + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--save-images', action='store_true', help="Save images (default: %(default)s)") - return parser + def __init__(self, save_images=False): + super().__init__() + + self._save_images = save_images + def __call__(self, extractor, save_dir): tf = _import_tf() diff --git a/datumaro/datumaro/components/extractors/tfrecord.py b/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py similarity index 78% rename from datumaro/datumaro/components/extractors/tfrecord.py rename to datumaro/datumaro/plugins/tf_detection_api_format/extractor.py index 46b78b63ba12..3592072c2f9e 100644 --- a/datumaro/datumaro/components/extractors/tfrecord.py +++ b/datumaro/datumaro/plugins/tf_detection_api_format/extractor.py @@ -8,60 +8,68 @@ import os.path as osp import re -from datumaro.components.extractor import AnnotationType, DEFAULT_SUBSET_NAME, \ - LabelCategories, BboxObject, DatasetItem, Extractor -from datumaro.components.formats.tfrecord import DetectionApiPath +from datumaro.components.extractor import (SourceExtractor, + DEFAULT_SUBSET_NAME, DatasetItem, + AnnotationType, Bbox, LabelCategories +) from datumaro.util.image import lazy_image, decode_image from datumaro.util.tf_util import import_tf as _import_tf +from .format import DetectionApiPath + def clamp(value, _min, _max): return max(min(_max, value), _min) -class DetectionApiExtractor(Extractor): - class Subset(Extractor): - def __init__(self, name, parent): - super().__init__() - self._name = name - self._parent = parent - self.items = OrderedDict() - - def __iter__(self): - for item in self.items.values(): - yield item - - def __len__(self): - return len(self.items) - - def categories(self): - return self._parent.categories() - - def __init__(self, path, images_dir=None): +class TfDetectionApiExtractor(SourceExtractor): + def __init__(self, path): super().__init__() + assert osp.isfile(path) + images_dir = '' root_dir = osp.dirname(osp.abspath(path)) if osp.basename(root_dir) == DetectionApiPath.ANNOTATIONS_DIR: root_dir = osp.dirname(root_dir) images_dir = osp.join(root_dir, DetectionApiPath.IMAGES_DIR) if not osp.isdir(images_dir): - images_dir = None - self._images_dir = images_dir - - self._subsets = {} + images_dir = '' subset_name = osp.splitext(osp.basename(path))[0] if subset_name == DEFAULT_SUBSET_NAME: subset_name = None - subset = DetectionApiExtractor.Subset(subset_name, self) + self._subset_name = subset_name + items, labels = self._parse_tfrecord_file(path, subset_name, images_dir) - subset.items = items - self._subsets[subset_name] = subset + self._items = items + self._categories = self._load_categories(labels) + + def categories(self): + return self._categories + + def __iter__(self): + for item in self._items: + yield item + + def __len__(self): + return len(self._items) + + def subsets(self): + if self._subset_name: + return [self._subset_name] + return None + def get_subset(self, name): + if name != self._subset_name: + return None + return self + + @staticmethod + def _load_categories(labels): label_categories = LabelCategories() labels = sorted(labels.items(), key=lambda item: item[1]) for label, _ in labels: label_categories.add(label) - self._categories = { + return { AnnotationType.label: label_categories } @@ -114,7 +122,7 @@ def _parse_tfrecord_file(cls, filepath, subset_name, images_dir): for label, id in cls._parse_labelmap(labelmap_text).items() }) - dataset_items = OrderedDict() + dataset_items = [] for record in dataset: parsed_record = tf.io.parse_single_example(record, features) @@ -163,7 +171,7 @@ def _parse_tfrecord_file(cls, filepath, subset_name, images_dir): y = clamp(shape[2] * frame_height, 0, frame_height) w = clamp(shape[3] * frame_width, 0, frame_width) - x h = clamp(shape[4] * frame_height, 0, frame_height) - y - annotations.append(BboxObject(x, y, w, h, + annotations.append(Bbox(x, y, w, h, label=dataset_labels.get(label, None), id=index )) @@ -175,32 +183,7 @@ def _parse_tfrecord_file(cls, filepath, subset_name, images_dir): if osp.exists(image_path): image = lazy_image(image_path) - dataset_items[item_id] = DatasetItem(id=item_id, subset=subset_name, - image=image, annotations=annotations) + dataset_items.append(DatasetItem(id=item_id, subset=subset_name, + image=image, annotations=annotations)) return dataset_items, dataset_labels - - def categories(self): - return self._categories - - def __iter__(self): - for subset in self._subsets.values(): - for item in subset: - yield item - - def __len__(self): - length = 0 - for subset in self._subsets.values(): - length += len(subset) - return length - - def subsets(self): - return list(self._subsets) - - def get_subset(self, name): - return self._subsets[name] - - def get(self, item_id, subset=None, path=None): - if path is not None: - return None - return self.get_subset(subset).items.get(item_id, None) \ No newline at end of file diff --git a/datumaro/datumaro/components/formats/tfrecord.py b/datumaro/datumaro/plugins/tf_detection_api_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/tfrecord.py rename to datumaro/datumaro/plugins/tf_detection_api_format/format.py diff --git a/datumaro/datumaro/components/importers/tfrecord.py b/datumaro/datumaro/plugins/tf_detection_api_format/importer.py similarity index 92% rename from datumaro/datumaro/components/importers/tfrecord.py rename to datumaro/datumaro/plugins/tf_detection_api_format/importer.py index 368c3d0fa9b3..3000c8881635 100644 --- a/datumaro/datumaro/components/importers/tfrecord.py +++ b/datumaro/datumaro/plugins/tf_detection_api_format/importer.py @@ -7,8 +7,10 @@ import logging as log import os.path as osp +from datumaro.components.extractor import Importer -class DetectionApiImporter: + +class TfDetectionApiImporter(Importer): EXTRACTOR_NAME = 'tf_detection_api' def __call__(self, path, **extra_params): diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py new file mode 100644 index 000000000000..7449d4fdeb14 --- /dev/null +++ b/datumaro/datumaro/plugins/transforms.py @@ -0,0 +1,255 @@ + +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from itertools import groupby +import logging as log + +import pycocotools.mask as mask_utils + +from datumaro.components.extractor import (Transform, AnnotationType, + Mask, RleMask, Polygon) +from datumaro.components.cli_plugin import CliPlugin +import datumaro.util.mask_tools as mask_tools + + +class CropCoveredSegments(Transform, CliPlugin): + def transform_item(self, item): + annotations = [] + segments = [] + for ann in item.annotations: + if ann.type in {AnnotationType.polygon, AnnotationType.mask}: + segments.append(ann) + else: + annotations.append(ann) + if not segments: + return item + + if not item.has_image: + raise Exception("Image info is required for this transform") + h, w = item.image.shape[:2] + segments = self.crop_segments(segments, w, h) + + annotations += segments + return self.wrap_item(item, annotations=annotations) + + @classmethod + def crop_segments(cls, segment_anns, img_width, img_height): + segment_anns = sorted(segment_anns, key=lambda x: x.z_order) + + segments = [] + for s in segment_anns: + if s.type == AnnotationType.polygon: + segments.append(s.points) + elif s.type == AnnotationType.mask: + if isinstance(s, RleMask): + rle = s._rle + else: + rle = mask_tools.mask_to_rle(s.image) + segments.append(rle) + + segments = mask_tools.crop_covered_segments( + segments, img_width, img_height) + + new_anns = [] + for ann, new_segment in zip(segment_anns, segments): + fields = {'z_order': ann.z_order, 'label': ann.label, + 'id': ann.id, 'group': ann.group, 'attributes': ann.attributes + } + if ann.type == AnnotationType.polygon: + if fields['group'] is None: + fields['group'] = cls._make_group_id( + segment_anns + new_anns, fields['id']) + for polygon in new_segment: + new_anns.append(Polygon(points=polygon, **fields)) + else: + rle = mask_tools.mask_to_rle(new_segment) + rle = mask_utils.frPyObjects(rle, *rle['size']) + new_anns.append(RleMask(rle=rle, **fields)) + + return new_anns + + @staticmethod + def _make_group_id(anns, ann_id): + if ann_id: + return ann_id + max_gid = max(anns, default=0, key=lambda x: x.group) + return max_gid + 1 + +class MergeInstanceSegments(Transform, CliPlugin): + """ + Replaces instance masks and, optionally, polygons with a single mask. + """ + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('--include-polygons', action='store_true', + help="Include polygons") + return parser + + def __init__(self, extractor, include_polygons=False): + super().__init__(extractor) + + self._include_polygons = include_polygons + + def transform_item(self, item): + annotations = [] + segments = [] + for ann in item.annotations: + if ann.type in {AnnotationType.polygon, AnnotationType.mask}: + segments.append(ann) + else: + annotations.append(ann) + if not segments: + return item + + if not item.has_image: + raise Exception("Image info is required for this transform") + h, w = item.image.shape[:2] + instances = self.find_instances(segments) + segments = [self.merge_segments(i, w, h, self._include_polygons) + for i in instances] + segments = sum(segments, []) + + annotations += segments + return self.wrap_item(item, annotations=annotations) + + @classmethod + def merge_segments(cls, instance, img_width, img_height, + include_polygons=False): + polygons = [a for a in instance if a.type == AnnotationType.polygon] + masks = [a for a in instance if a.type == AnnotationType.mask] + if not polygons and not masks: + return [] + + leader = cls.find_group_leader(polygons + masks) + instance = [] + + # Build the resulting mask + mask = None + + if include_polygons and polygons: + polygons = [p.points for p in polygons] + mask = mask_tools.rles_to_mask(polygons, img_width, img_height) + else: + instance += polygons # keep unused polygons + + if masks: + if mask is not None: + masks += [mask] + mask = cls.merge_masks(masks) + + if mask is None: + return instance + + mask = mask_tools.mask_to_rle(mask) + mask = mask_utils.frPyObjects(mask, *mask['size']) + instance.append( + RleMask(rle=mask, label=leader.label, z_order=leader.z_order, + id=leader.id, attributes=leader.attributes, group=leader.group + ) + ) + return instance + + @staticmethod + def find_group_leader(group): + return max(group, key=lambda x: x.get_area()) + + @staticmethod + def merge_masks(masks): + if not masks: + return None + + def get_mask(m): + if isinstance(m, Mask): + return m.image + else: + return m + + binary_mask = get_mask(masks[0]) + for m in masks[1:]: + binary_mask |= get_mask(m) + + return binary_mask + + @staticmethod + def find_instances(annotations): + segment_anns = (a for a in annotations + if a.type in {AnnotationType.polygon, AnnotationType.mask} + ) + + ann_groups = [] + for g_id, group in groupby(segment_anns, lambda a: a.group): + if g_id is None: + ann_groups.extend(([a] for a in group)) + else: + ann_groups.append(list(group)) + + return ann_groups + +class PolygonsToMasks(Transform, CliPlugin): + def transform_item(self, item): + annotations = [] + for ann in item.annotations: + if ann.type == AnnotationType.polygon: + if not item.has_image: + raise Exception("Image info is required for this transform") + h, w = item.image.shape[:2] + annotations.append(self.convert_polygon(ann, h, w)) + else: + annotations.append(ann) + + return self.wrap_item(item, annotations=annotations) + + @staticmethod + def convert_polygon(polygon, img_h, img_w): + rle = mask_utils.frPyObjects([polygon.points], img_h, img_w)[0] + + return RleMask(rle=rle, label=polygon.label, z_order=polygon.z_order, + id=polygon.id, attributes=polygon.attributes, group=polygon.group) + +class MasksToPolygons(Transform, CliPlugin): + def transform_item(self, item): + annotations = [] + for ann in item.annotations: + if ann.type == AnnotationType.mask: + polygons = self.convert_mask(ann) + if not polygons: + log.debug("[%s]: item %s: " + "Mask conversion to polygons resulted in too " + "small polygons, which were discarded" % \ + (self.NAME, item.id)) + annotations.extend(polygons) + else: + annotations.append(ann) + + return self.wrap_item(item, annotations=annotations) + + @staticmethod + def convert_mask(mask): + polygons = mask_tools.mask_to_polygons(mask.image) + + return [ + Polygon(points=p, label=mask.label, z_order=mask.z_order, + id=mask.id, attributes=mask.attributes, group=mask.group) + for p in polygons + ] + +class Reindex(Transform, CliPlugin): + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-s', '--start', type=int, default=1, + help="Start value for item ids") + return parser + + def __init__(self, extractor, start=1): + super().__init__(extractor) + + self._start = start + + def __iter__(self): + for i, item in enumerate(self._extractor): + yield self.wrap_item(item, id=i + self._start) diff --git a/datumaro/datumaro/components/converters/voc.py b/datumaro/datumaro/plugins/voc_format/converter.py similarity index 78% rename from datumaro/datumaro/components/converters/voc.py rename to datumaro/datumaro/plugins/voc_format/converter.py index 814e44f2de1e..ccc4fe5bc0e5 100644 --- a/datumaro/datumaro/components/converters/voc.py +++ b/datumaro/datumaro/plugins/voc_format/converter.py @@ -11,16 +11,18 @@ import os import os.path as osp +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.converter import Converter from datumaro.components.extractor import (DEFAULT_SUBSET_NAME, AnnotationType, - LabelCategories + LabelCategories, CompiledMask, ) -from datumaro.components.formats.voc import (VocTask, VocPath, +from datumaro.util.image import save_image +from datumaro.util.mask_tools import paint_mask, remap_mask + +from .format import (VocTask, VocPath, VocInstColormap, VocPose, parse_label_map, make_voc_label_map, make_voc_categories, write_label_map ) -from datumaro.util.image import save_image -from datumaro.util.mask_tools import paint_mask, remap_mask def _convert_attr(name, attributes, type_conv, default=None, warn=True): @@ -131,13 +133,15 @@ def save_subsets(self): segm_list = OrderedDict() for item in subset: - item_id = str(item.id) + log.debug("Converting item '%s'", item.id) + if self._save_images: - data = item.image - if data is not None: + if item.has_image: save_image(osp.join(self._images_dir, - str(item_id) + VocPath.IMAGE_EXT), - data) + item.id + VocPath.IMAGE_EXT), + item.image) + else: + log.debug("Item '%s' has no image" % item.id) labels = [] bboxes = [] @@ -152,13 +156,13 @@ def save_subsets(self): if len(bboxes) != 0: root_elem = ET.Element('annotation') - if '_' in item_id: - folder = item_id[ : item_id.find('_')] + if '_' in item.id: + folder = item.id[ : item.id.find('_')] else: folder = '' ET.SubElement(root_elem, 'folder').text = folder ET.SubElement(root_elem, 'filename').text = \ - item_id + VocPath.IMAGE_EXT + item.id + VocPath.IMAGE_EXT source_elem = ET.SubElement(root_elem, 'source') ET.SubElement(source_elem, 'database').text = 'Unknown' @@ -198,17 +202,25 @@ def save_subsets(self): obj_label = self.get_label(obj.label) ET.SubElement(obj_elem, 'name').text = obj_label - pose = _convert_attr('pose', attr, lambda v: VocPose[v], - VocPose.Unspecified) - ET.SubElement(obj_elem, 'pose').text = pose.name + if 'pose' in attr: + pose = _convert_attr('pose', attr, + lambda v: VocPose[v], VocPose.Unspecified) + ET.SubElement(obj_elem, 'pose').text = pose.name - truncated = _convert_attr('truncated', attr, int, 0) - ET.SubElement(obj_elem, 'truncated').text = \ - '%d' % truncated + if 'truncated' in attr: + truncated = _convert_attr('truncated', attr, int, 0) + ET.SubElement(obj_elem, 'truncated').text = \ + '%d' % truncated - difficult = _convert_attr('difficult', attr, int, 0) - ET.SubElement(obj_elem, 'difficult').text = \ - '%d' % difficult + if 'difficult' in attr: + difficult = _convert_attr('difficult', attr, int, 0) + ET.SubElement(obj_elem, 'difficult').text = \ + '%d' % difficult + + if 'occluded' in attr: + occluded = _convert_attr('occluded', attr, int, 0) + ET.SubElement(obj_elem, 'occluded').text = \ + '%d' % occluded bbox = obj.get_bbox() if bbox is not None: @@ -226,12 +238,14 @@ def save_subsets(self): label_actions = self._get_actions(obj_label) actions_elem = ET.Element('actions') for action in label_actions: - presented = _convert_attr(action, attr, - lambda v: int(v == True), 0) - ET.SubElement(actions_elem, action).text = \ - '%d' % presented - - objects_with_actions[new_obj_id][action] = presented + present = 0 + if action in attr: + present = _convert_attr(action, attr, + lambda v: int(v == True), 0) + ET.SubElement(actions_elem, action).text = \ + '%d' % present + + objects_with_actions[new_obj_id][action] = present if len(actions_elem) != 0: obj_elem.append(actions_elem) @@ -239,41 +253,44 @@ def save_subsets(self): VocTask.detection, VocTask.person_layout, VocTask.action_classification]): - with open(osp.join(self._ann_dir, item_id + '.xml'), 'w') as f: + with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f: f.write(ET.tostring(root_elem, encoding='unicode', pretty_print=True)) - clsdet_list[item_id] = True - layout_list[item_id] = objects_with_parts - action_list[item_id] = objects_with_actions + clsdet_list[item.id] = True + layout_list[item.id] = objects_with_parts + action_list[item.id] = objects_with_actions - for label_obj in labels: - label = self.get_label(label_obj.label) + for label_ann in labels: + label = self.get_label(label_ann.label) if not self._is_label(label): continue - class_list = class_lists.get(item_id, set()) - class_list.add(label_obj.label) - class_lists[item_id] = class_list + class_list = class_lists.get(item.id, set()) + class_list.add(label_ann.label) + class_lists[item.id] = class_list - clsdet_list[item_id] = True + clsdet_list[item.id] = True - for mask_obj in masks: - if mask_obj.attributes.get('class') == True: - self.save_segm(osp.join(self._segm_dir, - item_id + VocPath.SEGM_EXT), - mask_obj) - if mask_obj.attributes.get('instances') == True: - self.save_segm(osp.join(self._inst_dir, - item_id + VocPath.SEGM_EXT), - mask_obj, VocInstColormap) + if masks: + compiled_mask = CompiledMask.from_instance_masks(masks, + instance_labels=[self._label_id_mapping(m.label) + for m in masks]) - segm_list[item_id] = True + self.save_segm( + osp.join(self._segm_dir, item.id + VocPath.SEGM_EXT), + compiled_mask.class_mask) + self.save_segm( + osp.join(self._inst_dir, item.id + VocPath.SEGM_EXT), + compiled_mask.instance_mask, + colormap=VocInstColormap) + + segm_list[item.id] = True if len(item.annotations) == 0: - clsdet_list[item_id] = None - layout_list[item_id] = None - action_list[item_id] = None - segm_list[item_id] = None + clsdet_list[item.id] = None + layout_list[item.id] = None + action_list[item.id] = None + segm_list[item.id] = None if set(self._tasks) & set([None, VocTask.classification, @@ -361,14 +378,12 @@ def save_layout_lists(self, subset_name, layout_list): else: f.write('%s\n' % (item)) - def save_segm(self, path, annotation, colormap=None): - data = annotation.image + def save_segm(self, path, mask, colormap=None): if self._apply_colormap: if colormap is None: colormap = self._categories[AnnotationType.mask].colormap - data = self._remap_mask(data) - data = paint_mask(data, colormap) - save_image(path, data) + mask = paint_mask(mask, colormap) + save_image(path, mask) def save_label_map(self): path = osp.join(self._save_dir, VocPath.LABELMAP_FILE) @@ -468,30 +483,25 @@ def _make_label_id_map(self): if void_labels: log.warning("The following labels are remapped to background: %s" % ', '.join(void_labels)) + log.debug("Saving segmentations with the following label mapping: \n%s" % + '\n'.join(["#%s '%s' -> #%s '%s'" % + ( + src_id, src_label, id_mapping[src_id], + self._categories[AnnotationType.label] \ + .items[id_mapping[src_id]].name + ) + for src_id, src_label in source_labels.items() + ]) + ) def map_id(src_id): - return id_mapping[src_id] + return id_mapping.get(src_id, 0) return map_id def _remap_mask(self, mask): return remap_mask(mask, self._label_id_mapping) -class VocConverter(Converter): - def __init__(self, - tasks=None, save_images=False, apply_colormap=False, label_map=None, - cmdline_args=None): - super().__init__() - - self._options = { - 'tasks': tasks, - 'save_images': save_images, - 'apply_colormap': apply_colormap, - 'label_map': label_map, - } - - if cmdline_args is not None: - self._options.update(self._parse_cmdline(cmdline_args)) - +class VocConverter(Converter, CliPlugin): @staticmethod def _split_tasks_string(s): return [VocTask[i.strip()] for i in s.split(',')] @@ -503,10 +513,8 @@ def _get_labelmap(s): return LabelmapType[s].name @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='voc') + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--save-images', action='store_true', help="Save images (default: %(default)s)") @@ -523,21 +531,37 @@ def build_cmdline_parser(cls, parser=None): return parser + def __init__(self, tasks=None, save_images=False, + apply_colormap=False, label_map=None): + super().__init__() + + self._options = { + 'tasks': tasks, + 'save_images': save_images, + 'apply_colormap': apply_colormap, + 'label_map': label_map, + } + def __call__(self, extractor, save_dir): converter = _Converter(extractor, save_dir, **self._options) converter.convert() -def VocClassificationConverter(**kwargs): - return VocConverter(VocTask.classification, **kwargs) +class VocClassificationConverter(VocConverter): + def __init__(self, **kwargs): + super().__init__(VocTask.classification, **kwargs) -def VocDetectionConverter(**kwargs): - return VocConverter(VocTask.detection, **kwargs) +class VocDetectionConverter(VocConverter): + def __init__(self, **kwargs): + super().__init__(VocTask.detection, **kwargs) -def VocLayoutConverter(**kwargs): - return VocConverter(VocTask.person_layout, **kwargs) +class VocLayoutConverter(VocConverter): + def __init__(self, **kwargs): + super().__init__(VocTask.person_layout, **kwargs) -def VocActionConverter(**kwargs): - return VocConverter(VocTask.action_classification, **kwargs) +class VocActionConverter(VocConverter): + def __init__(self, **kwargs): + super().__init__(VocTask.action_classification, **kwargs) -def VocSegmentationConverter(**kwargs): - return VocConverter(VocTask.segmentation, **kwargs) +class VocSegmentationConverter(VocConverter): + def __init__(self, **kwargs): + super().__init__(VocTask.segmentation, **kwargs) diff --git a/datumaro/datumaro/components/extractors/voc.py b/datumaro/datumaro/plugins/voc_format/extractor.py similarity index 86% rename from datumaro/datumaro/components/extractors/voc.py rename to datumaro/datumaro/plugins/voc_format/extractor.py index 086649f583f5..cb8b5e4eb4f8 100644 --- a/datumaro/datumaro/components/extractors/voc.py +++ b/datumaro/datumaro/plugins/voc_format/extractor.py @@ -8,20 +8,22 @@ import os.path as osp from xml.etree import ElementTree as ET -from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, LabelObject, MaskObject, BboxObject, -) -from datumaro.components.formats.voc import ( - VocTask, VocPath, VocInstColormap, parse_label_map, make_voc_categories +from datumaro.components.extractor import (SourceExtractor, Extractor, + DEFAULT_SUBSET_NAME, DatasetItem, + AnnotationType, Label, Mask, Bbox, CompiledMask ) from datumaro.util import dir_items from datumaro.util.image import lazy_image from datumaro.util.mask_tools import lazy_mask, invert_colormap +from .format import ( + VocTask, VocPath, VocInstColormap, parse_label_map, make_voc_categories +) + _inverse_inst_colormap = invert_colormap(VocInstColormap) -class VocExtractor(Extractor): +class VocExtractor(SourceExtractor): class Subset(Extractor): def __init__(self, name, parent): super().__init__() @@ -45,20 +47,25 @@ def _load_subsets(self, subsets_dir): subsets = {} for subset_name in subset_names: + subset_file_name = subset_name + if subset_name == DEFAULT_SUBSET_NAME: + subset_name = None subset = __class__.Subset(subset_name, self) - with open(osp.join(subsets_dir, subset_name + '.txt'), 'r') as f: + with open(osp.join(subsets_dir, subset_file_name + '.txt'), 'r') as f: subset.items = [line.split()[0] for line in f] subsets[subset_name] = subset return subsets def _load_cls_annotations(self, subsets_dir, subset_names): + subset_file_names = [n if n else DEFAULT_SUBSET_NAME + for n in subset_names] dir_files = dir_items(subsets_dir, '.txt', truncate_ext=True) label_annotations = defaultdict(list) label_anno_files = [s for s in dir_files \ - if '_' in s and s[s.rfind('_') + 1:] in subset_names] + if '_' in s and s[s.rfind('_') + 1:] in subset_file_names] for ann_filename in label_anno_files: with open(osp.join(subsets_dir, ann_filename + '.txt'), 'r') as f: label = ann_filename[:ann_filename.rfind('_')] @@ -139,43 +146,77 @@ def _get_label_id(self, label): assert label_id is not None return label_id - def _get_annotations(self, item): + @staticmethod + def _lazy_extract_mask(mask, c): + return lambda: mask == c + + def _get_annotations(self, item_id): item_annotations = [] if self._task is VocTask.segmentation: + class_mask = None segm_path = osp.join(self._path, VocPath.SEGMENTATION_DIR, - item + VocPath.SEGM_EXT) + item_id + VocPath.SEGM_EXT) if osp.isfile(segm_path): inverse_cls_colormap = \ self._categories[AnnotationType.mask].inverse_colormap - item_annotations.append(MaskObject( - image=lazy_mask(segm_path, inverse_cls_colormap), - attributes={ 'class': True } - )) + class_mask = lazy_mask(segm_path, inverse_cls_colormap) + instances_mask = None inst_path = osp.join(self._path, VocPath.INSTANCES_DIR, - item + VocPath.SEGM_EXT) + item_id + VocPath.SEGM_EXT) if osp.isfile(inst_path): - item_annotations.append(MaskObject( - image=lazy_mask(inst_path, _inverse_inst_colormap), - attributes={ 'instances': True } - )) + instances_mask = lazy_mask(inst_path, _inverse_inst_colormap) + + if instances_mask is not None: + compiled_mask = CompiledMask(class_mask, instances_mask) + + if class_mask is not None: + label_cat = self._categories[AnnotationType.label] + instance_labels = compiled_mask.get_instance_labels( + class_count=len(label_cat.items)) + else: + instance_labels = {i: None + for i in range(compiled_mask.instance_count)} + + for instance_id, label_id in instance_labels.items(): + image = compiled_mask.lazy_extract(instance_id) + + attributes = dict() + if label_id is not None: + actions = {a: False + for a in label_cat.items[label_id].attributes + } + attributes.update(actions) + + item_annotations.append(Mask( + image=image, label=label_id, + attributes=attributes, group=instance_id + )) + elif class_mask is not None: + log.warn("item '%s': has only class segmentation, " + "instance masks will not be available" % item_id) + classes = class_mask.image.unique() + for label_id in classes: + image = self._lazy_extract_mask(class_mask, label_id) + item_annotations.append(Mask(image=image, label=label_id)) cls_annotations = self._annotations.get(VocTask.classification) if cls_annotations is not None and \ self._task is VocTask.classification: - item_labels = cls_annotations.get(item) + item_labels = cls_annotations.get(item_id) if item_labels is not None: for label_id in item_labels: - item_annotations.append(LabelObject(label_id)) + item_annotations.append(Label(label_id)) det_annotations = self._annotations.get(VocTask.detection) if det_annotations is not None: - det_annotations = det_annotations.get(item) + det_annotations = det_annotations.get(item_id) if det_annotations is not None: root_elem = ET.fromstring(det_annotations) for obj_id, object_elem in enumerate(root_elem.findall('object')): + obj_id += 1 attributes = {} group = None @@ -225,24 +266,22 @@ def _get_annotations(self, item): for part_elem in object_elem.findall('part'): part = part_elem.find('name').text part_label_id = self._get_label_id(part) - bbox = self._parse_bbox(part_elem) + part_bbox = self._parse_bbox(part_elem) group = obj_id if self._task is not VocTask.person_layout: break - if bbox is None: + if part_bbox is None: continue - item_annotations.append(BboxObject( - *bbox, label=part_label_id, - group=obj_id)) + item_annotations.append(Bbox(*part_bbox, label=part_label_id, + group=group)) - if self._task is VocTask.person_layout and group is None: + if self._task is VocTask.person_layout and not group: continue if self._task is VocTask.action_classification and not actions: continue - item_annotations.append(BboxObject( - *obj_bbox, label=obj_label_id, + item_annotations.append(Bbox(*obj_bbox, label=obj_label_id, attributes=attributes, id=obj_id, group=group)) return item_annotations @@ -482,7 +521,7 @@ def _get_annotations(self, item, subset_name): if cls_ann is not None: for desc in cls_ann: label_id, conf = desc - annotations.append(LabelObject( + annotations.append(Label( int(label_id), attributes={ 'score': float(conf) } )) @@ -508,7 +547,7 @@ def _get_annotations(self, item, subset_name): if det_ann is not None: for desc in det_ann: label_id, conf, left, top, right, bottom = desc - annotations.append(BboxObject( + annotations.append(Bbox( x=float(left), y=float(top), w=float(right) - float(left), h=float(bottom) - float(top), label=int(label_id), @@ -560,7 +599,7 @@ def _get_annotations(self, item, subset_name): if cls_image_path and osp.isfile(cls_image_path): inverse_cls_colormap = \ self._categories[AnnotationType.mask].inverse_colormap - annotations.append(MaskObject( + annotations.append(Mask( image=lazy_mask(cls_image_path, inverse_cls_colormap), attributes={ 'class': True } )) @@ -568,7 +607,7 @@ def _get_annotations(self, item, subset_name): inst_ann = self._annotations[subset_name] inst_image_path = inst_ann.get(item) if inst_image_path and osp.isfile(inst_image_path): - annotations.append(MaskObject( + annotations.append(Mask( image=lazy_mask(inst_image_path, _inverse_inst_colormap), attributes={ 'instances': True } )) @@ -641,7 +680,7 @@ def _get_annotations(self, item, subset_name): for part in parts: label_id, bbox = part - annotations.append(BboxObject( + annotations.append(Bbox( *bbox, label=label_id, attributes=attributes)) @@ -672,7 +711,7 @@ def _get_annotations(self, item, subset_name): if action_ann is not None: for desc in action_ann: action_id, obj_id, conf = desc - annotations.append(LabelObject( + annotations.append(Label( action_id, attributes={ 'score': conf, diff --git a/datumaro/datumaro/components/formats/voc.py b/datumaro/datumaro/plugins/voc_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/voc.py rename to datumaro/datumaro/plugins/voc_format/format.py diff --git a/datumaro/datumaro/components/importers/voc.py b/datumaro/datumaro/plugins/voc_format/importer.py similarity index 88% rename from datumaro/datumaro/components/importers/voc.py rename to datumaro/datumaro/plugins/voc_format/importer.py index bc0409df805f..f3d7c5ef3b97 100644 --- a/datumaro/datumaro/components/importers/voc.py +++ b/datumaro/datumaro/plugins/voc_format/importer.py @@ -6,15 +6,17 @@ import os import os.path as osp -from datumaro.components.formats.voc import VocTask, VocPath +from datumaro.components.extractor import Importer from datumaro.util import find +from .format import VocTask, VocPath -class VocImporter: + +class VocImporter(Importer): _TASKS = [ - (VocTask.classification, 'voc_cls', 'Main'), - (VocTask.detection, 'voc_det', 'Main'), - (VocTask.segmentation, 'voc_segm', 'Segmentation'), + (VocTask.classification, 'voc_classification', 'Main'), + (VocTask.detection, 'voc_detection', 'Main'), + (VocTask.segmentation, 'voc_segmentation', 'Segmentation'), (VocTask.person_layout, 'voc_layout', 'Layout'), (VocTask.action_classification, 'voc_action', 'Action'), ] diff --git a/datumaro/datumaro/components/converters/yolo.py b/datumaro/datumaro/plugins/yolo_format/converter.py similarity index 84% rename from datumaro/datumaro/components/converters/yolo.py rename to datumaro/datumaro/plugins/yolo_format/converter.py index a25c7b04d9ac..d50c2a0aab21 100644 --- a/datumaro/datumaro/components/converters/yolo.py +++ b/datumaro/datumaro/plugins/yolo_format/converter.py @@ -10,9 +10,11 @@ from datumaro.components.converter import Converter from datumaro.components.extractor import AnnotationType -from datumaro.components.formats.yolo import YoloPath +from datumaro.components.cli_plugin import CliPlugin from datumaro.util.image import save_image +from .format import YoloPath + def _make_yolo_bbox(img_size, box): # https://github.com/pjreddie/darknet/blob/master/scripts/voc_label.py @@ -24,30 +26,20 @@ def _make_yolo_bbox(img_size, box): h = (box[3] - box[1]) / img_size[1] return x, y, w, h -class YoloConverter(Converter): +class YoloConverter(Converter, CliPlugin): # https://github.com/AlexeyAB/darknet#how-to-train-to-detect-your-custom-objects - def __init__(self, save_images=False, cmdline_args=None): - super().__init__() - self._save_images = save_images - - if cmdline_args is not None: - options = self._parse_cmdline(cmdline_args) - for k, v in options.items(): - if hasattr(self, '_' + str(k)): - setattr(self, '_' + str(k), v) - @classmethod - def build_cmdline_parser(cls, parser=None): - import argparse - if not parser: - parser = argparse.ArgumentParser(prog='yolo') - + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--save-images', action='store_true', help="Save images (default: %(default)s)") - return parser + def __init__(self, save_images=False): + super().__init__() + self._save_images = save_images + def __call__(self, extractor, save_dir): os.makedirs(save_dir, exist_ok=True) @@ -88,9 +80,10 @@ def __call__(self, extractor, save_dir): osp.basename(subset_dir), image_name) if self._save_images: - image_path = osp.join(subset_dir, image_name) - if not osp.exists(image_path): - save_image(image_path, item.image) + if item.has_image: + save_image(osp.join(subset_dir, image_name), item.image) + else: + log.debug("Item '%s' has no images" % item.id) height, width = item.image.shape[:2] diff --git a/datumaro/datumaro/components/extractors/yolo.py b/datumaro/datumaro/plugins/yolo_format/extractor.py similarity index 91% rename from datumaro/datumaro/components/extractors/yolo.py rename to datumaro/datumaro/plugins/yolo_format/extractor.py index 81ead7b2baed..cb52327717c1 100644 --- a/datumaro/datumaro/components/extractors/yolo.py +++ b/datumaro/datumaro/plugins/yolo_format/extractor.py @@ -7,14 +7,15 @@ import os.path as osp import re -from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, BboxObject, LabelCategories +from datumaro.components.extractor import (SourceExtractor, Extractor, + DatasetItem, AnnotationType, Bbox, LabelCategories ) -from datumaro.components.formats.yolo import YoloPath from datumaro.util.image import lazy_image +from .format import YoloPath -class YoloExtractor(Extractor): + +class YoloExtractor(SourceExtractor): class Subset(Extractor): def __init__(self, name, parent): super().__init__() @@ -124,9 +125,9 @@ def _parse_annotations(anno_path, image_width, image_height): h = float(h) x = float(xc) - w * 0.5 y = float(yc) - h * 0.5 - annotations.append(BboxObject( - x * image_width, y * image_height, - w * image_width, h * image_height, + annotations.append(Bbox( + round(x * image_width, 1), round(y * image_height, 1), + round(w * image_width, 1), round(h * image_height, 1), label=label_id )) return annotations @@ -137,7 +138,7 @@ def _load_categories(names_path): with open(names_path, 'r') as f: for label in f: - label_categories.add(label) + label_categories.add(label.strip()) return label_categories diff --git a/datumaro/datumaro/components/formats/yolo.py b/datumaro/datumaro/plugins/yolo_format/format.py similarity index 100% rename from datumaro/datumaro/components/formats/yolo.py rename to datumaro/datumaro/plugins/yolo_format/format.py diff --git a/datumaro/datumaro/components/importers/yolo.py b/datumaro/datumaro/plugins/yolo_format/importer.py similarity index 91% rename from datumaro/datumaro/components/importers/yolo.py rename to datumaro/datumaro/plugins/yolo_format/importer.py index df8f739626a1..26532d091286 100644 --- a/datumaro/datumaro/components/importers/yolo.py +++ b/datumaro/datumaro/plugins/yolo_format/importer.py @@ -7,8 +7,10 @@ import logging as log import os.path as osp +from datumaro.components.extractor import Importer -class YoloImporter: + +class YoloImporter(Importer): def __call__(self, path, **extra_params): from datumaro.components.project import Project # cyclic import project = Project() diff --git a/datumaro/datumaro/util/image.py b/datumaro/datumaro/util/image.py index 46ef267c3d2a..395c6f5080d5 100644 --- a/datumaro/datumaro/util/image.py +++ b/datumaro/datumaro/util/image.py @@ -136,7 +136,7 @@ def __init__(self, path, loader=load_image, cache=None): def __call__(self): image = None - image_id = id(self) # path is not necessary hashable or a file path + image_id = hash(self) # path is not necessary hashable or a file path cache = self._get_cache() if cache is not None: @@ -155,3 +155,6 @@ def _get_cache(self): elif cache == False: return None return cache + + def __hash__(self): + return hash((id(self), self.path, self.loader)) \ No newline at end of file diff --git a/datumaro/datumaro/util/mask_tools.py b/datumaro/datumaro/util/mask_tools.py index ca8de5ac5bd3..847392406399 100644 --- a/datumaro/datumaro/util/mask_tools.py +++ b/datumaro/datumaro/util/mask_tools.py @@ -93,9 +93,16 @@ def remap_mask(mask, map_fn): return np.array([map_fn(c) for c in range(256)], dtype=np.uint8)[mask] +def make_index_mask(binary_mask, index): + return np.choose(binary_mask, np.array([0, index], dtype=np.uint8)) + +def make_binary_mask(mask): + return np.nonzero(mask) + def load_mask(path, inverse_colormap=None): mask = load_image(path) + mask = mask.astype(np.uint8) if inverse_colormap is not None: if len(mask.shape) == 3 and mask.shape[2] != 1: mask = unpaint_mask(mask, inverse_colormap) @@ -250,3 +257,23 @@ def rles_to_mask(rles, width, height): rles = mask_utils.merge(rles) mask = mask_utils.decode(rles) return mask + +def find_mask_bbox(mask): + cols = np.any(mask, axis=0) + rows = np.any(mask, axis=1) + x0, x1 = np.where(cols)[0][[0, -1]] + y0, y1 = np.where(rows)[0][[0, -1]] + return [x0, y0, x1 - x0, y1 - y0] + +def merge_masks(masks): + """ + Merges masks into one, mask order is resposible for z order. + """ + if not masks: + return None + + merged_mask = masks[0] + for m in masks[1:]: + merged_mask = np.where(m != 0, m, merged_mask) + + return merged_mask \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index e855fad077e7..1d6395feef94 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -9,6 +9,9 @@ import shutil import tempfile +from datumaro.components.extractor import AnnotationType +from datumaro.util import find + def current_function_name(depth=1): return inspect.getouterframes(inspect.currentframe())[depth].function @@ -20,7 +23,7 @@ def __init__(self, path, is_dir=False, ignore_errors=False): self.ignore_errors = ignore_errors def __enter__(self): - return self + return self.path # pylint: disable=redefined-builtin def __exit__(self, type=None, value=None, traceback=None): @@ -51,4 +54,46 @@ def item_to_str(item): 'ann[%s]: %s' % (i, ann_to_str(a)) for i, a in enumerate(item.annotations) ] - ) \ No newline at end of file + ) + +def compare_categories(test, expected, actual): + test.assertEqual( + sorted(expected, key=lambda t: t.value), + sorted(actual, key=lambda t: t.value) + ) + + if AnnotationType.label in expected: + test.assertEqual( + expected[AnnotationType.label].items, + actual[AnnotationType.label].items, + ) + if AnnotationType.mask in expected: + test.assertEqual( + expected[AnnotationType.mask].colormap, + actual[AnnotationType.mask].colormap, + ) + if AnnotationType.points in expected: + test.assertEqual( + expected[AnnotationType.points].items, + actual[AnnotationType.points].items, + ) + +def compare_datasets(test, expected, actual): + compare_categories(test, expected.categories(), actual.categories()) + + test.assertEqual(sorted(expected.subsets()), sorted(actual.subsets())) + test.assertEqual(len(expected), len(actual)) + for item_a in expected: + item_b = find(actual, lambda x: x.id == item_a.id) + test.assertFalse(item_b is None, item_a.id) + test.assertEqual(len(item_a.annotations), len(item_b.annotations)) + for ann_a in item_a.annotations: + # We might find few corresponding items, so check them all + ann_b_matches = [x for x in item_b.annotations + if x.id == ann_a.id and \ + x.type == ann_a.type and x.group == ann_a.group] + test.assertFalse(len(ann_b_matches) == 0, 'ann id: %s' % ann_a.id) + + ann_b = find(ann_b_matches, lambda x: x == ann_a) + test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a)) + item_b.annotations.remove(ann_b) # avoid repeats \ No newline at end of file diff --git a/datumaro/docs/user_manual.md b/datumaro/docs/user_manual.md index 0b61c11ac54f..65284fa1166c 100644 --- a/datumaro/docs/user_manual.md +++ b/datumaro/docs/user_manual.md @@ -470,7 +470,7 @@ def process_outputs(inputs, outputs): y = max(int(det[4] * input_height), 0) w = min(int(det[5] * input_width - x), input_width) h = min(int(det[6] * input_height - y), input_height) - image_results.append(BboxObject(x, y, w, h, + image_results.append(Bbox(x, y, w, h, label=label, attributes={'score': conf} )) results.append(image_results[:max_det]) diff --git a/datumaro/tests/test_RISE.py b/datumaro/tests/test_RISE.py index a7560f1ff0e3..7dab28dd0c37 100644 --- a/datumaro/tests/test_RISE.py +++ b/datumaro/tests/test_RISE.py @@ -4,7 +4,7 @@ from unittest import TestCase -from datumaro.components.extractor import LabelObject, BboxObject +from datumaro.components.extractor import Label, Bbox from datumaro.components.launcher import Launcher from datumaro.components.algorithms.rise import RISE @@ -32,7 +32,7 @@ def _process(self, image): other_conf = (1.0 - cls_conf) / (self.class_count - 1) return [ - LabelObject(i, attributes={ + Label(i, attributes={ 'score': cls_conf if cls == i else other_conf }) \ for i in range(self.class_count) ] @@ -94,7 +94,7 @@ def _process(self, image): if roi.threshold < roi_sum / roi_base_sum: cls = roi.label detections.append( - BboxObject(roi.x, roi.y, roi.w, roi.h, + Bbox(roi.x, roi.y, roi.w, roi.h, label=cls, attributes={'score': cls_conf}) ) @@ -108,7 +108,7 @@ def _process(self, image): box = [roi.x, roi.y, roi.w, roi.h] offset = (np.random.rand(4) - 0.5) * self.pixel_jitter detections.append( - BboxObject(*(box + offset), + Bbox(*(box + offset), label=cls, attributes={'score': cls_conf}) ) @@ -189,7 +189,7 @@ def DISABLED_test_roi_nms(): detections = [] for i, roi in enumerate(rois): detections.append( - BboxObject(roi.x, roi.y, roi.w, roi.h, + Bbox(roi.x, roi.y, roi.w, roi.h, label=roi.label, attributes={'score': roi.conf}) ) @@ -199,7 +199,7 @@ def DISABLED_test_roi_nms(): box = [roi.x, roi.y, roi.w, roi.h] offset = (np.random.rand(4) - 0.5) * pixel_jitter detections.append( - BboxObject(*(box + offset), + Bbox(*(box + offset), label=cls, attributes={'score': cls_conf}) ) diff --git a/datumaro/tests/test_coco_format.py b/datumaro/tests/test_coco_format.py index e32303b6d66b..9dd64f878fdd 100644 --- a/datumaro/tests/test_coco_format.py +++ b/datumaro/tests/test_coco_format.py @@ -2,17 +2,15 @@ import numpy as np import os import os.path as osp -from PIL import Image from unittest import TestCase from datumaro.components.project import Project from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, LabelObject, MaskObject, PointsObject, PolygonObject, - BboxObject, CaptionObject, + AnnotationType, Label, Mask, Points, Polygon, Bbox, Caption, LabelCategories, PointsCategories ) -from datumaro.components.converters.coco import ( +from datumaro.plugins.coco_format.converter import ( CocoConverter, CocoImageInfoConverter, CocoCaptionsConverter, @@ -20,8 +18,10 @@ CocoPersonKeypointsConverter, CocoLabelsConverter, ) +from datumaro.plugins.coco_format.importer import CocoImporter +from datumaro.util.image import save_image from datumaro.util import find -from datumaro.util.test_utils import TestDir +from datumaro.util.test_utils import TestDir, compare_datasets class CocoImporterTest(TestCase): @@ -59,8 +59,8 @@ def generate_annotation(): }) annotation['images'].append({ "id": 1, - "width": 10, - "height": 5, + "width": 5, + "height": 10, "file_name": '000000000001.jpg', "license": 0, "flickr_url": '', @@ -101,8 +101,7 @@ def COCO_dataset_generate(self, path): os.makedirs(ann_dir) image = np.ones((10, 5, 3), dtype=np.uint8) - image = Image.fromarray(image).convert('RGB') - image.save(osp.join(img_dir, '000000000001.jpg')) + save_image(osp.join(img_dir, '000000000001.jpg'), image) annotation = self.generate_annotation() @@ -110,9 +109,9 @@ def COCO_dataset_generate(self, path): json.dump(annotation, outfile) def test_can_import(self): - with TestDir() as temp_dir: - self.COCO_dataset_generate(temp_dir.path) - project = Project.import_from(temp_dir.path, 'coco') + with TestDir() as test_dir: + self.COCO_dataset_generate(test_dir) + project = Project.import_from(test_dir, 'coco') dataset = project.make_dataset() self.assertListEqual(['val'], sorted(dataset.subsets())) @@ -121,7 +120,7 @@ def test_can_import(self): item = next(iter(dataset)) self.assertTrue(item.has_image) self.assertEqual(np.sum(item.image), np.prod(item.image.shape)) - self.assertEqual(4, len(item.annotations)) + self.assertEqual(2, len(item.annotations)) ann_1 = find(item.annotations, lambda x: x.id == 1) ann_1_poly = find(item.annotations, lambda x: \ @@ -137,38 +136,17 @@ def test_can_import(self): class CocoConverterTest(TestCase): def _test_save_and_load(self, source_dataset, converter, test_dir, - importer_params=None, target_dataset=None): - converter(source_dataset, test_dir.path) - - if not importer_params: - importer_params = {} - project = Project.import_from(test_dir.path, 'coco', - **importer_params) - parsed_dataset = project.make_dataset() - - if target_dataset is not None: - source_dataset = target_dataset - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - - self.assertEqual(len(source_dataset), len(parsed_dataset)) - - for item_a in source_dataset: - item_b = find(parsed_dataset, lambda x: x.id == item_a.id) - self.assertFalse(item_b is None) - self.assertEqual(len(item_a.annotations), len(item_b.annotations)) - for ann_a in item_a.annotations: - # We might find few corresponding items, so check them all - ann_b_matches = [x for x in item_b.annotations - if x.id == ann_a.id and \ - x.type == ann_a.type and x.group == ann_a.group] - self.assertFalse(len(ann_b_matches) == 0, 'aid: %s' % ann_a.id) - - ann_b = find(ann_b_matches, lambda x: x == ann_a) - self.assertEqual(ann_a, ann_b, 'aid: %s' % ann_a.id) - item_b.annotations.remove(ann_b) # avoid repeats + target_dataset=None, importer_args=None): + converter(source_dataset, test_dir) + + if importer_args is None: + importer_args = {} + parsed_dataset = CocoImporter()(test_dir, **importer_args).make_dataset() + + if target_dataset is None: + target_dataset = source_dataset + + compare_datasets(self, expected=target_dataset, actual=parsed_dataset) def test_can_save_and_load_captions(self): class TestExtractor(Extractor): @@ -176,17 +154,17 @@ def __iter__(self): return iter([ DatasetItem(id=1, subset='train', annotations=[ - CaptionObject('hello', id=1, group=1), - CaptionObject('world', id=2, group=2), + Caption('hello', id=1, group=1), + Caption('world', id=2, group=2), ]), DatasetItem(id=2, subset='train', annotations=[ - CaptionObject('test', id=3, group=3), + Caption('test', id=3, group=3), ]), DatasetItem(id=3, subset='val', annotations=[ - CaptionObject('word', id=1, group=1), + Caption('word', id=1, group=1), ] ), ]) @@ -207,39 +185,76 @@ def __iter__(self): DatasetItem(id=1, subset='train', image=np.ones((4, 4, 3)), annotations=[ # Bbox + single polygon - BboxObject(0, 1, 2, 2, + Bbox(0, 1, 2, 2, label=2, group=1, id=1, attributes={ 'is_crowd': False }), - PolygonObject([0, 1, 2, 1, 2, 3, 0, 3], + Polygon([0, 1, 2, 1, 2, 3, 0, 3], attributes={ 'is_crowd': False }, label=2, group=1, id=1), ]), DatasetItem(id=2, subset='train', image=np.ones((4, 4, 3)), annotations=[ # Mask + bbox - MaskObject(np.array([ + Mask(np.array([ [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 1, 1], [0, 0, 0, 0]], - ), + ), attributes={ 'is_crowd': True }, label=4, group=3, id=3), - BboxObject(1, 0, 2, 2, label=4, group=3, id=3, + Bbox(1, 0, 2, 2, label=4, group=3, id=3, attributes={ 'is_crowd': True }), ]), DatasetItem(id=3, subset='val', image=np.ones((4, 4, 3)), annotations=[ # Bbox + mask - BboxObject(0, 1, 2, 2, label=4, group=3, id=3, + Bbox(0, 1, 2, 2, label=4, group=3, id=3, attributes={ 'is_crowd': True }), - MaskObject(np.array([ + Mask(np.array([ + [0, 0, 0, 0], + [1, 1, 1, 0], + [1, 1, 0, 0], + [0, 0, 0, 0]], + ), + attributes={ 'is_crowd': True }, + label=4, group=3, id=3), + ]), + ]) + + def categories(self): + return categories + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='train', image=np.ones((4, 4, 3)), + annotations=[ + Polygon([0, 1, 2, 1, 2, 3, 0, 3], + attributes={ 'is_crowd': False }, + label=2, group=1, id=1), + ]), + DatasetItem(id=2, subset='train', image=np.ones((4, 4, 3)), + annotations=[ + Mask(np.array([ + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 1, 1], + [0, 0, 0, 0]], + ), + attributes={ 'is_crowd': True }, + label=4, group=3, id=3), + ]), + + DatasetItem(id=3, subset='val', image=np.ones((4, 4, 3)), + annotations=[ + Mask(np.array([ [0, 0, 0, 0], [1, 1, 1, 0], [1, 1, 0, 0], [0, 0, 0, 0]], - ), + ), attributes={ 'is_crowd': True }, label=4, group=3, id=3), ]), @@ -250,7 +265,8 @@ def categories(self): with TestDir() as test_dir: self._test_save_and_load(TestExtractor(), - CocoInstancesConverter(), test_dir) + CocoInstancesConverter(), test_dir, + target_dataset=DstExtractor()) def test_can_merge_polygons_on_loading(self): label_categories = LabelCategories() @@ -263,9 +279,9 @@ def __iter__(self): return iter([ DatasetItem(id=1, image=np.zeros((6, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=3, id=4, group=4), - PolygonObject([5, 0, 9, 0, 5, 5], + Polygon([5, 0, 9, 0, 5, 5], label=3, id=4, group=4), ] ), @@ -278,16 +294,7 @@ class TargetExtractor(TestExtractor): def __iter__(self): items = list(super().__iter__()) items[0]._annotations = [ - BboxObject(0, 0, 9, 5, - label=3, id=4, group=4, - attributes={ 'is_crowd': False }), - PolygonObject([0, 0, 4, 0, 4, 4], - label=3, id=4, group=4, - attributes={ 'is_crowd': False }), - PolygonObject([5, 0, 9, 0, 5, 5], - label=3, id=4, group=4, - attributes={ 'is_crowd': False }), - MaskObject(np.array([ + Mask(np.array([ [0, 1, 1, 1, 0, 1, 1, 1, 1, 0], [0, 0, 1, 1, 0, 1, 1, 1, 0, 0], [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], @@ -305,7 +312,7 @@ def __iter__(self): with TestDir() as test_dir: self._test_save_and_load(TestExtractor(), CocoInstancesConverter(), test_dir, - importer_params={'merge_instance_polygons': True}, + importer_args={'merge_instance_polygons': True}, target_dataset=TargetExtractor()) def test_can_crop_covered_segments(self): @@ -315,56 +322,47 @@ def test_can_crop_covered_segments(self): class SrcTestExtractor(Extractor): def __iter__(self): - items = [ + return iter([ DatasetItem(id=1, image=np.zeros((5, 5, 3)), annotations=[ - MaskObject(np.array([ + Mask(np.array([ [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [1, 1, 0, 1, 1], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], - ), + ), label=2, id=1, z_order=0), - PolygonObject([1, 1, 4, 1, 4, 4, 1, 4], + Polygon([1, 1, 4, 1, 4, 4, 1, 4], label=1, id=2, z_order=1), ] ), - ] - return iter(items) + ]) def categories(self): return { AnnotationType.label: label_categories } class DstTestExtractor(Extractor): def __iter__(self): - items = [ + return iter([ DatasetItem(id=1, image=np.zeros((5, 5, 3)), annotations=[ - BboxObject(0, 0, 4, 4, - label=2, id=1, group=1, - attributes={ 'is_crowd': True }), - MaskObject(np.array([ + Mask(np.array([ [0, 0, 1, 1, 1], [0, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]], - ), + ), attributes={ 'is_crowd': True }, label=2, id=1, group=1), - BboxObject(1, 1, 3, 3, - label=1, id=2, group=2, - attributes={ 'is_crowd': False }), - PolygonObject([1, 1, 4, 1, 4, 4, 1, 4], + Polygon([1, 1, 4, 1, 4, 4, 1, 4], label=1, id=2, group=2, attributes={ 'is_crowd': False }), - # NOTE: Why it's 4 in COCOapi?.. ] ), - ] - return iter(items) + ]) def categories(self): return { AnnotationType.label: label_categories } @@ -384,9 +382,9 @@ def __iter__(self): return iter([ DatasetItem(id=1, image=np.zeros((6, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=3, id=4, group=4), - PolygonObject([5, 0, 9, 0, 5, 5], + Polygon([5, 0, 9, 0, 5, 5], label=3, id=4, group=4), ] ), @@ -400,9 +398,7 @@ def __iter__(self): return iter([ DatasetItem(id=1, image=np.zeros((6, 10, 3)), annotations=[ - BboxObject(0, 0, 9, 5, label=3, id=4, group=4, - attributes={ 'is_crowd': True }), - MaskObject(np.array([ + Mask(np.array([ [0, 1, 1, 1, 0, 1, 1, 1, 1, 0], [0, 0, 1, 1, 0, 1, 1, 1, 0, 0], [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], @@ -411,7 +407,7 @@ def __iter__(self): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # only internal fragment (without the border), # but not everywhere... - ), + ), attributes={ 'is_crowd': True }, label=3, id=4, group=4), ] @@ -433,22 +429,20 @@ def test_can_convert_masks_to_polygons(self): class SrcTestExtractor(Extractor): def __iter__(self): - items = [ + return iter([ DatasetItem(id=1, image=np.zeros((5, 10, 3)), annotations=[ - MaskObject(np.array([ - [0, 1, 1, 1, 0, 1, 1, 1, 1, 0], - [0, 0, 1, 1, 0, 1, 1, 1, 0, 0], - [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - ), + Mask(np.array([ + [0, 1, 1, 1, 0, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), label=3, id=4, group=4), ] ), - ] - return iter(items) + ]) def categories(self): return { AnnotationType.label: label_categories } @@ -458,13 +452,11 @@ def __iter__(self): return iter([ DatasetItem(id=1, image=np.zeros((5, 10, 3)), annotations=[ - BboxObject(1, 0, 7, 3, label=3, id=4, group=4, - attributes={ 'is_crowd': False }), - PolygonObject( + Polygon( [3.0, 2.5, 1.0, 0.0, 3.5, 0.0, 3.0, 2.5], label=3, id=4, group=4, attributes={ 'is_crowd': False }), - PolygonObject( + Polygon( [5.0, 3.5, 4.5, 0.0, 8.0, 0.0, 5.0, 3.5], label=3, id=4, group=4, attributes={ 'is_crowd': False }), @@ -504,17 +496,17 @@ def __iter__(self): return iter([ DatasetItem(id=1, subset='train', annotations=[ - LabelObject(4, id=1, group=1), - LabelObject(9, id=2, group=2), + Label(4, id=1, group=1), + Label(9, id=2, group=2), ]), DatasetItem(id=2, subset='train', annotations=[ - LabelObject(4, id=4, group=4), + Label(4, id=4, group=4), ]), DatasetItem(id=3, subset='val', annotations=[ - LabelObject(2, id=1, group=1), + Label(2, id=1, group=1), ]), ]) @@ -547,25 +539,25 @@ def __iter__(self): DatasetItem(id=1, subset='train', image=np.zeros((5, 5, 3)), annotations=[ # Full instance annotations: polygon + keypoints - PointsObject([0, 0, 0, 2, 4, 1], [0, 1, 2], + Points([0, 0, 0, 2, 4, 1], [0, 1, 2], label=3, group=1, id=1), - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=3, group=1, id=1), # Full instance annotations: bbox + keypoints - PointsObject([1, 2, 3, 4, 2, 3], group=2, id=2), - BboxObject(1, 2, 2, 2, group=2, id=2), + Points([1, 2, 3, 4, 2, 3], group=2, id=2), + Bbox(1, 2, 2, 2, group=2, id=2), ]), DatasetItem(id=2, subset='train', annotations=[ # Solitary keypoints - PointsObject([1, 2, 0, 2, 4, 1], label=5, id=3), + Points([1, 2, 0, 2, 4, 1], label=5, id=3), ]), DatasetItem(id=3, subset='val', annotations=[ # Solitary keypoints with no label - PointsObject([0, 0, 1, 2, 3, 4], [0, 1, 2], id=3), + Points([0, 0, 1, 2, 3, 4], [0, 1, 2], id=3), ]), ]) @@ -577,48 +569,36 @@ def __iter__(self): return iter([ DatasetItem(id=1, subset='train', image=np.zeros((5, 5, 3)), annotations=[ - PointsObject([0, 0, 0, 2, 4, 1], [0, 1, 2], - label=3, group=1, id=1, - attributes={'is_crowd': False}), - PolygonObject([0, 0, 4, 0, 4, 4], + Points([0, 0, 0, 2, 4, 1], [0, 1, 2], label=3, group=1, id=1, attributes={'is_crowd': False}), - BboxObject(0, 0, 4, 4, + Polygon([0, 0, 4, 0, 4, 4], label=3, group=1, id=1, attributes={'is_crowd': False}), - PointsObject([1, 2, 3, 4, 2, 3], - group=2, id=2, - attributes={'is_crowd': False}), - PolygonObject([1, 2, 3, 2, 3, 4, 1, 4], + Points([1, 2, 3, 4, 2, 3], group=2, id=2, attributes={'is_crowd': False}), - BboxObject(1, 2, 2, 2, + Polygon([1, 2, 3, 2, 3, 4, 1, 4], group=2, id=2, attributes={'is_crowd': False}), ]), DatasetItem(id=2, subset='train', annotations=[ - PointsObject([1, 2, 0, 2, 4, 1], + Points([1, 2, 0, 2, 4, 1], label=5, group=3, id=3, attributes={'is_crowd': False}), - PolygonObject([0, 1, 4, 1, 4, 2, 0, 2], - label=5, group=3, id=3, - attributes={'is_crowd': False}), - BboxObject(0, 1, 4, 1, + Polygon([0, 1, 4, 1, 4, 2, 0, 2], label=5, group=3, id=3, attributes={'is_crowd': False}), ]), DatasetItem(id=3, subset='val', annotations=[ - PointsObject([0, 0, 1, 2, 3, 4], [0, 1, 2], - group=3, id=3, - attributes={'is_crowd': False}), - PolygonObject([1, 2, 3, 2, 3, 4, 1, 4], + Points([0, 0, 1, 2, 3, 4], [0, 1, 2], group=3, id=3, attributes={'is_crowd': False}), - BboxObject(1, 2, 2, 2, + Polygon([1, 2, 3, 2, 3, 4, 1, 4], group=3, id=3, attributes={'is_crowd': False}), ]), @@ -634,11 +614,11 @@ class TestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=1, annotations=[ - LabelObject(2, id=1, group=1), + Label(2, id=1, group=1), ]), DatasetItem(id=2, annotations=[ - LabelObject(3, id=2, group=2), + Label(3, id=2, group=2), ]), ]) diff --git a/datumaro/tests/test_command_targets.py b/datumaro/tests/test_command_targets.py index d029b92397a8..e9f4167fc965 100644 --- a/datumaro/tests/test_command_targets.py +++ b/datumaro/tests/test_command_targets.py @@ -21,7 +21,7 @@ def test_image_false_when_no_file(self): def test_image_false_when_false(self): with TestDir() as test_dir: - path = osp.join(test_dir.path, 'test.jpg') + path = osp.join(test_dir, 'test.jpg') with open(path, 'w+') as f: f.write('qwerty123') @@ -33,7 +33,7 @@ def test_image_false_when_false(self): def test_image_true_when_true(self): with TestDir() as test_dir: - path = osp.join(test_dir.path, 'test.jpg') + path = osp.join(test_dir, 'test.jpg') image = np.random.random_sample([10, 10, 3]) cv2.imwrite(path, image) @@ -60,7 +60,7 @@ def test_project_false_when_no_name(self): def test_project_true_when_project_file(self): with TestDir() as test_dir: - path = osp.join(test_dir.path, 'test.jpg') + path = osp.join(test_dir, 'test.jpg') Project().save(path) target = ProjectTarget() @@ -91,9 +91,9 @@ def test_project_false_when_not_project_name(self): self.assertFalse(status) - def test_project_true_when_not_project_file(self): + def test_project_false_when_not_project_file(self): with TestDir() as test_dir: - path = osp.join(test_dir.path, 'test.jpg') + path = osp.join(test_dir, 'test.jpg') with open(path, 'w+') as f: f.write('wqererw') diff --git a/datumaro/tests/test_cvat_format.py b/datumaro/tests/test_cvat_format.py index 8a4c95ad4cbe..dc4a2dc43972 100644 --- a/datumaro/tests/test_cvat_format.py +++ b/datumaro/tests/test_cvat_format.py @@ -6,22 +6,21 @@ from unittest import TestCase from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, PointsObject, PolygonObject, PolyLineObject, BboxObject, + AnnotationType, Points, Polygon, PolyLine, Bbox, LabelCategories, ) -from datumaro.components.importers.cvat import CvatImporter -from datumaro.components.converters.cvat import CvatConverter -from datumaro.components.project import Project -import datumaro.components.formats.cvat as Cvat +from datumaro.plugins.cvat_format.importer import CvatImporter +from datumaro.plugins.cvat_format.converter import CvatConverter +from datumaro.plugins.cvat_format.format import CvatPath from datumaro.util.image import save_image -from datumaro.util.test_utils import TestDir, item_to_str +from datumaro.util.test_utils import TestDir, compare_datasets class CvatExtractorTest(TestCase): @staticmethod def generate_dummy_cvat(path): - images_dir = osp.join(path, Cvat.CvatPath.IMAGES_DIR) - anno_dir = osp.join(path, Cvat.CvatPath.ANNOTATIONS_DIR) + images_dir = osp.join(path, CvatPath.IMAGES_DIR) + anno_dir = osp.join(path, CvatPath.ANNOTATIONS_DIR) os.makedirs(images_dir) os.makedirs(anno_dir) @@ -103,80 +102,55 @@ def test_can_load(self): class TestExtractor(Extractor): def __iter__(self): return iter([ - DatasetItem(id=1, subset='train', image=np.ones((8, 8, 3)), + DatasetItem(id=0, subset='train', image=np.ones((8, 8, 3)), annotations=[ - BboxObject(0, 2, 4, 2, label=0, + Bbox(0, 2, 4, 2, label=0, attributes={ 'occluded': True, 'z_order': 1, 'a1': True, 'a2': 'v3' }), - PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], + PolyLine([1, 2, 3, 4, 5, 6, 7, 8], attributes={'occluded': False, 'z_order': 0}), ]), - DatasetItem(id=2, subset='train', image=np.ones((10, 10, 3)), + DatasetItem(id=1, subset='train', image=np.ones((10, 10, 3)), annotations=[ - PolygonObject([1, 2, 3, 4, 6, 5], + Polygon([1, 2, 3, 4, 6, 5], attributes={'occluded': False, 'z_order': 1}), - PointsObject([1, 2, 3, 4, 5, 6], label=1, + Points([1, 2, 3, 4, 5, 6], label=1, attributes={'occluded': False, 'z_order': 2}), ]), ]) def categories(self): label_categories = LabelCategories() - for i in range(10): - label_categories.add('label_' + str(i)) + label_categories.add('label1', attributes={'a1', 'a2'}) + label_categories.add('label2') return { AnnotationType.label: label_categories, } with TestDir() as test_dir: - self.generate_dummy_cvat(test_dir.path) + self.generate_dummy_cvat(test_dir) source_dataset = TestExtractor() - parsed_dataset = CvatImporter()(test_dir.path).make_dataset() + parsed_dataset = CvatImporter()(test_dir).make_dataset() - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - self.assertEqual(len(source_dataset), len(parsed_dataset)) - for subset_name in source_dataset.subsets(): - source_subset = source_dataset.get_subset(subset_name) - parsed_subset = parsed_dataset.get_subset(subset_name) - for item_a, item_b in zip(source_subset, parsed_subset): - self.assertEqual(len(item_a.annotations), len(item_b.annotations)) - for ann_a, ann_b in zip(item_a.annotations, item_b.annotations): - self.assertEqual(ann_a, ann_b) + compare_datasets(self, source_dataset, parsed_dataset) class CvatConverterTest(TestCase): def _test_save_and_load(self, source_dataset, converter, test_dir, - importer_params=None, target_dataset=None): - converter(source_dataset, test_dir.path) - - if not importer_params: - importer_params = {} - project = Project.import_from(test_dir.path, 'cvat', **importer_params) - parsed_dataset = project.make_dataset() - - if target_dataset is not None: - source_dataset = target_dataset - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - - self.assertEqual(len(source_dataset), len(parsed_dataset)) - - for subset_name in source_dataset.subsets(): - source_subset = source_dataset.get_subset(subset_name) - parsed_subset = parsed_dataset.get_subset(subset_name) - self.assertEqual(len(source_subset), len(parsed_subset)) - for idx, (item_a, item_b) in enumerate( - zip(source_subset, parsed_subset)): - self.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \ - (idx, item_to_str(item_a), item_to_str(item_b))) + target_dataset=None, importer_args=None): + converter(source_dataset, test_dir) + + if importer_args is None: + importer_args = {} + parsed_dataset = CvatImporter()(test_dir, **importer_args).make_dataset() + + if target_dataset is None: + target_dataset = source_dataset + + compare_datasets(self, expected=target_dataset, actual=parsed_dataset) def test_can_save_and_load(self): label_categories = LabelCategories() @@ -190,32 +164,32 @@ def __iter__(self): return iter([ DatasetItem(id=0, subset='s1', image=np.zeros((5, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=1, group=4, attributes={ 'occluded': True }), - PolygonObject([5, 0, 9, 0, 5, 5], + Polygon([5, 0, 9, 0, 5, 5], label=2, group=4, attributes={ 'unknown': 'bar' }), - PointsObject([1, 1, 3, 2, 2, 3], + Points([1, 1, 3, 2, 2, 3], label=2, attributes={ 'a1': 'x', 'a2': 42 }), ] ), DatasetItem(id=1, subset='s1', annotations=[ - PolyLineObject([0, 0, 4, 0, 4, 4], + PolyLine([0, 0, 4, 0, 4, 4], label=3, id=4, group=4), - BboxObject(5, 0, 1, 9, + Bbox(5, 0, 1, 9, label=3, id=4, group=4), ] ), DatasetItem(id=2, subset='s2', image=np.ones((5, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=3, group=4, attributes={ 'z_order': 1, 'occluded': False }), - PolyLineObject([5, 0, 9, 0, 5, 5]), # will be skipped as no label + PolyLine([5, 0, 9, 0, 5, 5]), # will be skipped as no label ] ), ]) @@ -228,13 +202,13 @@ def __iter__(self): return iter([ DatasetItem(id=0, subset='s1', image=np.zeros((5, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=1, group=4, attributes={ 'z_order': 0, 'occluded': True }), - PolygonObject([5, 0, 9, 0, 5, 5], + Polygon([5, 0, 9, 0, 5, 5], label=2, group=4, attributes={ 'z_order': 0, 'occluded': False }), - PointsObject([1, 1, 3, 2, 2, 3], + Points([1, 1, 3, 2, 2, 3], label=2, attributes={ 'z_order': 0, 'occluded': False, 'a1': 'x', 'a2': 42 }), @@ -242,10 +216,10 @@ def __iter__(self): ), DatasetItem(id=1, subset='s1', annotations=[ - PolyLineObject([0, 0, 4, 0, 4, 4], + PolyLine([0, 0, 4, 0, 4, 4], label=3, group=4, attributes={ 'z_order': 0, 'occluded': False }), - BboxObject(5, 0, 1, 9, + Bbox(5, 0, 1, 9, label=3, group=4, attributes={ 'z_order': 0, 'occluded': False }), ] @@ -253,7 +227,7 @@ def __iter__(self): DatasetItem(id=2, subset='s2', image=np.ones((5, 10, 3)), annotations=[ - PolygonObject([0, 0, 4, 0, 4, 4], + Polygon([0, 0, 4, 0, 4, 4], label=3, group=4, attributes={ 'z_order': 1, 'occluded': False }), ] diff --git a/datumaro/tests/test_datumaro_format.py b/datumaro/tests/test_datumaro_format.py index 77b1b1c07e9b..4a168a6deaa9 100644 --- a/datumaro/tests/test_datumaro_format.py +++ b/datumaro/tests/test_datumaro_format.py @@ -4,11 +4,11 @@ from datumaro.components.project import Project from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, LabelObject, MaskObject, PointsObject, PolygonObject, - PolyLineObject, BboxObject, CaptionObject, + AnnotationType, Label, Mask, Points, Polygon, + PolyLine, Bbox, Caption, LabelCategories, MaskCategories, PointsCategories ) -from datumaro.components.converters.datumaro import DatumaroConverter +from datumaro.plugins.datumaro_format.converter import DatumaroConverter from datumaro.util.test_utils import TestDir, item_to_str from datumaro.util.mask_tools import generate_colormap @@ -19,30 +19,30 @@ def __iter__(self): return iter([ DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), annotations=[ - CaptionObject('hello', id=1), - CaptionObject('world', id=2, group=5), - LabelObject(2, id=3, attributes={ + Caption('hello', id=1), + Caption('world', id=2, group=5), + Label(2, id=3, attributes={ 'x': 1, 'y': '2', }), - BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ + Bbox(1, 2, 3, 4, label=4, id=4, attributes={ 'score': 1.0, }), - BboxObject(5, 6, 7, 8, id=5, group=5), - PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5), - MaskObject(label=3, id=5, image=np.ones((2, 3))), + Bbox(5, 6, 7, 8, id=5, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, id=5), + Mask(label=3, id=5, image=np.ones((2, 3))), ]), DatasetItem(id=21, subset='train', annotations=[ - CaptionObject('test'), - LabelObject(2), - BboxObject(1, 2, 3, 4, 5, id=42, group=42) + Caption('test'), + Label(2), + Bbox(1, 2, 3, 4, 5, id=42, group=42) ]), DatasetItem(id=2, subset='val', annotations=[ - PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11), - PolygonObject([1, 2, 3, 4, 5, 6, 7, 8], id=12), + PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11), + Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12), ]), DatasetItem(id=42, subset='test'), @@ -74,9 +74,9 @@ def test_can_save_and_load(self): source_dataset = self.TestExtractor() converter = DatumaroConverter(save_images=True) - converter(source_dataset, test_dir.path) + converter(source_dataset, test_dir) - project = Project.import_from(test_dir.path, 'datumaro') + project = Project.import_from(test_dir, 'datumaro') parsed_dataset = project.make_dataset() self.assertListEqual( diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 5f0655f1228f..9ad9c1de6fdf 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,6 +1,6 @@ from unittest import TestCase -from datumaro.components.extractor import DatasetItem, LabelObject, BboxObject +from datumaro.components.extractor import DatasetItem, Label, Bbox from datumaro.components.comparator import Comparator @@ -8,7 +8,7 @@ class DiffTest(TestCase): def test_no_bbox_diff_with_same_item(self): detections = 3 anns = [ - BboxObject(i * 10, 10, 10, 10, label=i, + Bbox(i * 10, 10, 10, 10, label=i, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) ] @@ -38,12 +38,12 @@ def test_can_find_bbox_with_wrong_label(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - BboxObject(i * 10, 10, 10, 10, label=i, + Bbox(i * 10, 10, 10, 10, label=i, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) ]) item2 = DatasetItem(id=2, annotations=[ - BboxObject(i * 10, 10, 10, 10, label=(i + 1) % class_count, + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) ]) @@ -72,12 +72,12 @@ def test_can_find_missing_boxes(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - BboxObject(i * 10, 10, 10, 10, label=i, + Bbox(i * 10, 10, 10, 10, label=i, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) if i % 2 == 0 ]) item2 = DatasetItem(id=2, annotations=[ - BboxObject(i * 10, 10, 10, 10, label=(i + 1) % class_count, + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) if i % 2 == 1 ]) @@ -102,7 +102,7 @@ def test_can_find_missing_boxes(self): def test_no_label_diff_with_same_item(self): detections = 3 anns = [ - LabelObject(i, attributes={'score': (1.0 + i) / detections}) \ + Label(i, attributes={'score': (1.0 + i) / detections}) \ for i in range(detections) ] item = DatasetItem(id=1, annotations=anns) @@ -121,14 +121,14 @@ def test_no_label_diff_with_same_item(self): def test_can_find_wrong_label(self): item1 = DatasetItem(id=1, annotations=[ - LabelObject(0), - LabelObject(1), - LabelObject(2), + Label(0), + Label(1), + Label(2), ]) item2 = DatasetItem(id=2, annotations=[ - LabelObject(2), - LabelObject(3), - LabelObject(4), + Label(2), + Label(3), + Label(4), ]) conf_thresh = 0.5 diff --git a/datumaro/tests/test_image.py b/datumaro/tests/test_image.py index 1e1ed5c75c0d..9614fcc90d7b 100644 --- a/datumaro/tests/test_image.py +++ b/datumaro/tests/test_image.py @@ -23,7 +23,7 @@ def test_save_and_load_backends(self): src_image = np.random.randint(0, 255 + 1, (2, 4)) else: src_image = np.random.randint(0, 255 + 1, (2, 4, c)) - path = osp.join(test_dir.path, 'img.png') # lossless + path = osp.join(test_dir, 'img.png') # lossless image_module._IMAGE_BACKEND = save_backend image_module.save_image(path, src_image) diff --git a/datumaro/tests/test_image_dir_format.py b/datumaro/tests/test_image_dir_format.py index 27568d55822f..6b382c212cb1 100644 --- a/datumaro/tests/test_image_dir_format.py +++ b/datumaro/tests/test_image_dir_format.py @@ -5,7 +5,7 @@ from datumaro.components.project import Project from datumaro.components.extractor import Extractor, DatasetItem -from datumaro.util.test_utils import TestDir +from datumaro.util.test_utils import TestDir, compare_datasets from datumaro.util.image import save_image @@ -22,27 +22,9 @@ def test_can_load(self): source_dataset = self.TestExtractor() for item in source_dataset: - save_image(osp.join(test_dir.path, '%s.jpg' % item.id), - item.image) + save_image(osp.join(test_dir, '%s.jpg' % item.id), item.image) - project = Project.import_from(test_dir.path, 'image_dir') + project = Project.import_from(test_dir, 'image_dir') parsed_dataset = project.make_dataset() - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - - self.assertEqual(len(source_dataset), len(parsed_dataset)) - - for subset_name in source_dataset.subsets(): - source_subset = source_dataset.get_subset(subset_name) - parsed_subset = parsed_dataset.get_subset(subset_name) - self.assertEqual(len(source_subset), len(parsed_subset)) - for idx, (item_a, item_b) in enumerate( - zip(source_subset, parsed_subset)): - self.assertEqual(item_a, item_b, str(idx)) - - self.assertEqual( - source_dataset.categories(), - parsed_dataset.categories()) \ No newline at end of file + compare_datasets(self, source_dataset, parsed_dataset) diff --git a/datumaro/tests/test_images.py b/datumaro/tests/test_images.py index 8c05d61404e8..e3f12a3df89c 100644 --- a/datumaro/tests/test_images.py +++ b/datumaro/tests/test_images.py @@ -15,7 +15,7 @@ def test_cache_works(self): image = np.ones((100, 100, 3), dtype=np.uint8) image = Image.fromarray(image).convert('RGB') - image_path = osp.join(test_dir.path, 'image.jpg') + image_path = osp.join(test_dir, 'image.jpg') image.save(image_path) caching_loader = lazy_image(image_path, cache=None) diff --git a/datumaro/tests/test_masks.py b/datumaro/tests/test_masks.py index d019f254ff41..1619f1db1adf 100644 --- a/datumaro/tests/test_masks.py +++ b/datumaro/tests/test_masks.py @@ -120,5 +120,19 @@ def test_can_remap_mask(self): actual = mask_tools.remap_mask(src, remap_fn) + self.assertTrue(np.array_equal(expected, actual), + '%s\nvs.\n%s' % (expected, actual)) + + def test_can_merge_masks(self): + masks = [ + np.array([0, 2, 4, 0, 0, 1]), + np.array([0, 1, 1, 0, 2, 0]), + np.array([0, 0, 2, 3, 0, 0]), + ] + expected = \ + np.array([0, 1, 2, 3, 2, 1]) + + actual = mask_tools.merge_masks(masks) + self.assertTrue(np.array_equal(expected, actual), '%s\nvs.\n%s' % (expected, actual)) \ No newline at end of file diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 93a2aad484fa..8c4106687d72 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -9,8 +9,7 @@ from datumaro.components.launcher import Launcher, InferenceWrapper from datumaro.components.converter import Converter from datumaro.components.extractor import (Extractor, DatasetItem, - LabelObject, MaskObject, PointsObject, PolygonObject, - PolyLineObject, BboxObject, CaptionObject, + Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, ) from datumaro.components.config import Config, DefaultConfig, SchemaBuilder from datumaro.components.dataset_filter import \ @@ -26,7 +25,7 @@ def test_project_generate(self): }) with TestDir() as test_dir: - project_path = test_dir.path + project_path = test_dir Project.generate(project_path, src_config) self.assertTrue(osp.isdir(project_path)) @@ -80,9 +79,9 @@ def test_added_source_can_be_dumped(self): project.add_source(source_name, origin) with TestDir() as test_dir: - project.save(test_dir.path) + project.save(test_dir) - loaded = Project.load(test_dir.path) + loaded = Project.load(test_dir) loaded = loaded.get_source(source_name) self.assertEqual(origin, loaded) @@ -114,19 +113,19 @@ def test_can_dump_added_model(self): project.add_model(model_name, saved) with TestDir() as test_dir: - project.save(test_dir.path) + project.save(test_dir) - loaded = Project.load(test_dir.path) + loaded = Project.load(test_dir) loaded = loaded.get_model(model_name) self.assertEqual(saved, loaded) def test_can_have_project_source(self): with TestDir() as test_dir: - Project.generate(test_dir.path) + Project.generate(test_dir) project2 = Project() project2.add_source('project1', { - 'url': test_dir.path, + 'url': test_dir, }) dataset = project2.make_dataset() @@ -141,7 +140,7 @@ def __iter__(self): class TestLauncher(Launcher): def launch(self, inputs): for i, inp in enumerate(inputs): - yield [ LabelObject(attributes={'idx': i, 'data': inp}) ] + yield [ Label(attributes={'idx': i, 'data': inp}) ] model_name = 'model' launcher_name = 'custom_launcher' @@ -167,12 +166,12 @@ class TestExtractorSrc(Extractor): def __iter__(self): for i in range(2): yield DatasetItem(id=i, subset='train', image=i, - annotations=[ LabelObject(i) ]) + annotations=[ Label(i) ]) class TestLauncher(Launcher): def launch(self, inputs): for inp in inputs: - yield [ LabelObject(inp) ] + yield [ Label(inp) ] class TestConverter(Converter): def __call__(self, extractor, save_dir): @@ -194,7 +193,7 @@ def __iter__(self): label = int(f.readline().strip()) assert subset == 'train' yield DatasetItem(id=index, subset=subset, - annotations=[ LabelObject(label) ]) + annotations=[ Label(label) ]) model_name = 'model' launcher_name = 'custom_launcher' @@ -208,10 +207,10 @@ def __iter__(self): project.add_source('source', { 'format': extractor_name }) with TestDir() as test_dir: - project.make_dataset().apply_model(model_name=model_name, - save_dir=test_dir.path) + project.make_dataset().apply_model(model=model_name, + save_dir=test_dir) - result = Project.load(test_dir.path) + result = Project.load(test_dir) result.env.extractors.register(extractor_name, TestExtractorDst) it = iter(result.make_dataset()) item1 = next(it) @@ -266,9 +265,9 @@ def test_can_save_and_load_own_dataset(self): src_dataset = src_project.make_dataset() item = DatasetItem(id=1) src_dataset.put(item) - src_dataset.save(test_dir.path) + src_dataset.save(test_dir) - loaded_project = Project.load(test_dir.path) + loaded_project = Project.load(test_dir) loaded_dataset = loaded_project.make_dataset() self.assertEqual(list(src_dataset), list(loaded_dataset)) @@ -285,12 +284,12 @@ def test_project_own_dataset_can_be_modified(self): def test_project_compound_child_can_be_modified_recursively(self): with TestDir() as test_dir: child1 = Project({ - 'project_dir': osp.join(test_dir.path, 'child1'), + 'project_dir': osp.join(test_dir, 'child1'), }) child1.save() child2 = Project({ - 'project_dir': osp.join(test_dir.path, 'child2'), + 'project_dir': osp.join(test_dir, 'child2'), }) child2.save() @@ -316,15 +315,15 @@ def test_project_can_merge_item_annotations(self): class TestExtractor1(Extractor): def __iter__(self): yield DatasetItem(id=1, subset='train', annotations=[ - LabelObject(2, id=3), - LabelObject(3, attributes={ 'x': 1 }), + Label(2, id=3), + Label(3, attributes={ 'x': 1 }), ]) class TestExtractor2(Extractor): def __iter__(self): yield DatasetItem(id=1, subset='train', annotations=[ - LabelObject(3, attributes={ 'x': 1 }), - LabelObject(4, id=4), + Label(3, attributes={ 'x': 1 }), + Label(4, id=4), ]) project = Project() @@ -346,17 +345,17 @@ def test_item_representations(): item = DatasetItem(id=1, subset='subset', path=['a', 'b'], image=np.ones((5, 4, 3)), annotations=[ - LabelObject(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2), - CaptionObject('hello', id=1), - CaptionObject('world', group=5), - LabelObject(2, id=3, attributes={ 'x': 1, 'y': '2' }), - BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }), - BboxObject(5, 6, 7, 8, id=5, group=5), - PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5), - MaskObject(id=5, image=np.ones((3, 2))), - MaskObject(label=3, id=5, image=np.ones((2, 3))), - PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11), - PolygonObject([1, 2, 3, 4, 5, 6, 7, 8]), + Label(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2), + Caption('hello', id=1), + Caption('world', group=5), + Label(2, id=3, attributes={ 'x': 1, 'y': '2' }), + Bbox(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }), + Bbox(5, 6, 7, 8, id=5, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, id=5), + Mask(id=5, image=np.ones((3, 2))), + Mask(label=3, id=5, image=np.ones((2, 3))), + PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11), + Polygon([1, 2, 3, 4, 5, 6, 7, 8]), ] ) @@ -381,12 +380,12 @@ def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ - LabelObject(0), - LabelObject(1), + Label(0), + Label(1), ]), DatasetItem(id=2, annotations=[ - LabelObject(0), - LabelObject(2), + Label(0), + Label(2), ]), ]) @@ -395,10 +394,10 @@ def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ - LabelObject(0), + Label(0), ]), DatasetItem(id=2, annotations=[ - LabelObject(0), + Label(0), ]), ]) @@ -415,12 +414,12 @@ def __iter__(self): return iter([ DatasetItem(id=0), DatasetItem(id=1, annotations=[ - LabelObject(0), - LabelObject(1), + Label(0), + Label(1), ]), DatasetItem(id=2, annotations=[ - LabelObject(0), - LabelObject(2), + Label(0), + Label(2), ]), ]) @@ -428,7 +427,7 @@ class DstTestExtractor(Extractor): def __iter__(self): return iter([ DatasetItem(id=2, annotations=[ - LabelObject(2), + Label(2), ]), ]) diff --git a/datumaro/tests/test_tfrecord_format.py b/datumaro/tests/test_tfrecord_format.py index 2ea3b223647c..664359d0bd01 100644 --- a/datumaro/tests/test_tfrecord_format.py +++ b/datumaro/tests/test_tfrecord_format.py @@ -2,43 +2,29 @@ from unittest import TestCase -from datumaro.components.project import Project from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, BboxObject, LabelCategories + AnnotationType, Bbox, LabelCategories ) -from datumaro.components.extractors.tfrecord import DetectionApiExtractor -from datumaro.components.converters.tfrecord import DetectionApiConverter -from datumaro.util import find -from datumaro.util.test_utils import TestDir +from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter +from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor +from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter +from datumaro.util.test_utils import TestDir, compare_datasets class TfrecordConverterTest(TestCase): - def _test_can_save_and_load(self, source_dataset, converter, test_dir, - importer_params=None): - converter(source_dataset, test_dir.path) - - if not importer_params: - importer_params = {} - project = Project.import_from(test_dir.path, 'tf_detection_api', - **importer_params) - parsed_dataset = project.make_dataset() - - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - - self.assertEqual(len(source_dataset), len(parsed_dataset)) - - for item_a in source_dataset: - item_b = find(parsed_dataset, lambda x: x.id == item_a.id) - self.assertFalse(item_b is None) - self.assertEqual(len(item_a.annotations), len(item_b.annotations)) - for ann_a in item_a.annotations: - ann_b = find(item_b.annotations, lambda x: \ - x.id == ann_a.id and \ - x.type == ann_a.type and x.group == ann_a.group) - self.assertEqual(ann_a, ann_b, 'id: ' + str(ann_a.id)) + def _test_save_and_load(self, source_dataset, converter, test_dir, + target_dataset=None, importer_args=None): + converter(source_dataset, test_dir) + + if importer_args is None: + importer_args = {} + parsed_dataset = TfDetectionApiImporter()(test_dir, **importer_args) \ + .make_dataset() + + if target_dataset is None: + target_dataset = source_dataset + + compare_datasets(self, expected=target_dataset, actual=parsed_dataset) def test_can_save_bboxes(self): class TestExtractor(Extractor): @@ -47,16 +33,16 @@ def __iter__(self): DatasetItem(id=1, subset='train', image=np.ones((16, 16, 3)), annotations=[ - BboxObject(0, 4, 4, 8, label=2, id=0), - BboxObject(0, 4, 4, 4, label=3, id=1), - BboxObject(2, 4, 4, 4, id=2), + Bbox(0, 4, 4, 8, label=2, id=0), + Bbox(0, 4, 4, 4, label=3, id=1), + Bbox(2, 4, 4, 4, id=2), ] ), DatasetItem(id=2, subset='val', image=np.ones((8, 8, 3)), annotations=[ - BboxObject(1, 2, 4, 2, label=3, id=0), + Bbox(1, 2, 4, 2, label=3, id=0), ] ), @@ -74,8 +60,8 @@ def categories(self): } with TestDir() as test_dir: - self._test_can_save_and_load( - TestExtractor(), DetectionApiConverter(save_images=True), + self._test_save_and_load( + TestExtractor(), TfDetectionApiConverter(save_images=True), test_dir) def test_can_save_dataset_with_no_subsets(self): @@ -85,15 +71,15 @@ def __iter__(self): DatasetItem(id=1, image=np.ones((16, 16, 3)), annotations=[ - BboxObject(2, 1, 4, 4, label=2, id=0), - BboxObject(4, 2, 8, 4, label=3, id=1), + Bbox(2, 1, 4, 4, label=2, id=0), + Bbox(4, 2, 8, 4, label=3, id=1), ] ), DatasetItem(id=2, image=np.ones((8, 8, 3)) * 2, annotations=[ - BboxObject(4, 4, 4, 4, label=3, id=0), + Bbox(4, 4, 4, 4, label=3, id=0), ] ), @@ -111,8 +97,8 @@ def categories(self): } with TestDir() as test_dir: - self._test_can_save_and_load( - TestExtractor(), DetectionApiConverter(save_images=True), + self._test_save_and_load( + TestExtractor(), TfDetectionApiConverter(save_images=True), test_dir) def test_labelmap_parsing(self): @@ -137,6 +123,6 @@ def test_labelmap_parsing(self): 'qw3': 6, 'qw4': 7, } - parsed = DetectionApiExtractor._parse_labelmap(text) + parsed = TfDetectionApiExtractor._parse_labelmap(text) self.assertEqual(expected, parsed) diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py new file mode 100644 index 000000000000..e5f0600af3d4 --- /dev/null +++ b/datumaro/tests/test_transforms.py @@ -0,0 +1,188 @@ +import numpy as np + +from unittest import TestCase + +from datumaro.components.extractor import (Extractor, DatasetItem, + Mask, Polygon +) +from datumaro.util.test_utils import compare_datasets +import datumaro.plugins.transforms as transforms + + +class TransformsTest(TestCase): + def test_reindex(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=10), + DatasetItem(id=10, subset='train'), + DatasetItem(id='a', subset='val'), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=5), + DatasetItem(id=6, subset='train'), + DatasetItem(id=7, subset='val'), + ]) + + actual = transforms.Reindex(SrcExtractor(), start=5) + compare_datasets(self, DstExtractor(), actual) + + def test_mask_to_polygons(self): + class SrcExtractor(Extractor): + def __iter__(self): + items = [ + DatasetItem(id=1, image=np.zeros((5, 10, 3)), + annotations=[ + Mask(np.array([ + [0, 1, 1, 1, 0, 1, 1, 1, 1, 0], + [0, 0, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), + ), + ] + ), + ] + return iter(items) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 10, 3)), + annotations=[ + Polygon([3.0, 2.5, 1.0, 0.0, 3.5, 0.0, 3.0, 2.5]), + Polygon([5.0, 3.5, 4.5, 0.0, 8.0, 0.0, 5.0, 3.5]), + ] + ), + ]) + + actual = transforms.MasksToPolygons(SrcExtractor()) + compare_datasets(self, DstExtractor(), actual) + + def test_polygons_to_masks(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 10, 3)), + annotations=[ + Polygon([0, 0, 4, 0, 4, 4]), + Polygon([5, 0, 9, 0, 5, 5]), + ] + ), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 10, 3)), + annotations=[ + Mask(np.array([ + [0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), + ), + Mask(np.array([ + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), + ), + ] + ), + ]) + + actual = transforms.PolygonsToMasks(SrcExtractor()) + compare_datasets(self, DstExtractor(), actual) + + def test_crop_covered_segments(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + # The mask is partially covered by the polygon + Mask(np.array([ + [0, 0, 1, 1, 1], + [0, 0, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0]], + ), + z_order=0), + Polygon([1, 1, 4, 1, 4, 4, 1, 4], + z_order=1), + ] + ), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + Mask(np.array([ + [0, 0, 1, 1, 1], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [1, 1, 1, 0, 0]], + ), + z_order=0), + Polygon([1, 1, 4, 1, 4, 4, 1, 4], + z_order=1), + ] + ), + ]) + + actual = transforms.CropCoveredSegments(SrcExtractor()) + compare_datasets(self, DstExtractor(), actual) + + def test_merge_instance_segments(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + Mask(np.array([ + [0, 0, 1, 1, 1], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [1, 1, 1, 0, 0]], + ), + z_order=0), + Polygon([1, 1, 4, 1, 4, 4, 1, 4], + z_order=1), + ] + ), + ]) + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, image=np.zeros((5, 5, 3)), + annotations=[ + Mask(np.array([ + [0, 0, 1, 1, 1], + [0, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 0], + [1, 1, 1, 0, 0]], + ), + z_order=0), + ] + ), + ]) + + actual = transforms.MergeInstanceSegments(SrcExtractor(), + include_polygons=True) + compare_datasets(self, DstExtractor(), actual) \ No newline at end of file diff --git a/datumaro/tests/test_voc_format.py b/datumaro/tests/test_voc_format.py index de58ce40ce4f..d64e467b8a0b 100644 --- a/datumaro/tests/test_voc_format.py +++ b/datumaro/tests/test_voc_format.py @@ -8,17 +8,17 @@ from unittest import TestCase from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, BboxObject, LabelCategories, + AnnotationType, Label, Bbox, Mask, LabelCategories, ) -import datumaro.components.formats.voc as VOC -from datumaro.components.extractors.voc import ( +import datumaro.plugins.voc_format.format as VOC +from datumaro.plugins.voc_format.extractor import ( VocClassificationExtractor, VocDetectionExtractor, VocSegmentationExtractor, VocLayoutExtractor, VocActionExtractor, ) -from datumaro.components.converters.voc import ( +from datumaro.plugins.voc_format.converter import ( VocConverter, VocClassificationConverter, VocDetectionConverter, @@ -26,10 +26,9 @@ VocActionConverter, VocSegmentationConverter, ) -from datumaro.components.importers.voc import VocImporter +from datumaro.plugins.voc_format.importer import VocImporter from datumaro.components.project import Project -from datumaro.util import find -from datumaro.util.test_utils import TestDir +from datumaro.util.test_utils import TestDir, compare_datasets class VocTest(TestCase): @@ -121,7 +120,7 @@ def generate_dummy_voc(path): ET.SubElement(root_elem, 'segmented').text = '1' obj1_elem = ET.SubElement(root_elem, 'object') - ET.SubElement(obj1_elem, 'name').text = VOC.VocLabel(1).name + ET.SubElement(obj1_elem, 'name').text = 'cat' ET.SubElement(obj1_elem, 'pose').text = VOC.VocPose(1).name ET.SubElement(obj1_elem, 'truncated').text = '1' ET.SubElement(obj1_elem, 'difficult').text = '0' @@ -132,7 +131,7 @@ def generate_dummy_voc(path): ET.SubElement(obj1bb_elem, 'ymax').text = '4' obj2_elem = ET.SubElement(root_elem, 'object') - ET.SubElement(obj2_elem, 'name').text = VOC.VocLabel.person.name + ET.SubElement(obj2_elem, 'name').text = 'person' obj2bb_elem = ET.SubElement(obj2_elem, 'bndbox') ET.SubElement(obj2bb_elem, 'xmin').text = '4' ET.SubElement(obj2bb_elem, 'ymin').text = '5' @@ -157,9 +156,10 @@ def generate_dummy_voc(path): subset = subsets[subset_name] for item in subset: cv2.imwrite(osp.join(segm_dir, item + '.png'), - np.ones([10, 20, 3]) * VOC.VocColormap[2]) + np.tile(VOC.VocColormap[2][::-1], (5, 10, 1)) + ) cv2.imwrite(osp.join(inst_dir, item + '.png'), - np.ones([10, 20, 3]) * VOC.VocColormap[2]) + np.tile(1, (5, 10, 1))) # Test images subset_name = 'test' @@ -170,338 +170,440 @@ def generate_dummy_voc(path): return subsets -class VocExtractorTest(TestCase): - def test_can_load_voc_cls(self): - with TestDir() as test_dir: - generated_subsets = generate_dummy_voc(test_dir.path) - - extractor = VocClassificationExtractor(test_dir.path) +class TestExtractorBase(Extractor): + _categories = VOC.make_voc_categories() - self.assertEqual(len(generated_subsets), len(extractor.subsets())) + def _label(self, voc_label): + return self.categories()[AnnotationType.label].find(voc_label)[0] - subset_name = 'train' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) + def categories(self): + return self._categories - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - count = 0 - for label in VOC.VocLabel: - if label.value % 2 == 1: - count += 1 - ann = find(item.annotations, - lambda x: x.type == AnnotationType.label and \ - get_label(extractor, x.label) == label.name) - self.assertFalse(ann is None) - self.assertEqual(count, len(item.annotations)) - - subset_name = 'test' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) +class VocExtractorTest(TestCase): + def test_can_load_voc_cls(self): + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id='2007_000001', subset='train', + annotations=[ + Label(self._label(l.name)) + for l in VOC.VocLabel if l.value % 2 == 1 + ] + ), + + DatasetItem(id='2007_000002', subset='test') + ]) - self.assertEqual(0, len(item.annotations)) + with TestDir() as test_dir: + generate_dummy_voc(test_dir) + parsed_dataset = VocClassificationExtractor(test_dir) + compare_datasets(self, DstExtractor(), parsed_dataset) def test_can_load_voc_det(self): - with TestDir() as test_dir: - generated_subsets = generate_dummy_voc(test_dir.path) - - extractor = VocDetectionExtractor(test_dir.path) - - self.assertEqual(len(generated_subsets), len(extractor.subsets())) - - subset_name = 'train' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - obj1 = find(item.annotations, - lambda x: x.type == AnnotationType.bbox and \ - get_label(extractor, x.label) == VOC.VocLabel(1).name) - self.assertFalse(obj1 is None) - self.assertListEqual([1, 2, 2, 2], obj1.get_bbox()) - self.assertDictEqual( - { - 'pose': VOC.VocPose(1).name, - 'truncated': True, - 'occluded': False, - 'difficult': False, - }, - obj1.attributes) - - obj2 = find(item.annotations, - lambda x: x.type == AnnotationType.bbox and \ - get_label(extractor, x.label) == VOC.VocLabel.person.name) - self.assertFalse(obj2 is None) - self.assertListEqual([4, 5, 2, 2], obj2.get_bbox()) - - self.assertEqual(2, len(item.annotations)) - - subset_name = 'test' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - self.assertEqual(0, len(item.annotations)) + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id='2007_000001', subset='train', + annotations=[ + Bbox(1, 2, 2, 2, label=self._label('cat'), + attributes={ + 'pose': VOC.VocPose(1).name, + 'truncated': True, + 'difficult': False, + 'occluded': False, + }, + id=1, + ), + Bbox(4, 5, 2, 2, label=self._label('person'), + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + **{ + a.name: a.value % 2 == 1 + for a in VOC.VocAction + } + }, + id=2, group=2, + # TODO: Actions and group should be excluded + # as soon as correct merge is implemented + ), + ] + ), + + DatasetItem(id='2007_000002', subset='test') + ]) - def test_can_load_voc_segm(self): with TestDir() as test_dir: - generated_subsets = generate_dummy_voc(test_dir.path) - - extractor = VocSegmentationExtractor(test_dir.path) - - self.assertEqual(len(generated_subsets), len(extractor.subsets())) - - subset_name = 'train' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - cls_mask = find(item.annotations, - lambda x: x.type == AnnotationType.mask and \ - x.attributes.get('class') == True) - self.assertFalse(cls_mask is None) - self.assertFalse(cls_mask.image is None) + generate_dummy_voc(test_dir) + parsed_dataset = VocDetectionExtractor(test_dir) + compare_datasets(self, DstExtractor(), parsed_dataset) - inst_mask = find(item.annotations, - lambda x: x.type == AnnotationType.mask and \ - x.attributes.get('instances') == True) - self.assertFalse(inst_mask is None) - self.assertFalse(inst_mask.image is None) - - self.assertEqual(2, len(item.annotations)) - - subset_name = 'test' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - self.assertEqual(0, len(item.annotations)) + def test_can_load_voc_segm(self): + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id='2007_000001', subset='train', + annotations=[ + Mask(image=np.ones([5, 10]), + label=self._label(VOC.VocLabel(2).name), + group=1, + ), + ] + ), + + DatasetItem(id='2007_000002', subset='test') + ]) - def test_can_load_voc_layout(self): with TestDir() as test_dir: - generated_subsets = generate_dummy_voc(test_dir.path) - - extractor = VocLayoutExtractor(test_dir.path) - - self.assertEqual(len(generated_subsets), len(extractor.subsets())) - - subset_name = 'train' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - obj2 = find(item.annotations, - lambda x: x.type == AnnotationType.bbox and \ - get_label(extractor, x.label) == VOC.VocLabel.person.name) - self.assertFalse(obj2 is None) - self.assertListEqual([4, 5, 2, 2], obj2.get_bbox()) - - obj2head = find(item.annotations, - lambda x: x.type == AnnotationType.bbox and \ - get_label(extractor, x.label) == VOC.VocBodyPart(1).name) - self.assertTrue(obj2.id == obj2head.group) - self.assertListEqual([5.5, 6, 2, 2], obj2head.get_bbox()) + generate_dummy_voc(test_dir) + parsed_dataset = VocSegmentationExtractor(test_dir) + compare_datasets(self, DstExtractor(), parsed_dataset) - self.assertEqual(2, len(item.annotations)) - - subset_name = 'test' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) - - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) - - self.assertEqual(0, len(item.annotations)) + def test_can_load_voc_layout(self): + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id='2007_000001', subset='train', + annotations=[ + Bbox(4, 5, 2, 2, label=self._label('person'), + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + **{ + a.name: a.value % 2 == 1 + for a in VOC.VocAction + } + }, + id=2, group=2, + # TODO: Actions should be excluded + # as soon as correct merge is implemented + ), + Bbox(5.5, 6, 2, 2, label=self._label( + VOC.VocBodyPart(1).name), + group=2 + ) + ] + ), + + DatasetItem(id='2007_000002', subset='test') + ]) - def test_can_load_voc_action(self): with TestDir() as test_dir: - generated_subsets = generate_dummy_voc(test_dir.path) - - extractor = VocActionExtractor(test_dir.path) - - self.assertEqual(len(generated_subsets), len(extractor.subsets())) - - subset_name = 'train' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) + generate_dummy_voc(test_dir) + parsed_dataset = VocLayoutExtractor(test_dir) + compare_datasets(self, DstExtractor(), parsed_dataset) - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) + def test_can_load_voc_action(self): + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id='2007_000001', subset='train', + annotations=[ + Bbox(4, 5, 2, 2, label=self._label('person'), + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + **{ + a.name: a.value % 2 == 1 + for a in VOC.VocAction + } + # TODO: group should be excluded + # as soon as correct merge is implemented + }, + id=2, group=2, + ), + ] + ), + + DatasetItem(id='2007_000002', subset='test') + ]) - obj2 = find(item.annotations, - lambda x: x.type == AnnotationType.bbox and \ - get_label(extractor, x.label) == VOC.VocLabel.person.name) - self.assertFalse(obj2 is None) - self.assertListEqual([4, 5, 2, 2], obj2.get_bbox()) + with TestDir() as test_dir: + generate_dummy_voc(test_dir) + parsed_dataset = VocActionExtractor(test_dir) + compare_datasets(self, DstExtractor(), parsed_dataset) - for action in VOC.VocAction: - attr = obj2.attributes[action.name] - self.assertEqual(attr, action.value % 2) +class VocConverterTest(TestCase): + def _test_save_and_load(self, source_dataset, converter, test_dir, + target_dataset=None, importer_args=None): + converter(source_dataset, test_dir) - subset_name = 'test' - generated_subset = generated_subsets[subset_name] - for id_ in generated_subset: - parsed_subset = extractor.get_subset(subset_name) - self.assertEqual(len(generated_subset), len(parsed_subset)) + if importer_args is None: + importer_args = {} + parsed_dataset = VocImporter()(test_dir, **importer_args).make_dataset() - item = find(parsed_subset, lambda x: x.id == id_) - self.assertFalse(item is None) + if target_dataset is None: + target_dataset = source_dataset - self.assertEqual(0, len(item.annotations)) + compare_datasets(self, expected=target_dataset, actual=parsed_dataset) -class VocConverterTest(TestCase): - def _test_can_save_voc(self, src_extractor, converter, test_dir, - target_extractor=None): - converter(src_extractor, test_dir) + def test_can_save_voc_cls(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=0, subset='a', annotations=[ + Label(1), + Label(2), + Label(3), + ]), - result_extractor = VocImporter()(test_dir).make_dataset() - if target_extractor is None: - target_extractor = src_extractor + DatasetItem(id=1, subset='b', annotations=[ + Label(4), + ]), + ]) - if AnnotationType.label in target_extractor.categories(): - self.assertEqual( - target_extractor.categories()[AnnotationType.label].items, - result_extractor.categories()[AnnotationType.label].items) - if AnnotationType.mask in target_extractor.categories(): - self.assertEqual( - target_extractor.categories()[AnnotationType.mask].colormap, - result_extractor.categories()[AnnotationType.mask].colormap) + with TestDir() as test_dir: + self._test_save_and_load(TestExtractor(), + VocClassificationConverter(label_map='voc'), test_dir) - self.assertEqual(len(target_extractor), len(result_extractor)) - for item_a, item_b in zip(target_extractor, result_extractor): - self.assertEqual(item_a.id, item_b.id) - self.assertEqual(len(item_a.annotations), len(item_b.annotations)) - for ann_a, ann_b in zip(item_a.annotations, item_b.annotations): - self.assertEqual(ann_a.type, ann_b.type) + def test_can_save_voc_det(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Bbox(2, 3, 4, 5, label=2, + attributes={ 'occluded': True } + ), + Bbox(2, 3, 4, 5, label=3, + attributes={ 'truncated': True }, + ), + ]), - def _test_can_save_voc_dummy(self, extractor_type, converter, test_dir): - dummy_dir = osp.join(test_dir, 'dummy') - generate_dummy_voc(dummy_dir) - gen_extractor = extractor_type(dummy_dir) + DatasetItem(id=2, subset='b', annotations=[ + Bbox(5, 4, 6, 5, label=3, + attributes={ 'difficult': True }, + ), + ]), + ]) - self._test_can_save_voc(gen_extractor, converter, - osp.join(test_dir, 'converted')) + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Bbox(2, 3, 4, 5, label=2, id=1, + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': True, + } + ), + Bbox(2, 3, 4, 5, label=3, id=2, + attributes={ + 'truncated': True, + 'difficult': False, + 'occluded': False, + }, + ), + ]), - def test_can_save_voc_cls(self): - with TestDir() as test_dir: - self._test_can_save_voc_dummy( - VocClassificationExtractor, VocClassificationConverter(label_map='voc'), - test_dir.path) + DatasetItem(id=2, subset='b', annotations=[ + Bbox(5, 4, 6, 5, label=3, id=1, + attributes={ + 'truncated': False, + 'difficult': True, + 'occluded': False, + }, + ), + ]), + ]) - def test_can_save_voc_det(self): with TestDir() as test_dir: - self._test_can_save_voc_dummy( - VocDetectionExtractor, VocDetectionConverter(label_map='voc'), - test_dir.path) + self._test_save_and_load(TestExtractor(), + VocDetectionConverter(label_map='voc'), test_dir, + target_dataset=DstExtractor()) def test_can_save_voc_segm(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + # overlapping masks, the first should be truncated + # the second and third are different instances + Mask(image=np.array([[0, 1, 1, 1, 0]]), label=4, + z_order=1), + Mask(image=np.array([[1, 1, 0, 0, 0]]), label=3, + z_order=2), + Mask(image=np.array([[0, 0, 0, 1, 0]]), label=3, + z_order=2), + ]), + ]) + + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Mask(image=np.array([[0, 0, 1, 0, 0]]), label=4, + group=1), + Mask(image=np.array([[1, 1, 0, 0, 0]]), label=3, + group=2), + Mask(image=np.array([[0, 0, 0, 1, 0]]), label=3, + group=3), + ]), + ]) + with TestDir() as test_dir: - self._test_can_save_voc_dummy( - VocSegmentationExtractor, VocSegmentationConverter(label_map='voc'), - test_dir.path) + self._test_save_and_load(TestExtractor(), + VocSegmentationConverter(label_map='voc'), test_dir, + target_dataset=DstExtractor()) def test_can_save_voc_layout(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Bbox(2, 3, 4, 5, label=2, id=1, group=1, + attributes={ + 'pose': VOC.VocPose(1).name, + 'truncated': True, + 'difficult': False, + 'occluded': False, + } + ), + Bbox(2, 3, 1, 1, label=self._label( + VOC.VocBodyPart(1).name), group=1), + Bbox(5, 4, 3, 2, label=self._label( + VOC.VocBodyPart(2).name), group=1), + ]), + ]) + with TestDir() as test_dir: - self._test_can_save_voc_dummy( - VocLayoutExtractor, VocLayoutConverter(label_map='voc'), - test_dir.path) + self._test_save_and_load(TestExtractor(), + VocLayoutConverter(label_map='voc'), test_dir) def test_can_save_voc_action(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Bbox(2, 3, 4, 5, label=2, + attributes={ + 'truncated': True, + VOC.VocAction(1).name: True, + VOC.VocAction(2).name: True, + } + ), + Bbox(5, 4, 3, 2, label=self._label('person'), + attributes={ + 'truncated': True, + VOC.VocAction(1).name: True, + VOC.VocAction(2).name: True, + } + ), + ]), + ]) + + class DstExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', annotations=[ + Bbox(2, 3, 4, 5, label=2, id=1, + attributes={ + 'truncated': True, + 'difficult': False, + 'occluded': False, + # no attributes here in the label categories + } + ), + Bbox(5, 4, 3, 2, label=self._label('person'), id=2, + attributes={ + 'truncated': True, + 'difficult': False, + 'occluded': False, + VOC.VocAction(1).name: True, + VOC.VocAction(2).name: True, + **{ + a.name: False for a in VOC.VocAction + if a.value not in {1, 2} + } + } + ), + ]), + ]) + with TestDir() as test_dir: - self._test_can_save_voc_dummy( - VocActionExtractor, VocActionConverter(label_map='voc'), - test_dir.path) + self._test_save_and_load(TestExtractor(), + VocActionConverter(label_map='voc'), test_dir, + target_dataset=DstExtractor()) def test_can_save_dataset_with_no_subsets(self): - class TestExtractor(Extractor): + class TestExtractor(TestExtractorBase): def __iter__(self): return iter([ DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=2, id=1), - BboxObject(2, 3, 4, 5, label=3, id=2), + Label(2), + Label(3), ]), DatasetItem(id=2, annotations=[ - BboxObject(5, 4, 6, 5, label=3, id=1), + Label(3), ]), ]) - def categories(self): - return VOC.make_voc_categories() + with TestDir() as test_dir: + self._test_save_and_load(TestExtractor(), + VocConverter(label_map='voc'), test_dir) + + def test_can_save_dataset_with_images(self): + class TestExtractor(TestExtractorBase): + def __iter__(self): + return iter([ + DatasetItem(id=1, subset='a', image=np.ones([4, 5, 3])), + DatasetItem(id=2, subset='a', image=np.ones([5, 4, 3])), + + DatasetItem(id=3, subset='b', image=np.ones([2, 6, 3])), + ]) with TestDir() as test_dir: - self._test_can_save_voc(TestExtractor(), VocConverter(label_map='voc'), - test_dir.path) + self._test_save_and_load(TestExtractor(), + VocConverter(label_map='voc', save_images=True), test_dir) def test_dataset_with_voc_labelmap(self): - class SrcExtractor(Extractor): + class SrcExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=0, id=1), - BboxObject(1, 2, 3, 4, label=1, id=2), - ]) + Bbox(2, 3, 4, 5, label=self._label('cat'), id=1), + Bbox(1, 2, 3, 4, label=self._label('non_voc_label'), id=2), + ]) def categories(self): label_cat = LabelCategories() - label_cat.add(VOC.VocLabel(1).name) + label_cat.add(VOC.VocLabel.cat.name) label_cat.add('non_voc_label') return { AnnotationType.label: label_cat, } - class DstExtractor(Extractor): + class DstExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=0, id=1), - ]) + # drop non voc label + Bbox(2, 3, 4, 5, label=self._label('cat'), id=1, + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + } + ), + ]) def categories(self): return VOC.make_voc_categories() with TestDir() as test_dir: - self._test_can_save_voc( + self._test_save_and_load( SrcExtractor(), VocConverter(label_map='voc'), - test_dir.path, target_extractor=DstExtractor()) + test_dir, target_dataset=DstExtractor()) def test_dataset_with_guessed_labelmap(self): - class SrcExtractor(Extractor): + class SrcExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=0, id=1), - BboxObject(1, 2, 3, 4, label=1, id=2), - ]) + Bbox(2, 3, 4, 5, label=0, id=1), + Bbox(1, 2, 3, 4, label=1, id=2), + ]) def categories(self): label_cat = LabelCategories() @@ -511,14 +613,25 @@ def categories(self): AnnotationType.label: label_cat, } - class DstExtractor(Extractor): + class DstExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=0, id=1), - BboxObject(1, 2, 3, 4, - label=self.categories()[AnnotationType.label] \ - .find('non_voc_label')[0], id=2), - ]) + Bbox(2, 3, 4, 5, label=self._label(VOC.VocLabel(1).name), id=1, + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + } + ), + Bbox(1, 2, 3, 4, + label=self._label('non_voc_label'), id=2, + attributes={ + 'truncated': False, + 'difficult': False, + 'occluded': False, + } + ), + ]) def categories(self): label_map = VOC.make_voc_label_map() @@ -528,20 +641,20 @@ def categories(self): return VOC.make_voc_categories(label_map) with TestDir() as test_dir: - self._test_can_save_voc( + self._test_save_and_load( SrcExtractor(), VocConverter(label_map='guess'), - test_dir.path, target_extractor=DstExtractor()) + test_dir, target_dataset=DstExtractor()) def test_dataset_with_fixed_labelmap(self): - class SrcExtractor(Extractor): + class SrcExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(2, 3, 4, 5, label=0, id=1), - BboxObject(1, 2, 3, 4, label=1, id=2, group=2, - attributes={'act1': True}), - BboxObject(2, 3, 4, 5, label=2, id=3, group=2), - BboxObject(2, 3, 4, 6, label=3, id=4, group=2), - ]) + Bbox(2, 3, 4, 5, label=0, id=1), + Bbox(1, 2, 3, 4, label=1, id=2, group=2, + attributes={'act1': True}), + Bbox(2, 3, 4, 5, label=2, id=3, group=2), + Bbox(2, 3, 4, 6, label=3, id=4, group=2), + ]) def categories(self): label_cat = LabelCategories() @@ -557,30 +670,36 @@ def categories(self): 'label': [None, ['label_part1', 'label_part2'], ['act1', 'act2']] } - class DstExtractor(Extractor): + class DstExtractor(TestExtractorBase): def __iter__(self): yield DatasetItem(id=1, annotations=[ - BboxObject(1, 2, 3, 4, label=0, id=2, group=2, - attributes={'act1': True, 'act2': False}), - BboxObject(2, 3, 4, 5, label=1, id=3, group=2), - BboxObject(2, 3, 4, 6, label=2, id=4, group=2), - ]) + Bbox(1, 2, 3, 4, label=0, id=1, group=1, + attributes={ + 'act1': True, + 'act2': False, + 'truncated': False, + 'difficult': False, + 'occluded': False, + } + ), + Bbox(2, 3, 4, 5, label=1, group=1), + Bbox(2, 3, 4, 6, label=2, group=1), + ]) def categories(self): return VOC.make_voc_categories(label_map) with TestDir() as test_dir: - self._test_can_save_voc( + self._test_save_and_load( SrcExtractor(), VocConverter(label_map=label_map), - test_dir.path, target_extractor=DstExtractor()) + test_dir, target_dataset=DstExtractor()) -class VocImporterTest(TestCase): +class VocImportTest(TestCase): def test_can_import(self): with TestDir() as test_dir: - dummy_dir = osp.join(test_dir.path, 'dummy') - subsets = generate_dummy_voc(dummy_dir) + subsets = generate_dummy_voc(test_dir) - dataset = Project.import_from(dummy_dir, 'voc').make_dataset() + dataset = Project.import_from(test_dir, 'voc').make_dataset() self.assertEqual(len(VOC.VocTask), len(dataset.sources)) self.assertEqual(set(subsets), set(dataset.subsets())) @@ -594,7 +713,7 @@ def test_can_write_and_parse_labelmap(self): src_label_map['qq'] = [None, ['part1', 'part2'], ['act1', 'act2']] with TestDir() as test_dir: - file_path = osp.join(test_dir.path, 'test.txt') + file_path = osp.join(test_dir, 'test.txt') VOC.write_label_map(file_path, src_label_map) dst_label_map = VOC.parse_label_map(file_path) diff --git a/datumaro/tests/test_yolo_format.py b/datumaro/tests/test_yolo_format.py index 6b24ba5d927d..9ce972a206fa 100644 --- a/datumaro/tests/test_yolo_format.py +++ b/datumaro/tests/test_yolo_format.py @@ -3,11 +3,11 @@ from unittest import TestCase from datumaro.components.extractor import (Extractor, DatasetItem, - AnnotationType, BboxObject, LabelCategories, + AnnotationType, Bbox, LabelCategories, ) -from datumaro.components.importers.yolo import YoloImporter -from datumaro.components.converters.yolo import YoloConverter -from datumaro.util.test_utils import TestDir +from datumaro.plugins.yolo_format.importer import YoloImporter +from datumaro.plugins.yolo_format.converter import YoloConverter +from datumaro.util.test_utils import TestDir, compare_datasets class YoloFormatTest(TestCase): @@ -17,22 +17,22 @@ def __iter__(self): return iter([ DatasetItem(id=1, subset='train', image=np.ones((8, 8, 3)), annotations=[ - BboxObject(0, 2, 4, 2, label=2), - BboxObject(0, 1, 2, 3, label=4), + Bbox(0, 2, 4, 2, label=2), + Bbox(0, 1, 2, 3, label=4), ]), DatasetItem(id=2, subset='train', image=np.ones((10, 10, 3)), annotations=[ - BboxObject(0, 2, 4, 2, label=2), - BboxObject(3, 3, 2, 3, label=4), - BboxObject(2, 1, 2, 3, label=4), + Bbox(0, 2, 4, 2, label=2), + Bbox(3, 3, 2, 3, label=4), + Bbox(2, 1, 2, 3, label=4), ]), DatasetItem(id=3, subset='valid', image=np.ones((8, 8, 3)), annotations=[ - BboxObject(0, 1, 5, 2, label=2), - BboxObject(0, 2, 3, 2, label=5), - BboxObject(0, 2, 4, 2, label=6), - BboxObject(0, 7, 3, 2, label=7), + Bbox(0, 1, 5, 2, label=2), + Bbox(0, 2, 3, 2, label=5), + Bbox(0, 2, 4, 2, label=6), + Bbox(0, 7, 3, 2, label=7), ]), ]) @@ -47,22 +47,7 @@ def categories(self): with TestDir() as test_dir: source_dataset = TestExtractor() - YoloConverter(save_images=True)(source_dataset, test_dir.path) - parsed_dataset = YoloImporter()(test_dir.path).make_dataset() + YoloConverter(save_images=True)(source_dataset, test_dir) + parsed_dataset = YoloImporter()(test_dir).make_dataset() - self.assertListEqual( - sorted(source_dataset.subsets()), - sorted(parsed_dataset.subsets()), - ) - self.assertEqual(len(source_dataset), len(parsed_dataset)) - for subset_name in source_dataset.subsets(): - source_subset = source_dataset.get_subset(subset_name) - parsed_subset = parsed_dataset.get_subset(subset_name) - for item_a, item_b in zip(source_subset, parsed_subset): - self.assertEqual(len(item_a.annotations), len(item_b.annotations)) - for ann_a, ann_b in zip(item_a.annotations, item_b.annotations): - self.assertEqual(ann_a.type, ann_b.type) - self.assertAlmostEqual(ann_a.x, ann_b.x) - self.assertAlmostEqual(ann_a.y, ann_b.y) - self.assertAlmostEqual(ann_a.w, ann_b.w) - self.assertAlmostEqual(ann_a.h, ann_b.h) \ No newline at end of file + compare_datasets(self, source_dataset, parsed_dataset) \ No newline at end of file