Skip to content

Commit

Permalink
Introduce Validator plugin type (cvat-ai#299)
Browse files Browse the repository at this point in the history
* Introduce Validator plugin type
  • Loading branch information
chuneuny-emily authored Jun 19, 2021
1 parent e8e305b commit 5cd4abb
Show file tree
Hide file tree
Showing 7 changed files with 1,349 additions and 1,290 deletions.
34 changes: 19 additions & 15 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datumaro.components.project import \
PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG
from datumaro.components.project import Environment, Project
from datumaro.components.validator import validate_annotations, TaskType
from datumaro.components.validator import Validator, TaskType
from datumaro.util import error_rollback

from ...util import (CliException, MultilineFormatter, add_subparser,
Expand Down Expand Up @@ -801,8 +801,7 @@ def build_validate_parser(parser_ctor=argparse.ArgumentParser):
""",
formatter_class=MultilineFormatter)

parser.add_argument('task_type',
choices=[task_type.name for task_type in TaskType],
parser.add_argument('-t', '--task_type', choices=[task_type.name for task_type in TaskType],
help="Task type for validation")
parser.add_argument('-s', '--subset', dest='subset_name', default=None,
help="Subset to validate (default: None)")
Expand All @@ -816,19 +815,24 @@ def build_validate_parser(parser_ctor=argparse.ArgumentParser):

def validate_command(args):
project = load_project(args.project_dir)
task_type = args.task_type
subset_name = args.subset_name
dst_file_name = f'validation_results-{task_type}'
dst_file_name = f'report-{args.task_type}'

dataset = project.make_dataset()
if subset_name is not None:
dataset = dataset.get_subset(subset_name)
dst_file_name += f'-{subset_name}'
if args.subset_name is not None:
dataset = dataset.get_subset(args.subset_name)
dst_file_name += f'-{args.subset_name}'

try:
validator_type = project.env.validators[args.task_type]
except KeyError:
raise CliException("Validator type '%s' is not found" % args.task_type)

extra_args = {}
from datumaro.components.validator import _Validator
extra_args = _Validator.parse_cmdline(args.extra_args)
validation_results = validate_annotations(dataset, task_type, **extra_args)
if hasattr(validator_type, 'parse_cmdline'):
extra_args = validator_type.parse_cmdline(args.extra_args)

validator = validator_type(**extra_args)
report = validator.validate(dataset)

def numpy_encoder(obj):
if isinstance(obj, np.generic):
Expand All @@ -843,12 +847,12 @@ def _make_serializable(d):
if isinstance(val, dict):
_make_serializable(val)

_make_serializable(validation_results)
_make_serializable(report)

dst_file = generate_next_file_name(dst_file_name, ext='.json')
log.info("Writing project validation results to '%s'" % dst_file)
with open(dst_file, 'w') as f:
json.dump(validation_results, f, indent=4, sort_keys=True,
json.dump(report, f, indent=4, sort_keys=True,
default=numpy_encoder)

def build_parser(parser_ctor=argparse.ArgumentParser):
Expand All @@ -875,4 +879,4 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
add_subparser(subparsers, 'stats', build_stats_parser)
add_subparser(subparsers, 'validate', build_validate_parser)

return parser
return parser
2 changes: 1 addition & 1 deletion datumaro/components/cli_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ def parse_cmdline(cls, args=None):
return args

def remove_plugin_type(s):
for t in {'transform', 'extractor', 'converter', 'launcher', 'importer'}:
for t in {'transform', 'extractor', 'converter', 'launcher', 'importer', 'validator'}:
s = s.replace('_' + t, '')
return s
8 changes: 7 additions & 1 deletion datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self, config=None):
from datumaro.components.extractor import (Importer, Extractor,
Transform)
from datumaro.components.launcher import Launcher
from datumaro.components.validator import Validator
self.extractors = PluginRegistry(
builtin=select(builtin, Extractor),
local=select(custom, Extractor)
Expand All @@ -172,6 +173,10 @@ def __init__(self, config=None):
builtin=select(builtin, Transform),
local=select(custom, Transform)
)
self.validators = PluginRegistry(
builtin=select(builtin, Validator),
local=select(custom, Validator)
)

@staticmethod
def _find_plugins(plugins_dir):
Expand Down Expand Up @@ -262,7 +267,8 @@ def _load_plugins2(cls, plugins_dir):
from datumaro.components.extractor import (Extractor, Importer,
Transform)
from datumaro.components.launcher import Launcher
types = [Extractor, Converter, Importer, Launcher, Transform]
from datumaro.components.validator import Validator
types = [Extractor, Converter, Importer, Launcher, Transform, Validator]

return cls._load_plugins(plugins_dir, types)

Expand Down
Loading

0 comments on commit 5cd4abb

Please sign in to comment.