Skip to content

Commit

Permalink
Fix validator definitions (#303)
Browse files Browse the repository at this point in the history
* update changelog

* Fixes in validator definitions

* Update validator cli
  • Loading branch information
Maxim Zhiltsov authored Jun 19, 2021
1 parent 5cd4abb commit 9eb5246
Showing 5 changed files with 85 additions and 139 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Subformat importers for VOC and COCO (<https://github.com/openvinotoolkit/datumaro/pull/281>)
- Support for KITTI dataset segmentation and detection format (<https://github.com/openvinotoolkit/datumaro/pull/282>)
- Updated YOLO format user manual (<https://github.com/openvinotoolkit/datumaro/pull/295>)
- A base class for dataset validation plugins (<https://github.com/openvinotoolkit/datumaro/pull/299>)

### Changed
-
17 changes: 13 additions & 4 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from datumaro.components.project import \
PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from datumaro.components.validator import Validator, TaskType
from datumaro.components.validator import TaskType
from datumaro.util import error_rollback

from ...util import (CliException, MultilineFormatter, add_subparser,
@@ -794,17 +794,26 @@ def print_extractor_info(extractor, indent=''):
return 0

def build_validate_parser(parser_ctor=argparse.ArgumentParser):
def _parse_task_type(s):
try:
return TaskType[s.lower()].name
except:
raise argparse.ArgumentTypeError("Unknown task type %s. Expected "
"one of: %s" % (s, ', '.join(t.name for t in TaskType)))


parser = parser_ctor(help="Validate project",
description="""
Validates project based on specified task type and stores
results like statistics, reports and summary in JSON file.
""",
formatter_class=MultilineFormatter)

parser.add_argument('-t', '--task_type', choices=[task_type.name for task_type in TaskType],
help="Task type for validation")
parser.add_argument('-t', '--task_type', type=_parse_task_type,
help="Task type for validation, one of %s" % \
', '.join(t.name for t in TaskType))
parser.add_argument('-s', '--subset', dest='subset_name', default=None,
help="Subset to validate (default: None)")
help="Subset to validate (default: whole dataset)")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to validate (default: current dir)")
parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None,
52 changes: 45 additions & 7 deletions datumaro/components/validator.py
Original file line number Diff line number Diff line change
@@ -19,17 +19,55 @@ class TaskType(Enum):
segmentation = auto()


class IValidator:
class Validator:
def validate(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
"""
Returns the validation results of a dataset based on task type.
Args:
dataset (IDataset): Dataset to be validated
class Validator(IValidator):
def validate(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
Raises:
ValueError
Returns:
validation_results (dict):
Dict with validation statistics, reports and summary.
"""

validation_results = {}
if not isinstance(dataset, IDataset):
raise TypeError("Invalid dataset type '%s'" % type(dataset))

# generate statistics
stats = self.compute_statistics(dataset)
validation_results['statistics'] = stats

# generate validation reports and summary
reports = self.generate_reports(stats)
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
validation_results['summary'] = summary

return validation_results

def compute_statistics(self, dataset: IDataset) -> Dict:
raise NotImplementedError()
"""
Computes statistics of the dataset based on task type.
Args:
dataset (IDataset): a dataset to be validated
Returns:
stats (dict): A dict object containing statistics of the dataset.
"""
raise NotImplementedError("Must be implemented in a subclass")

def generate_reports(self, stats: Dict) -> List[Dict]:
raise NotImplementedError()
raise NotImplementedError("Must be implemented in a subclass")
151 changes: 24 additions & 127 deletions datumaro/plugins/validators.py
Original file line number Diff line number Diff line change
@@ -3,16 +3,11 @@
# SPDX-License-Identifier: MIT

from copy import deepcopy
from typing import Dict, List

import json
import logging as log

import numpy as np

from datumaro.components.validator import (Severity, TaskType, Validator)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import IDataset
from datumaro.components.errors import (MissingLabelCategories,
MissingAnnotation, MultiLabelAnnotations, MissingAttribute,
UndefinedLabel, UndefinedAttribute, LabelDefinedButNotFound,
@@ -25,7 +20,7 @@
from datumaro.util import parse_str_enum_value


class _TaskValidator(Validator):
class _TaskValidator(Validator, CliPlugin):
# statistics templates
numerical_stat_template = {
'items_far_from_mean': {},
@@ -48,17 +43,28 @@ class _TaskValidator(Validator):
----------
task_type : str or TaskType
task type (ie. classification, detection, segmentation)
Methods
-------
validate(dataset):
Validate annotations based on task type.
compute_statistics(dataset):
Computes various statistics of the dataset based on task type.
generate_reports(stats):
Abstract method that must be implemented in a subclass.
"""

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, task_type, few_samples_thr=None,
imbalance_ratio_thr=None, far_from_mean_thr=None,
dominance_ratio_thr=None, topk_bins=None):
@@ -102,41 +108,6 @@ def __init__(self, task_type, few_samples_thr=None,
self.dominance_thr = dominance_ratio_thr
self.topk_bins_ratio = topk_bins

def validate(self, dataset: IDataset):
"""
Returns the validation results of a dataset based on task type.
Args:
dataset (IDataset): Dataset to be validated
task_type (str or TaskType): Type of the task
(classification, detection, segmentation)
Raises:
ValueError
Returns:
validation_results (dict):
Dict with validation statistics, reports and summary.
"""
validation_results = {}
if not isinstance(dataset, IDataset):
raise TypeError("Invalid dataset type '%s'" % type(dataset))

# generate statistics
stats = self.compute_statistics(dataset)
validation_results['statistics'] = stats

# generate validation reports and summary
reports = self.generate_reports(stats)
reports = list(map(lambda r: r.to_dict(), reports))

summary = {
'errors': sum(map(lambda r: r['severity'] == 'error', reports)),
'warnings': sum(map(lambda r: r['severity'] == 'warning', reports))
}

validation_results['validation_reports'] = reports
validation_results['summary'] = summary

return validation_results

def _compute_common_statistics(self, dataset):
defined_attr_template = {
'items_missing_attribute': [],
@@ -285,20 +256,6 @@ def _far_from_mean(val, mean, stdev):
item_key, {})
far_from_mean[ann.id] = val

def compute_statistics(self, dataset: IDataset):
"""
Computes statistics of the dataset based on task type.
Parameters
----------
dataset : IDataset object
Returns
-------
stats (dict): A dict object containing statistics of the dataset.
"""
return NotImplementedError

def _check_missing_label_categories(self, stats):
validation_reports = []

@@ -578,36 +535,14 @@ def _check_far_from_attr_mean(self, label_name, attr_name, attr_stats):

return validation_reports

def generate_reports(self, stats: Dict) -> List[Dict]:
raise NotImplementedError('Should be implemented in a subclass.')

def _generate_validation_report(self, error, *args, **kwargs):
return [error(*args, **kwargs)]


class ClassificationValidator(_TaskValidator, CliPlugin):
class ClassificationValidator(_TaskValidator):
"""
A specific validator class for classification task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
@@ -709,29 +644,10 @@ def generate_reports(self, stats):
return reports


class DetectionValidator(_TaskValidator, CliPlugin):
class DetectionValidator(_TaskValidator):
"""
A specific validator class for detection task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
@@ -1014,29 +930,10 @@ def generate_reports(self, stats):
return reports


class SegmentationValidator(_TaskValidator, CliPlugin):
class SegmentationValidator(_TaskValidator):
"""
A specific validator class for (instance) segmentation task.
"""
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-fs', '--few_samples_thr', default=1, type=int,
help="Threshold for giving a warning for minimum number of"
"samples per class")
parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int,
help="Threshold for giving data imbalance warning;"
"IR(imbalance ratio) = majority/minority")
parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float,
help="Threshold for giving a warning that data is far from mean;"
"A constant used to define mean +/- k * standard deviation;")
parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float,
help="Threshold for giving a warning for bounding box imbalance;"
"Dominace_ratio = ratio of Top-k bin to total in histogram;")
parser.add_argument('-k', '--topk_bins', default=0.1, type=float,
help="Ratio of bins with the highest number of data"
"to total bins in the histogram; [0, 1]; 0.1 = 10%;")
return parser

def __init__(self, few_samples_thr, imbalance_ratio_thr,
far_from_mean_thr, dominance_ratio_thr, topk_bins):
3 changes: 2 additions & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,8 @@
FarFromAttrMean, OnlyOneAttributeValue)
from datumaro.components.extractor import Bbox, Label, Mask, Polygon
from datumaro.components.validator import TaskType
from datumaro.plugins.validators import (_TaskValidator, ClassificationValidator, DetectionValidator, SegmentationValidator)
from datumaro.plugins.validators import (_TaskValidator,
ClassificationValidator, DetectionValidator, SegmentationValidator)
from .requirements import Requirements, mark_requirement


0 comments on commit 9eb5246

Please sign in to comment.