From 9eb5246febb4b4ae10c321fae80413bb87fb1a7d Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 19 Jun 2021 09:49:21 +0300 Subject: [PATCH] Fix validator definitions (#303) * update changelog * Fixes in validator definitions * Update validator cli --- CHANGELOG.md | 1 + datumaro/cli/contexts/project/__init__.py | 17 ++- datumaro/components/validator.py | 52 +++++++- datumaro/plugins/validators.py | 151 ++++------------------ tests/test_validator.py | 3 +- 5 files changed, 85 insertions(+), 139 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32c27abf99..c17e398a02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Subformat importers for VOC and COCO () - Support for KITTI dataset segmentation and detection format () - Updated YOLO format user manual () +- A base class for dataset validation plugins () ### Changed - diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index e0c88520a1..43e76542a8 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -18,7 +18,7 @@ from datumaro.components.project import \ PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG from datumaro.components.project import Environment, Project -from datumaro.components.validator import Validator, TaskType +from datumaro.components.validator import TaskType from datumaro.util import error_rollback from ...util import (CliException, MultilineFormatter, add_subparser, @@ -794,6 +794,14 @@ def print_extractor_info(extractor, indent=''): return 0 def build_validate_parser(parser_ctor=argparse.ArgumentParser): + def _parse_task_type(s): + try: + return TaskType[s.lower()].name + except: + raise argparse.ArgumentTypeError("Unknown task type %s. Expected " + "one of: %s" % (s, ', '.join(t.name for t in TaskType))) + + parser = parser_ctor(help="Validate project", description=""" Validates project based on specified task type and stores @@ -801,10 +809,11 @@ def build_validate_parser(parser_ctor=argparse.ArgumentParser): """, formatter_class=MultilineFormatter) - parser.add_argument('-t', '--task_type', choices=[task_type.name for task_type in TaskType], - help="Task type for validation") + parser.add_argument('-t', '--task_type', type=_parse_task_type, + help="Task type for validation, one of %s" % \ + ', '.join(t.name for t in TaskType)) parser.add_argument('-s', '--subset', dest='subset_name', default=None, - help="Subset to validate (default: None)") + help="Subset to validate (default: whole dataset)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to validate (default: current dir)") parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, diff --git a/datumaro/components/validator.py b/datumaro/components/validator.py index b0db512a6e..fa8bc4441a 100644 --- a/datumaro/components/validator.py +++ b/datumaro/components/validator.py @@ -19,17 +19,55 @@ class TaskType(Enum): segmentation = auto() -class IValidator: +class Validator: def validate(self, dataset: IDataset) -> Dict: - raise NotImplementedError() + """ + Returns the validation results of a dataset based on task type. + Args: + dataset (IDataset): Dataset to be validated -class Validator(IValidator): - def validate(self, dataset: IDataset) -> Dict: - raise NotImplementedError() + Raises: + ValueError + + Returns: + validation_results (dict): + Dict with validation statistics, reports and summary. + """ + + validation_results = {} + if not isinstance(dataset, IDataset): + raise TypeError("Invalid dataset type '%s'" % type(dataset)) + + # generate statistics + stats = self.compute_statistics(dataset) + validation_results['statistics'] = stats + + # generate validation reports and summary + reports = self.generate_reports(stats) + reports = list(map(lambda r: r.to_dict(), reports)) + + summary = { + 'errors': sum(map(lambda r: r['severity'] == 'error', reports)), + 'warnings': sum(map(lambda r: r['severity'] == 'warning', reports)) + } + + validation_results['validation_reports'] = reports + validation_results['summary'] = summary + + return validation_results def compute_statistics(self, dataset: IDataset) -> Dict: - raise NotImplementedError() + """ + Computes statistics of the dataset based on task type. + + Args: + dataset (IDataset): a dataset to be validated + + Returns: + stats (dict): A dict object containing statistics of the dataset. + """ + raise NotImplementedError("Must be implemented in a subclass") def generate_reports(self, stats: Dict) -> List[Dict]: - raise NotImplementedError() + raise NotImplementedError("Must be implemented in a subclass") diff --git a/datumaro/plugins/validators.py b/datumaro/plugins/validators.py index ce171c8208..c5f936c029 100644 --- a/datumaro/plugins/validators.py +++ b/datumaro/plugins/validators.py @@ -3,16 +3,11 @@ # SPDX-License-Identifier: MIT from copy import deepcopy -from typing import Dict, List - -import json -import logging as log import numpy as np from datumaro.components.validator import (Severity, TaskType, Validator) from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.dataset import IDataset from datumaro.components.errors import (MissingLabelCategories, MissingAnnotation, MultiLabelAnnotations, MissingAttribute, UndefinedLabel, UndefinedAttribute, LabelDefinedButNotFound, @@ -25,7 +20,7 @@ from datumaro.util import parse_str_enum_value -class _TaskValidator(Validator): +class _TaskValidator(Validator, CliPlugin): # statistics templates numerical_stat_template = { 'items_far_from_mean': {}, @@ -48,17 +43,28 @@ class _TaskValidator(Validator): ---------- task_type : str or TaskType task type (ie. classification, detection, segmentation) - - Methods - ------- - validate(dataset): - Validate annotations based on task type. - compute_statistics(dataset): - Computes various statistics of the dataset based on task type. - generate_reports(stats): - Abstract method that must be implemented in a subclass. """ + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, + help="Threshold for giving a warning for minimum number of" + "samples per class") + parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, + help="Threshold for giving data imbalance warning;" + "IR(imbalance ratio) = majority/minority") + parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, + help="Threshold for giving a warning that data is far from mean;" + "A constant used to define mean +/- k * standard deviation;") + parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, + help="Threshold for giving a warning for bounding box imbalance;" + "Dominace_ratio = ratio of Top-k bin to total in histogram;") + parser.add_argument('-k', '--topk_bins', default=0.1, type=float, + help="Ratio of bins with the highest number of data" + "to total bins in the histogram; [0, 1]; 0.1 = 10%;") + return parser + def __init__(self, task_type, few_samples_thr=None, imbalance_ratio_thr=None, far_from_mean_thr=None, dominance_ratio_thr=None, topk_bins=None): @@ -102,41 +108,6 @@ def __init__(self, task_type, few_samples_thr=None, self.dominance_thr = dominance_ratio_thr self.topk_bins_ratio = topk_bins - def validate(self, dataset: IDataset): - """ - Returns the validation results of a dataset based on task type. - Args: - dataset (IDataset): Dataset to be validated - task_type (str or TaskType): Type of the task - (classification, detection, segmentation) - Raises: - ValueError - Returns: - validation_results (dict): - Dict with validation statistics, reports and summary. - """ - validation_results = {} - if not isinstance(dataset, IDataset): - raise TypeError("Invalid dataset type '%s'" % type(dataset)) - - # generate statistics - stats = self.compute_statistics(dataset) - validation_results['statistics'] = stats - - # generate validation reports and summary - reports = self.generate_reports(stats) - reports = list(map(lambda r: r.to_dict(), reports)) - - summary = { - 'errors': sum(map(lambda r: r['severity'] == 'error', reports)), - 'warnings': sum(map(lambda r: r['severity'] == 'warning', reports)) - } - - validation_results['validation_reports'] = reports - validation_results['summary'] = summary - - return validation_results - def _compute_common_statistics(self, dataset): defined_attr_template = { 'items_missing_attribute': [], @@ -285,20 +256,6 @@ def _far_from_mean(val, mean, stdev): item_key, {}) far_from_mean[ann.id] = val - def compute_statistics(self, dataset: IDataset): - """ - Computes statistics of the dataset based on task type. - - Parameters - ---------- - dataset : IDataset object - - Returns - ------- - stats (dict): A dict object containing statistics of the dataset. - """ - return NotImplementedError - def _check_missing_label_categories(self, stats): validation_reports = [] @@ -578,36 +535,14 @@ def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats): return validation_reports - def generate_reports(self, stats: Dict) -> List[Dict]: - raise NotImplementedError('Should be implemented in a subclass.') - def _generate_validation_report(self, error, *args, **kwargs): return [error(*args, **kwargs)] -class ClassificationValidator(_TaskValidator, CliPlugin): +class ClassificationValidator(_TaskValidator): """ A specific validator class for classification task. """ - @classmethod - def build_cmdline_parser(cls, **kwargs): - parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, - help="Threshold for giving a warning for minimum number of" - "samples per class") - parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, - help="Threshold for giving data imbalance warning;" - "IR(imbalance ratio) = majority/minority") - parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, - help="Threshold for giving a warning that data is far from mean;" - "A constant used to define mean +/- k * standard deviation;") - parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, - help="Threshold for giving a warning for bounding box imbalance;" - "Dominace_ratio = ratio of Top-k bin to total in histogram;") - parser.add_argument('-k', '--topk_bins', default=0.1, type=float, - help="Ratio of bins with the highest number of data" - "to total bins in the histogram; [0, 1]; 0.1 = 10%;") - return parser def __init__(self, few_samples_thr, imbalance_ratio_thr, far_from_mean_thr, dominance_ratio_thr, topk_bins): @@ -709,29 +644,10 @@ def generate_reports(self, stats): return reports -class DetectionValidator(_TaskValidator, CliPlugin): +class DetectionValidator(_TaskValidator): """ A specific validator class for detection task. """ - @classmethod - def build_cmdline_parser(cls, **kwargs): - parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, - help="Threshold for giving a warning for minimum number of" - "samples per class") - parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, - help="Threshold for giving data imbalance warning;" - "IR(imbalance ratio) = majority/minority") - parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, - help="Threshold for giving a warning that data is far from mean;" - "A constant used to define mean +/- k * standard deviation;") - parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, - help="Threshold for giving a warning for bounding box imbalance;" - "Dominace_ratio = ratio of Top-k bin to total in histogram;") - parser.add_argument('-k', '--topk_bins', default=0.1, type=float, - help="Ratio of bins with the highest number of data" - "to total bins in the histogram; [0, 1]; 0.1 = 10%;") - return parser def __init__(self, few_samples_thr, imbalance_ratio_thr, far_from_mean_thr, dominance_ratio_thr, topk_bins): @@ -1014,29 +930,10 @@ def generate_reports(self, stats): return reports -class SegmentationValidator(_TaskValidator, CliPlugin): +class SegmentationValidator(_TaskValidator): """ A specific validator class for (instance) segmentation task. """ - @classmethod - def build_cmdline_parser(cls, **kwargs): - parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, - help="Threshold for giving a warning for minimum number of" - "samples per class") - parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, - help="Threshold for giving data imbalance warning;" - "IR(imbalance ratio) = majority/minority") - parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, - help="Threshold for giving a warning that data is far from mean;" - "A constant used to define mean +/- k * standard deviation;") - parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, - help="Threshold for giving a warning for bounding box imbalance;" - "Dominace_ratio = ratio of Top-k bin to total in histogram;") - parser.add_argument('-k', '--topk_bins', default=0.1, type=float, - help="Ratio of bins with the highest number of data" - "to total bins in the histogram; [0, 1]; 0.1 = 10%;") - return parser def __init__(self, few_samples_thr, imbalance_ratio_thr, far_from_mean_thr, dominance_ratio_thr, topk_bins): diff --git a/tests/test_validator.py b/tests/test_validator.py index 7f855e4cb8..22d47795ef 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -18,7 +18,8 @@ FarFromAttrMean, OnlyOneAttributeValue) from datumaro.components.extractor import Bbox, Label, Mask, Polygon from datumaro.components.validator import TaskType -from datumaro.plugins.validators import (_TaskValidator, ClassificationValidator, DetectionValidator, SegmentationValidator) +from datumaro.plugins.validators import (_TaskValidator, + ClassificationValidator, DetectionValidator, SegmentationValidator) from .requirements import Requirements, mark_requirement