diff --git a/datumaro/datumaro/cli/project/__init__.py b/datumaro/datumaro/cli/project/__init__.py index cd728ef477f6..234f89d72f3f 100644 --- a/datumaro/datumaro/cli/project/__init__.py +++ b/datumaro/datumaro/cli/project/__init__.py @@ -11,6 +11,7 @@ from datumaro.components.project import Project from datumaro.components.comparator import Comparator +from datumaro.components.dataset_filter import DatasetItemEncoder from .diff import DiffVisualizer from ..util.project import make_project_path, load_project @@ -131,7 +132,12 @@ def build_export_parser(parser): "'/item[image/width < image/height]'; " "extract images with large-area bboxes: " "'/item[annotation/type=\"bbox\" and annotation/area>2000]'" + "filter out irrelevant annotations from items: " + "'/item/annotation[label = \"person\"]'" ) + parser.add_argument('-a', '--filter-annotations', action='store_true', + help="Filter annotations instead of dataset " + "items (default: %(default)s)") parser.add_argument('-d', '--dest', dest='dst_dir', required=True, help="Directory to save output") parser.add_argument('-f', '--output-format', required=True, @@ -158,10 +164,11 @@ def export_command(args): dataset = project.make_dataset() log.info("Exporting the project...") - dataset.export( + dataset.export_project( save_dir=dst_dir, output_format=args.output_format, filter_expr=args.filter, + filter_annotations=args.filter_annotations, cmdline_args=args.extra_args) log.info("Project exported to '%s' as '%s'" % \ (dst_dir, args.output_format)) @@ -177,12 +184,21 @@ def build_docs_parser(parser): def build_extract_parser(parser): parser.add_argument('-e', '--filter', default=None, - help="Filter expression for dataset items. Examples: " + help="XML XPath filter expression for dataset items. Examples: " "extract images with width < height: " "'/item[image/width < image/height]'; " "extract images with large-area bboxes: " - "'/item[annotation/type=\"bbox\" and annotation/area>2000]'" + "'/item[annotation/type=\"bbox\" and annotation/area>2000]' " + "filter out irrelevant annotations from items: " + "'/item/annotation[label = \"person\"]'" ) + parser.add_argument('-a', '--filter-annotations', action='store_true', + help="Filter annotations instead of dataset " + "items (default: %(default)s)") + parser.add_argument('--remove-empty', action='store_true', + help="Remove an item if there are no annotations left after filtration") + parser.add_argument('--dry-run', action='store_true', + help="Print XML representations to be filtered and exit") parser.add_argument('-d', '--dest', dest='dst_dir', required=True, help="Output directory") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -193,9 +209,27 @@ def extract_command(args): project = load_project(args.project_dir) dst_dir = osp.abspath(args.dst_dir) - os.makedirs(dst_dir, exist_ok=False) + if not args.dry_run: + os.makedirs(dst_dir, exist_ok=False) + + dataset = project.make_dataset() + + kwargs = {} + if args.filter_annotations: + kwargs['remove_empty'] = args.remove_empty + + if args.dry_run: + dataset = dataset.extract(filter_expr=args.filter, + filter_annotations=args.filter_annotations, **kwargs) + for item in dataset: + encoded_item = DatasetItemEncoder.encode(item, dataset.categories()) + xml_item = DatasetItemEncoder.to_string(encoded_item) + print(xml_item) + return 0 + + dataset.extract_project(save_dir=dst_dir, filter_expr=args.filter, + filter_annotations=args.filter_annotations, **kwargs) - project.make_dataset().extract(filter_expr=args.filter, save_dir=dst_dir) log.info("Subproject extracted to '%s'" % (dst_dir)) return 0 @@ -279,7 +313,7 @@ def transform_command(args): dst_dir = osp.abspath(args.dst_dir) os.makedirs(dst_dir, exist_ok=False) - project.make_dataset().transform( + project.make_dataset().apply_model( save_dir=dst_dir, model_name=args.model_name) diff --git a/datumaro/datumaro/cli/source/__init__.py b/datumaro/datumaro/cli/source/__init__.py index 6c3f9f993c8b..8fa3364b3f87 100644 --- a/datumaro/datumaro/cli/source/__init__.py +++ b/datumaro/datumaro/cli/source/__init__.py @@ -188,6 +188,9 @@ def build_export_parser(parser): "extract images with large-area bboxes: " "'/item[annotation/type=\"bbox\" and annotation/area>2000]'" ) + parser.add_argument('-a', '--filter-annotations', action='store_true', + help="Filter annotations instead of dataset " + "items (default: %(default)s)") parser.add_argument('-d', '--dest', dest='dst_dir', required=True, help="Directory to save output") parser.add_argument('-f', '--output-format', required=True, @@ -215,10 +218,11 @@ def export_command(args): dataset = source_project.make_dataset() log.info("Exporting the project...") - dataset.export( + dataset.export_project( save_dir=dst_dir, output_format=args.output_format, filter_expr=args.filter, + filter_annotations=args.filter_annotations, cmdline_args=args.extra_args) log.info("Source '%s' exported to '%s' as '%s'" % \ (args.name, dst_dir, args.output_format)) diff --git a/datumaro/datumaro/components/config_model.py b/datumaro/datumaro/components/config_model.py index fe133cb626c6..d21d3393b0da 100644 --- a/datumaro/datumaro/components/config_model.py +++ b/datumaro/datumaro/components/config_model.py @@ -60,7 +60,6 @@ def __init__(self, config=None): .add('subsets', list) \ .add('sources', lambda: _DefaultConfig( lambda v=None: Source(v))) \ - .add('filter', str) \ \ .add('project_filename', str, internal=True) \ .add('project_dir', str, internal=True) \ diff --git a/datumaro/datumaro/components/converters/voc.py b/datumaro/datumaro/components/converters/voc.py index c296c351f3be..18c99783b5ec 100644 --- a/datumaro/datumaro/components/converters/voc.py +++ b/datumaro/datumaro/components/converters/voc.py @@ -462,7 +462,7 @@ def _make_label_id_map(self): void_labels = [src_label for src_id, src_label in source_labels.items() if src_label not in target_labels] if void_labels: - log.warn("The following labels are remapped to background: %s" % + log.warning("The following labels are remapped to background: %s" % ', '.join(void_labels)) def map_id(src_id): diff --git a/datumaro/datumaro/components/dataset_filter.py b/datumaro/datumaro/components/dataset_filter.py index 28339df098a7..a32b5df6f559 100644 --- a/datumaro/datumaro/components/dataset_filter.py +++ b/datumaro/datumaro/components/dataset_filter.py @@ -4,38 +4,27 @@ # SPDX-License-Identifier: MIT from lxml import etree as ET # NOTE: lxml has proper XPath implementation -from datumaro.components.extractor import (DatasetItem, Annotation, +from datumaro.components.extractor import (DatasetItem, Extractor, + Annotation, AnnotationType, LabelObject, MaskObject, PointsObject, PolygonObject, PolyLineObject, BboxObject, CaptionObject, ) -def _cast(value, type_conv, default=None): - if value is None: - return default - try: - return type_conv(value) - except Exception: - return default - class DatasetItemEncoder: - def encode_item(self, item): + @classmethod + def encode(cls, item, categories=None): item_elem = ET.Element('item') ET.SubElement(item_elem, 'id').text = str(item.id) ET.SubElement(item_elem, 'subset').text = str(item.subset) - - # Dataset wrapper-specific - ET.SubElement(item_elem, 'source').text = \ - str(getattr(item, 'source', None)) - ET.SubElement(item_elem, 'extractor').text = \ - str(getattr(item, 'extractor', None)) + ET.SubElement(item_elem, 'path').text = str('/'.join(item.path)) image = item.image if image is not None: - item_elem.append(self.encode_image(image)) + item_elem.append(cls.encode_image(image)) for ann in item.annotations: - item_elem.append(self.encode_object(ann)) + item_elem.append(cls.encode_annotation(ann, categories)) return item_elem @@ -52,7 +41,7 @@ def encode_image(cls, image): return image_elem @classmethod - def encode_annotation(cls, annotation): + def encode_annotation_base(cls, annotation): assert isinstance(annotation, Annotation) ann_elem = ET.Element('annotation') ET.SubElement(ann_elem, 'id').text = str(annotation.id) @@ -65,18 +54,31 @@ def encode_annotation(cls, annotation): return ann_elem + @staticmethod + def _get_label(label_id, categories): + label = '' + if categories is not None: + label_cat = categories.get(AnnotationType.label) + if label_cat is not None: + label = label_cat.items[label_id].name + return label + @classmethod - def encode_label_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + def encode_label_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) ET.SubElement(ann_elem, 'label_id').text = str(obj.label) return ann_elem @classmethod - def encode_mask_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + def encode_mask_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) ET.SubElement(ann_elem, 'label_id').text = str(obj.label) mask = obj.image @@ -86,9 +88,11 @@ def encode_mask_object(cls, obj): return ann_elem @classmethod - def encode_bbox_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + def encode_bbox_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) ET.SubElement(ann_elem, 'label_id').text = str(obj.label) ET.SubElement(ann_elem, 'x').text = str(obj.x) ET.SubElement(ann_elem, 'y').text = str(obj.y) @@ -99,9 +103,11 @@ def encode_bbox_object(cls, obj): return ann_elem @classmethod - def encode_points_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + def encode_points_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) ET.SubElement(ann_elem, 'label_id').text = str(obj.label) x, y, w, h = obj.get_bbox() @@ -113,20 +119,22 @@ def encode_points_object(cls, obj): ET.SubElement(bbox_elem, 'h').text = str(h) ET.SubElement(bbox_elem, 'area').text = str(area) - points = ann_elem.points + points = obj.points for i in range(0, len(points), 2): point_elem = ET.SubElement(ann_elem, 'point') - ET.SubElement(point_elem, 'x').text = str(points[i * 2]) - ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1]) + ET.SubElement(point_elem, 'x').text = str(points[i]) + ET.SubElement(point_elem, 'y').text = str(points[i + 1]) ET.SubElement(point_elem, 'visible').text = \ - str(ann_elem.visibility[i // 2].name) + str(obj.visibility[i // 2].name) return ann_elem @classmethod - def encode_polyline_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + def encode_polygon_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) ET.SubElement(ann_elem, 'label_id').text = str(obj.label) x, y, w, h = obj.get_bbox() @@ -138,57 +146,142 @@ def encode_polyline_object(cls, obj): ET.SubElement(bbox_elem, 'h').text = str(h) ET.SubElement(bbox_elem, 'area').text = str(area) - points = ann_elem.points + points = obj.points for i in range(0, len(points), 2): point_elem = ET.SubElement(ann_elem, 'point') - ET.SubElement(point_elem, 'x').text = str(points[i * 2]) - ET.SubElement(point_elem, 'y').text = str(points[i * 2 + 1]) + ET.SubElement(point_elem, 'x').text = str(points[i]) + ET.SubElement(point_elem, 'y').text = str(points[i + 1]) + + return ann_elem + + @classmethod + def encode_polyline_object(cls, obj, categories): + ann_elem = cls.encode_annotation_base(obj) + + ET.SubElement(ann_elem, 'label').text = \ + str(cls._get_label(obj.label, categories)) + ET.SubElement(ann_elem, 'label_id').text = str(obj.label) + + x, y, w, h = obj.get_bbox() + area = w * h + bbox_elem = ET.SubElement(ann_elem, 'bbox') + ET.SubElement(bbox_elem, 'x').text = str(x) + ET.SubElement(bbox_elem, 'y').text = str(y) + ET.SubElement(bbox_elem, 'w').text = str(w) + ET.SubElement(bbox_elem, 'h').text = str(h) + ET.SubElement(bbox_elem, 'area').text = str(area) + + points = obj.points + for i in range(0, len(points), 2): + point_elem = ET.SubElement(ann_elem, 'point') + ET.SubElement(point_elem, 'x').text = str(points[i]) + ET.SubElement(point_elem, 'y').text = str(points[i + 1]) return ann_elem @classmethod def encode_caption_object(cls, obj): - ann_elem = cls.encode_annotation(obj) + ann_elem = cls.encode_annotation_base(obj) ET.SubElement(ann_elem, 'caption').text = str(obj.caption) return ann_elem - def encode_object(self, o): + @classmethod + def encode_annotation(cls, o, categories=None): if isinstance(o, LabelObject): - return self.encode_label_object(o) + return cls.encode_label_object(o, categories) if isinstance(o, MaskObject): - return self.encode_mask_object(o) + return cls.encode_mask_object(o, categories) if isinstance(o, BboxObject): - return self.encode_bbox_object(o) + return cls.encode_bbox_object(o, categories) if isinstance(o, PointsObject): - return self.encode_points_object(o) + return cls.encode_points_object(o, categories) if isinstance(o, PolyLineObject): - return self.encode_polyline_object(o) + return cls.encode_polyline_object(o, categories) if isinstance(o, PolygonObject): - return self.encode_polygon_object(o) + return cls.encode_polygon_object(o, categories) if isinstance(o, CaptionObject): - return self.encode_caption_object(o) - if isinstance(o, Annotation): # keep after derived classes - return self.encode_annotation(o) + return cls.encode_caption_object(o) + raise NotImplementedError("Unexpected annotation object passed: %s" % o) + + @staticmethod + def to_string(encoded_item): + return ET.tostring(encoded_item, encoding='unicode', pretty_print=True) + +def XPathDatasetFilter(extractor, xpath=None): + if xpath is None: + return extractor + xpath = ET.XPath(xpath) + f = lambda item: bool(xpath( + DatasetItemEncoder.encode(item, extractor.categories()))) + return extractor.select(f) + +class XPathAnnotationsFilter(Extractor): # NOTE: essentially, a transform + class ItemWrapper(DatasetItem): + def __init__(self, item, annotations): + self._item = item + self._annotations = annotations + + @DatasetItem.id.getter + def id(self): + return self._item.id + + @DatasetItem.subset.getter + def subset(self): + return self._item.subset - if isinstance(o, DatasetItem): - return self.encode_item(o) + @DatasetItem.path.getter + def path(self): + return self._item.path - return None + @DatasetItem.annotations.getter + def annotations(self): + return self._annotations -class XPathDatasetFilter: - def __init__(self, filter_text=None): - self._filter = None - if filter_text is not None: - self._filter = ET.XPath(filter_text) - self._encoder = DatasetItemEncoder() + @DatasetItem.has_image.getter + def has_image(self): + return self._item.has_image - def __call__(self, item): - encoded_item = self._serialize_item(item) + @DatasetItem.image.getter + def image(self): + return self._item.image + + def __init__(self, extractor, xpath=None, remove_empty=False): + super().__init__() + self._extractor = extractor + + if xpath is not None: + xpath = ET.XPath(xpath) + self._filter = xpath + + self._remove_empty = remove_empty + + def __len__(self): + return len(self._extractor) + + def __iter__(self): + for item in self._extractor: + item = self._filter_item(item) + if item is not None: + yield item + + def subsets(self): + return self._extractor.subsets() + + def categories(self): + return self._extractor.categories() + + def _filter_item(self, item): if self._filter is None: - return True - return bool(self._filter(encoded_item)) + return item + encoded = DatasetItemEncoder.encode(item, self._extractor.categories()) + filtered = self._filter(encoded) + filtered = [elem for elem in filtered if elem.tag == 'annotation'] + + encoded = encoded.findall('annotation') + annotations = [item.annotations[encoded.index(e)] for e in filtered] - def _serialize_item(self, item): - return self._encoder.encode_item(item) \ No newline at end of file + if self._remove_empty and len(annotations) == 0: + return None + return self.ItemWrapper(item, annotations) \ No newline at end of file diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 13e9708ac878..8c07cfe3b1f5 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -580,9 +580,9 @@ def __iter__(self): return filter(self.predicate, self.iterable) class _ExtractorBase(IExtractor): - def __init__(self, length=None): + def __init__(self, length=None, subsets=None): self._length = length - self._subsets = None + self._subsets = subsets def _init_cache(self): subsets = set() @@ -612,9 +612,12 @@ def get_subset(self, name): else: raise Exception("Unknown subset '%s' requested" % name) + def transform(self, method, *args, **kwargs): + return method(self, *args, **kwargs) + class DatasetIteratorWrapper(_ExtractorBase): - def __init__(self, iterable, categories): - super().__init__(length=None) + def __init__(self, iterable, categories, subsets=None): + super().__init__(length=None, subsets=subsets) self._iterable = iterable self._categories = categories @@ -626,7 +629,7 @@ def categories(self): def select(self, pred): return DatasetIteratorWrapper( - _DatasetFilter(self, pred), self.categories()) + _DatasetFilter(self, pred), self.categories(), self.subsets()) class Extractor(_ExtractorBase): def __init__(self, length=None): @@ -637,7 +640,7 @@ def categories(self): def select(self, pred): return DatasetIteratorWrapper( - _DatasetFilter(self, pred), self.categories()) + _DatasetFilter(self, pred), self.categories(), self.subsets()) DEFAULT_SUBSET_NAME = 'default' \ No newline at end of file diff --git a/datumaro/datumaro/components/launcher.py b/datumaro/datumaro/components/launcher.py index 0e11e8cf076e..3ac1e1fb6725 100644 --- a/datumaro/datumaro/components/launcher.py +++ b/datumaro/datumaro/components/launcher.py @@ -10,7 +10,7 @@ # pylint: disable=no-self-use class Launcher: - def __init__(self): + def __init__(self, model_dir=None): pass def launch(self, inputs): diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index 6fc16c1533de..a1e9645d635a 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -14,9 +14,10 @@ from datumaro.components.config import Config, DEFAULT_FORMAT from datumaro.components.config_model import * -from datumaro.components.extractor import * -from datumaro.components.launcher import * -from datumaro.components.dataset_filter import XPathDatasetFilter +from datumaro.components.extractor import DatasetItem, Extractor +from datumaro.components.launcher import InferenceWrapper +from datumaro.components.dataset_filter import \ + XPathDatasetFilter, XPathAnnotationsFilter def import_foreign_module(name, path): @@ -305,7 +306,131 @@ def image(self): return self._image return self._item.image -class ProjectDataset(Extractor): +class Dataset(Extractor): + @classmethod + def from_extractors(cls, *sources): + # merge categories + # TODO: implement properly with merging and annotations remapping + categories = {} + for source in sources: + categories.update(source.categories()) + for source in sources: + for cat_type, source_cat in source.categories().items(): + assert categories[cat_type] == source_cat + dataset = Dataset(categories=categories) + + # merge items + subsets = defaultdict(lambda: Subset(dataset)) + for source in sources: + for item in source: + path = None # NOTE: merge everything into our own dataset + + existing_item = subsets[item.subset].items.get(item.id) + if existing_item is not None: + image = None + if existing_item.has_image: + # TODO: think of image comparison + image = cls._lazy_image(existing_item) + + item = DatasetItemWrapper(item=item, path=path, + image=image, annotations=self._merge_anno( + existing_item.annotations, item.annotations)) + else: + item = DatasetItemWrapper(item=item, path=path, + annotations=item.annotations) + + subsets[item.subset].items[item.id] = item + + self._subsets = dict(subsets) + + def __init__(self, categories=None): + super().__init__() + + self._subsets = {} + + if not categories: + categories = {} + self._categories = categories + + def __iter__(self): + for subset in self._subsets.values(): + for item in subset: + yield item + + def __len__(self): + if self._length is None: + self._length = reduce(lambda s, x: s + len(x), + self._subsets.values(), 0) + return self._length + + def get_subset(self, name): + return self._subsets[name] + + def subsets(self): + return list(self._subsets) + + def categories(self): + return self._categories + + def get(self, item_id, subset=None, path=None): + if path: + raise KeyError("Requested dataset item path is not found") + return self._subsets[subset].items[item_id] + + def put(self, item, item_id=None, subset=None, path=None): + if path: + raise KeyError("Requested dataset item path is not found") + + if item_id is None: + item_id = item.id + if subset is None: + subset = item.subset + + item = DatasetItemWrapper(item=item, path=None, + annotations=item.annotations) + if item.subset not in self._subsets: + self._subsets[item.subset] = Subset(self) + self._subsets[subset].items[item_id] = item + self._length = None + + return item + + def extract(self, filter_expr, filter_annotations=False, **kwargs): + if filter_annotations: + return self.transform(XPathAnnotationsFilter, filter_expr, **kwargs) + else: + return self.transform(XPathDatasetFilter, filter_expr, **kwargs) + + def update(self, items): + for item in items: + self.put(item) + return self + + def define_categories(self, categories): + assert not self._categories + self._categories = categories + + @staticmethod + def _lazy_image(item): + # NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result + return lambda: item.image + + @staticmethod + def _merge_anno(a, b): + from itertools import chain + merged = [] + for item in chain(a, b): + found = False + for elem in merged: + if elem == item: + found = True + break + if not found: + merged.append(item) + + return merged + +class ProjectDataset(Dataset): def __init__(self, project): super().__init__() @@ -313,11 +438,6 @@ def __init__(self, project): config = self.config env = self.env - dataset_filter = None - if config.filter: - dataset_filter = XPathDatasetFilter(config.filter) - self._filter = dataset_filter - sources = {} for s_name, source in config.sources.items(): s_format = source.format @@ -335,7 +455,7 @@ def __init__(self, project): own_source = None own_source_dir = osp.join(config.project_dir, config.dataset_dir) - if osp.isdir(config.project_dir) and osp.isdir(own_source_dir): + if config.project_dir and osp.isdir(own_source_dir): log.disable(log.INFO) own_source = env.make_importer(DEFAULT_FORMAT)(own_source_dir) \ .make_dataset() @@ -358,9 +478,6 @@ def __init__(self, project): for source_name, source in self._sources.items(): log.debug("Loading '%s' source contents..." % source_name) for item in source: - if dataset_filter and not dataset_filter(item): - continue - existing_item = subsets[item.subset].items.get(item.id) if existing_item is not None: image = None @@ -370,14 +487,14 @@ def __init__(self, project): path = existing_item.path if item.path != path: - path = None + path = None # NOTE: move to our own dataset item = DatasetItemWrapper(item=item, path=path, image=image, annotations=self._merge_anno( existing_item.annotations, item.annotations)) else: s_config = config.sources[source_name] if s_config and \ - s_config.format != self.env.PROJECT_EXTRACTOR_NAME: + s_config.format != env.PROJECT_EXTRACTOR_NAME: # NOTE: consider imported sources as our own dataset path = None else: @@ -394,9 +511,6 @@ def __init__(self, project): if own_source is not None: log.debug("Loading own dataset...") for item in own_source: - if dataset_filter and not dataset_filter(item): - continue - if not item.has_image: existing_item = subsets[item.subset].items.get(item.id) if existing_item is not None: @@ -417,55 +531,9 @@ def __init__(self, project): self._length = None - @staticmethod - def _lazy_image(item): - # NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result - return lambda: item.image - - @staticmethod - def _merge_anno(a, b): - from itertools import chain - merged = [] - for item in chain(a, b): - found = False - for elem in merged: - if elem == item: - found = True - break - if not found: - merged.append(item) - - return merged - def iterate_own(self): return self.select(lambda item: not item.path) - def __iter__(self): - for subset in self._subsets.values(): - for item in subset: - if self._filter and not self._filter(item): - continue - yield item - - def __len__(self): - if self._length is None: - self._length = reduce(lambda s, x: s + len(x), - self._subsets.values(), 0) - return self._length - - def get_subset(self, name): - return self._subsets[name] - - def subsets(self): - return list(self._subsets) - - def categories(self): - return self._categories - - def define_categories(self, categories): - assert not self._categories - self._categories = categories - def get(self, item_id, subset=None, path=None): if path: source = path[0] @@ -498,54 +566,6 @@ def put(self, item, item_id=None, subset=None, path=None): return item - def build(self, tasks=None): - pass - - def docs(self): - pass - - def transform(self, model_name, save_dir=None): - project = Project(self.config) - project.config.remove('sources') - - if save_dir is None: - save_dir = self.config.project_dir - project.config.project_dir = save_dir - - dataset = project.make_dataset() - launcher = self._project.make_executable_model(model_name) - inference = InferenceWrapper(self, launcher) - dataset.update(inference) - - dataset.save(merge=True) - - def export(self, save_dir, output_format, - filter_expr=None, **converter_kwargs): - save_dir = osp.abspath(save_dir) - os.makedirs(save_dir, exist_ok=True) - - dataset = self - if filter_expr: - dataset_filter = XPathDatasetFilter(filter_expr) - dataset = dataset.select(dataset_filter) - - converter = self.env.make_converter(output_format, **converter_kwargs) - converter(dataset, save_dir) - - def extract(self, save_dir, filter_expr=None): - project = Project(self.config) - if filter_expr: - XPathDatasetFilter(filter_expr) - project.set_filter(filter_expr) - project.save(save_dir) - - def update(self, items): - for item in items: - if self._filter and not self._filter(item): - continue - self.put(item) - return self - def save(self, save_dir=None, merge=False, recursive=True, save_images=False): if save_dir is None: @@ -600,6 +620,60 @@ def config(self): def sources(self): return self._sources + def _save_branch_project(self, extractor, save_dir=None): + # NOTE: probably this function should be in the ViewModel layer + save_dir = osp.abspath(save_dir) + if save_dir: + dst_project = Project() + else: + if not self.config.project_dir: + raise Exception("Either a save directory or a project " + "directory should be specified") + save_dir = self.config.project_dir + + dst_project = Project(Config(self.config)) + dst_project.config.remove('project_dir') + dst_project.config.remove('sources') + + dst_dataset = dst_project.make_dataset() + dst_dataset.define_categories(extractor.categories()) + dst_dataset.update(extractor) + + dst_dataset.save(save_dir=save_dir, merge=True) + + def transform_project(self, method, *args, save_dir=None, **kwargs): + # NOTE: probably this function should be in the ViewModel layer + transformed = self.transform(method, *args, **kwargs) + self._save_branch_project(transformed, save_dir=save_dir) + + def apply_model(self, model_name, save_dir=None): + # NOTE: probably this function should be in the ViewModel layer + launcher = self._project.make_executable_model(model_name) + self.transform_project(InferenceWrapper, launcher, save_dir=save_dir) + + def export_project(self, save_dir, output_format, + filter_expr=None, filter_annotations=False, **converter_kwargs): + # NOTE: probably this function should be in the ViewModel layer + save_dir = osp.abspath(save_dir) + os.makedirs(save_dir, exist_ok=True) + + dataset = self + if filter_expr: + dataset = dataset.extract(filter_expr, filter_annotations) + + converter = self.env.make_converter(output_format, **converter_kwargs) + converter(dataset, save_dir) + + def extract_project(self, filter_expr, filter_annotations=False, + save_dir=None, remove_empty=False): + # NOTE: probably this function should be in the ViewModel layer + filtered = self + if filter_expr: + filtered = self.extract(filter_expr, + filter_annotations=filter_annotations, + remove_empty=remove_empty) + self._save_branch_project(filtered, save_dir=save_dir) + class Project: @staticmethod def load(path): @@ -697,24 +771,10 @@ def make_source_project(self, name): config = Config(self.config) config.remove('sources') config.remove('subsets') - config.remove('filter') project = Project(config) project.add_source(name, source) return project - def get_filter(self): - if 'filter' in self.config: - return self.config.filter - return '' - - def set_filter(self, value=None): - if not value: - self.config.remove('filter') - else: - # check filter - XPathDatasetFilter(value) - self.config.filter = value - def local_model_dir(self, model_name): return osp.join( self.config.env_dir, self.env.config.models_dir, model_name) @@ -726,4 +786,4 @@ def local_source_dir(self, source_name): def load_project_as_dataset(url): # implement the function declared above return Project.load(url).make_dataset() -# pylint: enable=function-redefined \ No newline at end of file +# pylint: enable=function-redefined diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index a66668fdcd0d..c30a570cb59c 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -1,3 +1,4 @@ +import numpy as np import os import os.path as osp @@ -7,9 +8,13 @@ from datumaro.components.project import Source, Model from datumaro.components.launcher import Launcher, InferenceWrapper from datumaro.components.converter import Converter -from datumaro.components.extractor import Extractor, DatasetItem, LabelObject +from datumaro.components.extractor import (Extractor, DatasetItem, + LabelObject, MaskObject, PointsObject, PolygonObject, + PolyLineObject, BboxObject, CaptionObject, +) from datumaro.components.config import Config, DefaultConfig, SchemaBuilder -from datumaro.components.dataset_filter import XPathDatasetFilter +from datumaro.components.dataset_filter import \ + XPathDatasetFilter, XPathAnnotationsFilter, DatasetItemEncoder from datumaro.util.test_utils import TestDir @@ -129,18 +134,11 @@ def test_can_have_project_source(self): def test_can_batch_launch_custom_model(self): class TestExtractor(Extractor): - def __init__(self, url, n=0): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(5): yield DatasetItem(id=i, subset='train', image=i) class TestLauncher(Launcher): - def __init__(self, **kwargs): - pass - def launch(self, inputs): for i, inp in enumerate(inputs): yield [ LabelObject(attributes={'idx': i, 'data': inp}) ] @@ -152,7 +150,7 @@ def launch(self, inputs): project.env.launchers.register(launcher_name, TestLauncher) project.add_model(model_name, { 'launcher': launcher_name }) model = project.make_executable_model(model_name) - extractor = TestExtractor('', n=5) + extractor = TestExtractor() batch_size = 3 executor = InferenceWrapper(extractor, model, batch_size=batch_size) @@ -166,19 +164,12 @@ def launch(self, inputs): def test_can_do_transform_with_custom_model(self): class TestExtractorSrc(Extractor): - def __init__(self, url, n=2): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(2): yield DatasetItem(id=i, subset='train', image=i, annotations=[ LabelObject(i) ]) class TestLauncher(Launcher): - def __init__(self, **kwargs): - pass - def launch(self, inputs): for inp in inputs: yield [ LabelObject(inp) ] @@ -186,7 +177,7 @@ def launch(self, inputs): class TestConverter(Converter): def __call__(self, extractor, save_dir): for item in extractor: - with open(osp.join(save_dir, '%s.txt' % item.id), 'w+') as f: + with open(osp.join(save_dir, '%s.txt' % item.id), 'w') as f: f.write(str(item.subset) + '\n') f.write(str(item.annotations[0].label) + '\n') @@ -199,8 +190,8 @@ def __iter__(self): for path in self.items: with open(path, 'r') as f: index = osp.splitext(osp.basename(path))[0] - subset = f.readline()[:-1] - label = int(f.readline()[:-1]) + subset = f.readline().strip() + label = int(f.readline().strip()) assert subset == 'train' yield DatasetItem(id=index, subset=subset, annotations=[ LabelObject(label) ]) @@ -217,7 +208,8 @@ def __iter__(self): project.add_source('source', { 'format': extractor_name }) with TestDir() as test_dir: - project.make_dataset().transform(model_name, test_dir.path) + project.make_dataset().apply_model(model_name=model_name, + save_dir=test_dir.path) result = Project.load(test_dir.path) result.env.extractors.register(extractor_name, TestExtractorDst) @@ -255,21 +247,16 @@ def __iter__(self): def test_project_filter_can_be_applied(self): class TestExtractor(Extractor): - def __init__(self, url, n=10): - super().__init__(length=n) - self.n = n - def __iter__(self): - for i in range(self.n): + for i in range(10): yield DatasetItem(id=i, subset='train') e_type = 'type' project = Project() project.env.extractors.register(e_type, TestExtractor) project.add_source('source', { 'format': e_type }) - project.set_filter('/item[id < 5]') - dataset = project.make_dataset() + dataset = project.make_dataset().extract('/item[id < 5]') self.assertEqual(5, len(dataset)) @@ -326,30 +313,23 @@ def test_project_compound_child_can_be_modified_recursively(self): self.assertEqual(1, len(dataset.sources['child2'])) def test_project_can_merge_item_annotations(self): - class TestExtractor(Extractor): - def __init__(self, url, v=None): - super().__init__() - self.v = v - + class TestExtractor1(Extractor): def __iter__(self): - v1_item = DatasetItem(id=1, subset='train', annotations=[ + yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(2, id=3), LabelObject(3, attributes={ 'x': 1 }), ]) - v2_item = DatasetItem(id=1, subset='train', annotations=[ + class TestExtractor2(Extractor): + def __iter__(self): + yield DatasetItem(id=1, subset='train', annotations=[ LabelObject(3, attributes={ 'x': 1 }), LabelObject(4, id=4), ]) - if self.v == 1: - yield v1_item - else: - yield v2_item - project = Project() - project.env.extractors.register('t1', lambda p: TestExtractor(p, v=1)) - project.env.extractors.register('t2', lambda p: TestExtractor(p, v=2)) + project.env.extractors.register('t1', TestExtractor1) + project.env.extractors.register('t2', TestExtractor2) project.add_source('source1', { 'format': 't1' }) project.add_source('source2', { 'format': 't2' }) @@ -361,23 +341,103 @@ def __iter__(self): self.assertEqual(3, len(item.annotations)) class DatasetFilterTest(TestCase): - class TestExtractor(Extractor): - def __init__(self, url, n=0): - super().__init__(length=n) - self.n = n - - def __iter__(self): - for i in range(self.n): - yield DatasetItem(id=i, subset='train') + @staticmethod + def test_item_representations(): + item = DatasetItem(id=1, subset='subset', path=['a', 'b'], + image=np.ones((5, 4, 3)), + annotations=[ + LabelObject(0, attributes={'a1': 1, 'a2': '2'}, id=1, group=2), + CaptionObject('hello', id=1), + CaptionObject('world', group=5), + LabelObject(2, id=3, attributes={ 'x': 1, 'y': '2' }), + BboxObject(1, 2, 3, 4, label=4, id=4, attributes={ 'a': 1.0 }), + BboxObject(5, 6, 7, 8, id=5, group=5), + PointsObject([1, 2, 2, 0, 1, 1], label=0, id=5), + MaskObject(label=3, id=5, image=np.ones((2, 3))), + PolyLineObject([1, 2, 3, 4, 5, 6, 7, 8], id=11), + PolygonObject([1, 2, 3, 4, 5, 6, 7, 8]), + ] + ) + + encoded = DatasetItemEncoder.encode(item) + DatasetItemEncoder.to_string(encoded) + + def test_item_filter_can_be_applied(self): + class TestExtractor(Extractor): + def __iter__(self): + for i in range(4): + yield DatasetItem(id=i, subset='train') - def test_xpathfilter_can_be_applied(self): - extractor = self.TestExtractor('', n=4) - dataset_filter = XPathDatasetFilter('/item[id > 1]') + extractor = TestExtractor() - filtered = extractor.select(dataset_filter) + filtered = XPathDatasetFilter(extractor, '/item[id > 1]') self.assertEqual(2, len(filtered)) + def test_annotations_filter_can_be_applied(self): + class SrcTestExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=0), + DatasetItem(id=1, annotations=[ + LabelObject(0), + LabelObject(1), + ]), + DatasetItem(id=2, annotations=[ + LabelObject(0), + LabelObject(2), + ]), + ]) + + class DstTestExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=0), + DatasetItem(id=1, annotations=[ + LabelObject(0), + ]), + DatasetItem(id=2, annotations=[ + LabelObject(0), + ]), + ]) + + extractor = SrcTestExtractor() + + filtered = XPathAnnotationsFilter(extractor, + '/item/annotation[label_id = 0]') + + self.assertListEqual(list(filtered), list(DstTestExtractor())) + + def test_annotations_filter_can_remove_empty_items(self): + class SrcTestExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=0), + DatasetItem(id=1, annotations=[ + LabelObject(0), + LabelObject(1), + ]), + DatasetItem(id=2, annotations=[ + LabelObject(0), + LabelObject(2), + ]), + ]) + + class DstTestExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=2, annotations=[ + LabelObject(2), + ]), + ]) + + extractor = SrcTestExtractor() + + filtered = XPathAnnotationsFilter(extractor, + '/item/annotation[label_id = 2]', remove_empty=True) + + self.assertListEqual(list(filtered), list(DstTestExtractor())) + class ConfigTest(TestCase): def test_can_produce_multilayer_config_from_dict(self): schema_low = SchemaBuilder() \ @@ -409,9 +469,6 @@ def test_can_produce_multilayer_config_from_dict(self): class ExtractorTest(TestCase): def test_custom_extractor_can_be_created(self): class CustomExtractor(Extractor): - def __init__(self, url): - super().__init__() - def __iter__(self): return iter([ DatasetItem(id=0, subset='train'),