From 7406fe89e59991679f63c6f73ecbd8682ec5b10f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 26 Feb 2021 20:40:32 +0300 Subject: [PATCH] Fix diff cli, add confusion matrices for polygons and masks (#117) * Fix diff cli, add confusion matrices for polygons and masks - Fixed diff invocation problem - Added test - Added confusion matrices for polygons ans masks - Added support for mismathing classes - Fixed ProjectDataset.Subset * Update changelog --- CHANGELOG.md | 2 + datumaro/cli/contexts/project/__init__.py | 23 +- datumaro/cli/contexts/project/diff.py | 260 +++++++++++++--------- datumaro/components/extractor.py | 13 ++ datumaro/components/operations.py | 38 ++-- datumaro/components/project.py | 31 ++- tests/cli/__init__.py | 0 tests/cli/test_diff.py | 123 ++++++++++ 8 files changed, 334 insertions(+), 156 deletions(-) create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/test_diff.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bdede04c0e9f..2d3f7ea337ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Laziness, source caching, tracking of changes and partial updating for `Dataset` () - `Market-1501` dataset format () - `LFW` dataset format () +- Support of polygons' and masks' confusion matrices and mismathing classes in `diff` command () ### Changed - OpenVINO model launcher is updated for OpenVINO r2021.1 () @@ -25,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - High memory consumption and low performance of mask import/export, #53 () - Masks, covered by class 0 (background), should be exported with holes inside () +- `diff` command invocation problem with missing class methods () ### Security - diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index 3dbaeb6aa89e..26c97bd3504f 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -18,11 +18,12 @@ from datumaro.components.project import \ PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project +from datumaro.util import error_rollback from ...util import (CliException, MultilineFormatter, add_subparser, make_file_name) from ...util.project import generate_next_file_name, load_project -from .diff import DiffVisualizer +from .diff import DatasetDiffVisualizer def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -506,8 +507,8 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, help="Directory to save comparison results (default: do not save)") parser.add_argument('-v', '--visualizer', - default=DiffVisualizer.DEFAULT_FORMAT, - choices=[f.name for f in DiffVisualizer.Format], + default=DatasetDiffVisualizer.DEFAULT_FORMAT.name, + choices=[f.name for f in DatasetDiffVisualizer.OutputFormat], help="Output format (default: %(default)s)") parser.add_argument('--iou-thresh', default=0.5, type=float, help="IoU match threshold for detections (default: %(default)s)") @@ -521,6 +522,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): return parser +@error_rollback('on_error', implicit=True) def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) @@ -540,17 +542,14 @@ def diff_command(args): dst_dir = osp.abspath(dst_dir) log.info("Saving diff to '%s'" % dst_dir) - dst_dir_existed = osp.exists(dst_dir) - try: - visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.visualizer) - visualizer.save_dataset_diff( + if not osp.exists(dst_dir): + on_error.do(shutil.rmtree, dst_dir, ignore_errors=True) + + with DatasetDiffVisualizer(save_dir=dst_dir, comparator=comparator, + output_format=args.visualizer) as visualizer: + visualizer.save( first_project.make_dataset(), second_project.make_dataset()) - except BaseException: - if not dst_dir_existed and osp.isdir(dst_dir): - shutil.rmtree(dst_dir, ignore_errors=True) - raise return 0 diff --git a/datumaro/cli/contexts/project/diff.py b/datumaro/cli/contexts/project/diff.py index 7f638bbb0a65..52eeddc64572 100644 --- a/datumaro/cli/contexts/project/diff.py +++ b/datumaro/cli/contexts/project/diff.py @@ -4,11 +4,15 @@ # SPDX-License-Identifier: MIT from collections import Counter +from itertools import zip_longest from enum import Enum -import numpy as np +import logging as log import os import os.path as osp +import cv2 +import numpy as np + _formats = ['simple'] import warnings @@ -17,74 +21,84 @@ import tensorboardX as tb _formats.append('tensorboard') -from datumaro.components.extractor import AnnotationType +from datumaro.components.dataset import IDataset +from datumaro.components.extractor import AnnotationType, LabelCategories from datumaro.util.image import save_image -Format = Enum('Formats', _formats) +OutputFormat = Enum('Formats', _formats) -class DiffVisualizer: - Format = Format - DEFAULT_FORMAT = Format.simple +class DatasetDiffVisualizer: + OutputFormat = OutputFormat + DEFAULT_FORMAT = OutputFormat.simple _UNMATCHED_LABEL = -1 def __init__(self, comparator, save_dir, output_format=DEFAULT_FORMAT): - self.comparator = comparator + self.cmp = comparator if isinstance(output_format, str): - output_format = Format[output_format] - assert output_format in Format + output_format = OutputFormat[output_format] + assert output_format in OutputFormat self.output_format = output_format self.save_dir = save_dir - if output_format is Format.tensorboard: + + def __enter__(self): + os.makedirs(self.save_dir, exist_ok=True) + + if self.output_format is OutputFormat.tensorboard: logdir = osp.join(self.save_dir, 'logs', 'diff') self.file_writer = tb.SummaryWriter(logdir) - if output_format is Format.simple: + elif self.output_format is OutputFormat.simple: self.label_diff_writer = None - self.categories = {} + self._a_classes = {} + self._b_classes = {} self.label_confusion_matrix = Counter() self.bbox_confusion_matrix = Counter() + self.polygon_confusion_matrix = Counter() + self.mask_confusion_matrix = Counter() - def save_dataset_diff(self, extractor_a, extractor_b): - if self.save_dir: - os.makedirs(self.save_dir, exist_ok=True) + return self - if len(extractor_a) != len(extractor_b): - print("Datasets have different lengths: %s vs %s" % \ - (len(extractor_a), len(extractor_b))) - - self.categories = {} + def __exit__(self, *args, **kwargs): + if self.output_format is OutputFormat.tensorboard: + self.file_writer.flush() + self.file_writer.close() + elif self.output_format is OutputFormat.simple: + if self.label_diff_writer: + self.label_diff_writer.flush() + self.label_diff_writer.close() - label_mismatch = self.comparator. \ - compare_dataset_labels(extractor_a, extractor_b) - if label_mismatch is None: - print("Datasets have no label information") - elif len(label_mismatch) != 0: + def save(self, a: IDataset, b: IDataset): + if len(a) != len(b): + print("Datasets have different lengths: %s vs %s" % \ + (len(a), len(b))) + + a_classes = a.categories().get(AnnotationType.label, LabelCategories()) + b_classes = b.categories().get(AnnotationType.label, LabelCategories()) + class_mismatch = [(idx, a_cls, b_cls) + for idx, (a_cls, b_cls) in enumerate( + zip_longest(a_classes, b_classes)) + if getattr(a_cls, 'name', None) != getattr(b_cls, 'name', None) + ] + if class_mismatch: print("Datasets have mismatching labels:") - for a_label, b_label in label_mismatch: - if a_label is None: - print(" > %s" % b_label.name) - elif b_label is None: - print(" < %s" % a_label.name) + for idx, a_class, b_class in class_mismatch: + if a_class and b_class: + print(" #%s: %s != %s" % (idx, a_class.name, b_class.name)) + elif a_class: + print(" #%s: > %s" % (idx, a_class.name)) else: - print(" %s != %s" % (a_label.name, b_label.name)) - else: - self.categories.update(extractor_a.categories()) - self.categories.update(extractor_b.categories()) + print(" #%s: < %s" % (idx, b_class.name)) + self._a_classes = a.categories().get(AnnotationType.label) + self._b_classes = b.categories().get(AnnotationType.label) - self.label_confusion_matrix = Counter() - self.bbox_confusion_matrix = Counter() - - if self.output_format is Format.tensorboard: - self.file_writer.reopen() - - ids_a = set((item.id, item.subset) for item in extractor_a) - ids_b = set((item.id, item.subset) for item in extractor_b) + ids_a = set((item.id, item.subset) for item in a) + ids_b = set((item.id, item.subset) for item in b) ids = ids_a & ids_b if len(ids) != len(ids_a): @@ -95,32 +109,36 @@ def save_dataset_diff(self, extractor_a, extractor_b): print(ids_b - ids) for item_id, item_subset in ids: - item_a = extractor_a.get(item_id, item_subset) - item_b = extractor_a.get(item_id, item_subset) + item_a = a.get(item_id, item_subset) + item_b = b.get(item_id, item_subset) - label_diff = self.comparator.compare_item_labels(item_a, item_b) + label_diff = self.cmp.match_labels(item_a, item_b) self.update_label_confusion(label_diff) - bbox_diff = self.comparator.compare_item_bboxes(item_a, item_b) + bbox_diff = self.cmp.match_boxes(item_a, item_b) self.update_bbox_confusion(bbox_diff) + polygon_diff = self.cmp.match_polygons(item_a, item_b) + self.update_polygon_confusion(polygon_diff) + + mask_diff = self.cmp.match_masks(item_a, item_b) + self.update_mask_confusion(mask_diff) + self.save_item_label_diff(item_a, item_b, label_diff) self.save_item_bbox_diff(item_a, item_b, bbox_diff) if len(self.label_confusion_matrix) != 0: self.save_conf_matrix(self.label_confusion_matrix, - 'labels_confusion.png') + 'label_confusion.png') if len(self.bbox_confusion_matrix) != 0: self.save_conf_matrix(self.bbox_confusion_matrix, 'bbox_confusion.png') - - if self.output_format is Format.tensorboard: - self.file_writer.flush() - self.file_writer.close() - elif self.output_format is Format.simple: - if self.label_diff_writer: - self.label_diff_writer.flush() - self.label_diff_writer.close() + if len(self.polygon_confusion_matrix) != 0: + self.save_conf_matrix(self.polygon_confusion_matrix, + 'polygon_confusion.png') + if len(self.mask_confusion_matrix) != 0: + self.save_conf_matrix(self.mask_confusion_matrix, + 'mask_confusion.png') def update_label_confusion(self, label_diff): matches, a_unmatched, b_unmatched = label_diff @@ -131,23 +149,31 @@ def update_label_confusion(self, label_diff): for b_label in b_unmatched: self.label_confusion_matrix[(self._UNMATCHED_LABEL, b_label)] += 1 - def update_bbox_confusion(self, bbox_diff): - matches, mispred, a_unmatched, b_unmatched = bbox_diff - for a_bbox, b_bbox in matches: - self.bbox_confusion_matrix[(a_bbox.label, b_bbox.label)] += 1 - for a_bbox, b_bbox in mispred: - self.bbox_confusion_matrix[(a_bbox.label, b_bbox.label)] += 1 - for a_bbox in a_unmatched: - self.bbox_confusion_matrix[(a_bbox.label, self._UNMATCHED_LABEL)] += 1 - for b_bbox in b_unmatched: - self.bbox_confusion_matrix[(self._UNMATCHED_LABEL, b_bbox.label)] += 1 + @classmethod + def _update_segment_confusion(cls, matrix, diff): + matches, mispred, a_unmatched, b_unmatched = diff + for a_segm, b_segm in matches: + matrix[(a_segm.label, b_segm.label)] += 1 + for a_segm, b_segm in mispred: + matrix[(a_segm.label, b_segm.label)] += 1 + for a_segm in a_unmatched: + matrix[(a_segm.label, cls._UNMATCHED_LABEL)] += 1 + for b_segm in b_unmatched: + matrix[(cls._UNMATCHED_LABEL, b_segm.label)] += 1 + + def update_bbox_confusion(self, diff): + self._update_segment_confusion(self.bbox_confusion_matrix, diff) + + def update_polygon_confusion(self, diff): + self._update_segment_confusion(self.polygon_confusion_matrix, diff) + + def update_mask_confusion(self, diff): + self._update_segment_confusion(self.mask_confusion_matrix, diff) @classmethod def draw_text_with_background(cls, frame, text, origin, font=None, scale=1.0, color=(0, 0, 0), thickness=1, bgcolor=(1, 1, 1)): - import cv2 - if not font: font = cv2.FONT_HERSHEY_SIMPLEX @@ -162,8 +188,6 @@ def draw_text_with_background(cls, frame, text, origin, return text_size, baseline def draw_detection_roi(self, frame, x, y, w, h, label, conf, color): - import cv2 - cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2) text = '%s %.2f%%' % (label, 100.0 * conf) @@ -175,17 +199,22 @@ def draw_detection_roi(self, frame, x, y, w, h, label, conf, color): np.array([x, y]) - line_height * 0.5, font, scale=text_scale, color=[255 - c for c in color]) - def get_label(self, label_id): - cat = self.categories.get(AnnotationType.label) + def get_a_label(self, label_id): + return self._get_label(self._a_classes, label_id) + + def get_b_label(self, label_id): + return self._get_label(self._b_classes, label_id) + + @staticmethod + def _get_label(cat: LabelCategories, label_id): if cat is None: return str(label_id) - return cat.items[label_id].name + return cat[label_id].name - def draw_bbox(self, img, shape, color): + def draw_bbox(self, img, shape, label, color): x, y, w, h = shape.get_bbox() self.draw_detection_roi(img, int(x), int(y), int(w), int(h), - self.get_label(shape.label), shape.attributes.get('score', 1), - color) + label, shape.attributes.get('score', 1), color) def get_label_diff_file(self): if self.label_diff_writer is None: @@ -197,43 +226,51 @@ def save_item_label_diff(self, item_a, item_b, diff): _, a_unmatched, b_unmatched = diff if 0 < len(a_unmatched) + len(b_unmatched): - if self.output_format is Format.simple: + if self.output_format is OutputFormat.simple: f = self.get_label_diff_file() f.write(item_a.id + '\n') for a_label in a_unmatched: - f.write(' >%s\n' % self.get_label(a_label)) + f.write(' >%s\n' % self.get_a_label(a_label)) for b_label in b_unmatched: - f.write(' <%s\n' % self.get_label(b_label)) - elif self.output_format is Format.tensorboard: + f.write(' <%s\n' % self.get_b_label(b_label)) + elif self.output_format is OutputFormat.tensorboard: tag = item_a.id for a_label in a_unmatched: self.file_writer.add_text(tag, - '>%s\n' % self.get_label(a_label)) + '>%s\n' % self.get_a_label(a_label)) for b_label in b_unmatched: self.file_writer.add_text(tag, - '<%s\n' % self.get_label(b_label)) + '<%s\n' % self.get_b_label(b_label)) def save_item_bbox_diff(self, item_a, item_b, diff): _, mispred, a_unmatched, b_unmatched = diff if 0 < len(a_unmatched) + len(b_unmatched) + len(mispred): + if not item_a.has_image or not item_a.image.has_data: + log.warning("Item %s: item has no image data, " + "it will be skipped" % (item_a.id)) + return img_a = item_a.image.data.copy() img_b = img_a.copy() for a_bbox, b_bbox in mispred: - self.draw_bbox(img_a, a_bbox, (0, 255, 0)) - self.draw_bbox(img_b, b_bbox, (0, 0, 255)) + self.draw_bbox(img_a, a_bbox, self.get_a_label(a_bbox.label), + (0, 255, 0)) + self.draw_bbox(img_b, b_bbox, self.get_b_label(b_bbox.label), + (0, 0, 255)) for a_bbox in a_unmatched: - self.draw_bbox(img_a, a_bbox, (255, 255, 0)) + self.draw_bbox(img_a, a_bbox, self.get_a_label(a_bbox.label), + (255, 255, 0)) for b_bbox in b_unmatched: - self.draw_bbox(img_b, b_bbox, (255, 255, 0)) + self.draw_bbox(img_b, b_bbox, self.get_b_label(b_bbox.label), + (255, 255, 0)) img = np.hstack([img_a, img_b]) - path = osp.join(self.save_dir, item_a.id) + path = osp.join(self.save_dir, item_a.subset, item_a.id) - if self.output_format is Format.simple: + if self.output_format is OutputFormat.simple: save_image(path + '.png', img, create_dir=True) - elif self.output_format is Format.tensorboard: + elif self.output_format is OutputFormat.tensorboard: self.save_as_tensorboard(img, path) def save_as_tensorboard(self, img, name): @@ -245,29 +282,36 @@ def save_as_tensorboard(self, img, name): def save_conf_matrix(self, conf_matrix, filename): import matplotlib.pyplot as plt - classes = None - label_categories = self.categories.get(AnnotationType.label) - if label_categories is not None: - classes = { id: c.name for id, c in enumerate(label_categories.items) } - if classes is None: - classes = { c: 'label_%s' % c for c, _ in conf_matrix } - classes[self._UNMATCHED_LABEL] = 'unmatched' - - class_idx = { id: i for i, id in enumerate(classes.keys()) } - matrix = np.zeros((len(classes), len(classes)), dtype=int) + def _get_class_map(label_categories): + classes = None + if label_categories is not None: + classes = { id: c.name + for id, c in enumerate(label_categories.items) } + if classes is None: + classes = { c: 'label_%s' % c for c, _ in conf_matrix } + classes[self._UNMATCHED_LABEL] = 'unmatched' + classes[None] = 'no_class' + return classes + a_classes = _get_class_map(self._a_classes) + b_classes = _get_class_map(self._b_classes) + + a_class_idx = { id: i for i, id in enumerate(a_classes) } + b_class_idx = { id: i for i, id in enumerate(b_classes) } + matrix = np.zeros((len(a_classes), len(b_classes)), dtype=int) for idx_pair in conf_matrix: - index = (class_idx[idx_pair[0]], class_idx[idx_pair[1]]) + index = (a_class_idx[idx_pair[0]], b_class_idx[idx_pair[1]]) matrix[index] = conf_matrix[idx_pair] - labels = [label for id, label in classes.items()] + a_labels = [label for id, label in a_classes.items()] + b_labels = [label for id, label in b_classes.items()] fig = plt.figure() fig.add_subplot(111) table = plt.table( cellText=matrix, - colLabels=labels, - rowLabels=labels, - loc ='center') + rowLabels=a_labels, + colLabels=b_labels, + loc='center') table.auto_set_font_size(False) table.set_fontsize(8) table.scale(3, 3) @@ -278,13 +322,13 @@ def save_conf_matrix(self, conf_matrix, filename): plt.gca().spines[pos].set_visible(False) for idx_pair in conf_matrix: - i = class_idx[idx_pair[0]] - j = class_idx[idx_pair[1]] + i = a_class_idx[idx_pair[0]] + j = b_class_idx[idx_pair[1]] if conf_matrix[idx_pair] != 0: - if i != j: - table._cells[(i + 1, j)].set_facecolor('#FF0000') - else: + if a_classes[idx_pair[0]] == b_classes[idx_pair[1]]: table._cells[(i + 1, j)].set_facecolor('#00FF00') + else: + table._cells[(i + 1, j)].set_facecolor('#FF0000') plt.savefig(osp.join(self.save_dir, filename), bbox_inches='tight', pad_inches=0.05) diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index f0a7a2615411..4b3f17ab9b6e 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -98,6 +98,7 @@ def _reindex(self): self._indices = indices def add(self, name: str, parent: str = None, attributes: dict = None): + assert name assert name not in self._indices, name index = len(self.items) @@ -114,6 +115,9 @@ def find(self, name: str): def __getitem__(self, idx): return self.items[idx] + def __contains__(self, idx): + return 0 <= idx and idx < len(self.items) + def __len__(self): return len(self.items) @@ -127,6 +131,11 @@ class Label(Annotation): @attrs(eq=False) class MaskCategories(Categories): + @classmethod + def make_default(cls, size=256): + from datumaro.util.mask_tools import generate_colormap + return cls(generate_colormap(size)) + colormap = attrib(factory=dict, validator=default_if_none(dict)) _inverse_colormap = attrib(default=None, validator=attr.validators.optional(dict)) @@ -626,6 +635,10 @@ def __iter__(self): def __len__(self): return len(self._items) + def get(self, id, subset=None): #pylint: disable=redefined-builtin + assert subset == self._subset, '%s != %s' % (subset, self._subset) + return super().get(id, subset or self._subset) + class Importer: @classmethod def detect(cls, path): diff --git a/datumaro/components/operations.py b/datumaro/components/operations.py index 56305ba3087b..13847838004c 100644 --- a/datumaro/components/operations.py +++ b/datumaro/components/operations.py @@ -15,7 +15,8 @@ from datumaro.components.cli_plugin import CliPlugin from datumaro.util import find, filter_dict -from datumaro.components.extractor import (AnnotationType, Bbox, Label, +from datumaro.components.extractor import (AnnotationType, Bbox, + CategoriesInfo, Label, LabelCategories, PointsCategories, MaskCategories) from datumaro.components.errors import (DatumaroError, FailedAttrVotingError, FailedLabelVotingError, MismatchingImageInfoError, NoMatchingAnnError, @@ -1170,29 +1171,6 @@ def get_label(ann): class DistanceComparator: iou_threshold = attrib(converter=float, default=0.5) - @staticmethod - def match_datasets(a, b): - a_items = set((item.id, item.subset) for item in a) - b_items = set((item.id, item.subset) for item in b) - - matches = a_items & b_items - a_unmatched = a_items - b_items - b_unmatched = b_items - a_items - return matches, a_unmatched, b_unmatched - - @staticmethod - def match_classes(a, b): - a_label_cat = a.categories().get(AnnotationType.label, LabelCategories()) - b_label_cat = b.categories().get(AnnotationType.label, LabelCategories()) - - a_labels = set(c.name for c in a_label_cat) - b_labels = set(c.name for c in b_label_cat) - - matches = a_labels & b_labels - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - return matches, a_unmatched, b_unmatched - def match_annotations(self, item_a, item_b): return { t: self._match_ann_type(t, item_a, item_b) for t in AnnotationType @@ -1319,6 +1297,18 @@ def _default_hash(item): unique.setdefault(h, set()).add((item.id, item.subset)) return unique +def match_classes(a: CategoriesInfo, b: CategoriesInfo): + a_label_cat = a.get(AnnotationType.label, LabelCategories()) + b_label_cat = b.get(AnnotationType.label, LabelCategories()) + + a_labels = set(c.name for c in a_label_cat) + b_labels = set(c.name for c in b_label_cat) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + @attrs class ExactComparator: match_images = attrib(kw_only=True, type=bool, default=False) diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 22639db54ab9..49d9bd850feb 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from collections import defaultdict, OrderedDict +from collections import OrderedDict import logging as log import os import os.path as osp @@ -16,15 +16,17 @@ XPathDatasetFilter) from datumaro.components.environment import Environment from datumaro.components.errors import DatumaroError -from datumaro.components.extractor import Extractor +from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Extractor from datumaro.components.launcher import ModelTransform from datumaro.components.operations import ExactMerge class ProjectDataset(IDataset): class Subset(Extractor): - def __init__(self, parent): + def __init__(self, parent, name): + super().__init__(subsets=[name]) self.parent = parent + self.name = name or DEFAULT_SUBSET_NAME self.items = OrderedDict() def __iter__(self): @@ -36,6 +38,11 @@ def __len__(self): def categories(self): return self.parent.categories() + def get(self, id, subset=None): #pylint: disable=redefined-builtin + subset = subset or self.name + assert subset == self.name, '%s != %s' % (subset, self.name) + return super().get(id, subset) + def __init__(self, project): super().__init__() @@ -70,11 +77,13 @@ def __init__(self, project): self._categories = categories # merge items - subsets = defaultdict(lambda: self.Subset(self)) + subsets = {} for source_name, source in self._sources.items(): log.debug("Loading '%s' source contents..." % source_name) for item in source: - existing_item = subsets[item.subset].items.get(item.id) + existing_item = subsets.setdefault( + item.subset, self.Subset(self, item.subset)). \ + items.get(item.id) if existing_item is not None: path = existing_item.path if item.path != path: @@ -96,18 +105,16 @@ def __init__(self, project): if own_source is not None: log.debug("Loading own dataset...") for item in own_source: - existing_item = subsets[item.subset].items.get(item.id) + existing_item = subsets.setdefault( + item.subset, self.Subset(self, item.subset)). \ + items.get(item.id) if existing_item is not None: item = item.wrap(path=None, image=ExactMerge.merge_images(existing_item, item)) subsets[item.subset].items[item.id] = item - # TODO: implement subset remapping when needed - subsets_filter = config.subsets - if len(subsets_filter) != 0: - subsets = { k: v for k, v in subsets.items() if k in subsets_filter} - self._subsets = dict(subsets) + self._subsets = subsets self._length = None @@ -154,7 +161,7 @@ def put(self, item, id=None, subset=None, \ item = item.wrap(path=path) if subset not in self._subsets: - self._subsets[subset] = self.Subset(self) + self._subsets[subset] = self.Subset(self, subset) self._subsets[subset].items[id] = item self._length = None diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/cli/test_diff.py b/tests/cli/test_diff.py new file mode 100644 index 000000000000..591b117119f8 --- /dev/null +++ b/tests/cli/test_diff.py @@ -0,0 +1,123 @@ +from unittest import TestCase + +import os +import os.path as osp + +import numpy as np + +from datumaro.cli.contexts.project.diff import DatasetDiffVisualizer +from datumaro.components.operations import DistanceComparator +from datumaro.components.project import Dataset +from datumaro.components.extractor import (DatasetItem, + AnnotationType, Label, Mask, Points, Polygon, + PolyLine, Bbox, Caption, + LabelCategories, MaskCategories, PointsCategories +) +from datumaro.util.image import Image +from datumaro.util.test_utils import TestDir + + +class DiffTest(TestCase): + def test_can_compare_projects(self): # just a smoke test + label_categories1 = LabelCategories.from_iterable(['x', 'a', 'b', 'y']) + mask_categories1 = MaskCategories.make_default(len(label_categories1)) + + point_categories1 = PointsCategories() + for index, _ in enumerate(label_categories1.items): + point_categories1.add(index, ['cat1', 'cat2'], joints=[[0, 1]]) + + dataset1 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Caption('hello', id=1), + Caption('world', id=2, group=5), + Label(2, id=3, attributes={ + 'x': 1, + 'y': '2', + }), + Bbox(1, 2, 3, 4, label=0, id=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, id=5, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, id=5, z_order=4), + Mask(label=3, id=5, z_order=2, image=np.ones((2, 3))), + ]), + DatasetItem(id=21, subset='train', + annotations=[ + Caption('test'), + Label(2), + Bbox(1, 2, 3, 4, label=2, id=42, group=42) + ]), + + DatasetItem(id=2, subset='val', + annotations=[ + PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11, z_order=1), + Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4), + ]), + + DatasetItem(id=42, subset='test', + attributes={'a1': 5, 'a2': '42'}), + + DatasetItem(id=42), + DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))), + ], categories={ + AnnotationType.label: label_categories1, + AnnotationType.mask: mask_categories1, + AnnotationType.points: point_categories1, + }) + + + label_categories2 = LabelCategories.from_iterable(['a', 'b', 'x', 'y']) + mask_categories2 = MaskCategories.make_default(len(label_categories2)) + + point_categories2 = PointsCategories() + for index, _ in enumerate(label_categories2.items): + point_categories2.add(index, ['cat1', 'cat2'], joints=[[0, 1]]) + + dataset2 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Caption('hello', id=1), + Caption('world', id=2, group=5), + Label(2, id=3, attributes={ + 'x': 1, + 'y': '2', + }), + Bbox(1, 2, 3, 4, label=1, id=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, id=5, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, id=5, z_order=4), + Mask(label=3, id=5, z_order=2, image=np.ones((2, 3))), + ]), + DatasetItem(id=21, subset='train', + annotations=[ + Caption('test'), + Label(2), + Bbox(1, 2, 3, 4, label=3, id=42, group=42) + ]), + + DatasetItem(id=2, subset='val', + annotations=[ + PolyLine([1, 2, 3, 4, 5, 6, 7, 8], id=11, z_order=1), + Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4), + ]), + + DatasetItem(id=42, subset='test', + attributes={'a1': 5, 'a2': '42'}), + + DatasetItem(id=42), + DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))), + ], categories={ + AnnotationType.label: label_categories2, + AnnotationType.mask: mask_categories2, + AnnotationType.points: point_categories2, + }) + + with TestDir() as test_dir: + with DatasetDiffVisualizer(save_dir=test_dir, + comparator=DistanceComparator(iou_threshold=0.8), + ) as visualizer: + visualizer.save(dataset1, dataset2) + + self.assertNotEqual(0, os.listdir(osp.join(test_dir))) \ No newline at end of file