diff --git a/datumaro/cli/__main__.py b/datumaro/cli/__main__.py index 2ecf9f7a78..d8dcc4154b 100644 --- a/datumaro/cli/__main__.py +++ b/datumaro/cli/__main__.py @@ -56,22 +56,47 @@ def make_parser(): parser.add_argument('--version', action='version', version=VERSION) _LogManager._define_loglevel_option(parser) + parser.add_argument('--detached', action='store_true', + help=argparse.SUPPRESS) + # help="Work in VCS-detached mode. VCS operations will be unavailable.") known_contexts = [ - ('project', contexts.project, "Actions with project (deprecated)"), + ('project', contexts.project, "Actions with project"), + ('repo', contexts.repository, "Actions with VCS repositories"), + ('remote', contexts.remote, "Actions with data remotes"), ('source', contexts.source, "Actions with data sources"), ('model', contexts.model, "Actions with models"), ] known_commands = [ - ('create', commands.create, "Create project"), + ("Project modification:", None, ''), + ('create', commands.create, "Create empty project"), ('import', commands.import_, "Create project from existing dataset"), ('add', commands.add, "Add data source to project"), ('remove', commands.remove, "Remove data source from project"), + + ("", None, ''), + ("Project versioning:", None, ''), + ('check_updates', commands.check_updates, "Check remote repository for updates"), + ('fetch', commands.fetch, "Fetch updates from remote repository"), + ('pull', commands.pull, "Pull updates from remote repository"), + ('push', commands.push, "Push updates to remote repository"), + ('checkout', commands.checkout, "Switch to another branch or revision"), + ('commit', commands.commit, "Commit changes in tracked files"), + ('status', commands.status, "Show status information"), + ('refs', commands.refs, "List branches and revisions"), + ('tag', commands.tag, "Give name to revision"), + ('track', commands.track, "Start tracking a local file or directory"), + ('update', commands.update, "Change data source revision"), + + ("", None, ''), + ("Dataset and project operations:", None, ''), ('export', commands.export, "Export project in some format"), - ('filter', commands.filter, "Filter project"), - ('transform', commands.transform, "Transform project"), + ('filter', commands.filter, "Filter project items"), + ('transform', commands.transform, "Modify project items"), + ('apply', commands.apply, "Apply a few transforms to project"), + ('build', commands.build, "Build project"), ('merge', commands.merge, "Merge projects"), - ('convert', commands.convert, "Convert dataset into another format"), + ('convert', commands.convert, "Convert dataset between formats"), ('diff', commands.diff, "Compare projects with intersection"), ('ediff', commands.ediff, "Compare projects for equality"), ('stats', commands.stats, "Compute project statistics"), @@ -104,7 +129,8 @@ def make_parser(): subcommands = parser.add_subparsers(title=subcommands_desc, description="", help=argparse.SUPPRESS) for command_name, command, _ in known_contexts + known_commands: - add_subparser(subcommands, command_name, command.build_parser) + if command is not None: + add_subparser(subcommands, command_name, command.build_parser) return parser @@ -119,8 +145,15 @@ def main(args=None): parser.print_help() return 1 + if args.detached: + from datumaro.components.project import ProjectVcs + ProjectVcs.G_DETACHED = True + try: - return args.command(args) + retcode = args.command(args) + if retcode is None: + retcode = 0 + return retcode except CliException as e: log.error(e) return 1 diff --git a/datumaro/cli/commands/__init__.py b/datumaro/cli/commands/__init__.py index 9324f12252..e5ccf54d40 100644 --- a/datumaro/cli/commands/__init__.py +++ b/datumaro/cli/commands/__init__.py @@ -7,7 +7,8 @@ from . import ( create, add, remove, import_, explain, - export, merge, convert, transform, filter, + export, merge, convert, apply, transform, filter, build, update, diff, ediff, stats, + commit, fetch, pull, push, track, checkout, refs, status, check_updates, tag, info, validate ) diff --git a/datumaro/cli/commands/apply.py b/datumaro/cli/commands/apply.py new file mode 100644 index 0000000000..a8006f9173 --- /dev/null +++ b/datumaro/cli/commands/apply.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# pylint: disable=unused-import + +from ..contexts.project import build_apply_parser as build_parser \ No newline at end of file diff --git a/datumaro/cli/commands/build.py b/datumaro/cli/commands/build.py new file mode 100644 index 0000000000..a7b4522172 --- /dev/null +++ b/datumaro/cli/commands/build.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# pylint: disable=unused-import + +from ..contexts.project import build_build_parser as build_parser \ No newline at end of file diff --git a/datumaro/cli/commands/check_updates.py b/datumaro/cli/commands/check_updates.py new file mode 100644 index 0000000000..ca58c7e70d --- /dev/null +++ b/datumaro/cli/commands/check_updates.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('targets', nargs='*', + help="Names of sources and models") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=check_updates_command) + + return parser + +def check_updates_command(args): + project = load_project(args.project_dir) + + project.vcs.check_updates(targets=args.targets) + + return 0 diff --git a/datumaro/cli/commands/checkout.py b/datumaro/cli/commands/checkout.py new file mode 100644 index 0000000000..63d3abcdad --- /dev/null +++ b/datumaro/cli/commands/checkout.py @@ -0,0 +1,39 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # args can't be resolved automatically + parser.add_argument('rev', nargs='?', + help="Commit or tag (default: current)") + parser.add_argument('targets', nargs='*', + help="Names of sources, models, tracked files and dirs (default: all)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=checkout_command) + + return parser + +def checkout_command(args): + try: + pos = args._positionals.index('--') + has_sep = True + except ValueError: + pos = 1 + has_sep = False + args.rev = args._positionals[:pos] or [] + args.targets = args._positionals[pos + has_sep:] + + project = load_project(args.project_dir) + + project.vcs.checkout(rev=args.rev, targets=args.targets) + + return 0 diff --git a/datumaro/cli/commands/commit.py b/datumaro/cli/commands/commit.py new file mode 100644 index 0000000000..242aafc454 --- /dev/null +++ b/datumaro/cli/commands/commit.py @@ -0,0 +1,27 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('paths', nargs='*', + help="Files to include in the commit (default: all tracked)") + parser.add_argument('-m', '--message', required=True, help="Commit message") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=commit_command) + + return parser + +def commit_command(args): + project = load_project(args.project_dir) + + project.vcs.commit(args.paths, args.message) + + return 0 diff --git a/datumaro/cli/commands/create.py b/datumaro/cli/commands/create.py index 1396d5f9ed..293a2a392e 100644 --- a/datumaro/cli/commands/create.py +++ b/datumaro/cli/commands/create.py @@ -4,4 +4,64 @@ # pylint: disable=unused-import -from ..contexts.project import build_create_parser as build_parser \ No newline at end of file +import argparse +import logging as log +import os +import os.path as osp +import shutil + +from datumaro.components.project import \ + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG +from datumaro.components.project import Project + +from ..util import CliException, MultilineFormatter + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Create empty project", + description=""" + Create a new empty project.|n + |n + Examples:|n + - Create a project in the current directory:|n + |s|screate -n myproject|n + |n + - Create a project in other directory:|n + |s|screate -o path/I/like/ + """, + formatter_class=MultilineFormatter) + + parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', + help="Save directory for the new project (default: current dir") + parser.add_argument('-n', '--name', default=None, + help="Name of the new project (default: same as project dir)") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.set_defaults(command=create_command) + + return parser + +def create_command(args): + project_dir = osp.abspath(args.dst_dir) + + project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) + if osp.isdir(project_env_dir) and os.listdir(project_env_dir): + if args.overwrite: + shutil.rmtree(project_env_dir, ignore_errors=True) + else: + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % project_env_dir) + + project_name = args.name + if project_name is None: + project_name = osp.basename(project_dir) + + log.info("Creating project at '%s'" % project_dir) + + Project.generate(project_dir, { + 'project_name': project_name, + }) + + log.info("Project has been created at '%s'" % project_dir) + + return 0 diff --git a/datumaro/cli/commands/diff.py b/datumaro/cli/commands/diff.py index a50c8f0a4e..06be50db5e 100644 --- a/datumaro/cli/commands/diff.py +++ b/datumaro/cli/commands/diff.py @@ -2,6 +2,87 @@ # # SPDX-License-Identifier: MIT -# pylint: disable=unused-import +import argparse +import logging as log +import os +import os.path as osp +import shutil -from ..contexts.project import build_diff_parser as build_parser \ No newline at end of file +from datumaro.components.operations import DistanceComparator +from datumaro.util import error_rollback + +from ..util import CliException, MultilineFormatter +from ..util.project import generate_next_file_name, load_project +from ..contexts.project.diff import DatasetDiffVisualizer + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects", + description=""" + Compares two projects, match annotations by distance.|n + |n + Examples:|n + - Compare two projects, match boxes if IoU > 0.7,|n + |s|s|s|sprint results to Tensorboard: + |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 + """, + formatter_class=MultilineFormatter) + + parser.add_argument('other_project_dir', + help="Directory of the second project to be compared") + parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, + help="Directory to save comparison results (default: do not save)") + parser.add_argument('-v', '--visualizer', + default=DatasetDiffVisualizer.DEFAULT_FORMAT.name, + choices=[f.name for f in DatasetDiffVisualizer.OutputFormat], + help="Output format (default: %(default)s)") + parser.add_argument('--iou-thresh', default=0.5, type=float, + help="IoU match threshold for detections (default: %(default)s)") + parser.add_argument('--conf-thresh', default=0.5, type=float, + help="Confidence threshold for detections (default: %(default)s)") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the first project to be compared (default: current dir)") + parser.set_defaults(command=diff_command) + + return parser + +@error_rollback('on_error', implicit=True) +def diff_command(args): + first_project = load_project(args.project_dir) + + try: + second_project = load_project(args.other_project_dir) + except FileNotFoundError: + if first_project.vcs.is_ref(args.other_project_dir): + raise NotImplementedError("It seems that you're trying to compare " + "different revisions of the project. " + "Comparisons between project revisions are not implemented yet.") + raise + + comparator = DistanceComparator(iou_threshold=args.iou_thresh) + + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + else: + dst_dir = generate_next_file_name('%s-%s-diff' % ( + first_project.config.project_name, + second_project.config.project_name) + ) + dst_dir = osp.abspath(dst_dir) + log.info("Saving diff to '%s'" % dst_dir) + + if not osp.exists(dst_dir): + on_error.do(shutil.rmtree, dst_dir, ignore_errors=True) + + visualizer = DatasetDiffVisualizer(save_dir=dst_dir, + comparator=comparator, output_format=args.visualizer) + visualizer.save_dataset_diff( + first_project.make_dataset(), + second_project.make_dataset()) + + return 0 diff --git a/datumaro/cli/commands/ediff.py b/datumaro/cli/commands/ediff.py index ac5ba8c467..460ebf6579 100644 --- a/datumaro/cli/commands/ediff.py +++ b/datumaro/cli/commands/ediff.py @@ -2,6 +2,91 @@ # # SPDX-License-Identifier: MIT -# pylint: disable=unused-import +import argparse +import json +import logging as log -from ..contexts.project import build_ediff_parser as build_parser \ No newline at end of file +from datumaro.components.operations import ExactComparator + +from ..util import MultilineFormatter +from ..util.project import generate_next_file_name, load_project + + +_ediff_default_if = ['id', 'group'] # avoid https://bugs.python.org/issue16399 + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects for equality", + description=""" + Compares two projects for equality.|n + |n + Examples:|n + - Compare two projects, exclude annotation group |n + |s|s|sand the 'is_crowd' attribute from comparison:|n + |s|sediff other/project/ -if group -ia is_crowd + """, + formatter_class=MultilineFormatter) + + parser.add_argument('other_project_dir', + help="Directory of the second project to be compared") + parser.add_argument('-iia', '--ignore-item-attr', action='append', + help="Ignore item attribute (repeatable)") + parser.add_argument('-ia', '--ignore-attr', action='append', + help="Ignore annotation attribute (repeatable)") + parser.add_argument('-if', '--ignore-field', action='append', + help="Ignore annotation field (repeatable, default: %s)" % \ + _ediff_default_if) + parser.add_argument('--match-images', action='store_true', + help='Match dataset items by images instead of ids') + parser.add_argument('--all', action='store_true', + help="Include matches in the output") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the first project to be compared (default: current dir)") + parser.set_defaults(command=ediff_command) + + return parser + +def ediff_command(args): + first_project = load_project(args.project_dir) + + try: + second_project = load_project(args.other_project_dir) + except FileNotFoundError: + if first_project.vcs.is_ref(args.other_project_dir): + raise NotImplementedError("It seems that you're trying to compare " + "different revisions of the project. " + "Comparisons between project revisions are not implemented yet.") + raise + + if args.ignore_field: + args.ignore_field = _ediff_default_if + comparator = ExactComparator( + match_images=args.match_images, + ignored_fields=args.ignore_field, + ignored_attrs=args.ignore_attr, + ignored_item_attrs=args.ignore_item_attr) + matches, mismatches, a_extra, b_extra, errors = \ + comparator.compare_datasets( + first_project.make_dataset(), second_project.make_dataset()) + output = { + "mismatches": mismatches, + "a_extra_items": sorted(a_extra), + "b_extra_items": sorted(b_extra), + "errors": errors, + } + if args.all: + output["matches"] = matches + + output_file = generate_next_file_name('diff', ext='.json') + with open(output_file, 'w') as f: + json.dump(output, f, indent=4, sort_keys=True) + + print("Found:") + print("The first project has %s unmatched items" % len(a_extra)) + print("The second project has %s unmatched items" % len(b_extra)) + print("%s item conflicts" % len(errors)) + print("%s matching annotations" % len(matches)) + print("%s mismatching annotations" % len(mismatches)) + + log.info("Output has been saved to '%s'" % output_file) + + return 0 diff --git a/datumaro/cli/commands/explain.py b/datumaro/cli/commands/explain.py index 9c3e1d147a..c47ae011fb 100644 --- a/datumaro/cli/commands/explain.py +++ b/datumaro/cli/commands/explain.py @@ -152,7 +152,7 @@ def explain_command(args): for item in dataset: image = item.image.data if image is None: - log.warn( + log.warning( "Dataset item %s does not have image data. Skipping." % \ (item.id)) continue diff --git a/datumaro/cli/commands/fetch.py b/datumaro/cli/commands/fetch.py new file mode 100644 index 0000000000..6a53b0b3a3 --- /dev/null +++ b/datumaro/cli/commands/fetch.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('targets', nargs='*', + help="Names of sources, models, tracked files and dirs (default: all)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=fetch_command) + + return parser + +def fetch_command(args): + project = load_project(args.project_dir) + + project.vcs.fetch(targets=args.targets) + + return 0 diff --git a/datumaro/cli/commands/merge.py b/datumaro/cli/commands/merge.py index 2abb56e465..6b6c6d27c3 100644 --- a/datumaro/cli/commands/merge.py +++ b/datumaro/cli/commands/merge.py @@ -10,7 +10,7 @@ from datumaro.components.project import Project from datumaro.components.operations import IntersectMerge -from datumaro.components.errors import QualityError, MergeError +from datumaro.components.errors import DatasetQualityError, DatasetMergeError from ..util import at_least, MultilineFormatter, CliException from ..util.project import generate_next_file_name, load_project @@ -37,7 +37,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): def _group(s): return s.split(',') - parser.add_argument('project', nargs='+', action=at_least(2), + parser.add_argument('targets', nargs='*', help="Path to a project (repeatable)") parser.add_argument('-iou', '--iou-thresh', default=0.25, type=float, help="IoU match threshold for segments (default: %(default)s)") @@ -57,6 +57,8 @@ def _group(s): help="Output directory (default: current project's dir)") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") parser.set_defaults(command=merge_command) return parser @@ -104,9 +106,9 @@ def save_merge_report(merger, path): all_errors = [] for e in merger.errors: - if isinstance(e, QualityError): + if isinstance(e, DatasetQualityError): item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 - elif isinstance(e, MergeError): + elif isinstance(e, DatasetMergeError): for s in e.sources: source_errors[s] = source_errors.get(s, 0) + 1 item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1 diff --git a/datumaro/cli/commands/pull.py b/datumaro/cli/commands/pull.py new file mode 100644 index 0000000000..5fce0d2e7b --- /dev/null +++ b/datumaro/cli/commands/pull.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('targets', nargs='*', + help="Names of sources, models, tracked files and dirs (default: all)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=pull_command) + + return parser + +def pull_command(args): + project = load_project(args.project_dir) + + project.vcs.pull(targets=args.targets) + + return 0 diff --git a/datumaro/cli/commands/push.py b/datumaro/cli/commands/push.py new file mode 100644 index 0000000000..55a1f86bd0 --- /dev/null +++ b/datumaro/cli/commands/push.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('targets', nargs='*', + help="Names of sources, models, tracked files and dirs (default: all)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=push_command) + + return parser + +def push_command(args): + project = load_project(args.project_dir) + + project.vcs.push(targets=args.targets) + + return 0 diff --git a/datumaro/cli/commands/refs.py b/datumaro/cli/commands/refs.py new file mode 100644 index 0000000000..97cae406e9 --- /dev/null +++ b/datumaro/cli/commands/refs.py @@ -0,0 +1,28 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=refs_command) + + return parser + +def refs_command(args): + project = load_project(args.project_dir) + + print('Branches:', ', '.join(project.vcs.refs)) + + tags = project.vcs.tags + if tags: + print('Tags:', ', '.join(tags)) + + return 0 diff --git a/datumaro/cli/commands/status.py b/datumaro/cli/commands/status.py new file mode 100644 index 0000000000..72ccb1626a --- /dev/null +++ b/datumaro/cli/commands/status.py @@ -0,0 +1,66 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import os.path as osp +from io import StringIO + +from datumaro.components.config import Config +from datumaro.components.config_model import PROJECT_SCHEMA + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=status_command) + + return parser + +def status_command(args): + project = load_project(args.project_dir) + + data_status = project.vcs.dvc.status() + for stage_name, stage_status in data_status.items(): + if stage_name.endswith('.dvc'): + stage_name = osp.splitext(osp.basename(stage_name))[0] + print(stage_status, stage_name) + + project_status = project.vcs.git.status() + config_path = osp.join(project.config.env_dir, + project.config.project_filename) + if config_path in project_status: + current_conf = Config.parse(config_path, schema=PROJECT_SCHEMA) + + prev_conf = project.vcs.git.show(config_path, rev='HEAD') + prev_conf = Config.parse(StringIO(prev_conf), schema=PROJECT_SCHEMA) + + + a_sources = set(prev_conf.sources) + b_sources = set(current_conf.sources) + + added = b_sources - a_sources + removed = a_sources - b_sources + modified = set(s for s in a_sources & b_sources + if (prev_conf.sources[s] != current_conf.sources[s]) or \ + (prev_conf.build_targets[s] != current_conf.build_targets[s]) + ) + + for s in a_sources | b_sources: + if s in added: + print('A', s) + if s in removed: + print('D', s) + if s in modified: + print('M', s) + + for path, path_status in project_status.items(): + if path.endswith('.dvc'): + path = osp.splitext(osp.basename(path))[0] + print(path_status, path) + + return 0 diff --git a/datumaro/cli/commands/tag.py b/datumaro/cli/commands/tag.py new file mode 100644 index 0000000000..3f29d6f80a --- /dev/null +++ b/datumaro/cli/commands/tag.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Give a name (tag) to the current revision") + + parser.add_argument('name', + help="Name (tag) for the current revision") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=tag_command) + + return parser + +def tag_command(args): + project = load_project(args.project_dir) + + project.vcs.tag(args.name) + + return 0 diff --git a/datumaro/cli/commands/track.py b/datumaro/cli/commands/track.py new file mode 100644 index 0000000000..c91a44c989 --- /dev/null +++ b/datumaro/cli/commands/track.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('paths', nargs='+', + help="Track files or directories") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=track_command) + + return parser + +def track_command(args): + project = load_project(args.project_dir) + + project.vcs.add(args.paths) + + return 0 diff --git a/datumaro/cli/commands/update.py b/datumaro/cli/commands/update.py new file mode 100644 index 0000000000..932f645d49 --- /dev/null +++ b/datumaro/cli/commands/update.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +# pylint: disable=unused-import + +from ..contexts.source import build_pull_parser as build_parser \ No newline at end of file diff --git a/datumaro/cli/contexts/__init__.py b/datumaro/cli/contexts/__init__.py index b903435527..0dd0b2395d 100644 --- a/datumaro/cli/contexts/__init__.py +++ b/datumaro/cli/contexts/__init__.py @@ -3,4 +3,4 @@ # # SPDX-License-Identifier: MIT -from . import project, source, model \ No newline at end of file +from . import project, source, model, remote, repository diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index dfb2dc5ba9..d738d0cd14 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -17,7 +17,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().launchers.items) + builtins = sorted(Environment().launchers) parser = parser_ctor(help="Add model to project", description=""" @@ -29,34 +29,42 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): """ % ', '.join(builtins), formatter_class=MultilineFormatter) - parser.add_argument('-l', '--launcher', required=True, - help="Model launcher") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, - help="Additional arguments for converter (pass '-- -h' for help)") - parser.add_argument('--copy', action='store_true', - help="Copy the model to the project") + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('url', nargs='?', help="URL to the model data") parser.add_argument('-n', '--name', default=None, help="Name of the model to be added (default: generate automatically)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite if exists") + parser.add_argument('-l', '--launcher', required=True, + help="Model launcher") + parser.add_argument('--no-check', action='store_true', + help="Skip model availability checking") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, + help="Additional arguments for converter (pass '-- -h' for help)") parser.set_defaults(command=add_command) return parser @error_rollback('on_error', implicit=True) def add_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + else: + pos = 1 + args.url = (args._positionals[:pos] or [''])[0] + args.extra_args = args._positionals[pos + has_sep:] + project = load_project(args.project_dir) - if args.name: - if not args.overwrite and args.name in project.config.models: - raise CliException("Model '%s' already exists " - "(pass --overwrite to overwrite)" % args.name) + name = args.name + if name: + if name in project.config.models: + raise CliException("Model '%s' already exists" % name) else: - args.name = generate_next_name( - project.config.models, 'model', '-', default=0) - assert args.name not in project.config.models, args.name + name = generate_next_name(list(project.models), + 'model', sep='-', default=0) try: launcher = project.env.launchers[args.launcher] @@ -64,38 +72,52 @@ def add_command(args): raise CliException("Launcher '%s' is not found" % args.launcher) cli_plugin = getattr(launcher, 'cli_plugin', launcher) - model_args = cli_plugin.parse_cmdline(args.extra_args) - - if args.copy: + model_args = {} + if args.extra_args: + model_args = cli_plugin.parse_cmdline(args.extra_args) + + if args.url and args.copy: + raise CliException("Can't specify both 'url' and 'copy' args, " + "'copy' is only applicable for local paths.") + elif args.copy: log.info("Copying model data") - model_dir = osp.join(project.config.project_dir, - project.local_model_dir(args.name)) + model_dir = project.models.model_dir(name) os.makedirs(model_dir, exist_ok=False) on_error.do(shutil.rmtree, model_dir, ignore_errors=True) try: cli_plugin.copy_model(model_dir, model_args) except (AttributeError, NotImplementedError): - log.error("Can't copy: copying is not available for '%s' models" % \ + log.error("Can't copy: copying is not available for '%s' models. " + "The model will be used as a local-only.", args.launcher) + model_dir = '' + else: + model_dir = args.url - log.info("Checking the model") - project.add_model(args.name, { + project.models.add(name, { + 'url': model_dir, 'launcher': args.launcher, 'options': model_args, }) - project.make_executable_model(args.name) + on_error.do(project.models.remove, name, force=True, keep_data=False, + ignore_errors=True) + + if not args.no_check: + log.info("Checking the model...") + project.models.make_executable_model(name) project.save() - log.info("Model '%s' with launcher '%s' has been added to project '%s'" % \ - (args.name, args.launcher, project.config.project_name)) + log.info("Model '%s' with launcher '%s' has been added to project", + name, args.launcher) return 0 def build_remove_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor() + parser = parser_ctor(help="Remove model from project", + description="Remove a model from a project") parser.add_argument('name', help="Name of the model to be removed") @@ -113,17 +135,47 @@ def remove_command(args): return 0 +def build_pull_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Update model revision", + description=""" + Update model revision.|n + |n + A specific revision can be required by the '--rev' parameter. + Otherwise, the latest remote version will be used. + """) + + parser.add_argument('names', nargs='+', + help="Names of models to update") + parser.add_argument('--rev', + help="A revision to update the model to") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=pull_command) + + return parser + +def pull_command(args): + project = load_project(args.project_dir) + + project.models.pull(args.names, rev=args.rev) + project.save() + + return 0 + def build_run_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor() + parser = parser_ctor(help="Launches model inference", + description="Launches model inference on a project target.") + parser.add_argument('target', nargs='?', default='project', + help="Project target to launch inference on (default: project)") parser.add_argument('-o', '--output-dir', dest='dst_dir', - help="Directory to save output") + help="Directory to save output (default: auto-generated)") parser.add_argument('-m', '--model', dest='model_name', required=True, help="Model to apply to the project") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.add_argument('--overwrite', action='store_true', - help="Overwrite if exists") + help="Overwrite output dorectory if exists") parser.set_defaults(command=run_command) return parser @@ -140,7 +192,7 @@ def run_command(args): dst_dir = generate_next_file_name('%s-inference' % \ project.config.project_name) - project.make_dataset().apply_model( + project.make_dataset(args.target).run_model( save_dir=osp.abspath(dst_dir), model=args.model_name) @@ -179,6 +231,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): subparsers = parser.add_subparsers() add_subparser(subparsers, 'add', build_add_parser) add_subparser(subparsers, 'remove', build_remove_parser) + add_subparser(subparsers, 'pull', build_pull_parser) add_subparser(subparsers, 'run', build_run_parser) add_subparser(subparsers, 'info', build_info_parser) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index 64c4a28481..fe9f4bfb34 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -11,81 +11,24 @@ from enum import Enum from datumaro.components.dataset_filter import DatasetItemEncoder +from datumaro.components.errors import DatasetMergeError from datumaro.components.extractor import AnnotationType -from datumaro.components.operations import (DistanceComparator, - ExactComparator, compute_ann_statistics, compute_image_statistics) -from datumaro.components.project import \ - PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG -from datumaro.components.project import Environment, Project +from datumaro.components.operations import (compute_ann_statistics, + compute_image_statistics) +from datumaro.components.project import (Project, ProjectBuildTargets, + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG) +from datumaro.components.environment import Environment from datumaro.components.validator import validate_annotations, TaskType -from datumaro.util import error_rollback +from datumaro.util import str_to_bool, error_rollback from ...util import (CliException, MultilineFormatter, add_subparser, make_file_name) from ...util.project import generate_next_file_name, load_project -from .diff import DatasetDiffVisualizer -def build_create_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Create empty project", - description=""" - Create a new empty project.|n - |n - Examples:|n - - Create a project in the current directory:|n - |s|screate -n myproject|n - |n - - Create a project in other directory:|n - |s|screate -o path/I/like/ - """, - formatter_class=MultilineFormatter) - - parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', - help="Save directory for the new project (default: current dir") - parser.add_argument('-n', '--name', default=None, - help="Name of the new project (default: same as project dir)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite existing files in the save directory") - parser.set_defaults(command=create_command) - - return parser - -def create_command(args): - project_dir = osp.abspath(args.dst_dir) - - project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) - if osp.isdir(project_env_dir) and os.listdir(project_env_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % project_env_dir) - else: - shutil.rmtree(project_env_dir, ignore_errors=True) - - own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) - if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % own_dataset_dir) - else: - # NOTE: remove the dir to avoid using data from previous project - shutil.rmtree(own_dataset_dir) - - project_name = args.name - if project_name is None: - project_name = osp.basename(project_dir) - - log.info("Creating project at '%s'" % project_dir) - - Project.generate(project_dir, { - 'project_name': project_name, - }) - - log.info("Project has been created at '%s'" % project_dir) - - return 0 - def build_import_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().importers.items) + env = Environment() + builtins = sorted(set(env.extractors) | set(env.importers)) parser = parser_ctor(help="Create project from an existing dataset", description=""" @@ -128,9 +71,9 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): help="Directory to save the new project to (default: current dir)") parser.add_argument('-n', '--name', default=None, help="Name of the new project (default: same as project dir)") - parser.add_argument('--copy', action='store_true', - help="Copy the dataset instead of saving source links") - parser.add_argument('--skip-check', action='store_true', + parser.add_argument('--no-pull', action='store_true', + help="Do not download or copy dataset") + parser.add_argument('--no-check', action='store_true', help="Skip source checking") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") @@ -144,6 +87,7 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser): return parser +@error_rollback('on_error', implicit=True) def import_command(args): project_dir = osp.abspath(args.dst_dir) @@ -155,15 +99,6 @@ def import_command(args): else: shutil.rmtree(project_env_dir, ignore_errors=True) - own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) - if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % own_dataset_dir) - else: - # NOTE: remove the dir to avoid using data from previous project - shutil.rmtree(own_dataset_dir) - project_name = args.name if project_name is None: project_name = osp.basename(project_dir) @@ -208,24 +143,34 @@ def import_command(args): log.info("Importing project as '%s'" % fmt) - project = Project.import_from(osp.abspath(args.source), fmt, **extra_args) - project.config.project_name = project_name - project.config.project_dir = project_dir + if not osp.isdir(project_dir): + on_error.do(shutil.rmtree, project_dir, ignore_errors=True) - if not args.skip_check or args.copy: - log.info("Checking the dataset...") - dataset = project.make_dataset() - if args.copy: - log.info("Cloning data...") - dataset.save(merge=True, save_images=True) - else: - project.save() + project = Project.generate(save_dir=project_dir, config={ + 'project_name': project_name + }) + + name = 'source' + project.sources.add(name, { + 'url': args.source, + 'format': args.format, + 'options': extra_args, + }) + + if not args.no_pull: + log.info("Pulling the source...") + project.sources.pull(name) + + if not (args.no_check or args.no_pull): + log.info("Checking the source...") + project.sources.make_dataset(name) + + project.save() log.info("Project has been created at '%s'" % project_dir) return 0 - class FilterModes(Enum): # primary items = 1 @@ -296,6 +241,10 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): """ % ', '.join(builtins), formatter_class=MultilineFormatter) + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', nargs='?', default='project', + help="Target to do export for (default: '%(default)s')") parser.add_argument('-e', '--filter', default=None, help="Filter expression for dataset items") parser.add_argument('--filter-mode', default=FilterModes.i.name, @@ -317,6 +266,15 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): return parser def export_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + else: + pos = 1 + args.target = (args._positionals[:pos] or \ + [ProjectBuildTargets.MAIN_TARGET])[0] + args.extra_args = args._positionals[pos + has_sep:] + project = load_project(args.project_dir) dst_dir = args.dst_dir @@ -334,21 +292,33 @@ def export_command(args): except KeyError: raise CliException("Converter for format '%s' is not found" % \ args.format) - extra_args = converter.parse_cmdline(args.extra_args) + extra_args = {} + if args.extra_args: + extra_args = converter.parse_cmdline(args.extra_args) - filter_args = FilterModes.make_filter_args(args.filter_mode) + if args.filter: + filter_args = FilterModes.make_filter_args(args.filter_mode) + filter_args['expr'] = args.filter log.info("Loading the project...") - dataset = project.make_dataset() - - log.info("Exporting the project...") + target = args.target if args.filter: - dataset = dataset.filter(args.filter, **filter_args) - converter = project.env.converters[args.format] - converter.convert(dataset, save_dir=dst_dir, **extra_args) + _, target = project.build_targets.add_filter_stage( + target, filter_args) + _, target = project.build_targets.add_convert_stage( + target, args.format, extra_args) - log.info("Project exported to '%s' as '%s'" % (dst_dir, args.format)) + status = project.vcs.status() + if status: # TODO: narrow only to the affected sources + raise CliException("Can't modify project " \ + "when there are uncommitted changes: %s" % status) + + log.info("Exporting...") + + project.build(target, out_dir=dst_dir) + + log.info("Results have been saved to '%s'" % dst_dir) return 0 @@ -390,6 +360,10 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): """, formatter_class=MultilineFormatter) + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', default='project', nargs='?', + help="Project target to apply transform to (default: project)") parser.add_argument('-e', '--filter', default=None, help="XML XPath filter expression for dataset items") parser.add_argument('-m', '--mode', default=FilterModes.i.name, @@ -398,6 +372,10 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): (', '.join(FilterModes.list_options()) , '%(default)s')) parser.add_argument('--dry-run', action='store_true', help="Print XML representations to be filtered and exit") + parser.add_argument('--stage', type=str_to_bool, default=True, + help="Include this action as a project build step (default: %(default)s)") + parser.add_argument('--apply', type=str_to_bool, default=True, + help="Run this action immediately (default: %(default)s)") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, help="Output directory (default: update current project)") parser.add_argument('--overwrite', action='store_true', @@ -409,6 +387,15 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): return parser def filter_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + else: + pos = 1 + args.target = (args._positionals[:pos] or \ + [ProjectBuildTargets.MAIN_TARGET])[0] + args.extra_args = args._positionals[pos + has_sep:] + project = load_project(args.project_dir) if not args.dry_run: @@ -417,17 +404,19 @@ def filter_command(args): if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): raise CliException("Directory '%s' already exists " "(pass --overwrite to overwrite)" % dst_dir) - else: + elif args.target == project.build_targets.MAIN_TARGET: dst_dir = generate_next_file_name('%s-filter' % \ project.config.project_name) + else: + dst_dir = project.sources.work_dir(args.target) dst_dir = osp.abspath(dst_dir) - dataset = project.make_dataset() - filter_args = FilterModes.make_filter_args(args.mode) + filter_args['expr'] = args.filter if args.dry_run: - dataset = dataset.filter(expr=args.filter, **filter_args) + dataset = project.make_dataset(args.target) + dataset = dataset.filter(**filter_args) for item in dataset: encoded_item = DatasetItemEncoder.encode(item, dataset.categories()) xml_item = DatasetItemEncoder.to_string(encoded_item) @@ -437,10 +426,41 @@ def filter_command(args): if not args.filter: raise CliException("Expected a filter expression ('-e' argument)") - dataset.filter_project(save_dir=dst_dir, - filter_expr=args.filter, **filter_args) + if args.target == project.build_targets.MAIN_TARGET: + sources = [t for t in project.build_targets + if t != project.build_targets.MAIN_TARGET] + else: + sources = [args.target] + + for source in sources: + project.build_targets.add_filter_stage(source, filter_args) + + status = project.vcs.status() + if status: # TODO: narrow only to the affected sources + raise CliException("Can't modify project " \ + "when there are uncommitted changes: %s" % status) + + if args.apply: + log.info("Filtering...") - log.info("Subproject has been extracted to '%s'" % dst_dir) + if args.dst_dir: + project.build(args.target, out_dir=dst_dir) + + log.info("Results have been saved to '%s'" % dst_dir) + else: + for source in sources: + project.build(source) + project.sources[source].url = '' + + if not args.stage: + for source in sources: + project.build_targets.remove_stage(source, + project.build_targets[source].head.name) + + log.info("Finished") + + if args.stage: + project.save() return 0 @@ -491,137 +511,45 @@ def merge_command(args): return 0 -def build_diff_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Compare projects", +def build_apply_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Apply some operations to project", description=""" - Compares two projects, match annotations by distance.|n - |n - Examples:|n - - Compare two projects, match boxes if IoU > 0.7,|n - |s|s|s|sprint results to Tensorboard: - |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 + Applies several operations to a dataset + and produces a new dataset. """, formatter_class=MultilineFormatter) - parser.add_argument('other_project_dir', - help="Directory of the second project to be compared") + parser.add_argument('file', + help="Path to a file with a list of transforms and other actions") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Directory to save comparison results (default: do not save)") - parser.add_argument('-v', '--visualizer', - default=DatasetDiffVisualizer.DEFAULT_FORMAT.name, - choices=[f.name for f in DatasetDiffVisualizer.OutputFormat], - help="Output format (default: %(default)s)") - parser.add_argument('--iou-thresh', default=0.5, type=float, - help="IoU match threshold for detections (default: %(default)s)") - parser.add_argument('--conf-thresh', default=0.5, type=float, - help="Confidence threshold for detections (default: %(default)s)") + help="Directory to save output (default: current dir)") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") + parser.add_argument('--build', action='store_true', + help="Consider this invocation a build step") parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the first project to be compared (default: current dir)") - parser.set_defaults(command=diff_command) + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=apply_command) return parser -@error_rollback('on_error', implicit=True) -def diff_command(args): - first_project = load_project(args.project_dir) - second_project = load_project(args.other_project_dir) - - comparator = DistanceComparator(iou_threshold=args.iou_thresh) +def apply_command(args): + project = load_project(args.project_dir) dst_dir = args.dst_dir if dst_dir: if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): raise CliException("Directory '%s' already exists " "(pass --overwrite to overwrite)" % dst_dir) - else: - dst_dir = generate_next_file_name('%s-%s-diff' % ( - first_project.config.project_name, - second_project.config.project_name) - ) + elif not args.build: + dst_dir = generate_next_file_name('%s-apply' % \ + project.config.project_name) dst_dir = osp.abspath(dst_dir) - log.info("Saving diff to '%s'" % dst_dir) - if not osp.exists(dst_dir): - on_error.do(shutil.rmtree, dst_dir, ignore_errors=True) + pipeline = project.build_targets.read_pipeline(args.file) + project.build_targets.run_pipeline(pipeline, out_dir=dst_dir) - with DatasetDiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.visualizer) as visualizer: - visualizer.save( - first_project.make_dataset(), - second_project.make_dataset()) - - return 0 - -_ediff_default_if = ['id', 'group'] # avoid https://bugs.python.org/issue16399 - -def build_ediff_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Compare projects for equality", - description=""" - Compares two projects for equality.|n - |n - Examples:|n - - Compare two projects, exclude annotation group |n - |s|s|sand the 'is_crowd' attribute from comparison:|n - |s|sediff other/project/ -if group -ia is_crowd - """, - formatter_class=MultilineFormatter) - - parser.add_argument('other_project_dir', - help="Directory of the second project to be compared") - parser.add_argument('-iia', '--ignore-item-attr', action='append', - help="Ignore item attribute (repeatable)") - parser.add_argument('-ia', '--ignore-attr', action='append', - help="Ignore annotation attribute (repeatable)") - parser.add_argument('-if', '--ignore-field', action='append', - help="Ignore annotation field (repeatable, default: %s)" % \ - _ediff_default_if) - parser.add_argument('--match-images', action='store_true', - help='Match dataset items by images instead of ids') - parser.add_argument('--all', action='store_true', - help="Include matches in the output") - parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the first project to be compared (default: current dir)") - parser.set_defaults(command=ediff_command) - - return parser - -def ediff_command(args): - first_project = load_project(args.project_dir) - second_project = load_project(args.other_project_dir) - - if args.ignore_field: - args.ignore_field = _ediff_default_if - comparator = ExactComparator( - match_images=args.match_images, - ignored_fields=args.ignore_field, - ignored_attrs=args.ignore_attr, - ignored_item_attrs=args.ignore_item_attr) - matches, mismatches, a_extra, b_extra, errors = \ - comparator.compare_datasets( - first_project.make_dataset(), second_project.make_dataset()) - output = { - "mismatches": mismatches, - "a_extra_items": sorted(a_extra), - "b_extra_items": sorted(b_extra), - "errors": errors, - } - if args.all: - output["matches"] = matches - - output_file = generate_next_file_name('diff', ext='.json') - with open(output_file, 'w') as f: - json.dump(output, f, indent=4, sort_keys=True) - - print("Found:") - print("The first project has %s unmatched items" % len(a_extra)) - print("The second project has %s unmatched items" % len(b_extra)) - print("%s item conflicts" % len(errors)) - print("%s matching annotations" % len(matches)) - print("%s mismatching annotations" % len(mismatches)) - - log.info("Output has been saved to '%s'" % output_file) + log.info("Results have been saved to '%s'" % dst_dir) return 0 @@ -641,6 +569,10 @@ def build_transform_parser(parser_ctor=argparse.ArgumentParser): """ % ', '.join(builtins), formatter_class=MultilineFormatter) + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', nargs='?', + help="Project target to apply transform to (default: all)") parser.add_argument('-t', '--transform', required=True, help="Transform to apply to the project") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, @@ -649,23 +581,46 @@ def build_transform_parser(parser_ctor=argparse.ArgumentParser): help="Overwrite existing files in the save directory") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, + parser.add_argument('--stage', type=str_to_bool, default=True, + help="Include this action as a project build step (default: %(default)s)") + parser.add_argument('--apply', type=str_to_bool, default=True, + help="Run this action immediately (default: %(default)s)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, help="Additional arguments for transformation (pass '-- -h' for help)") parser.set_defaults(command=transform_command) return parser def transform_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + else: + pos = 1 + args.target = (args._positionals[:pos] or \ + [ProjectBuildTargets.MAIN_TARGET])[0] + args.extra_args = args._positionals[pos + has_sep:] + project = load_project(args.project_dir) dst_dir = args.dst_dir + + if args.stage and args.target not in project.sources and \ + args.target != project.build_targets.MAIN_TARGET: + raise CliException("Adding a stage is only allowed for " + "source or project targets") + if dst_dir: if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): raise CliException("Directory '%s' already exists " "(pass --overwrite to overwrite)" % dst_dir) else: - dst_dir = generate_next_file_name('%s-%s' % \ - (project.config.project_name, make_file_name(args.transform))) + if args.target == project.build_targets.MAIN_TARGET: + dst_dir = generate_next_file_name('%s-%s' % \ + (project.config.project_name, make_file_name(args.transform))) + else: + dst_dir = project.sources.work_dir(args.target) + dst_dir = osp.abspath(dst_dir) try: @@ -677,17 +632,77 @@ def transform_command(args): if hasattr(transform, 'parse_cmdline'): extra_args = transform.parse_cmdline(args.extra_args) - log.info("Loading the project...") - dataset = project.make_dataset() + if args.target == project.build_targets.MAIN_TARGET: + sources = [t for t in project.build_targets + if t != project.build_targets.MAIN_TARGET] + else: + sources = [args.target] + + for source in sources: + project.build_targets.add_transform_stage(source, + args.transform, extra_args) - log.info("Transforming the project...") - dataset.transform_project( - method=transform, - save_dir=dst_dir, - **extra_args - ) + status = project.vcs.status() + if status: # TODO: narrow only to the affected sources + raise CliException("Can't modify project " \ + "when there are uncommitted changes: %s" % status) - log.info("Transform results have been saved to '%s'" % dst_dir) + if args.apply: + log.info("Transforming...") + + if args.dst_dir: + project.build(args.target, out_dir=dst_dir) + + log.info("Results have been saved to '%s'" % dst_dir) + else: + for source in sources: + project.build(source) + project.sources[source].url = '' + + if not args.stage: + for source in sources: + project.build_targets.remove_stage(source, + project.build_targets[source].head.name) + + log.info("Finished") + + project.save() + + return 0 + +def build_build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Build project", + description=""" + Pulls related sources and builds the target + """, + formatter_class=MultilineFormatter) + + parser.add_argument('target', default='project', nargs='?', + help="Project target to apply transform to (default: project)") + parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, + help="Directory to save output (default: current dir)") + parser.add_argument('-f', '--force', action='store_true', + help="Rebuild the target, even if it has no changes. " + "Ignore uncommitted changes.") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=build_command) + + return parser + +def build_command(args): + project = load_project(args.project_dir) + + status = project.vcs.status() + if not args.force and [s + for d in status.values() if 'changed outs' in s + for co in d.values() + for s in co.values() + ]: + raise CliException("Can't build project " \ + "when there are uncommitted changes: %s" % status) + + project.build(args.target, force=args.force, out_dir=args.dst_dir) return 0 @@ -699,6 +714,8 @@ def build_stats_parser(parser_ctor=argparse.ArgumentParser): """, formatter_class=MultilineFormatter) + parser.add_argument('target', default='project', nargs='?', + help="Project target (default: project)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.set_defaults(command=stats_command) @@ -725,6 +742,8 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): """, formatter_class=MultilineFormatter) + parser.add_argument('target', default='project', nargs='?', + help="Project target (default: project)") parser.add_argument('--all', action='store_true', help="Print all information") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -737,23 +756,33 @@ def info_command(args): project = load_project(args.project_dir) config = project.config env = project.env - dataset = project.make_dataset() + + try: + dataset = project.make_dataset() + except DatasetMergeError as e: + dataset = None + dataset_problem = "Can't merge project sources automatically: %s " \ + "Conflicting sources are: %s" % (e, ', '.join(e.sources)) print("Project:") print(" name:", config.project_name) print(" location:", config.project_dir) print("Plugins:") - print(" importers:", ', '.join(env.importers.items)) - print(" extractors:", ', '.join(env.extractors.items)) - print(" converters:", ', '.join(env.converters.items)) - print(" launchers:", ', '.join(env.launchers.items)) + print(" extractors:", ', '.join( + sorted(set(env.extractors) | set(env.importers)))) + print(" converters:", ', '.join(env.converters)) + print(" launchers:", ', '.join(env.launchers)) print("Sources:") for source_name, source in config.sources.items(): print(" source '%s':" % source_name) print(" format:", source.format) print(" url:", source.url) - print(" location:", project.local_source_dir(source_name)) + if source.remote: + print(" remote:", + "%(url)s (%(type)s)" % project.vcs.remotes[source.remote]) + print(" location:", project.sources.work_dir(source_name)) + print(" options:", source.options) def print_extractor_info(extractor, indent=''): print("%slength:" % indent, len(extractor)) @@ -775,15 +804,18 @@ def print_extractor_info(extractor, indent=''): len(cat.items) - count_threshold) print("%s labels:" % indent, labels) - print("Dataset:") - print_extractor_info(dataset, indent=" ") + if dataset is not None: + print("Dataset:") + print_extractor_info(dataset, indent=" ") - subsets = dataset.subsets() - print(" subsets:", ', '.join(subsets)) - for subset_name in subsets: - subset = dataset.get_subset(subset_name) - print(" subset '%s':" % subset_name) - print_extractor_info(subset, indent=" ") + subsets = dataset.subsets() + print(" subsets:", ', '.join(subsets)) + for subset_name in subsets: + subset = dataset.get_subset(subset_name) + print(" subset '%s':" % subset_name) + print_extractor_info(subset, indent=" ") + else: + print("Merged dataset info is not available: ", dataset_problem) print("Models:") for model_name, model in config.models.items(): @@ -795,14 +827,16 @@ def print_extractor_info(extractor, indent=''): def build_validate_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Validate project", description=""" - Validates project based on specified task type and stores - results like statistics, reports and summary in JSON file. + Validates a project according to the task type and + reports summary in a JSON file. """, formatter_class=MultilineFormatter) parser.add_argument('task_type', choices=[task_type.name for task_type in TaskType], help="Task type for validation") + parser.add_argument('target', default='project', nargs='?', + help="Project build target to validate (default: project)") parser.add_argument('-s', '--subset', dest='subset_name', default=None, help="Subset to validate (default: None)") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -817,10 +851,10 @@ def validate_command(args): subset_name = args.subset_name dst_file_name = 'validation_results' - dataset = project.make_dataset() + dataset = project.make_dataset(args.target) if subset_name is not None: dataset = dataset.get_subset(subset_name) - dst_file_name += f'-{subset_name}' + dst_file_name += f'-{args.target}-{subset_name}' validation_results = validate_annotations(dataset, task_type) def _convert_tuple_keys_to_str(d): @@ -850,13 +884,10 @@ def build_parser(parser_ctor=argparse.ArgumentParser): formatter_class=MultilineFormatter) subparsers = parser.add_subparsers() - add_subparser(subparsers, 'create', build_create_parser) add_subparser(subparsers, 'import', build_import_parser) add_subparser(subparsers, 'export', build_export_parser) add_subparser(subparsers, 'filter', build_filter_parser) add_subparser(subparsers, 'merge', build_merge_parser) - add_subparser(subparsers, 'diff', build_diff_parser) - add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) diff --git a/datumaro/cli/contexts/remote.py b/datumaro/cli/contexts/remote.py new file mode 100644 index 0000000000..49ccf28616 --- /dev/null +++ b/datumaro/cli/contexts/remote.py @@ -0,0 +1,148 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import logging as log + +from ..util import CliException, MultilineFormatter, add_subparser +from ..util.project import load_project, generate_next_name + + +def build_add_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('-n', '--name', + help="Name of the new remote") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + + sp = parser.add_subparsers(dest='type') + + url_parser = sp.add_parser('url') + url_parser.add_argument('url', help="Path to the remote") + + git_parser = sp.add_parser('git') + git_parser.add_argument('url', help="Repository url") + git_parser.add_argument('--rev') + + parser.set_defaults(command=add_command) + + return parser + +def add_command(args): + project = load_project(args.project_dir) + + name = args.name + if not name: + name = generate_next_name(project.vcs.remotes, 'remote', + sep='-', default=1) + config = { + 'url': args.url, + 'type': args.type, + } + if args.type == 'git': + config['options'] = {'rev': args.rev} + project.vcs.remotes.add(name, config) + project.save() + + log.info("Remote '%s' has been added to the project" % name) + + return 0 + +def build_remove_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Remove remote from project", + description="Remove a remote from project.") + + parser.add_argument('names', nargs='+', + help="Names of the remotes to be removed") + parser.add_argument('-f', '--force', action='store_true', + help="Ignore possible errors during removal") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=remove_command) + + return parser + +def remove_command(args): + project = load_project(args.project_dir) + + if not args.names: + raise CliException("Expected remote name") + + for name in args.names: + project.vcs.remotes.remove(name, force=args.force) + project.save() + + return 0 + +def build_default_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Set or display the default remote", + description="Set or display the default remote.") + + parser.add_argument('name', nargs='?', + help="Name of the remote to set as default") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=default_command) + + return parser + +def default_command(args): + project = load_project(args.project_dir) + + if not args.name: + default = project.vcs.remotes.get_default() + if default: + print(default) + else: + print("The default remote is not set.") + + else: + project.vcs.remotes.set_default(args.name) + project.save() + + return 0 + +def build_info_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor() + + parser.add_argument('name', nargs='?', + help="Remote name") + parser.add_argument('-v', '--verbose', action='store_true', + help="Show details") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=info_command) + + return parser + +def info_command(args): + project = load_project(args.project_dir) + + if args.name: + remote = project.vcs.remotes[args.name] + print(remote) + else: + for name, conf in project.vcs.remotes.items(): + print(name) + if args.verbose: + print(conf) + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(description=""" + Manipulate remote data sources of a project.|n + |n + By default, the project to be operated on is searched for + in the current directory. An additional '-p' argument can be + passed to specify project location. + """, + formatter_class=MultilineFormatter) + + subparsers = parser.add_subparsers() + add_subparser(subparsers, 'add', build_add_parser) + add_subparser(subparsers, 'remove', build_remove_parser) + add_subparser(subparsers, 'default', build_default_parser) + add_subparser(subparsers, 'info', build_info_parser) + + return parser diff --git a/datumaro/cli/contexts/repository.py b/datumaro/cli/contexts/repository.py new file mode 100644 index 0000000000..bc0886a1f4 --- /dev/null +++ b/datumaro/cli/contexts/repository.py @@ -0,0 +1,125 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import logging as log + +from ..util import MultilineFormatter, add_subparser +from ..util.project import load_project, generate_next_name + + +def build_add_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Add a repository link") + + parser.add_argument('url', help="Repository url") + parser.add_argument('-n', '--name', help="Name of the new remote") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + + parser.set_defaults(command=add_command) + + return parser + +def add_command(args): + project = load_project(args.project_dir) + + name = args.name + if not name: + name = generate_next_name(project.vcs.repositories, 'remote', + sep='-', default='1') + project.vcs.repositories.add(name, args.url) + project.save() + + log.info("Repository '%s' has been added to the project" % name) + + return 0 + +def build_remove_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Remove a repository link") + + parser.add_argument('name', + help="Name of the repository to be removed") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=remove_command) + + return parser + +def remove_command(args): + project = load_project(args.project_dir) + + project.vcs.repositories.remove(args.name) + project.save() + + return 0 + +def build_default_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Set or display the default repository") + + parser.add_argument('name', nargs='?', + help="Name of the repository to set as default") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=default_command) + + return parser + +def default_command(args): + project = load_project(args.project_dir) + + if not args.name: + default = project.vcs.repositories.get_default() + if default: + print(default) + else: + print("The default repository is not set.") + + else: + project.vcs.repositories.set_default(args.name) + project.save() + + return 0 + +def build_info_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Display repository info") + + parser.add_argument('name', nargs='?', + help="Remote name") + parser.add_argument('-v', '--verbose', action='store_true', + help="Show details") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=info_command) + + return parser + +def info_command(args): + project = load_project(args.project_dir) + + if args.name: + remote = project.vcs.repositories[args.name] + print(remote) + else: + for name, conf in project.vcs.repositories.items(): + print(name) + if args.verbose: + print(conf) + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(description=""" + Manipulate repositories of the project.|n + |n + By default, the project to be operated on is searched for + in the current directory. An additional '-p' argument can be + passed to specify project location. + """, + formatter_class=MultilineFormatter) + + subparsers = parser.add_subparsers() + add_subparser(subparsers, 'add', build_add_parser) + add_subparser(subparsers, 'remove', build_remove_parser) + add_subparser(subparsers, 'default', build_default_parser) + add_subparser(subparsers, 'info', build_info_parser) + + return parser diff --git a/datumaro/cli/contexts/source.py b/datumaro/cli/contexts/source.py index caea28446c..85061cb508 100644 --- a/datumaro/cli/contexts/source.py +++ b/datumaro/cli/contexts/source.py @@ -1,30 +1,20 @@ -# Copyright (C) 2019-2021 Intel Corporation +# Copyright (C) 2019-2020 Intel Corporation # # SPDX-License-Identifier: MIT import argparse import logging as log -import os -import os.path as osp -import shutil from datumaro.components.project import Environment -from ..util import add_subparser, CliException, MultilineFormatter -from ..util.project import load_project +from datumaro.util import error_rollback +from ..util import CliException, MultilineFormatter, add_subparser +from ..util.project import generate_next_name, load_project -def build_add_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().extractors.items) - base_parser = argparse.ArgumentParser(add_help=False) - base_parser.add_argument('-n', '--name', default=None, - help="Name of the new source") - base_parser.add_argument('-f', '--format', required=True, - help="Source dataset format") - base_parser.add_argument('--skip-check', action='store_true', - help="Skip source checking") - base_parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the project to operate on (default: current dir)") +def build_add_parser(parser_ctor=argparse.ArgumentParser): + env = Environment() + builtins = sorted(set(env.extractors) | set(env.importers)) parser = parser_ctor(help="Add data source to project", description=""" @@ -32,10 +22,7 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): - a dataset in a supported format (check 'formats' section below)|n - a Datumaro project|n |n - The source can be either a local directory or a remote - git repository. Each source type has its own parameters, which can - be checked by:|n - '%s'.|n + The source can be a local path or a remote link.|n |n Formats:|n Datasets come in a wide variety of formats. Each dataset @@ -50,140 +37,90 @@ def build_add_parser(parser_ctor=argparse.ArgumentParser): An Extractor produces a list of dataset items corresponding to the dataset. It is possible to add a custom Extractor. To do this, you need to put an Extractor - definition script to /.datumaro/extractors.|n + definition script to /.datumaro/plugins.|n |n List of builtin source formats: %s|n |n Examples:|n - Add a local directory with VOC-like dataset:|n - |s|sadd path path/to/voc -f voc_detection|n + |s|sadd path/to/voc -f voc|n - Add a local file with CVAT annotations, call it 'mysource'|n |s|s|s|sto the project somewhere else:|n - |s|sadd path path/to/cvat.xml -f cvat -n mysource -p somewhere/else/ - """ % ('%(prog)s SOURCE_TYPE --help', ', '.join(builtins)), - formatter_class=MultilineFormatter, - add_help=False) + |s|sadd path/to/cvat.xml -f cvat -n mysource -p somewhere/|n + - Add a remote link to a COCO-like dataset:|n + |s|sadd git://example.net/repo/path/to/coco/dir -f coco + """ % ', '.join(builtins), + formatter_class=MultilineFormatter) + parser.add_argument('url', + help="URL to the source dataset") + parser.add_argument('-n', '--name', + help="Name of the new source (default: generate automatically)") + parser.add_argument('-f', '--format', required=True, + help="Source dataset format") + parser.add_argument('--no-check', action='store_true', + help="Skip source correctness checking") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, + help="Additional arguments for extractor (pass '-- -h' for help)") parser.set_defaults(command=add_command) - sp = parser.add_subparsers(dest='source_type', metavar='SOURCE_TYPE', - help="The type of the data source " - "(call '%s SOURCE_TYPE --help' for more info)" % parser.prog) - - dir_parser = sp.add_parser('path', help="Add local path as source", - parents=[base_parser]) - dir_parser.add_argument('url', - help="Path to the source") - dir_parser.add_argument('--copy', action='store_true', - help="Copy the dataset instead of saving source links") - - repo_parser = sp.add_parser('git', help="Add git repository as source", - parents=[base_parser]) - repo_parser.add_argument('url', - help="URL of the source git repository") - repo_parser.add_argument('-b', '--branch', default='master', - help="Branch of the source repository (default: %(default)s)") - repo_parser.add_argument('--checkout', action='store_true', - help="Do branch checkout") - - # NOTE: add common parameters to the parent help output - # the other way could be to use parse_known_args() - display_parser = argparse.ArgumentParser( - parents=[base_parser, parser], - prog=parser.prog, usage="%(prog)s [-h] SOURCE_TYPE ...", - description=parser.description, formatter_class=MultilineFormatter) - class HelpAction(argparse._HelpAction): - def __call__(self, parser, namespace, values, option_string=None): - display_parser.print_help() - parser.exit() - - parser.add_argument('-h', '--help', action=HelpAction, - help='show this help message and exit') - - # TODO: needed distinction on how to add an extractor or a remote source - return parser +@error_rollback('on_error', implicit=True) def add_command(args): project = load_project(args.project_dir) - if args.source_type == 'git': - name = args.name - if name is None: - name = osp.splitext(osp.basename(args.url))[0] - - if project.env.git.has_submodule(name): - raise CliException("Git submodule '%s' already exists" % name) - - try: - project.get_source(name) - raise CliException("Source '%s' already exists" % name) - except KeyError: - pass - - rel_local_dir = project.local_source_dir(name) - local_dir = osp.join(project.config.project_dir, rel_local_dir) - url = args.url - project.env.git.create_submodule(name, local_dir, - url=url, branch=args.branch, no_checkout=not args.checkout) - elif args.source_type == 'path': - url = osp.abspath(args.url) - if not osp.exists(url): - raise CliException("Source path '%s' does not exist" % url) - - name = args.name - if name is None: - name = osp.splitext(osp.basename(url))[0] - - if project.env.git.has_submodule(name): - raise CliException("Git submodule '%s' already exists" % name) - - try: - project.get_source(name) + name = args.name + if name: + if name in project.sources: raise CliException("Source '%s' already exists" % name) - except KeyError: - pass - - rel_local_dir = project.local_source_dir(name) - local_dir = osp.join(project.config.project_dir, rel_local_dir) - - if args.copy: - log.info("Copying from '%s' to '%s'" % (url, local_dir)) - if osp.isdir(url): - # copytree requires destination dir not to exist - shutil.copytree(url, local_dir) - url = rel_local_dir - elif osp.isfile(url): - os.makedirs(local_dir) - shutil.copy2(url, local_dir) - url = osp.join(rel_local_dir, osp.basename(url)) - else: - raise Exception("Expected file or directory") - else: - os.makedirs(local_dir) - - project.add_source(name, { 'url': url, 'format': args.format }) + else: + name = generate_next_name(list(project.sources), + 'source', sep='-', default='1') + + fmt = args.format + if fmt in project.env.importers: + arg_parser = project.env.importers[fmt] + elif fmt in project.env.extractors: + arg_parser = project.env.extractors[fmt] + else: + raise CliException("Unknown format '%s'. A format can be added" + "by providing an Extractor and Importer plugins" % fmt) - if not args.skip_check: + extra_args = {} + if args.extra_args: + if hasattr(arg_parser, 'parse_cmdline'): + extra_args = arg_parser.parse_cmdline(args.extra_args) + else: + raise CliException("Format '%s' does not accept " + "extra parameters" % fmt) + + project.sources.add(name, { + 'url': args.url, + 'format': args.format, + 'options': extra_args, + }) + on_error.do(project.sources.remove, name, force=True, keep_data=False, + ignore_errors=True) + + if not args.no_check: log.info("Checking the source...") - try: - project.make_source_project(name).make_dataset() - except Exception: - shutil.rmtree(local_dir, ignore_errors=True) - raise + project.sources.make_dataset(name) project.save() - log.info("Source '%s' has been added to the project, location: '%s'" \ - % (name, rel_local_dir)) + log.info("Source '%s' with format '%s' has been added to the project", + name, args.format) return 0 def build_remove_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Remove source from project", - description="Remove a source from a project.") + description="Remove a source from a project") - parser.add_argument('-n', '--name', required=True, - help="Name of the source to be removed") + parser.add_argument('names', nargs='+', + help="Names of the sources to be removed") parser.add_argument('--force', action='store_true', help="Ignore possible errors during removal") parser.add_argument('--keep-data', action='store_true', @@ -197,37 +134,64 @@ def build_remove_parser(parser_ctor=argparse.ArgumentParser): def remove_command(args): project = load_project(args.project_dir) - name = args.name - if not name: + if not args.names: raise CliException("Expected source name") - try: - project.get_source(name) - except KeyError: - if not args.force: - raise CliException("Source '%s' does not exist" % name) - if project.env.git.has_submodule(name): - if args.force: - log.warning("Forcefully removing the '%s' source..." % name) + for name in args.names: + project.sources.remove(name, force=args.force, keep_data=args.keep_data) + project.save() - project.env.git.remove_submodule(name, force=args.force) + log.info("Sources '%s' have been removed from the project" % \ + ', '.join(args.names)) - source_dir = osp.join(project.config.project_dir, - project.local_source_dir(name)) - project.remove_source(name) - project.save() + return 0 - if not args.keep_data: - shutil.rmtree(source_dir, ignore_errors=True) +def build_pull_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Update source revision", + description=""" + Update source revision.|n + |n + To remove existing pipelines for the updated sources + (start them from scratch), use the '--restart' parameter.|n + |n + A specific revision can be required by the '--rev' parameter. + Otherwise, the latest remote version will be used. + """) + + parser.add_argument('names', nargs='+', + help="Names of sources to update") + parser.add_argument('--rev', + help="A revision to update the source to") + parser.add_argument('--restart', action='store_true', + help="Removes existing pipelines for these sources") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=pull_command) - log.info("Source '%s' has been removed from the project" % name) + return parser + +def pull_command(args): + project = load_project(args.project_dir) + + for source in args.names: + if source not in project.sources: + raise KeyError("Unknown source '%s'" % source) + + project.sources.pull(args.names, rev=args.rev) + for source in args.names: + if args.restart: + stages = project.build_targets[source].stages + stages[:] = stages[:1] + project.build_targets.build(source, reset=False, force=True) + + project.save() return 0 def build_info_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor() - parser.add_argument('-n', '--name', + parser.add_argument('name', nargs='?', help="Source name") parser.add_argument('-v', '--verbose', action='store_true', help="Show details") @@ -241,13 +205,13 @@ def info_command(args): project = load_project(args.project_dir) if args.name: - source = project.get_source(args.name) + source = project.sources[args.name] print(source) else: - for name, conf in project.config.sources.items(): + for name, conf in project.sources.items(): print(name) if args.verbose: - print(dict(conf)) + print(conf) def build_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(description=""" @@ -267,6 +231,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): subparsers = parser.add_subparsers() add_subparser(subparsers, 'add', build_add_parser) add_subparser(subparsers, 'remove', build_remove_parser) + add_subparser(subparsers, 'pull', build_pull_parser) add_subparser(subparsers, 'info', build_info_parser) return parser diff --git a/datumaro/cli/util/__init__.py b/datumaro/cli/util/__init__.py index 0a4357f700..fabaa7e0ae 100644 --- a/datumaro/cli/util/__init__.py +++ b/datumaro/cli/util/__init__.py @@ -7,6 +7,7 @@ import textwrap from datumaro.components.errors import DatumaroError +from datumaro.util.os_util import make_file_name # pylint: disable=unused-import class CliException(DatumaroError): pass @@ -60,17 +61,3 @@ def __call__(self, parser, args, values, option_string=None): def at_least(n): return required_count(n, 0) - -def make_file_name(s): - # adapted from - # https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify - """ - Normalizes string, converts to lowercase, removes non-alpha characters, - and converts spaces to hyphens. - """ - import unicodedata, re - s = unicodedata.normalize('NFKD', s).encode('ascii', 'ignore') - s = s.decode() - s = re.sub(r'[^\w\s-]', '', s).strip().lower() - s = re.sub(r'[-\s]+', '-', s) - return s \ No newline at end of file diff --git a/datumaro/cli/util/project.py b/datumaro/cli/util/project.py index e157ded5ea..3256267ab6 100644 --- a/datumaro/cli/util/project.py +++ b/datumaro/cli/util/project.py @@ -4,10 +4,9 @@ # SPDX-License-Identifier: MIT import os -import re from datumaro.components.project import Project -from datumaro.util import cast +from datumaro.util.os_util import generate_next_name def load_project(project_dir): @@ -22,18 +21,3 @@ def generate_next_file_name(basename, basedir='.', sep='.', ext=''): """ return generate_next_name(os.listdir(basedir), basename, sep, ext) - -def generate_next_name(names, basename, sep='.', suffix='', default=None): - pattern = re.compile(r'%s(?:%s(\d+))?%s' % \ - tuple(map(re.escape, [basename, sep, suffix]))) - matches = [match for match in (pattern.match(n) for n in names) if match] - - max_idx = max([cast(match[1], int, 0) for match in matches], default=None) - if max_idx is None: - if default is not None: - idx = sep + str(default) - else: - idx = '' - else: - idx = sep + str(max_idx + 1) - return basename + idx + suffix \ No newline at end of file diff --git a/datumaro/components/config.py b/datumaro/components/config.py index 72c461ae8f..2393a3316b 100644 --- a/datumaro/components/config.py +++ b/datumaro/components/config.py @@ -225,11 +225,15 @@ def yaml_representer(dumper, value): def dump(self, path): if isinstance(path, str): with open(path, 'w') as f: - yaml.dump(self, f) + yaml.safe_dump(self, f) else: - yaml.dump(self, path) + yaml.safe_dump(self, path) -yaml.add_multi_representer(Config, Config.yaml_representer) +yaml.add_multi_representer(Config, Config.yaml_representer, + Dumper=yaml.SafeDumper) +yaml.add_multi_representer(tuple, + lambda dumper, value: dumper.represent_data(list(value)), + Dumper=yaml.SafeDumper) class DictConfig(Config): @@ -238,8 +242,13 @@ def __init__(self, default=None): self.__dict__['_default'] = default def set(self, key, value): - if key not in self.keys(allow_fallback=False): - value = self._default(value) - return super().set(key, value) - else: - return super().set(key, value) + if self._default is not None: + schema_entry_instance = self._default(value) + if not isinstance(value, type(schema_entry_instance)): + if isinstance(value, dict) and \ + isinstance(schema_entry_instance, Config): + value = schema_entry_instance + else: + raise Exception("Can not set key '%s' - schema mismatch" % (key)) + + return super().set(key, value) diff --git a/datumaro/components/config_model.py b/datumaro/components/config_model.py index 49f85e9133..06b3d64813 100644 --- a/datumaro/components/config_model.py +++ b/datumaro/components/config_model.py @@ -7,11 +7,25 @@ DictConfig as _DictConfig, \ SchemaBuilder as _SchemaBuilder +from datumaro.util import find + + +REMOTE_SCHEMA = _SchemaBuilder() \ + .add('url', str) \ + .add('type', str) \ + .add('options', dict) \ + .build() + +class Remote(Config): + def __init__(self, config=None): + super().__init__(config, schema=REMOTE_SCHEMA) + SOURCE_SCHEMA = _SchemaBuilder() \ .add('url', str) \ .add('format', str) \ .add('options', dict) \ + .add('remote', str) \ .build() class Source(Config): @@ -20,8 +34,10 @@ def __init__(self, config=None): MODEL_SCHEMA = _SchemaBuilder() \ + .add('url', str) \ .add('launcher', str) \ .add('options', dict) \ + .add('remote', str) \ .build() class Model(Config): @@ -29,20 +45,69 @@ def __init__(self, config=None): super().__init__(config, schema=MODEL_SCHEMA) +BUILDSTAGE_SCHEMA = _SchemaBuilder() \ + .add('name', str) \ + .add('type', str) \ + .add('kind', str) \ + .add('params', dict) \ + .build() + +class BuildStage(Config): + def __init__(self, config=None): + super().__init__(config, schema=BUILDSTAGE_SCHEMA) + +BUILDTARGET_SCHEMA = _SchemaBuilder() \ + .add('stages', list) \ + .add('parents', list) \ + .build() + +class BuildTarget(Config): + def __init__(self, config=None): + super().__init__(config, schema=BUILDTARGET_SCHEMA) + self.stages = [BuildStage(o) for o in self.stages] + + @property + def root(self): + return self.stages[0] + + @property + def head(self): + return self.stages[-1] + + def find_stage(self, stage): + if stage == 'root': + return self.root + elif stage == 'head': + return self.head + return find(self.stages, lambda x: x.name == stage or x == stage) + + def get_stage(self, stage): + res = self.find_stage(stage) + if res is None: + raise KeyError("Unknown stage '%s'" % stage) + return res + + PROJECT_SCHEMA = _SchemaBuilder() \ .add('project_name', str) \ .add('format_version', int) \ \ - .add('subsets', list) \ - .add('sources', lambda: _DictConfig( - lambda v=None: Source(v))) \ - .add('models', lambda: _DictConfig( - lambda v=None: Model(v))) \ + .add('default_repo', str) \ + .add('remotes', lambda: _DictConfig(lambda v=None: Remote(v))) \ + .add('sources', lambda: _DictConfig(lambda v=None: Source(v))) \ + .add('models', lambda: _DictConfig(lambda v=None: Model(v))) \ + .add('build_targets', lambda: _DictConfig(lambda v=None: BuildTarget(v))) \ \ .add('models_dir', str, internal=True) \ .add('plugins_dir', str, internal=True) \ .add('sources_dir', str, internal=True) \ .add('dataset_dir', str, internal=True) \ + .add('dvc_aux_dir', str, internal=True) \ + .add('pipelines_dir', str, internal=True) \ + .add('build_dir', str, internal=True) \ + .add('cache_dir', str, internal=True) \ + .add('revisions_dir', str, internal=True) \ + \ .add('project_filename', str, internal=True) \ .add('project_dir', str, internal=True) \ .add('env_dir', str, internal=True) \ @@ -50,13 +115,19 @@ def __init__(self, config=None): PROJECT_DEFAULT_CONFIG = Config({ 'project_name': 'undefined', - 'format_version': 1, + 'format_version': 2, 'sources_dir': 'sources', 'dataset_dir': 'dataset', 'models_dir': 'models', 'plugins_dir': 'plugins', + 'dvc_aux_dir': 'dvc_aux', + 'pipelines_dir': 'dvc_pipelines', + 'build_dir': 'build', + 'cache_dir': 'cache', + 'revisions_dir': 'revisions', + 'default_repo': 'origin', 'project_filename': 'config.yaml', 'project_dir': '', 'env_dir': '.datumaro', diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index 3cc16eb0e5..17cfcea56b 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -397,7 +397,7 @@ def flush_changes(self): class Dataset(IDataset): - _global_eager = False + _g_eager = False @classmethod def from_iterable(cls, iterable: Iterable[DatasetItem], @@ -452,8 +452,9 @@ def __init__(self, source: IDataset = None, self._format = DEFAULT_FORMAT self._source_path = None + self._options = {} - def define_categories(self, categories: Dict): + def define_categories(self, categories: CategoriesInfo): assert not self._data._categories and self._data._source is None self._data._categories = categories @@ -567,7 +568,7 @@ def is_cache_initialized(self) -> bool: @property def is_eager(self) -> bool: - return self.eager if self.eager is not None else self._global_eager + return self.eager if self.eager is not None else self._g_eager @property def is_bound(self) -> bool: @@ -626,8 +627,7 @@ def import_from(cls, path: str, format: str = None, env: Environment = None, if format in env.importers: importer = env.make_importer(format) with logging_disabled(log.INFO): - project = importer(path, **kwargs) - detected_sources = list(project.config.sources.values()) + detected_sources = importer(path, **kwargs) elif format in env.extractors: detected_sources = [{ 'url': path, 'format': format, 'options': kwargs @@ -678,10 +678,10 @@ def eager_mode(new_mode=True, dataset: Dataset = None): finally: dataset.eager = old_mode else: - old_mode = Dataset._global_eager + old_mode = Dataset._g_eager try: - Dataset._global_eager = new_mode + Dataset._g_eager = new_mode yield finally: - Dataset._global_eager = old_mode \ No newline at end of file + Dataset._g_eager = old_mode \ No newline at end of file diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index c27131a841..916e9c78d2 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -4,14 +4,12 @@ from functools import partial from glob import glob -import git import inspect import logging as log import os import os.path as osp from datumaro.components.config import Config -from datumaro.components.config_model import Model, Source from datumaro.util.os_util import import_foreign_module @@ -46,28 +44,8 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.items - -class ModelRegistry(Registry): - def __init__(self, config=None): - super().__init__(config, item_type=Model) - - def load(self, config): - # TODO: list default dir, insert values - if 'models' in config: - for name, model in config.models.items(): - self.register(name, model) - - -class SourceRegistry(Registry): - def __init__(self, config=None): - super().__init__(config, item_type=Source) - - def load(self, config): - # TODO: list default dir, insert values - if 'sources' in config: - for name, source in config.sources.items(): - self.register(name, source) - + def __iter__(self): + return iter(self.items) class PluginRegistry(Registry): def __init__(self, config=None, builtin=None, local=None): @@ -85,47 +63,6 @@ def __init__(self, config=None, builtin=None, local=None): self.register(k, v) -class GitWrapper: - def __init__(self, config=None): - self.repo = None - - if config is not None and config.project_dir: - self.init(config.project_dir) - - @staticmethod - def _git_dir(base_path): - return osp.join(base_path, '.git') - - @classmethod - def spawn(cls, path): - spawn = not osp.isdir(cls._git_dir(path)) - repo = git.Repo.init(path=path) - if spawn: - repo.config_writer().set_value("user", "name", "User") \ - .set_value("user", "email", "user@nowhere.com") \ - .release() - # gitpython does not support init, use git directly - repo.git.init() - repo.git.commit('-m', 'Initial commit', '--allow-empty') - return repo - - def init(self, path): - self.repo = self.spawn(path) - return self.repo - - def is_initialized(self): - return self.repo is not None - - def create_submodule(self, name, dst_dir, **kwargs): - self.repo.create_submodule(name, dst_dir, **kwargs) - - def has_submodule(self, name): - return name in [submodule.name for submodule in self.repo.submodules] - - def remove_submodule(self, name, **kwargs): - return self.repo.submodule(name).remove(**kwargs) - - class Environment: _builtin_plugins = None PROJECT_EXTRACTOR_NAME = 'datumaro_project' @@ -136,22 +73,18 @@ def __init__(self, config=None): config = Config(config, fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - self.models = ModelRegistry(config) - self.sources = SourceRegistry(config) - - self.git = GitWrapper(config) - env_dir = osp.join(config.project_dir, config.env_dir) builtin = self._load_builtin_plugins() custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir)) select = lambda seq, t: [e for e in seq if issubclass(e, t)] from datumaro.components.converter import Converter from datumaro.components.extractor import (Importer, Extractor, - Transform) + SourceExtractor, Transform) from datumaro.components.launcher import Launcher self.extractors = PluginRegistry( - builtin=select(builtin, Extractor), - local=select(custom, Extractor) + builtin=[e for e in select(builtin, Extractor) + if e != SourceExtractor], + local=[e for e in select(custom, Extractor) if e != SourceExtractor] ) self.extractors.register(self.PROJECT_EXTRACTOR_NAME, load_project_as_dataset) @@ -284,12 +217,6 @@ def make_converter(self, name, *args, **kwargs): def make_transform(self, name, *args, **kwargs): return partial(self.transforms.get(name), *args, **kwargs) - def register_model(self, name, model): - self.models.register(name, model) - - def unregister_model(self, name): - self.models.unregister(name) - def is_format_known(self, name): return name in self.importers or name in self.extractors diff --git a/datumaro/components/errors.py b/datumaro/components/errors.py index 3d8da0629b..729729caf1 100644 --- a/datumaro/components/errors.py +++ b/datumaro/components/errors.py @@ -8,6 +8,22 @@ class DatumaroError(Exception): pass +class VcsError(DatumaroError): + pass + +@attrs +class SourceExistsError(VcsError): + name = attrib() + + def __str__(self): + return "Source %s already exists" % (self.name, ) + +class ReadonlyProjectError(VcsError): + pass + +class DetachedProjectError(VcsError): + pass + @attrs class DatasetError(DatumaroError): item_id = attrib() @@ -18,20 +34,11 @@ def __str__(self): return "Item %s is repeated in the source sequence." % (self.item_id, ) @attrs -class MismatchingImageInfoError(DatasetError): - a = attrib() - b = attrib() - - def __str__(self): - return "Item %s: mismatching image size info: %s vs %s" % \ - (self.item_id, self.a, self.b) - -@attrs -class QualityError(DatasetError): +class DatasetQualityError(DatasetError): pass @attrs -class AnnotationsTooCloseError(QualityError): +class AnnotationsTooCloseError(DatasetQualityError): a = attrib() b = attrib() distance = attrib() @@ -41,7 +48,7 @@ def __str__(self): (self.item_id, self.a, self.b, self.distance) @attrs -class WrongGroupError(QualityError): +class WrongGroupError(DatasetQualityError): found = attrib(converter=set) expected = attrib(converter=set) group = attrib(converter=list) @@ -52,11 +59,25 @@ def __str__(self): (self.item_id, self.found, self.expected, self.group) @attrs -class MergeError(DatasetError): +class DatasetMergeError(DatasetError): sources = attrib(converter=set) @attrs -class NoMatchingAnnError(MergeError): +class MismatchingImageInfoError(DatasetMergeError): + a = attrib() + b = attrib() + sources = attrib(converter=set, default=set()) + + def __str__(self): + return "Item %s: mismatching image size info: %s vs %s" % \ + (self.item_id, self.a, self.b) + +@attrs +class ConflictingCategoriesError(DatasetMergeError): + sources = attrib(converter=set, default=set()) + +@attrs +class NoMatchingAnnError(DatasetMergeError): ann = attrib() def __str__(self): @@ -65,13 +86,13 @@ def __str__(self): (self.item_id, self.sources, self.ann) @attrs -class NoMatchingItemError(MergeError): +class NoMatchingItemError(DatasetMergeError): def __str__(self): return "Item %s: can't find matching item in sources %s" % \ (self.item_id, self.sources) @attrs -class FailedLabelVotingError(MergeError): +class FailedLabelVotingError(DatasetMergeError): votes = attrib() ann = attrib(default=None) @@ -81,7 +102,7 @@ def __str__(self): self.votes, self.sources) @attrs -class FailedAttrVotingError(MergeError): +class FailedAttrVotingError(DatasetMergeError): attr = attrib() votes = attrib() ann = attrib() diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index b913dece13..6e04c056d4 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -561,7 +561,7 @@ def categories(self) -> CategoriesInfo: def get(self, id, subset=None) -> Optional[DatasetItem]: raise NotImplementedError() -class Extractor(IExtractor): +class ExtractorBase(IExtractor): def __init__(self, length=None, subsets=None): self._length = length self._subsets = subsets @@ -602,13 +602,9 @@ def transform(self, method, *args, **kwargs): return method(self, *args, **kwargs) def select(self, pred): - class _DatasetFilter(Extractor): - def __init__(self, _): - super().__init__() + class _DatasetFilter(Transform): def __iter__(_): return filter(pred, iter(self)) - def categories(_): - return self.categories() return self.transform(_DatasetFilter) @@ -622,6 +618,10 @@ def get(self, id, subset=None): #pylint: disable=redefined-builtin return item return None +class Extractor(ExtractorBase): + "A base class for user-defined and built-in extractors" + pass + class SourceExtractor(Extractor): def __init__(self, length=None, subset=None): self._subset = subset or DEFAULT_SUBSET_NAME @@ -653,22 +653,18 @@ def find_sources(cls, path) -> List[Dict]: raise NotImplementedError() def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - - sources = self.find_sources(osp.normpath(path)) - if len(sources) == 0: + found_sources = self.find_sources(osp.normpath(path)) + if len(found_sources) == 0: raise Exception("Failed to find dataset at '%s'" % path) - for desc in sources: + sources = [] + for desc in found_sources: params = dict(extra_params) params.update(desc.get('options', {})) desc['options'] = params + sources.append(desc) - source_name = osp.splitext(osp.basename(desc['url']))[0] - project.add_source(source_name, desc) - - return project + return sources @classmethod def _find_sources_recursive(cls, path, ext, extractor_name, @@ -686,7 +682,7 @@ def _find_sources_recursive(cls, path, ext, extractor_name, break return sources -class Transform(Extractor): +class Transform(ExtractorBase): @staticmethod def wrap_item(item, **kwargs): return item.wrap(**kwargs) @@ -711,7 +707,7 @@ def subsets(self): def __len__(self): assert self._length in {None, 'parent'} or isinstance(self._length, int) if self._length is None and \ - self.__iter__.__func__ == Transform.__iter__ \ + self.__iter__.__func__ == __class__.__iter__ \ or self._length == 'parent': self._length = len(self._extractor) return super().__len__() diff --git a/datumaro/components/operations.py b/datumaro/components/operations.py index 6cfdf8e0f4..638a7ee49e 100644 --- a/datumaro/components/operations.py +++ b/datumaro/components/operations.py @@ -18,8 +18,9 @@ from datumaro.components.extractor import (AnnotationType, Bbox, CategoriesInfo, Label, LabelCategories, PointsCategories, MaskCategories) -from datumaro.components.errors import (DatumaroError, FailedAttrVotingError, - FailedLabelVotingError, MismatchingImageInfoError, NoMatchingAnnError, +from datumaro.components.errors import (DatasetMergeError, FailedAttrVotingError, + FailedLabelVotingError, ConflictingCategoriesError, + MismatchingImageInfoError, NoMatchingAnnError, NoMatchingItemError, AnnotationsTooCloseError, WrongGroupError) from datumaro.components.dataset import Dataset, DatasetItemStorage from datumaro.util.attrs_util import ensure_cls, default_if_none @@ -52,16 +53,17 @@ def merge_annotations_equal(a, b): def merge_categories(sources): categories = {} - for source in sources: + for source_idx, source in enumerate(sources): for cat_type, source_cat in source.items(): existing_cat = categories.setdefault(cat_type, source_cat) if existing_cat != source_cat and len(source_cat) != 0: if len(existing_cat) == 0: categories[cat_type] = source_cat else: - raise DatumaroError( + raise ConflictingCategoriesError( "Merging of datasets with different categories is " - "only allowed in 'merge' command.") + "only allowed in 'merge' command.", + sources=list(range(source_idx))) return categories class MergingStrategy(CliPlugin): @@ -81,23 +83,22 @@ class ExactMerge: @classmethod def merge(cls, *sources): items = DatasetItemStorage() - for source in sources: + for source_idx, source in enumerate(sources): for item in source: existing_item = items.get(item.id, item.subset) if existing_item is not None: path = existing_item.path if item.path != path: path = None - item = cls.merge_items(existing_item, item, path=path) + try: + item = cls.merge_items(existing_item, item, path=path) + except DatasetMergeError as e: + e.sources = set(range(source_idx)) + raise e items.put(item) return items - @staticmethod - def _lazy_image(item): - # NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result - return lambda: item.image - @classmethod def merge_items(cls, existing_item, current_item, path=None): return existing_item.wrap(path=path, @@ -1027,7 +1028,7 @@ def _extractor_stats(extractor): for item in extractor: if not (item.has_image and item.image.has_data): available = False - log.warn("Item %s has no image. Image stats won't be computed", + log.warning("Item %s has no image. Image stats won't be computed", item.id) break diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 829fde46ba..40aed7fcdb 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -2,382 +2,1708 @@ # # SPDX-License-Identifier: MIT -from collections import OrderedDict +import json import logging as log +import networkx as nx import os import os.path as osp import shutil +import unittest.mock +import urllib.parse +import yaml +from contextlib import ExitStack +from enum import Enum +from functools import partial +from glob import glob +from typing import Dict, List, Optional, Tuple, Union +from ruamel.yaml import YAML from datumaro.components.config import Config -from datumaro.components.config_model import (Model, Source, - PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA) -from datumaro.components.dataset import (IDataset, Dataset, DEFAULT_FORMAT) -from datumaro.components.dataset_filter import (XPathAnnotationsFilter, - XPathDatasetFilter) +from datumaro.components.config_model import (PROJECT_DEFAULT_CONFIG, + PROJECT_SCHEMA, BuildStage, Remote) from datumaro.components.environment import Environment -from datumaro.components.errors import DatumaroError -from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Extractor -from datumaro.components.launcher import ModelTransform -from datumaro.components.operations import ExactMerge +from datumaro.components.errors import (DatasetMergeError, DatumaroError, + DetachedProjectError, ReadonlyProjectError, SourceExistsError, VcsError) +from datumaro.components.dataset import Dataset, DEFAULT_FORMAT +from datumaro.util import find, error_rollback, parse_str_enum_value, str_to_bool +from datumaro.util.os_util import make_file_name, generate_next_name +from datumaro.util.log_utils import logging_disabled, catch_logs -class ProjectDataset(IDataset): - class Subset(Extractor): - def __init__(self, parent, name): - super().__init__(subsets=[name]) - self.parent = parent - self.name = name or DEFAULT_SUBSET_NAME - self.items = OrderedDict() +class ProjectSourceDataset(Dataset): + @classmethod + def from_source(cls, project: 'Project', source: str): + config = project.sources[source] + + path = osp.join(project.sources.work_dir(source), config.url) + readonly = not path or not osp.exists(path) + if path and not osp.exists(path) and not config.remote: + # backward compatibility + path = osp.join(project.config.project_dir, config.url) + readonly = True + + dataset = cls.import_from(path, env=project.env, + format=config.format, **config.options) + dataset._project = project + dataset._config = config + dataset._readonly = readonly + dataset.name = source + return dataset + + def save(self, save_dir=None, **kwargs): + if save_dir is None: + if self.readonly: + raise ReadonlyProjectError("Can't update a read-only dataset") + super().save(save_dir, **kwargs) - def __iter__(self): - yield from self.items.values() + @property + def readonly(self): + return not self._readonly and self.is_bound and \ + self._project.vcs.writeable - def __len__(self): - return len(self.items) + @property + def _env(self): + return self._project.env - def categories(self): - return self.parent.categories() + @property + def config(self): + return self._config - def get(self, id, subset=None): #pylint: disable=redefined-builtin - subset = subset or self.name - assert subset == self.name, '%s != %s' % (subset, self.name) - return super().get(id, subset) + def run_model(self, model, batch_size=1): + if isinstance(model, str): + model = self._project.models.make_executable_model(model) + return super().run_model(model, batch_size=batch_size) - def __init__(self, project): - super().__init__() +MergeStrategy = Enum('MergeStrategy', ['ours', 'theirs', 'conflict']) + +class CrudProxy: + @property + def _data(self): + raise NotImplementedError() + + def __len__(self): + return len(self._data) + + def __getitem__(self, name): + return self._data[name] + + def get(self, name, default=None): + return self._data.get(name, default) + + def __iter__(self): + return iter(self._data.keys()) + + def items(self): + return iter(self._data.items()) + + def __contains__(self, name): + return name in self._data + +class ProjectRepositories(CrudProxy): + def __init__(self, project_vcs): + self._vcs = project_vcs + + def set_default(self, name): + if name not in self: + raise KeyError("Unknown repository name '%s'" % name) + self._vcs._project.config.default_repo = name + + def get_default(self): + return self._vcs._project.config.default_repo + + @CrudProxy._data.getter + def _data(self): + return self._vcs.git.list_remotes() + + def add(self, name, url): + self._vcs.git.add_remote(name, url) + + def remove(self, name): + self._vcs.git.remove_remote(name) + +class ProjectRemotes(CrudProxy): + SUPPORTED_PROTOCOLS = {'', 'remote', 's3', 'ssh', 'http', 'https'} + + def __init__(self, project_vcs): + self._vcs = project_vcs + + def fetch(self, name=None): + self._vcs.dvc.fetch_remote(name) + + def pull(self, name=None): + self._vcs.dvc.pull_remote(name) + + def push(self, name=None): + self._vcs.dvc.push_remote(name) + + def set_default(self, name): + self._vcs.dvc.set_default_remote(name) + + def get_default(self): + return self._vcs.dvc.get_default_remote() + + @CrudProxy._data.getter + def _data(self): + return self._vcs._project.config.remotes + + def add(self, name, value): + url_parts = self.validate_url(value['url']) + if not url_parts.scheme: + value['url'] = osp.abspath(value['url']) + if not value.get('type'): + value['type'] = 'url' + + if not isinstance(value, Remote): + value = Remote(value) + value = self._data.set(name, value) + + assert value.type in {'url', 'git', 'dvc'}, value.type + self._vcs.dvc.add_remote(name, value) + return value + + def remove(self, name, force=False): + try: + self._vcs.dvc.remove_remote(name) + except DvcWrapper.DvcError: + if not force: + raise + + @classmethod + def validate_url(cls, url): + url_parts = urllib.parse.urlsplit(url) + if url_parts.scheme not in cls.SUPPORTED_PROTOCOLS and \ + not osp.exists(url): + if url_parts.scheme == 'git': + raise ValueError("git sources should be added as remote links") + if url_parts.scheme == 'dvc': + raise ValueError("dvc sources should be added as remote links") + raise ValueError( + "Invalid remote '%s': scheme '%s' is not supported, the only" + "available are: %s" % + (url, url_parts.scheme, ', '.join(cls.SUPPORTED_PROTOCOLS)) + ) + if not (url_parts.hostname or url_parts.path): + raise ValueError("URL must not be empty, url: '%s'" % url) + return url_parts + +class _DataSourceBase(CrudProxy): + def __init__(self, project, config_field): self._project = project - self._env = project.env - config = self.config - env = self.env - - sources = {} - for s_name, source in config.sources.items(): - s_format = source.format or env.PROJECT_EXTRACTOR_NAME - - url = source.url - if not source.url: - url = osp.join(config.project_dir, config.sources_dir, s_name) - sources[s_name] = Dataset.import_from(url, - format=s_format, env=env, **source.options) - self._sources = sources - - own_source = None - own_source_dir = osp.join(config.project_dir, config.dataset_dir) - if config.project_dir and osp.isdir(own_source_dir): - own_source = Dataset.load(own_source_dir) - - # merge categories - # TODO: implement properly with merging and annotations remapping - categories = ExactMerge.merge_categories(s.categories() - for s in self._sources.values()) - # ovewrite with own categories - if own_source is not None and (not categories or len(own_source) != 0): - categories.update(own_source.categories()) - self._categories = categories - - # merge items - subsets = {} - for source_name, source in self._sources.items(): - log.debug("Loading '%s' source contents..." % source_name) - for item in source: - existing_item = subsets.setdefault( - item.subset, self.Subset(self, item.subset)). \ - items.get(item.id) - if existing_item is not None: - path = existing_item.path - if item.path != path: - path = None # NOTE: move to our own dataset - item = ExactMerge.merge_items(existing_item, item, path=path) - else: - s_config = config.sources[source_name] - if s_config and \ - s_config.format != env.PROJECT_EXTRACTOR_NAME: - # NOTE: consider imported sources as our own dataset - path = None - else: - path = [source_name] + (item.path or []) - item = item.wrap(path=path) + self._field = config_field - subsets[item.subset].items[item.id] = item + @CrudProxy._data.getter + def _data(self): + return self._project.config[self._field] - # override with our items, fallback to existing images - if own_source is not None: - log.debug("Loading own dataset...") - for item in own_source: - existing_item = subsets.setdefault( - item.subset, self.Subset(self, item.subset)). \ - items.get(item.id) - if existing_item is not None: - item = item.wrap(path=None, - image=ExactMerge.merge_images(existing_item, item)) + def pull(self, names=None, rev=None): + if not self._project.vcs.writeable: + raise ReadonlyProjectError("Can't pull in a read-only project") - subsets[item.subset].items[item.id] = item + if not names: + names = [] + elif isinstance(names, str): + names = [names] + else: + names = list(names) - self._subsets = subsets + for name in names: + if name and name not in self: + raise KeyError("Unknown source '%s'" % name) - self._length = None + if rev and len(names) != 1: + raise ValueError("A revision can only be specified for a " + "single source invocation") - def iterate_own(self): - return self.select(lambda item: not item.path) + self._project.vcs.dvc.update_imports( + [self.dvcfile_path(name) for name in names], rev=rev) - def __iter__(self): - for subset in self._subsets.values(): - yield from subset + @classmethod + def _validate_url(cls, url): + return ProjectRemotes.validate_url(url) - def get_subset(self, name): - return self._subsets[name] + @classmethod + def _make_remote_name(cls, name): + return name - def subsets(self): - return self._subsets + def work_dir(self, name: str) -> str: + return osp.join(self._project.config.project_dir, name) - def categories(self): - return self._categories + def cache_dir(self, name: str, rev: str) -> str: + return osp.join( + self._project.config.project_dir, + self._project.config.env_dir, + self._project.config.cache_dir, + name, + self._project.config.revisions_dir, + rev) - def __len__(self): - return sum(len(s) for s in self._subsets.values()) - - def get(self, id, subset=None, \ - path=None): #pylint: disable=redefined-builtin - if path: - source = path[0] - return self._sources[source].get(id=id, subset=subset) - return self._subsets.get(subset, {}).get(id) - - def put(self, item, id=None, subset=None, \ - path=None): #pylint: disable=redefined-builtin - if path is None: - path = item.path - - if path: - source = path[0] - # TODO: reverse remapping - self._sources[source].put(item, id=id, subset=subset) - - if id is None: - id = item.id - if subset is None: - subset = item.subset - - item = item.wrap(path=path) - if subset not in self._subsets: - self._subsets[subset] = self.Subset(self, subset) - self._subsets[subset].items[id] = item - self._length = None - - return item - - def save(self, save_dir=None, merge=False, recursive=True, - save_images=False): - if save_dir is None: - assert self.config.project_dir - save_dir = self.config.project_dir - project = self._project + def validate_name(self, name: str): + valid_filename = make_file_name(name) + if valid_filename != name: + raise ValueError("Source name contains " + "prohibited symbols: %s" % (set(name) - set(valid_filename)) ) + + if name.startswith('.'): + raise ValueError("Source name can't start with '.'") + + def dvcfile_path(self, name): + return self._project.vcs.dvc_filepath(name) + + @classmethod + def _fix_dvc_file(cls, source_path, dvc_path, dst_name): + with open(dvc_path, 'r+') as dvc_file: + yaml = YAML(typ='rt') + dvc_data = yaml.load(dvc_file) + dvc_data['wdir'] = osp.join( + dvc_data['wdir'], osp.basename(source_path)) + dvc_data['outs'][0]['path'] = dst_name + + dvc_file.seek(0) + yaml.dump(dvc_data, dvc_file) + dvc_file.truncate() + + def _ensure_in_dir(self, source_path, dvc_path, dst_name): + if not osp.isfile(source_path): + return + tmp_dir = osp.join(self._project.config.project_dir, + self._project.config.env_dir, 'tmp') + os.makedirs(tmp_dir, exist_ok=True) + source_tmp = osp.join(tmp_dir, osp.basename(source_path)) + os.replace(source_path, source_tmp) + os.makedirs(source_path) + os.replace(source_tmp, osp.join(source_path, dst_name)) + + self._fix_dvc_file(source_path, dvc_path, dst_name) + + @error_rollback('on_error', implicit=True) + def add(self, name, value): + self.validate_name(name) + + if name in self: + raise SourceExistsError("Source '%s' already exists" % name) + + url = value.get('url', '') + + if self._project.vcs.writeable: + if url: + url_parts = self._validate_url(url) + + if not url: + # a generated source + remote_name = '' + path = url + elif url_parts.scheme == 'remote': + # add a source with existing remote + remote_name = url_parts.netloc + remote_conf = self._project.vcs.remotes[remote_name] + path = url_parts.path + if path == '/': # fix conflicts in remote interpretation + path = '' + url = remote_conf.url + path + else: + # add a source and a new remote + if not url_parts.scheme and not osp.exists(url): + raise FileNotFoundError( + "Can't find file or directory '%s'" % url) + + remote_name = self._make_remote_name(name) + if remote_name not in self._project.vcs.remotes: + on_error.do(self._project.vcs.remotes.remove, remote_name, + ignore_errors=True) + remote_conf = self._project.vcs.remotes.add(remote_name, { + 'url': url, + 'type': 'url', + }) + path = '' + + source_dir = self.work_dir(name) + + dvcfile = self.dvcfile_path(name) + if not osp.isfile(dvcfile): + on_error.do(os.remove, dvcfile, ignore_errors=True) + + if not remote_name: + pass + elif remote_conf.type == 'url': + self._project.vcs.dvc.import_url( + 'remote://%s%s' % (remote_name, path), + out=source_dir, dvc_path=dvcfile, download=True) + self._ensure_in_dir(source_dir, dvcfile, osp.basename(url)) + elif remote_conf.type in {'git', 'dvc'}: + self._project.vcs.dvc.import_repo(remote_conf.url, path=path, + out=source_dir, dvc_path=dvcfile, download=True) + self._ensure_in_dir(source_dir, dvcfile, osp.basename(url)) + else: + raise ValueError("Unknown remote type '%s'" % remote_conf.type) + + path = osp.basename(path) else: - merge = True + if not url or osp.exists(url): + # a local or a generated source + # in a read-only or in-memory project + remote_name = '' + path = url + else: + raise DetachedProjectError( + "Can only add an existing local, or generated " + "source to a detached project") - if merge: - project = Project(Config(self.config)) - project.config.remove('sources') + value['url'] = path + value['remote'] = remote_name + value = self._data.set(name, value) - save_dir = osp.abspath(save_dir) - dataset_save_dir = osp.join(save_dir, project.config.dataset_dir) + return value - converter_kwargs = { - 'save_images': save_images, - } + def remove(self, name, force=False, keep_data=True): + """Force - ignores errors and tries to wipe remaining data""" + + if name not in self._data and not force: + raise KeyError("Unknown source '%s'" % name) + + self._data.remove(name) + + if not self._project.vcs.writeable: + return - save_dir_existed = osp.exists(save_dir) + if force and not keep_data: + source_dir = self.work_dir(name) + if osp.isdir(source_dir): + shutil.rmtree(source_dir, ignore_errors=True) + + dvcfile = self.dvcfile_path(name) + if osp.isfile(dvcfile): + try: + self._project.vcs.dvc.remove(dvcfile, outs=not keep_data) + except DvcWrapper.DvcError: + if force: + os.remove(dvcfile) + else: + raise + + self._project.vcs.remotes.remove(name, force=force) + +class ProjectModels(_DataSourceBase): + def __init__(self, project): + super().__init__(project, 'models') + + def __getitem__(self, name): try: - os.makedirs(save_dir, exist_ok=True) - os.makedirs(dataset_save_dir, exist_ok=True) + return super().__getitem__(name) + except KeyError: + raise KeyError("Unknown model '%s'" % name) - if merge: - # merge and save the resulting dataset - self.env.converters.get(DEFAULT_FORMAT).convert( - self, dataset_save_dir, **converter_kwargs) - else: - if recursive: - # children items should already be updated - # so we just save them recursively - for source in self._sources.values(): - if isinstance(source, ProjectDataset): - source.save(**converter_kwargs) - - self.env.converters.get(DEFAULT_FORMAT).convert( - self.iterate_own(), dataset_save_dir, **converter_kwargs) - - project.save(save_dir) - except BaseException: - if not save_dir_existed and osp.isdir(save_dir): - shutil.rmtree(save_dir, ignore_errors=True) - raise + def work_dir(self, name): + return osp.join( + self._project.config.project_dir, + self._project.config.env_dir, + self._project.config.models_dir, name) - @property - def config(self): - return self._project.config + def make_executable_model(self, name): + model = self[name] + return self._project.env.make_launcher(model.launcher, + **model.options, model_dir=self.work_dir(name)) - @property - def env(self): - return self._project.env +class ProjectSources(_DataSourceBase): + def __init__(self, project): + super().__init__(project, 'sources') - @property - def sources(self): - return self._sources + def __getitem__(self, name): + try: + return super().__getitem__(name) + except KeyError: + raise KeyError("Unknown source '%s'" % name) + + def make_dataset(self, name, rev=None): + return ProjectSourceDataset.from_source(self._project, name) + + def validate_name(self, name): + super().validate_name(name) + + reserved_names = {'dataset', 'build', 'project'} + if name.lower() in reserved_names: + raise ValueError("Source name is reserved for internal use") + + def add(self, name, value): + value = super().add(name, value) + + self._project.build_targets.add_target(name) + + return value - def _save_branch_project(self, extractor, save_dir=None): - if not isinstance(extractor, Dataset): - extractor = Dataset.from_extractors( - extractor) # apply lazy transforms to avoid repeating traversals + def remove(self, name, force=False, keep_data=True): + self._project.build_targets.remove_target(name) - # NOTE: probably this function should be in the ViewModel layer - save_dir = osp.abspath(save_dir) - if save_dir: - dst_project = Project() + super().remove(name, force=force, keep_data=keep_data) + + +BuildStageType = Enum('BuildStageType', + ['source', 'project', 'transform', 'filter', 'convert', 'inference']) + +class ProjectBuildTargets(CrudProxy): + def __init__(self, project): + self._project = project + + @CrudProxy._data.getter + def _data(self): + data = self._project.config.build_targets + + if self.MAIN_TARGET not in data: + data[self.MAIN_TARGET] = { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.project.name, + }), + ] + } + + for source in self._project.sources: + if source not in data: + data[source] = { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.source.name, + }), + ] + } + + return data + + def __contains__(self, key): + if '.' in key: + target, stage = self._split_target_name(key) + return target in self._data and \ + self._data[target].find_stage(stage) is not None + return key in self._data + + def add_target(self, name): + return self._data.set(name, { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.source.name, + }), + ] + }) + + def add_stage(self, target, value, prev=None, + name=None) -> Tuple[BuildStage, str]: + target_name = target + target_stage_name = None + if '.' in target: + target_name, target_stage_name = self._split_target_name(target) + + if prev is None: + prev = target_stage_name + + target = self._data[target_name] + + if prev: + prev_stage = find(enumerate(target.stages), + lambda e: e[1].name == prev) + if prev_stage is None: + raise KeyError("Can't find stage '%s'" % prev) + prev_stage = prev_stage[0] else: - if not self.config.project_dir: - raise ValueError("Either a save directory or a project " - "directory should be specified") - save_dir = self.config.project_dir + prev_stage = len(target.stages) - 1 - dst_project = Project(Config(self.config)) - dst_project.config.remove('project_dir') - dst_project.config.remove('sources') - dst_project.config.project_name = osp.basename(save_dir) + name = value.get('name') or name + if not name: + name = generate_next_name((s.name for s in target.stages), + value['type'], sep='-') + else: + if target.find_stage(name): + raise VcsError("Stage '%s' already exists" % name) + value['name'] = name + + value = BuildStage(value) + assert BuildStageType[value.type] + target.stages.insert(prev_stage + 1, value) + return value, self._make_target_name(target_name, name) + + def remove_target(self, name): + assert name != self.MAIN_TARGET, "Can't remove the main target" + self._data.remove(name) + + def remove_stage(self, target, name): + assert name not in {self.BASE_STAGE}, "Can't remove a default stage" + + target = self._data[target] + idx = find(enumerate(target.stages), lambda e: e[1].name == name) + if idx is None: + raise KeyError("Can't find stage '%s'" % name) + target.stages.remove(idx) + + def add_transform_stage(self, target, transform, params=None, name=None): + if not transform in self._project.env.transforms: + raise KeyError("Unknown transform '%s'" % transform) + + return self.add_stage(target, { + 'type': BuildStageType.transform.name, + 'kind': transform, + 'params': params or {}, + }, name=name) + + def add_inference_stage(self, target, model, name=None): + if not model in self._project.config.models: + raise KeyError("Unknown model '%s'" % model) + + return self.add_stage(target, { + 'type': BuildStageType.inference.name, + 'kind': model, + }, name=name) + + def add_filter_stage(self, target, params=None, name=None): + return self.add_stage(target, { + 'type': BuildStageType.filter.name, + 'params': params or {}, + }, name=name) + + def add_convert_stage(self, target, format, \ + params=None, name=None): # pylint: disable=redefined-builtin + if not self._project.env.is_format_known(format): + raise KeyError("Unknown format '%s'" % format) + + return self.add_stage(target, { + 'type': BuildStageType.convert.name, + 'kind': format, + 'params': params or {}, + }, name=name) + + MAIN_TARGET = 'project' + BASE_STAGE = 'root' + def _get_build_graph(self): + graph = nx.DiGraph() + for target_name, target in self.items(): + if target_name == self.MAIN_TARGET: + # main target combines all the others + prev_stages = [self._make_target_name(n, t.head.name) + for n, t in self.items() if n != self.MAIN_TARGET] + else: + prev_stages = [self._make_target_name(t, self[t].head.name) + for t in target.parents] - dst_dataset = dst_project.make_dataset() - dst_dataset._categories = extractor.categories() - dst_dataset.update(extractor) + for stage in target.stages: + stage_name = self._make_target_name(target_name, stage['name']) + graph.add_node(stage_name, config=stage) + for prev_stage in prev_stages: + graph.add_edge(prev_stage, stage_name) + prev_stages = [stage_name] - dst_dataset.save(save_dir=save_dir, merge=True) + return graph - def transform(self, method, *args, **kwargs): - return method(self, *args, **kwargs) + @staticmethod + def _make_target_name(target, stage=None): + if stage: + return '%s.%s' % (target, stage) + return target - def filter(self, expr: str, filter_annotations: bool = False, - remove_empty: bool = False) -> Dataset: - if filter_annotations: - return self.transform(XPathAnnotationsFilter, expr, remove_empty) + @classmethod + def _split_target_name(cls, name): + if '.' in name: + target, stage = name.split('.', maxsplit=1) + if not target: + raise ValueError("Wrong target name '%s' - target name can't " + "be empty" % name) + if not stage: + raise ValueError("Wrong target name '%s' - expected " + "stage name after the separator" % name) else: - return self.transform(XPathDatasetFilter, expr) - - def update(self, other): - for item in other: - self.put(item) - return self - - def select(self, pred): - class _DatasetFilter(Extractor): - def __init__(self, _): - super().__init__() - def __iter__(_): - return filter(pred, iter(self)) - def categories(_): - return self.categories() - - return self.transform(_DatasetFilter) - - def export(self, save_dir: str, format, \ - **kwargs): #pylint: disable=redefined-builtin - dataset = Dataset.from_extractors(self, env=self.env) - dataset.export(save_dir, format, **kwargs) - - def define_categories(self, categories): - assert not self._categories - self._categories = categories - - def transform_project(self, method, save_dir=None, **method_kwargs): - # NOTE: probably this function should be in the ViewModel layer - if isinstance(method, str): - method = self.env.make_transform(method) - - transformed = self.transform(method, **method_kwargs) - self._save_branch_project(transformed, save_dir=save_dir) - - def apply_model(self, model, save_dir=None, batch_size=1): - # NOTE: probably this function should be in the ViewModel layer - if isinstance(model, str): - model = self._project.make_executable_model(model) - - self.transform_project(ModelTransform, launcher=model, - save_dir=save_dir, batch_size=batch_size) - - def export_project(self, save_dir, converter, - filter_expr=None, filter_annotations=False, remove_empty=False): - # NOTE: probably this function should be in the ViewModel layer - dataset = self - if filter_expr: - dataset = dataset.filter(filter_expr, - filter_annotations=filter_annotations, - remove_empty=remove_empty) - - save_dir = osp.abspath(save_dir) - save_dir_existed = osp.exists(save_dir) - try: - os.makedirs(save_dir, exist_ok=True) - converter(dataset, save_dir) - except BaseException: - if not save_dir_existed: - shutil.rmtree(save_dir) - raise - - def filter_project(self, filter_expr, filter_annotations=False, - save_dir=None, remove_empty=False): - # NOTE: probably this function should be in the ViewModel layer - dataset = self - if filter_expr: - dataset = dataset.filter(filter_expr, - filter_annotations=filter_annotations, - remove_empty=remove_empty) - self._save_branch_project(dataset, save_dir=save_dir) + target = name + stage = cls.BASE_STAGE + return target, stage + + def _get_target_subgraph(self, target): + if '.' not in target: + target = self._make_target_name(target, self[target].head.name) + + full_graph = self._get_build_graph() + + target_parents = set() + visited = set() + to_visit = {target} + while to_visit: + current = to_visit.pop() + visited.add(current) + for pred in full_graph.predecessors(current): + target_parents.add(pred) + if pred not in visited: + to_visit.add(pred) + + target_parents.add(target) + + return full_graph.subgraph(target_parents) + + def _get_target_config(self, name): + """Returns a target or stage description""" + target, stage = self._split_target_name(name) + target_config = self._data[target] + stage_config = target_config.get_stage(stage) + return stage_config + + def make_pipeline(self, target): + # a subgraph with all the target dependencies + target_subgraph = self._get_target_subgraph(target) + pipeline = [] + for node_name, node in target_subgraph.nodes.items(): + entry = { + 'name': node_name, + 'parents': list(target_subgraph.predecessors(node_name)), + 'config': dict(node['config']), + } + pipeline.append(entry) + return pipeline + + def generate_pipeline(self, target): + real_target = self._normalize_target(target) + + pipeline = self.make_pipeline(real_target) + path = osp.join(self._project.config.project_dir, + self._project.config.env_dir, self._project.config.pipelines_dir) + os.makedirs(path, exist_ok=True) + path = osp.join(path, make_file_name(target) + '.yml') + self.write_pipeline(pipeline, path) + + return path -class Project: @classmethod - def load(cls, path): - path = osp.abspath(path) - config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir, - PROJECT_DEFAULT_CONFIG.project_filename) - config = Config.parse(config_path) - config.project_dir = path - config.project_filename = osp.basename(config_path) - return Project(config) + def _read_pipeline_graph(cls, pipeline): + graph = nx.DiGraph() + for entry in pipeline: + target_name = entry['name'] + parents = entry['parents'] + target = BuildStage(entry['config']) + + graph.add_node(target_name, config=target) + for prev_stage in parents: + graph.add_edge(prev_stage, target_name) + + return graph + + def apply_pipeline(self, pipeline): + def _join_parent_datasets(force=True): + if 1 < len(parent_datasets) or force: + try: + dataset = Dataset.from_extractors(*parent_datasets, + env=self._project.env) + except DatasetMergeError as e: + e.sources = set( + getattr(parent_datasets[s], 'name') or str(s) + for s in e.sources) + raise e + else: + dataset = parent_datasets[0] + return dataset + + if len(pipeline) == 0: + raise Exception("Can't run empty pipeline") + + graph = self._read_pipeline_graph(pipeline) + + head = None + for node in graph.nodes: + if graph.out_degree(node) == 0: + assert head is None, "A pipeline can have only one " \ + "main target, but it has at least 2: %s, %s" % \ + (head, node) + head = node + assert head is not None, "A pipeline must have a finishing node" + + # Use DFS to traverse the graph and initialize nodes from roots to tops + to_visit = [head] + while to_visit: + current_name = to_visit.pop() + current = graph.nodes[current_name] + + assert current.get('dataset') is None + + parents_uninitialized = [] + parent_datasets = [] + for p_name in graph.predecessors(current_name): + parent = graph.nodes[p_name] + dataset = parent.get('dataset') + if dataset is None: + parents_uninitialized.append(p_name) + else: + parent_datasets.append(dataset) - def save(self, save_dir=None): - config = self.config + if parents_uninitialized: + to_visit.append(current_name) + to_visit.extend(parents_uninitialized) + continue - if save_dir is None: - assert config.project_dir - project_dir = config.project_dir + type_ = BuildStageType[current['config'].type] + params = current['config'].params + if type_ == BuildStageType.transform: + kind = current['config'].kind + try: + transform = self._project.env.transforms[kind] + except KeyError: + raise KeyError("Unknown transform '%s'" % kind) + + dataset = _join_parent_datasets() + dataset = dataset.transform(transform, **params) + + elif type_ == BuildStageType.filter: + dataset = _join_parent_datasets() + dataset = dataset.filter(**params) + + elif type_ == BuildStageType.inference: + kind = current['config'].kind + model = self._project.models.make_executable_model(kind) + + dataset = _join_parent_datasets() + dataset = dataset.run_model(model) + + elif type_ == BuildStageType.source: + assert len(parent_datasets) == 0, current_name + source, _ = self._split_target_name(current_name) + dataset = self._project.sources.make_dataset(source) + + elif type_ == BuildStageType.project: + dataset = _join_parent_datasets(force=True) + + elif type_ == BuildStageType.convert: + dataset = _join_parent_datasets() + + else: + raise NotImplementedError("Unknown stage type '%s'" % type_) + + current['dataset'] = dataset + + return graph, head + + @staticmethod + def write_pipeline(pipeline, path): + # force encoding and newline to produce same files on different OSes + # this should be used by DVC later, which checks file hashes + with open(path, 'w', encoding='utf-8', newline='') as f: + yaml.safe_dump(pipeline, f) + + @staticmethod + def read_pipeline(path): + with open(path) as f: + return yaml.safe_load(f) + + def make_dataset(self, target): + if len(self._data) == 1 and self.MAIN_TARGET in self._data: + raise DatumaroError("Can't create dataset from an empty project.") + + target = self._normalize_target(target) + + pipeline = self.make_pipeline(target) + graph, head = self.apply_pipeline(pipeline) + return graph.nodes[head]['dataset'] + + def _normalize_target(self, target): + if '.' not in target: + real_target = self._make_target_name(target, self[target].head.name) + else: + t, s = self._split_target_name(target) + assert self[t].get_stage(s), target + real_target = target + return real_target + + @classmethod + def pipeline_sources(cls, pipeline): + sources = set() + for item in pipeline: + if item['config']['type'] == BuildStageType.source.name: + s, _ = cls._split_target_name(item['name']) + sources.add(s) + return list(sources) + + def build(self, target, out_dir=None, force=False, reset=True): + def _rpath(p): + return osp.relpath(p, self._project.config.project_dir) + + def _source_dvc_path(source): + return _rpath(self._project.vcs.dvc_filepath(source)) + + def _reset_sources(sources): + for source in sources: + dvc_path = _source_dvc_path(source) + project_dir = self._project.config.project_dir + repo = self._project.vcs.dvc.repo + stage = repo.stage.load_file(osp.join(project_dir, dvc_path))[0] + try: + logs = None + with repo.lock, catch_logs('dvc') as logs: + stage.frozen = False + stage.run(force=True, no_commit=True) + except Exception: + if logs: + log.debug(logs.getvalue()) + raise + + def _restore_sources(sources): + if not self._project.vcs.has_commits() or not sources: + return + self._project.vcs.checkout(rev=None, targets=sources) + + _is_modified = partial(self._project.vcs.dvc.check_stage_status, + status='modified') + + + if not self._project.vcs.writeable: + raise VcsError("Can't build project in read-only or detached mode") + + if '.' in target: + raw_target, target_stage = self._split_target_name(target) + else: + raw_target = target + target_stage = None + + if raw_target not in self: + raise KeyError("Unknown target '%s'" % raw_target) + + if target_stage and target_stage != self[raw_target].head.name: + # build is not inplace, need to generate or ask output dir + inplace = False + else: + inplace = not out_dir + + if inplace: + if target == self.MAIN_TARGET: + out_dir = osp.join(self._project.config.project_dir, + self._project.config.build_dir) + elif target == raw_target: + out_dir = self._project.sources.data_dir(target) + + if not out_dir: + raise Exception("Output directory is not specified.") + + pipeline = self.make_pipeline(target) + related_sources = self.pipeline_sources(pipeline) + + if not force: + if inplace: + stages = [_source_dvc_path(s) for s in related_sources] + status = self._project.vcs.dvc.status(stages) + for stage, source in zip(stages, related_sources): + if _is_modified(status, stage): + raise VcsError("Can't build when there are " + "uncommitted changes in the source '%s'" % source) + elif osp.isdir(out_dir) and os.listdir(out_dir): + raise Exception("Can't build when output directory" + "is not empty") + + try: + if reset: + _reset_sources(related_sources) + + self.run_pipeline(pipeline, out_dir=out_dir) + + if raw_target != self.MAIN_TARGET: + related_sources.remove(raw_target) + + finally: + if reset: + _restore_sources(related_sources) + + def run_pipeline(self, pipeline, out_dir): + graph, head = self.apply_pipeline(pipeline) + head_node = graph.nodes[head] + raw_target, _ = self._split_target_name(head) + + dataset = head_node['dataset'] + dst_format = DEFAULT_FORMAT + options = {'save_images': True} + if raw_target in self._project.sources: + dst_format = self._project.sources[raw_target].format + elif head_node['config']['type'] == BuildStageType.convert.name: + dst_format = head_node['config'].kind + options.update(head_node['config'].params) + dataset.export(format=dst_format, save_dir=out_dir, **options) + + +class GitWrapper: + @staticmethod + def import_module(): + import git + return git + + try: + module = import_module.__func__() + except ImportError: + module = None + + def _git_dir(self): + return osp.join(self._project_dir, '.git') + + def __init__(self, project_dir, repo=None): + self._project_dir = project_dir + self.repo = repo + + if repo is None and \ + osp.isdir(project_dir) and osp.isdir(self._git_dir()): + self.repo = self.module.Repo(project_dir) + + @property + def initialized(self): + return self.repo is not None + + def init(self): + if self.initialized: + return + + repo = self.module.Repo.init(path=self._project_dir) + repo.config_writer() \ + .set_value("user", "name", "User") \ + .set_value("user", "email", "<>") \ + .release() + # gitpython does not support init, use git directly + repo.git.init() + + self.repo = repo + + @property + def refs(self) -> List[str]: + return [t.name for t in self.repo.refs] + + @property + def tags(self) -> List[str]: + return [t.name for t in self.repo.tags] + + def push(self, remote=None): + args = [remote] if remote else [] + remote = self.repo.remote(*args) + branch = self.repo.head.ref.name + if not self.repo.head.ref.tracking_branch(): + self.repo.git.push('--set-upstream', remote, branch) + else: + remote.push(branch) + + def pull(self, remote=None): + args = [remote] if remote else [] + return self.repo.remote(*args).pull() + + def check_updates(self, remote=None) -> List[str]: + args = [remote] if remote else [] + remote = self.repo.remote(*args) + prev_refs = {r.name: r.commit.hexsha for r in remote.refs} + remote.update() + new_refs = {r.name: r.commit.hexsha for r in remote.refs} + updated_refs = [(prev_refs.get(n), new_refs.get(n)) + for n, _ in (set(prev_refs.items()) ^ set(new_refs.items()))] + return updated_refs + + def fetch(self, remote=None): + args = [remote] if remote else [] + self.repo.remote(*args).fetch() + + def tag(self, name): + self.repo.create_tag(name) + + def checkout(self, ref=None, paths=None): + args = [] + if ref: + args.append(ref) + if paths: + args.append('--') + args.extend(paths) + self.repo.git.checkout(*args) + + def add(self, paths, all=False): # pylint: disable=redefined-builtin + if not all: + paths = [ + p2 for p in paths + for p2 in glob(osp.join(p, '**', '*'), recursive=True) + if osp.isdir(p) + ] + [ + p for p in paths if osp.isfile(p) + ] + self.repo.index.add(paths) else: - project_dir = save_dir + self.repo.git.add(all=True) + + def commit(self, message): + self.repo.index.commit(message) + + def status(self): + # R[everse] flag is needed for index to HEAD comparison + # to avoid inversed output in gitpython, which adds this flag + # git diff --cached HEAD [not not R] + diff = self.repo.index.diff(R=True) + return { + osp.relpath(d.a_rawpath.decode(), self._project_dir): d.change_type + for d in diff + } + + def list_remotes(self): + return { r.name: r.url for r in self.repo.remotes } - env_dir = osp.join(project_dir, config.env_dir) - save_dir = osp.abspath(env_dir) + def add_remote(self, name, url): + self.repo.create_remote(name, url) - project_dir_existed = osp.exists(project_dir) - env_dir_existed = osp.exists(env_dir) + def remove_remote(self, name): + self.repo.delete_remote(name) + + def is_ref(self, rev): try: - os.makedirs(save_dir, exist_ok=True) + self.repo.commit(rev) + return True + except (ValueError, self.module.exc.BadName): + return False + + def has_commits(self): + return self.is_ref('HEAD') + + IgnoreMode = Enum('IgnoreMode', ['rewrite', 'append', 'remove']) + + def ignore(self, paths, filepath=None, mode=None): + repo_root = self._project_dir + + def _make_ignored_path(path): + path = osp.join(repo_root, osp.normpath(path)) + assert path.startswith(repo_root), path + return osp.relpath(path, repo_root) + + IgnoreMode = self.IgnoreMode + mode = parse_str_enum_value(mode, IgnoreMode, IgnoreMode.append) + + if not filepath: + filepath = '.gitignore' + filepath = osp.abspath(osp.join(repo_root, filepath)) + assert filepath.startswith(repo_root), filepath + + paths = [_make_ignored_path(p) for p in paths] + + openmode = 'r+' + if not osp.isfile(filepath): + openmode = 'w+' # r+ cannot create, w+ truncates + with open(filepath, openmode) as f: + if mode in {IgnoreMode.append, IgnoreMode.remove}: + paths_to_write = set( + line.split('#', maxsplit=1)[0] \ + .split('/', maxsplit=1)[-1].strip() + for line in f + ) + f.seek(0) + else: + paths_to_write = set() - config_path = osp.join(save_dir, config.project_filename) - config.dump(config_path) - except BaseException: - if not env_dir_existed: - shutil.rmtree(save_dir, ignore_errors=True) - if not project_dir_existed: - shutil.rmtree(project_dir, ignore_errors=True) - raise + if mode in {IgnoreMode.append, IgnoreMode.rewrite}: + paths_to_write.update(paths) + elif mode == IgnoreMode.remove: + for p in paths: + paths_to_write.discard(p) + paths_to_write = sorted(p for p in paths_to_write if p) + f.write('# The file is autogenerated by Datumaro\n') + f.writelines('\n'.join(paths_to_write)) + f.truncate() + + def show(self, path, rev=None): + return self.repo.git.show('%s:%s' % (rev or '', path)) + +class DvcWrapper: @staticmethod - def generate(save_dir, config=None): - config = Config(config) - config.project_dir = save_dir - project = Project(config) - project.save(save_dir) - return project + def import_module(): + import dvc + import dvc.repo + import dvc.main + return dvc + + try: + module = import_module.__func__() + except ImportError: + module = None + + def _dvc_dir(self): + return osp.join(self._project_dir, '.dvc') + + class DvcError(Exception): + pass + + def __init__(self, project_dir): + self._project_dir = project_dir + self._repo = None + + if osp.isdir(project_dir) and osp.isdir(self._dvc_dir()): + with logging_disabled(): + self._repo = self.module.repo.Repo(project_dir) + + @property + def initialized(self): + return self._repo is not None + + @property + def repo(self): + self._repo = self.module.repo.Repo(self._project_dir) + return self._repo + + def init(self): + if self.initialized: + return + + with logging_disabled(): + self._repo = self.module.repo.Repo.init(self._project_dir) + + def push(self, targets=None, remote=None): + args = ['push'] + if remote: + args.append('--remote') + args.append(remote) + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def pull(self, targets=None, remote=None): + args = ['pull'] + if remote: + args.append('--remote') + args.append(remote) + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def check_updates(self, targets=None, remote=None): + args = ['fetch'] # no other way now? + if remote: + args.append('--remote') + args.append(remote) + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def fetch(self, targets=None, remote=None): + args = ['fetch'] + if remote: + args.append('--remote') + args.append(remote) + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def import_repo(self, url, path, out=None, dvc_path=None, rev=None, + download=True): + args = ['import'] + if dvc_path: + args.append('--file') + args.append(dvc_path) + os.makedirs(osp.dirname(dvc_path), exist_ok=True) + if rev: + args.append('--rev') + args.append(rev) + if out: + args.append('-o') + args.append(out) + if not download: + args.append('--no-exec') + args.append(url) + args.append(path) + self._exec(args) + + def import_url(self, url, out=None, dvc_path=None, download=True): + args = ['import-url'] + if dvc_path: + args.append('--file') + args.append(dvc_path) + os.makedirs(osp.dirname(dvc_path), exist_ok=True) + if not download: + args.append('--no-exec') + args.append(url) + if out: + args.append(out) + self._exec(args) + + def update_imports(self, targets=None, rev=None): + args = ['update'] + if rev: + args.append('--rev') + args.append(rev) + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def checkout(self, targets=None): + args = ['checkout'] + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def add(self, paths, dvc_path=None): + args = ['add'] + if dvc_path: + args.append('--file') + args.append(dvc_path) + os.makedirs(osp.dirname(dvc_path), exist_ok=True) + if paths: + if isinstance(paths, str): + args.append(paths) + else: + args.extend(paths) + self._exec(args) + + def remove(self, paths, outs=False): + args = ['remove'] + if outs: + args.append('--outs') + if paths: + if isinstance(paths, str): + args.append(paths) + else: + args.extend(paths) + self._exec(args) + + def commit(self, paths): + args = ['commit', '--recursive', '--force'] + if paths: + if isinstance(paths, str): + args.append(paths) + else: + args.extend(paths) + self._exec(args) + + def add_remote(self, name, config): + self._exec(['remote', 'add', name, config['url']]) + + def remove_remote(self, name): + self._exec(['remote', 'remove', name]) + + def list_remotes(self): + out = self._exec(['remote', 'list']) + return dict(line.split() for line in out.split('\n') if line) + + def get_default_remote(self): + out = self._exec(['remote', 'default']) + if out == 'No default remote set' or 1 < len(out.split()): + return None + return out + + def set_default_remote(self, name): + assert name and 1 == len(name.split()), "Invalid remote name '%s'" % name + self._exec(['remote', 'default', name]) + + def list_stages(self): + return set(s.addressing for s in self.repo.stages) + + def run(self, name, cmd, deps=None, outs=None, force=False): + args = ['run', '-n', name] + if force: + args.append('--force') + for d in deps: + args.append('-d') + args.append(d) + for o in outs: + args.append('--outs') + args.append(o) + args.extend(cmd) + self._exec(args, hide_output=False) + + def repro(self, targets=None, force=False, pull=False): + args = ['repro'] + if force: + args.append('--force') + if pull: + args.append('--pull') + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def status(self, targets=None): + args = ['status', '--show-json'] + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + out = self._exec(args).splitlines()[-1] + return json.loads(out) @staticmethod - def import_from(path, dataset_format=None, env=None, **format_options): + def check_stage_status(data, stage, status): + assert status in {'deleted', 'modified'} + return status in [s + for d in data.get(stage, []) if 'changed outs' in d + for co in d.values() + for s in co.values() + ] + + def _exec(self, args, hide_output=True, answer_on_input='y'): + contexts = ExitStack() + + args = ['--cd', self._project_dir] + args + contexts.callback(os.chdir, os.getcwd()) # restore cd after DVC + + if answer_on_input is not None: + def _input(*args): return answer_on_input + contexts.enter_context(unittest.mock.patch( + 'dvc.prompt.input', new=_input)) + + log.debug("Calling DVC main with args: %s", args) + + logs = contexts.enter_context(catch_logs('dvc')) + + with contexts: + retcode = self.module.main.main(args) + + logs = logs.getvalue() + if retcode != 0: + raise self.DvcError(logs) + if not hide_output: + print(logs) + return logs + +class ProjectVcs: + G_DETACHED = str_to_bool(os.getenv('DATUMARO_VCS_DETACHED', '0')) + + def __init__(self, project: 'Project', readonly: bool = False): + self._project = project + self.readonly = readonly + + self._git = None + self._dvc = None + if not self.G_DETACHED: + try: + GitWrapper.import_module() + DvcWrapper.import_module() + self._git = GitWrapper(project.config.project_dir) + self._dvc = DvcWrapper(project.config.project_dir) + except ImportError as e: + log.warning("Failed to init VCS for the project: %s", e) + else: + log.debug("Working in detached mode, " + "versioning commands won't be available") + + self._remotes = ProjectRemotes(self) + self._repos = ProjectRepositories(self) + + @property + def git(self) -> GitWrapper: + if not self._git: + message = "Git is not available. " + if not GitWrapper.module: + message += "Please, install the module with " \ + "'pip install gitpython'." + elif self.G_DETACHED: + message += "The project is in detached mode." + raise DetachedProjectError(message) + raise ImportError(message) + return self._git + + @property + def dvc(self) -> DvcWrapper: + if not self._dvc: + message = "DVC is not available. " + if not DvcWrapper.module: + message += "Please, install the module with " \ + "'pip install dvc'." + elif self.G_DETACHED: + message += "The project is in detached mode." + raise DetachedProjectError(message) + raise ImportError(message) + return self._dvc + + @property + def available(self): + return self._git and self._dvc + + @property + def detached(self): + return self.G_DETACHED + + @property + def initialized(self): + return not self.detached and self.available and \ + self.git.initialized and self.dvc.initialized + + @property + def writeable(self): + return not self.detached and not self.readonly and self.initialized + + @property + def readable(self): + return not self.detached and self.initialized + + @property + def remotes(self) -> ProjectRemotes: + return self._remotes + + @property + def repositories(self) -> ProjectRepositories: + return self._repos + + @property + def refs(self) -> List[str]: + if self.detached: + return [] + return self.git.refs + + @property + def tags(self) -> List[str]: + if self.detached: + return [] + return self.git.tags + + def push(self, targets: Optional[List[str]] = None, + remote: Optional[str] = None, repository: Optional[str] = None): + """ + Pushes the local DVC cache to the remote storage. + Pushes local Git changes to the remote repository. + + If not provided, uses the default remote storage and repository. + """ + + if self.detached: + log.debug("The project is in detached mode, skipping push.") + return + + if not self.writeable: + raise ReadonlyProjectError("Can't push in a read-only repository") + + assert targets is None or isinstance(targets, (str, list)), targets + if targets is None: + targets = [] + elif isinstance(targets, str): + targets = [targets] + targets = targets or [] + for i, t in enumerate(targets): + if not osp.exists(t): + targets[i] = self.dvc_filepath(t) + + # order matters + self.dvc.push(targets, remote=remote) + self.git.push(remote=repository) + + def pull(self, targets: Union[None, str, List[str]] = None, + remote: Optional[str] = None, repository: Optional[str] = None): + """ + Pulls the local DVC cache data from the remote storage. + Pulls local Git changes to the remote repository. + + If not provided, uses the default remote storage and repository. + """ + + if self.detached: + log.debug("The project is in detached mode, skipping pull.") + return + + if not self.writeable: + raise ReadonlyProjectError("Can't pull in a read-only repository") + + assert targets is None or isinstance(targets, (str, list)), targets + if targets is None: + targets = [] + elif isinstance(targets, str): + targets = [targets] + targets = targets or [] + for i, t in enumerate(targets): + if not osp.exists(t): + targets[i] = self.dvc_filepath(t) + + # order matters + self.git.pull(remote=repository) + self.dvc.pull(targets, remote=remote) + + def check_updates(self, + targets: Union[None, str, List[str]] = None) -> List[str]: + if self.detached: + log.debug("The project is in detached mode, " + "skipping checking for updates.") + return + + if not self.writeable: + raise ReadonlyProjectError( + "Can't check for updates in a read-only repository") + + assert targets is None or isinstance(targets, (str, list)), targets + if targets is None: + targets = [] + elif isinstance(targets, str): + targets = [targets] + targets = targets or [] + for i, t in enumerate(targets): + if not osp.exists(t): + targets[i] = self.dvc_filepath(t) + + updated_refs = self.git.check_updates() + updated_remotes = self.remotes.check_updates(targets) + return updated_refs, updated_remotes + + def fetch(self, targets: Union[None, str, List[str]] = None): + if self.detached: + log.debug("The project is in detached mode, skipping fetch.") + return + + if not self.writeable: + raise ReadonlyProjectError("Can't fetch in a read-only repository") + + assert targets is None or isinstance(targets, (str, list)), targets + if targets is None: + targets = [] + elif isinstance(targets, str): + targets = [targets] + targets = targets or [] + for i, t in enumerate(targets): + if not osp.exists(t): + targets[i] = self.dvc_filepath(t) + + self.git.fetch() + self.dvc.fetch(targets) + + def tag(self, name: str): + if self.detached: + log.debug("The project is in detached mode, skipping tag.") + return + + if not self.writeable: + raise ReadonlyProjectError("Can't tag in a read-only repository") + + self.git.tag(name) + + def checkout(self, rev: Optional[str] = None, + targets: Union[None, str, List[str]] = None): + if self.detached: + log.debug("The project is in detached mode, skipping checkout.") + return + + if not self.writeable: + raise ReadonlyProjectError( + "Can't checkout in a read-only repository") + + assert targets is None or isinstance(targets, (str, list)), targets + if targets is None: + targets = [] + elif isinstance(targets, str): + targets = [targets] + targets = targets or [] + for i, t in enumerate(targets): + if not osp.exists(t): + targets[i] = self.dvc_filepath(t) + + # order matters + self.git.checkout(rev, targets) + self.dvc.checkout(targets) + + def add(self, paths: List[str]): + if self.detached: + log.debug("The project is in detached mode, skipping adding files.") + return + + if not self.writeable: + raise ReadonlyProjectError( + "Can't track files in a read-only repository") + + if not paths: + raise ValueError("Expected at least one file path to add") + for p in paths: + self.dvc.add(p, dvc_path=self.dvc_aux_path(osp.basename(p))) + self.ensure_gitignored() + + def commit(self, paths: Union[None, List[str]], message): + if self.detached: + log.debug("The project is in detached mode, skipping commit.") + return + + if not self.writeable: + raise ReadonlyProjectError("Can't commit in a read-only repository") + + # order matters + if not paths: + paths = glob( + osp.join(self._project.config.project_dir, '**', '*.dvc'), + recursive=True) + self.dvc.commit(paths) + self.ensure_gitignored() + + project_dir = self._project.config.project_dir + env_dir = self._project.config.env_dir + self.git.add([ + osp.join(project_dir, env_dir), + osp.join(project_dir, '.dvc', 'config'), + osp.join(project_dir, '.dvc', '.gitignore'), + osp.join(project_dir, '.gitignore'), + osp.join(project_dir, '.dvcignore'), + ] + list(self.git.status())) + self.git.commit(message) + + def init(self): + if self.detached: + log.debug("The project is in detached mode, skipping init.") + return + + if self.readonly: + raise ReadonlyProjectError("Can't init in a read-only repository") + + # order matters + self.git.init() + self.dvc.init() + os.makedirs(self.dvc_aux_dir(), exist_ok=True) + + def status(self) -> Dict: + if self.detached: + log.debug("The project is in detached mode, " + "skipping checking status.") + return {} + + # check status of files and remotes + uncomitted = {} + uncomitted.update(self.git.status()) + uncomitted.update(self.dvc.status()) + return uncomitted + + def ensure_gitignored(self, paths: Union[None, str, List[str]] = None): + if self.detached: + return + + if not self.writeable: + raise ReadonlyProjectError("Can't update a read-only repository") + + if paths is None: + paths = [self._project.sources.work_dir(source) + for source in self._project.sources] + \ + [self._project.config.build_dir] + self.git.ignore(paths, mode='append') + + def dvc_aux_dir(self) -> str: + return osp.join(self._project.config.project_dir, + self._project.config.env_dir, + self._project.config.dvc_aux_dir) + + def dvc_filepath(self, target: str) -> str: + return osp.join(self.dvc_aux_dir(), target + '.dvc') + + def is_ref(self, ref: str) -> bool: + if self.detached: + return False + + return self.git.is_ref(ref) + + def has_commits(self) -> bool: + if self.detached: + return False + + return self.git.has_commits() + +class Project: + @classmethod + def import_from(cls, path: str, dataset_format: Optional[str] = None, + env: Optional[Environment] = None, **format_options) -> 'Project': if env is None: env = Environment() @@ -393,98 +1719,167 @@ def import_from(path, dataset_format=None, env=None, **format_options): ', '.join(matches)) dataset_format = matches[0] elif not env.is_format_known(dataset_format): - raise KeyError("Unknown dataset format '%s'" % dataset_format) - - if dataset_format in env.importers: - project = env.make_importer(dataset_format)(path, **format_options) - elif dataset_format in env.extractors: - project = Project(env=env) - project.add_source('source', { - 'url': path, - 'format': dataset_format, - 'options': format_options, - }) - else: - raise DatumaroError("Unknown format '%s'. To make it " + raise KeyError("Unknown format '%s'. To make it " "available, add the corresponding Extractor implementation " "to the environment" % dataset_format) + + project = Project(env=env) + project.sources.add('source', { + 'url': path, + 'format': dataset_format, + 'options': format_options, + }) return project - def __init__(self, config=None, env=None): - self.config = Config(config, - fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) + @classmethod + def generate(cls, save_dir: str, + config: Optional[Config] = None) -> 'Project': + config = Config(config) + config.project_dir = save_dir + project = Project(config) + project.save(save_dir) + return project + + @classmethod + def load(cls, path: str) -> 'Project': + path = osp.abspath(path) + config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir, + PROJECT_DEFAULT_CONFIG.project_filename) + config = Config.parse(config_path) + config.project_dir = path + config.project_filename = osp.basename(config_path) + return Project(config) + + @error_rollback('on_error', implicit=True) + def save(self, save_dir: Union[None, str] = None): + config = self.config + if save_dir and config.project_dir and save_dir != config.project_dir: + raise NotImplementedError("Can't copy or resave project " + "to another directory.") + + config.project_dir = save_dir or config.project_dir + assert config.project_dir + project_dir = config.project_dir + save_dir = osp.join(project_dir, config.env_dir) + + if not osp.exists(project_dir): + on_error.do(shutil.rmtree, project_dir, ignore_errors=True) + if not osp.exists(save_dir): + on_error.do(shutil.rmtree, save_dir, ignore_errors=True) + os.makedirs(save_dir, exist_ok=True) + + config.dump(osp.join(save_dir, config.project_filename)) + + if self.vcs.detached: + return + + if not self.vcs.initialized and not self.vcs.readonly: + self._vcs = ProjectVcs(self) + self.vcs.init() + if self.vcs.writeable: + self.vcs.ensure_gitignored() + self.vcs.git.add([ + osp.join(project_dir, config.env_dir), + osp.join(project_dir, '.dvc', 'config'), + osp.join(project_dir, '.dvc', '.gitignore'), + osp.join(project_dir, '.gitignore'), + osp.join(project_dir, '.dvcignore'), + ]) + + def __init__(self, config: Optional[Config] = None, + env: Optional[Environment] = None): + self._config = self._read_config(config) if env is None: - env = Environment(self.config) + env = Environment(self._config) elif config is not None: raise ValueError("env can only be provided when no config provided") - self.env = env + self._env = env + self._vcs = ProjectVcs(self) + self._sources = ProjectSources(self) + self._models = ProjectModels(self) + self._build_targets = ProjectBuildTargets(self) - def make_dataset(self): - return ProjectDataset(self) + @property + def sources(self) -> ProjectSources: + return self._sources - def add_source(self, name, value=None): - if value is None or isinstance(value, (dict, Config)): - value = Source(value) - self.config.sources[name] = value - self.env.sources.register(name, value) + @property + def models(self) -> ProjectModels: + return self._models - def remove_source(self, name): - self.config.sources.remove(name) - self.env.sources.unregister(name) + @property + def build_targets(self) -> ProjectBuildTargets: + return self._build_targets - def get_source(self, name): - try: - return self.config.sources[name] - except KeyError: - raise KeyError("Source '%s' is not found" % name) + @property + def vcs(self) -> ProjectVcs: + return self._vcs - def get_subsets(self): - return self.config.subsets + @property + def config(self) -> Config: + return self._config - def set_subsets(self, value): - if not value: - self.config.remove('subsets') - else: - self.config.subsets = value + @property + def env(self) -> Environment: + return self._env - def add_model(self, name, value=None): - if value is None or isinstance(value, (dict, Config)): - value = Model(value) - self.env.register_model(name, value) - self.config.models[name] = value + def make_dataset(self, + target: Union[None, str, List[str]] = None) -> Dataset: + if target is None: + target = 'project' + return self.build_targets.make_dataset(target) - def get_model(self, name): - try: - return self.env.models.get(name) - except KeyError: - raise KeyError("Model '%s' is not found" % name) + def publish(self): + # build + tag + push? + raise NotImplementedError() - def remove_model(self, name): - self.config.models.remove(name) - self.env.unregister_model(name) + def build(self, target: Union[None, str, List[str]] = None, + force: bool = False, out_dir: Union[None, str] = None): + if target is None: + target = 'project' + return self.build_targets.build(target, force=force, out_dir=out_dir) - def make_executable_model(self, name): - model = self.get_model(name) - return self.env.make_launcher(model.launcher, - **model.options, model_dir=osp.join( - self.config.project_dir, self.local_model_dir(name))) + @classmethod + def _read_config_v1(cls, config): + config = Config(config) + config.remove('subsets') + config.remove('format_version') + + config = cls._read_config_v2(config) + if osp.isdir(osp.join(config.project_dir, config.dataset_dir)): + name = generate_next_name(list(config.sources), 'source', + sep='-', default='1') + config.sources[name] = { + 'url': config.dataset_dir, + 'format': DEFAULT_FORMAT, + } + return config - def make_source_project(self, name): - source = self.get_source(name) + @classmethod + def _read_config_v2(cls, config): + return Config(config, + fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - config = Config(self.config) - config.remove('sources') - config.remove('subsets') - project = Project(config) - project.add_source(name, source) - return project + @classmethod + def _read_config(cls, config): + if config: + version = config.get('format_version') + else: + version = None + if version == 1: + return cls._read_config_v1(config) + elif version in {None, 2}: + return cls._read_config_v2(config) + else: + raise ValueError("Unknown project config file format version '%s'. " + "The only known are: 1, 2" % version) - def local_model_dir(self, model_name): - return osp.join( - self.config.env_dir, self.config.models_dir, model_name) +def merge_projects(a, b, strategy: MergeStrategy = None): + raise NotImplementedError() + +def compare_projects(a, b, **options): + raise NotImplementedError() - def local_source_dir(self, source_name): - return osp.join(self.config.sources_dir, source_name) def load_project_as_dataset(url): return Project.load(url).make_dataset() diff --git a/datumaro/plugins/coco_format/converter.py b/datumaro/plugins/coco_format/converter.py index 0caf89de3b..f71665f63a 100644 --- a/datumaro/plugins/coco_format/converter.py +++ b/datumaro/plugins/coco_format/converter.py @@ -263,7 +263,7 @@ def save_annotations(self, item): return if not item.has_image: - log.warn("Item '%s': skipping writing instances " + log.warning("Item '%s': skipping writing instances " "since no image info available" % item.id) return h, w = item.image.size diff --git a/datumaro/plugins/coco_format/extractor.py b/datumaro/plugins/coco_format/extractor.py index 29b97f7e27..9d808bad77 100644 --- a/datumaro/plugins/coco_format/extractor.py +++ b/datumaro/plugins/coco_format/extractor.py @@ -16,6 +16,7 @@ LabelCategories, PointsCategories ) from datumaro.util.image import Image +from datumaro.util.os_util import suppress_output from .format import CocoTask, CocoPath @@ -55,7 +56,9 @@ def _make_subset_loader(path): dataset = json.load(f) coco_api.dataset = dataset - coco_api.createIndex() + + with suppress_output(): + coco_api.createIndex() return coco_api def _load_categories(self, loader): diff --git a/datumaro/plugins/coco_format/importer.py b/datumaro/plugins/coco_format/importer.py index f613143e15..806d57e55f 100644 --- a/datumaro/plugins/coco_format/importer.py +++ b/datumaro/plugins/coco_format/importer.py @@ -29,9 +29,6 @@ def detect(cls, path): return len(cls.find_sources(path)) != 0 def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - subsets = self.find_sources(path) if len(subsets) == 0: @@ -50,6 +47,7 @@ def __call__(self, path, **extra_params): "Only one type will be used: %s" \ % (", ".join(t.name for t in ann_types), selected_ann_type.name)) + sources = [] for ann_files in subsets.values(): for ann_type, ann_file in ann_files.items(): if ann_type in conflicting_types: @@ -59,14 +57,13 @@ def __call__(self, path, **extra_params): continue log.info("Found a dataset at '%s'" % ann_file) - source_name = osp.splitext(osp.basename(ann_file))[0] - project.add_source(source_name, { + sources.append({ 'url': ann_file, 'format': self._COCO_EXTRACTORS[ann_type], 'options': dict(extra_params), }) - return project + return sources @staticmethod def find_sources(path): @@ -85,7 +82,7 @@ def find_sources(path): try: ann_type = CocoTask[ann_type] except KeyError: - log.warn("Skipping '%s': unknown subset " + log.warning("Skipping '%s': unknown subset " "type '%s', the only known are: %s" % \ (subset_path, ann_type, ', '.join([e.name for e in CocoTask]))) diff --git a/datumaro/plugins/voc_format/importer.py b/datumaro/plugins/voc_format/importer.py index 7da323249b..6a5b4c99cd 100644 --- a/datumaro/plugins/voc_format/importer.py +++ b/datumaro/plugins/voc_format/importer.py @@ -29,49 +29,30 @@ def find_path(root_path, path, depth=4): class VocImporter(Importer): _TASKS = [ - (VocTask.classification, 'voc_classification', 'Main'), - (VocTask.detection, 'voc_detection', 'Main'), - (VocTask.segmentation, 'voc_segmentation', 'Segmentation'), - (VocTask.person_layout, 'voc_layout', 'Layout'), - (VocTask.action_classification, 'voc_action', 'Action'), + ('voc_classification', 'Main'), + ('voc_detection', 'Main'), + ('voc_segmentation', 'Segmentation'), + ('voc_layout', 'Layout'), + ('voc_action', 'Action'), ] - def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - - subset_paths = self.find_sources(path) - if len(subset_paths) == 0: - raise Exception("Failed to find 'voc' dataset at '%s'" % path) - - for task, extractor_type, subset_path in subset_paths: - project.add_source('%s-%s' % - (task.name, osp.splitext(osp.basename(subset_path))[0]), - { - 'url': subset_path, - 'format': extractor_type, - 'options': dict(extra_params), - }) - - return project - @classmethod def find_sources(cls, path): # find root path for the dataset root_path = path - for task, extractor_type, task_dir in cls._TASKS: + for extractor_type, task_dir in cls._TASKS: task_path = find_path(root_path, osp.join(VocPath.SUBSETS_DIR, task_dir)) if task_path: root_path = osp.dirname(osp.dirname(task_path)) break - subset_paths = [] - for task, extractor_type, task_dir in cls._TASKS: + subsets = [] + for extractor_type, task_dir in cls._TASKS: task_path = osp.join(root_path, VocPath.SUBSETS_DIR, task_dir) - if not osp.isdir(task_path): continue - task_subsets = [p for p in glob(osp.join(task_path, '*.txt')) - if '_' not in osp.basename(p)] - subset_paths += [(task, extractor_type, p) for p in task_subsets] - return subset_paths + + subsets += cls._find_sources_recursive( + task_path, '.txt', extractor_type, max_depth=0, + file_filter=lambda p: '_' not in osp.basename(p)) + return subsets diff --git a/datumaro/plugins/yolo_format/converter.py b/datumaro/plugins/yolo_format/converter.py index fb71b8f172..71f021f0e6 100644 --- a/datumaro/plugins/yolo_format/converter.py +++ b/datumaro/plugins/yolo_format/converter.py @@ -49,7 +49,7 @@ def apply(self): if not subset_name or subset_name == DEFAULT_SUBSET_NAME: subset_name = YoloPath.DEFAULT_SUBSET_NAME elif subset_name not in YoloPath.SUBSET_NAMES: - log.warn("Skipping subset export '%s'. " + log.warning("Skipping subset export '%s'. " "If specified, the only valid names are %s" % \ (subset_name, ', '.join( "'%s'" % s for s in YoloPath.SUBSET_NAMES))) diff --git a/datumaro/util/command_targets.py b/datumaro/util/command_targets.py index 50c854f271..d1b30350d3 100644 --- a/datumaro/util/command_targets.py +++ b/datumaro/util/command_targets.py @@ -35,12 +35,7 @@ def is_project(value, project=None): def is_source(value, project=None): if project is not None: - try: - project.get_source(value) - return True - except KeyError: - pass - + return value in project.sources return False def is_external_source(value): diff --git a/datumaro/util/log_utils.py b/datumaro/util/log_utils.py index 6c8d8421e7..f846c2a6de 100644 --- a/datumaro/util/log_utils.py +++ b/datumaro/util/log_utils.py @@ -4,8 +4,10 @@ # SPDX-License-Identifier: MIT from contextlib import contextmanager +from io import StringIO import logging + @contextmanager def logging_disabled(max_level=logging.CRITICAL): previous_level = logging.root.manager.disable @@ -13,4 +15,22 @@ def logging_disabled(max_level=logging.CRITICAL): try: yield finally: - logging.disable(previous_level) \ No newline at end of file + logging.disable(previous_level) + +@contextmanager +def catch_logs(logger=None): + logger = logging.getLogger(logger) + + old_propagate = logger.propagate + prev_handlers = logger.handlers + + stream = StringIO() + handler = logging.StreamHandler(stream) + logger.handlers = [handler] + logger.propagate = False + + try: + yield stream + finally: + logger.handlers = prev_handlers + logger.propagate = old_propagate \ No newline at end of file diff --git a/datumaro/util/os_util.py b/datumaro/util/os_util.py index 094329206a..48ba99d5c4 100644 --- a/datumaro/util/os_util.py +++ b/datumaro/util/os_util.py @@ -2,11 +2,17 @@ # # SPDX-License-Identifier: MIT +from contextlib import contextmanager +from io import StringIO import importlib import os import os.path as osp +import re import subprocess import sys +import unicodedata + +from . import cast DEFAULT_MAX_DEPTH = 10 @@ -47,6 +53,42 @@ def walk(path, max_depth=None): yield dirpath, dirnames, filenames +@contextmanager +def suppress_output(stdout=True, stderr=False): + with open(os.devnull, "w") as devnull: + if stdout: + old_stdout = sys.stdout + sys.stdout = devnull + + if stderr: + old_stderr = sys.stderr + sys.stderr = devnull + + try: + yield + finally: + if stdout: + sys.stdout = old_stdout + if stderr: + sys.stderr = old_stderr + +@contextmanager +def catch_output(): + stdout = StringIO() + stderr = StringIO() + + old_stdout = sys.stdout + sys.stdout = stdout + + old_stderr = sys.stderr + sys.stderr = stderr + + try: + yield stdout, stderr + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + def dir_items(path, ext, truncate_ext=False): items = [] for f in os.listdir(path): @@ -72,3 +114,31 @@ def split_path(path): parts.reverse() return parts + +def make_file_name(s): + # adapted from + # https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify + """ + Normalizes string, converts to lowercase, removes non-alpha characters, + and converts spaces to hyphens. + """ + s = unicodedata.normalize('NFKD', s).encode('ascii', 'ignore') + s = s.decode() + s = re.sub(r'[^\w\s-]', '', s).strip().lower() + s = re.sub(r'[-\s]+', '-', s) + return s + +def generate_next_name(names, basename, sep='.', suffix='', default=None): + pattern = re.compile(r'%s(?:%s(\d+))?%s' % \ + tuple(map(re.escape, [basename, sep, suffix]))) + matches = [match for match in (pattern.match(n) for n in names) if match] + + max_idx = max([cast(match[1], int, 0) for match in matches], default=None) + if max_idx is None: + if default is not None: + idx = sep + str(default) + else: + idx = '' + else: + idx = sep + str(max_idx + 1) + return basename + idx + suffix diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index 8c5cf05af2..128ac6ee1b 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: MIT +from glob import glob import inspect import os import os.path as osp @@ -156,4 +157,17 @@ def test_save_and_load(test, source_dataset, converter, test_dir, importer, if not compare: compare = compare_datasets - compare(test, expected=target_dataset, actual=parsed_dataset, **kwargs) \ No newline at end of file + compare(test, expected=target_dataset, actual=parsed_dataset, **kwargs) + +def compare_dirs(test, a, b): + for a_path in glob(osp.join(a, '**', '*'), recursive=True): + rel_path = osp.relpath(a_path, a) + b_path = osp.join(b, rel_path) + if osp.isdir(a_path): + test.assertTrue(osp.isdir(b_path), rel_path) + continue + + test.assertTrue(osp.isfile(b_path), rel_path) + with open(a_path, 'rb') as a_file, \ + open(b_path, 'rb') as b_file: + test.assertEqual(a_file.read(), b_file.read(), rel_path) diff --git a/requirements.txt b/requirements.txt index 5cfc7dd4f2..5498e4a59c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ opencv-python-headless>=4.1.0.25 Pillow>=6.1.0 pycocotools>=2.0.0 PyYAML>=5.3.1 +ruamel.yaml>=0.16.5 scikit-image>=0.15.0 tensorboardX>=1.8 pandas>=1.1.5 diff --git a/setup.py b/setup.py index d1e5ff0152..a31a813846 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ def get_requirements(): requirements = [ 'attrs>=19.3.0', 'defusedxml', - 'GitPython', 'lxml', 'matplotlib', 'numpy>=1.17.3', @@ -45,6 +44,7 @@ def get_requirements(): 'pycocotools; platform_system != "Windows"', 'pycocotools-windows; platform_system == "Windows"', 'PyYAML', + 'ruamel.yaml', 'scikit-image', 'tensorboardX', ] @@ -82,6 +82,7 @@ def get_requirements(): extras_require={ 'tf': ['tensorflow'], 'tf-gpu': ['tensorflow-gpu'], + 'vcs': ['GitPython', 'dvc'], }, entry_points={ 'console_scripts': [ diff --git a/tests/assets/compat/v0.1/project/.datumaro/config.yaml b/tests/assets/compat/v0.1/project/.datumaro/config.yaml new file mode 100644 index 0000000000..c78b58be81 --- /dev/null +++ b/tests/assets/compat/v0.1/project/.datumaro/config.yaml @@ -0,0 +1,4 @@ +format_version: 1 +models: {} +project_name: undefined +subsets: [] diff --git a/tests/assets/compat/v0.1/project/dataset/annotations/test.json b/tests/assets/compat/v0.1/project/dataset/annotations/test.json new file mode 100644 index 0000000000..009bf270a4 --- /dev/null +++ b/tests/assets/compat/v0.1/project/dataset/annotations/test.json @@ -0,0 +1 @@ +{"info": {}, "categories": {"label": {"labels": [{"name": "a", "parent": ""}, {"name": "b", "parent": ""}]}}, "items": [{"id": "1", "annotations": [{"id": 0, "type": "label", "attributes": {}, "group": 0, "label_id": 1}]}]} \ No newline at end of file diff --git a/tests/assets/compat/v0.1/project/dataset/annotations/train.json b/tests/assets/compat/v0.1/project/dataset/annotations/train.json new file mode 100644 index 0000000000..5229dfc971 --- /dev/null +++ b/tests/assets/compat/v0.1/project/dataset/annotations/train.json @@ -0,0 +1 @@ +{"info": {}, "categories": {"label": {"labels": [{"name": "a", "parent": ""}, {"name": "b", "parent": ""}]}}, "items": [{"id": "0", "annotations": [{"id": 0, "type": "label", "attributes": {}, "group": 0, "label_id": 0}]}]} \ No newline at end of file diff --git a/tests/cli/test_project.py b/tests/cli/test_project.py new file mode 100644 index 0000000000..3fcfb0057c --- /dev/null +++ b/tests/cli/test_project.py @@ -0,0 +1,115 @@ +import numpy as np +import os.path as osp +import shutil + +from unittest import TestCase, skipIf + +from datumaro.components.dataset import Dataset +from datumaro.components.extractor import Bbox, DatasetItem +from datumaro.cli.__main__ import main +from datumaro.util.test_utils import TestDir, compare_datasets + + +no_vcs_installed = False +try: + import git # pylint: disable=unused-import + import dvc # pylint: disable=unused-import +except ImportError: + no_vcs_installed = True + +def run(test, *args, expected_code=0): + test.assertEqual(expected_code, main(args), str(args)) + +class ProjectIntegrationScenarios(TestCase): + def test_can_convert_voc_as_coco(self): + voc_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'voc_dataset') + + with TestDir() as test_dir: + result_dir = osp.join(test_dir, 'coco_export') + + run(self, 'convert', + '-if', 'voc', '-i', voc_dir, + '-f', 'coco', '-o', result_dir, + '--', '--save-images') + + self.assertTrue(osp.isdir(result_dir)) + + @skipIf(no_vcs_installed, "No VCS modules (Git, DVC) installed") + def test_can_export_coco_as_voc(self): + # TODO: use subformats once importers are removed + coco_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'coco_dataset', 'coco_instances') + + with TestDir() as test_dir: + run(self, 'create', '-o', test_dir) + run(self, 'add', '-f', 'coco', '-p', test_dir, coco_dir) + + result_dir = osp.join(test_dir, 'voc_export') + run(self, 'export', '-f', 'voc', '-p', test_dir, '-o', result_dir, + '--', '--save-images') + + self.assertTrue(osp.isdir(result_dir)) + + @skipIf(no_vcs_installed, "No VCS modules (Git, DVC) installed") + def test_can_list_info(self): + # TODO: use subformats once importers are removed + coco_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'coco_dataset', 'coco_instances') + + with TestDir() as test_dir: + run(self, 'create', '-o', test_dir) + run(self, 'add', '-f', 'coco', '-p', test_dir, coco_dir) + + run(self, 'info', '-p', test_dir) + + @skipIf(no_vcs_installed, "No VCS modules (Git, DVC) installed") + def test_can_use_vcs(self): + with TestDir() as test_dir: + dataset_dir = osp.join(test_dir, 'dataset') + project_dir = osp.join(test_dir, 'proj') + result_dir = osp.join(project_dir, 'result') + + Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), annotations=[ + Bbox(1, 1, 1, 1, label=0), + Bbox(2, 2, 2, 2, label=1), + ]) + ], categories=['a', 'b']).save(dataset_dir, save_images=True) + + run(self, 'create', '-o', project_dir) + run(self, 'commit', '-p', project_dir, '-m', 'Initial commit') + + run(self, 'add', '-p', project_dir, '-f', 'datumaro', dataset_dir) + run(self, 'commit', '-p', project_dir, '-m', 'Add data') + + run(self, 'transform', '-p', project_dir, + '-t', 'remap_labels', 'source-1', '--', '-l', 'b:cat') + run(self, 'commit', '-p', project_dir, '-m', 'Add transform') + + run(self, 'filter', '-p', project_dir, + '-e', '/item/annotation[label="cat"]', '-m', 'i+a', 'source-1') + run(self, 'commit', '-p', project_dir, '-m', 'Add filter') + + run(self, 'export', '-p', project_dir, '-f', 'coco', '-o', result_dir) + parsed = Dataset.import_from(result_dir, 'coco') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), + annotations=[ + Bbox(2, 2, 2, 2, label=1, + group=1, id=1, attributes={'is_crowd': False}), + ], attributes={ 'id': 1 }) + ], categories=['a', 'cat']), parsed, require_images=True) + + shutil.rmtree(result_dir, ignore_errors=True) + run(self, 'checkout', '-p', project_dir, 'HEAD~1') + run(self, 'export', '-p', project_dir, '-f', 'coco', '-o', result_dir) + parsed = Dataset.import_from(result_dir, 'coco') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), annotations=[ + Bbox(1, 1, 1, 1, label=0, + group=1, id=1, attributes={'is_crowd': False}), + Bbox(2, 2, 2, 2, label=1, + group=2, id=2, attributes={'is_crowd': False}), + ], attributes={ 'id': 1 }) + ], categories=['a', 'cat']), parsed, require_images=True) diff --git a/tests/test_command_targets.py b/tests/test_command_targets.py index 5b8a69f318..333a412f38 100644 --- a/tests/test_command_targets.py +++ b/tests/test_command_targets.py @@ -110,7 +110,7 @@ def test_source_false_when_no_project(self): def test_source_true_when_source_exists(self): source_name = 'qwerty' project = Project() - project.add_source(source_name) + project.sources.add(source_name, {'url': ''}) target = SourceTarget(project=project) status = target.test(source_name) @@ -120,7 +120,7 @@ def test_source_true_when_source_exists(self): def test_source_false_when_source_doesnt_exist(self): source_name = 'qwerty' project = Project() - project.add_source(source_name) + project.sources.add(source_name, {'url': ''}) target = SourceTarget(project=project) status = target.test(source_name + '123') diff --git a/tests/test_config.py b/tests/test_config.py index 32332b3545..bba0635e40 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,9 @@ +import os.path as osp + from unittest import TestCase from datumaro.components.config import Config, DictConfig, SchemaBuilder +from datumaro.util.test_utils import TestDir class ConfigTest(TestCase): @@ -30,3 +33,40 @@ def test_can_produce_multilayer_config_from_dict(self): }, schema=schema_top) self.assertEqual(value, source.container['elem'].desc.options['k']) + + def test_can_save_and_load(self): + with TestDir() as test_dir: + schema_low = SchemaBuilder() \ + .add('options', dict) \ + .build() + schema_mid = SchemaBuilder() \ + .add('desc', lambda: Config(schema=schema_low)) \ + .build() + schema_top = SchemaBuilder() \ + .add('container', lambda: DictConfig( + lambda v: Config(v, schema=schema_mid))) \ + .build() + + source = Config({ + 'container': { + 'elem': { + 'desc': { + 'options': { + 'k': (1, 2, 3), + 'd': 'asfd', + } + } + } + } + }, schema=schema_top) + p = osp.join(test_dir, 'f.yaml') + + source.dump(p) + + loaded = Config.parse(p, schema=schema_top) + + self.assertTrue(isinstance( + loaded.container['elem'].desc.options['k'], list)) + loaded.container['elem'].desc.options['k'] = \ + tuple(loaded.container['elem'].desc.options['k']) + self.assertEqual(source, loaded) \ No newline at end of file diff --git a/tests/test_cvat_format.py b/tests/test_cvat_format.py index 5b2c60e130..2bbbe9f0ac 100644 --- a/tests/test_cvat_format.py +++ b/tests/test_cvat_format.py @@ -4,7 +4,7 @@ import numpy as np from unittest import TestCase -from datumaro.components.project import Dataset +from datumaro.components.dataset import Dataset from datumaro.components.extractor import (DatasetItem, AnnotationType, Points, Polygon, PolyLine, Bbox, Label, LabelCategories, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 37f83c6f45..cedd4aff63 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -458,7 +458,7 @@ def __iter__(self): self.assertTrue(iter_called) - def test_can_chain_lazy_tranforms(self): + def test_can_chain_lazy_transforms(self): iter_called = False class TestExtractor(Extractor): def __iter__(self): @@ -608,6 +608,14 @@ def test_loader(): self.assertFalse(called) + def test_can_transform_labels(self): + result = Dataset.from_iterable([], categories=['a', 'b']) + + result.transform('remap_labels', {'a': 'c'}) + + compare_datasets(self, Dataset.from_iterable([], categories=['c', 'b']), + result) + class DatasetItemTest(TestCase): def test_ctor_requires_id(self): diff --git a/tests/test_project.py b/tests/test_project.py index b4ab7bbf58..8adae477ab 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,37 +1,34 @@ import numpy as np import os import os.path as osp +import shutil -from unittest import TestCase - -from datumaro.components.project import Project, Environment -from datumaro.components.config_model import Source, Model -from datumaro.components.launcher import Launcher, ModelTransform -from datumaro.components.extractor import (Extractor, DatasetItem, - Label, LabelCategories, AnnotationType) +from unittest import TestCase, skipIf, skip from datumaro.components.config import Config +from datumaro.components.config_model import Source, Model from datumaro.components.dataset import Dataset, DEFAULT_FORMAT -from datumaro.util.test_utils import TestDir, compare_datasets +from datumaro.components.extractor import (Bbox, Extractor, DatasetItem, + Label, LabelCategories, AnnotationType, Transform) +from datumaro.components.errors import VcsError +from datumaro.components.environment import Environment +from datumaro.components.launcher import Launcher, ModelTransform +from datumaro.components.project import (Project, BuildStageType, + GitWrapper, DvcWrapper) +from datumaro.util.test_utils import TestDir, compare_datasets, compare_dirs -class ProjectTest(TestCase): - def test_project_generate(self): - src_config = Config({ - 'project_name': 'test_project', - 'format_version': 1, - }) +class BaseProjectTest(TestCase): + def test_can_generate_project(self): + src_config = Config({ 'project_name': 'test_project' }) - with TestDir() as test_dir: - project_path = test_dir + with TestDir() as project_path: Project.generate(project_path, src_config) - self.assertTrue(osp.isdir(project_path)) - result_config = Project.load(project_path).config + + self.assertTrue(osp.isdir(project_path)) self.assertEqual( src_config.project_name, result_config.project_name) - self.assertEqual( - src_config.format_version, result_config.format_version) @staticmethod def test_default_ctor_is_ok(): @@ -41,133 +38,143 @@ def test_default_ctor_is_ok(): def test_empty_config_is_ok(): Project(Config()) - def test_add_source(self): - source_name = 'source' - origin = Source({ - 'url': 'path', - 'format': 'ext' - }) + def test_inmemory_project_is_not_initialized(self): project = Project() - project.add_source(source_name, origin) + self.assertFalse(project.vcs.detached) + self.assertFalse(project.vcs.readable) + self.assertFalse(project.vcs.writeable) + self.assertFalse(project.vcs.initialized) - added = project.get_source(source_name) - self.assertIsNotNone(added) - self.assertEqual(added, origin) + def test_can_add_existing_local_source(self): + # Reasons to exist: + # - Backward compatibility + # - In-memory and detached projects - def test_added_source_can_be_saved(self): - source_name = 'source' - origin = Source({ - 'url': 'path', - }) - project = Project() - project.add_source(source_name, origin) + with TestDir() as test_dir: + source_name = 'source' + origin = Source({ + 'url': test_dir, + 'format': 'fmt', + 'options': { 'a': 5, 'b': 'hello' } + }) + project = Project() + + project.sources.add(source_name, origin) - saved = project.config + added = project.sources[source_name] + self.assertEqual(added.url, origin.url) + self.assertEqual(added.format, origin.format) + self.assertEqual(added.options, origin.options) + + def test_cant_add_nonexisting_local_source(self): + project = Project() - self.assertEqual(origin, saved.sources[source_name]) + with self.assertRaisesRegex(Exception, 'Can only add an existing'): + project.sources.add('source', { 'url': '_p_a_t_h_' }) - def test_added_source_can_be_dumped(self): + def test_can_add_generated_source(self): source_name = 'source' origin = Source({ - 'url': 'path', + # no url + 'format': 'fmt', + 'options': { 'c': 5, 'd': 'hello' } }) project = Project() - project.add_source(source_name, origin) - - with TestDir() as test_dir: - project.save(test_dir) - - loaded = Project.load(test_dir) - loaded = loaded.get_source(source_name) - self.assertEqual(origin, loaded) - def test_can_import_with_custom_importer(self): - class TestImporter: - def __call__(self, path, subset=None): - return Project({ - 'project_filename': path, - 'subsets': [ subset ] - }) + project.sources.add(source_name, origin) - path = 'path' - importer_name = 'test_importer' + added = project.sources[source_name] + self.assertEqual(added.format, origin.format) + self.assertEqual(added.options, origin.options) - env = Environment() - env.importers.register(importer_name, TestImporter) - - project = Project.import_from(path, importer_name, env, - subset='train') + def test_can_make_dataset(self): + class CustomExtractor(Extractor): + def __iter__(self): + yield DatasetItem(42) - self.assertEqual(path, project.config.project_filename) - self.assertListEqual(['train'], project.config.subsets) + extractor_name = 'ext1' + project = Project() + project.env.extractors.register(extractor_name, CustomExtractor) + project.sources.add('src1', { 'format': extractor_name }) - def test_can_dump_added_model(self): - model_name = 'model' + dataset = project.make_dataset() - project = Project() - saved = Model({ 'launcher': 'name' }) - project.add_model(model_name, saved) + compare_datasets(self, CustomExtractor(), dataset) + def test_can_save_added_source(self): with TestDir() as test_dir: + project = Project() + project.sources.add('s', { 'format': 'fmt' }) + project.save(test_dir) loaded = Project.load(test_dir) - loaded = loaded.get_model(model_name) - self.assertEqual(saved, loaded) + self.assertEqual('fmt', loaded.sources['s'].format) - def test_can_have_project_source(self): - with TestDir() as test_dir: - Project.generate(test_dir) + def test_can_add_existing_local_model(self): + # Reasons to exist: + # - Backward compatibility + # - In-memory and detached projects - project2 = Project() - project2.add_source('project1', { + with TestDir() as test_dir: + source_name = 'source' + origin = Model({ 'url': test_dir, + 'launcher': 'test', + 'options': { 'a': 5, 'b': 'hello' } }) - dataset = project2.make_dataset() + project = Project() - self.assertTrue('project1' in dataset.sources) + project.models.add(source_name, origin) - def test_can_batch_launch_custom_model(self): - dataset = Dataset.from_iterable([ - DatasetItem(id=i, subset='train', image=np.array([i])) - for i in range(5) - ], categories=['label']) + added = project.models[source_name] + self.assertEqual(added.url, origin.url) + self.assertEqual(added.launcher, origin.launcher) + self.assertEqual(added.options, origin.options) - class TestLauncher(Launcher): - def launch(self, inputs): - for i, inp in enumerate(inputs): - yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ] + def test_cant_add_nonexisting_local_model(self): + project = Project() + + with self.assertRaisesRegex(Exception, 'Can only add an existing'): + project.models.add('m', { 'url': '_p_a_t_h_', 'launcher': 'test' }) + def test_can_add_generated_model(self): model_name = 'model' - launcher_name = 'custom_launcher' + origin = Model({ + # no url + 'launcher': 'test', + 'options': { 'c': 5, 'd': 'hello' } + }) + project = Project() + project.models.add(model_name, origin) + + added = project.models[model_name] + self.assertEqual(added.launcher, origin.launcher) + self.assertEqual(added.options, origin.options) + + def test_can_save_added_model(self): project = Project() - project.env.launchers.register(launcher_name, TestLauncher) - project.add_model(model_name, { 'launcher': launcher_name }) - model = project.make_executable_model(model_name) - batch_size = 3 - executor = ModelTransform(dataset, model, batch_size=batch_size) + saved = Model({ 'launcher': 'test' }) + project.models.add('model', saved) - for item in executor: - self.assertEqual(1, len(item.annotations)) - self.assertEqual(int(item.id) % batch_size, - item.annotations[0].attributes['idx']) - self.assertEqual(int(item.id), - item.annotations[0].attributes['data']) + with TestDir() as test_dir: + project.save(test_dir) - def test_can_do_transform_with_custom_model(self): - class TestExtractorSrc(Extractor): + loaded = Project.load(test_dir) + loaded = loaded.models['model'] + self.assertEqual(saved, loaded) + + def test_can_transform_source_with_model(self): + class TestExtractor(Extractor): def __iter__(self): - for i in range(2): - yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i, - annotations=[Label(i)]) + yield DatasetItem(0, image=np.ones([2, 2, 3]) * 0) + yield DatasetItem(1, image=np.ones([2, 2, 3]) * 1) def categories(self): - label_cat = LabelCategories() - label_cat.add('0') - label_cat.add('1') + label_cat = LabelCategories().from_iterable(['0', '1']) return { AnnotationType.label: label_cat } class TestLauncher(Launcher): @@ -175,233 +182,869 @@ def launch(self, inputs): for inp in inputs: yield [ Label(inp[0, 0, 0]) ] - class TestExtractorDst(Extractor): - def __init__(self, url): - super().__init__() - self.items = [osp.join(url, p) for p in sorted(os.listdir(url))] - - def __iter__(self): - for path in self.items: - with open(path, 'r') as f: - index = osp.splitext(osp.basename(path))[0] - label = int(f.readline().strip()) - yield DatasetItem(id=index, annotations=[Label(label)]) + expected = Dataset.from_iterable([ + DatasetItem(0, image=np.zeros([2, 2, 3]), annotations=[Label(0)]), + DatasetItem(1, image=np.ones([2, 2, 3]), annotations=[Label(1)]) + ], categories=['0', '1']) - model_name = 'model' launcher_name = 'custom_launcher' extractor_name = 'custom_extractor' project = Project() project.env.launchers.register(launcher_name, TestLauncher) - project.env.extractors.register(extractor_name, TestExtractorSrc) - project.add_model(model_name, { 'launcher': launcher_name }) - project.add_source('source', { 'format': extractor_name }) + project.env.extractors.register(extractor_name, TestExtractor) + project.models.add('model', { 'launcher': launcher_name }) + project.sources.add('source', { 'format': extractor_name }) + project.build_targets.add_inference_stage('source', 'model') - with TestDir() as test_dir: - project.make_dataset().apply_model(model=model_name, - save_dir=test_dir) + result = project.make_dataset() - result = Project.load(test_dir) - result.env.extractors.register(extractor_name, TestExtractorDst) - it = iter(result.make_dataset()) - item1 = next(it) - item2 = next(it) - self.assertEqual(0, item1.annotations[0].label) - self.assertEqual(1, item2.annotations[0].label) + compare_datasets(self, expected, result) - def test_source_datasets_can_be_merged(self): + def test_can_filter_source(self): class TestExtractor(Extractor): - def __init__(self, url, n=0, s=0): - super().__init__(length=n) - self.n = n - self.s = s - def __iter__(self): - for i in range(self.n): - yield DatasetItem(id=self.s + i, subset='train') - - e_name1 = 'e1' - e_name2 = 'e2' - n1 = 2 - n2 = 4 + yield DatasetItem(0) + yield DatasetItem(10) + yield DatasetItem(2) + yield DatasetItem(15) project = Project() - project.env.extractors.register(e_name1, lambda p: TestExtractor(p, n=n1)) - project.env.extractors.register(e_name2, lambda p: TestExtractor(p, n=n2, s=n1)) - project.add_source('source1', { 'format': e_name1 }) - project.add_source('source2', { 'format': e_name2 }) + project.env.extractors.register('f', TestExtractor) + project.sources.add('source', { 'format': 'f' }) + project.build_targets.add_filter_stage('source', { + 'expr': '/item[id < 5]' + }) dataset = project.make_dataset() - self.assertEqual(n1 + n2, len(dataset)) + self.assertEqual(2, len(dataset)) - def test_cant_merge_different_categories(self): - class TestExtractor1(Extractor): - def __iter__(self): - return iter([]) + def test_can_detect_and_import(self): + env = Environment() + env.importers.items = {DEFAULT_FORMAT: env.importers[DEFAULT_FORMAT]} + env.extractors.items = {DEFAULT_FORMAT: env.extractors[DEFAULT_FORMAT]} - def categories(self): - return { AnnotationType.label: - LabelCategories.from_iterable(['a', 'b']) } + source_dataset = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ Label(2) ]), + ], categories=['a', 'b', 'c']) - class TestExtractor2(Extractor): - def __iter__(self): - return iter([]) + with TestDir() as test_dir: + source_dataset.save(test_dir) - def categories(self): - return { AnnotationType.label: - LabelCategories.from_iterable(['b', 'a']) } + project = Project.import_from(test_dir, env=env) + imported_dataset = project.make_dataset() - e_name1 = 'e1' - e_name2 = 'e2' + self.assertEqual(next(iter(project.sources.items()))[1].format, + DEFAULT_FORMAT) + compare_datasets(self, source_dataset, imported_dataset) - project = Project() - project.env.extractors.register(e_name1, TestExtractor1) - project.env.extractors.register(e_name2, TestExtractor2) - project.add_source('source1', { 'format': e_name1 }) - project.add_source('source2', { 'format': e_name2 }) - with self.assertRaisesRegex(Exception, "different categories"): - project.make_dataset() +no_vcs_installed = False +try: + import git # pylint: disable=unused-import + import dvc # pylint: disable=unused-import +except ImportError: + no_vcs_installed = True - def test_project_filter_can_be_applied(self): - class TestExtractor(Extractor): - def __iter__(self): - for i in range(10): - yield DatasetItem(id=i, subset='train') +@skipIf(no_vcs_installed, "No VCS modules (Git, DVC) installed") +class AttachedProjectTest(TestCase): + def test_can_create(self): + with TestDir() as test_dir: + Project.generate(save_dir=test_dir) - e_type = 'type' - project = Project() - project.env.extractors.register(e_type, TestExtractor) - project.add_source('source', { 'format': e_type }) + Project.load(test_dir) + + self.assertTrue(osp.isdir(osp.join(test_dir, '.git'))) + self.assertTrue(osp.isdir(osp.join(test_dir, '.dvc'))) - dataset = project.make_dataset().filter('/item[id < 5]') + def test_can_add_source_by_url(self): + with TestDir() as test_dir: + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_base_url, + 'format': 'fmt', + }) + project.save() - self.assertEqual(5, len(dataset)) + source = project.sources['s1'] + self.assertEqual(source.url, '') + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'x', 'y.txt'))) + self.assertTrue(len(source.remote) != 0) + self.assertTrue(source.remote in project.vcs.remotes) - def test_can_save_and_load_own_dataset(self): + def test_can_add_source_with_existing_remote(self): with TestDir() as test_dir: - src_project = Project() - src_dataset = src_project.make_dataset() - item = DatasetItem(id=1) - src_dataset.put(item) - src_dataset.save(test_dir) + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.remotes.add('r1', { 'url': source_base_url }) + project.sources.add('s1', { + 'url': 'remote://r1/x/y.txt', + 'format': 'fmt' + }) + project.save() - loaded_project = Project.load(test_dir) - loaded_dataset = loaded_project.make_dataset() + source = project.sources['s1'] + remote = project.vcs.remotes[source.remote] + self.assertEqual(source.url, 'y.txt') + self.assertEqual(source.remote, 'r1') + self.assertEqual(remote.url, source_base_url) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) - self.assertEqual(list(src_dataset), list(loaded_dataset)) + def test_can_add_generated_source(self): + with TestDir() as test_dir: + source_name = 'source' + origin = Source({ + 'format': 'fmt', + 'options': { 'c': 5, 'd': 'hello' } + }) + project = Project.generate(save_dir=test_dir) - def test_project_own_dataset_can_be_modified(self): - project = Project() - dataset = project.make_dataset() + project.sources.add(source_name, origin) + project.save() - item = DatasetItem(id=1) - dataset.put(item) + added = project.sources[source_name] + self.assertEqual(added.format, origin.format) + self.assertEqual(added.options, origin.options) - self.assertEqual(item, next(iter(dataset))) + def test_can_add_git_source(self): + with TestDir() as test_dir: + git_repo_dir = osp.join(test_dir, 'git_repo') + os.makedirs(git_repo_dir) + GitWrapper.module.Repo.init(git_repo_dir, bare=True) + + git_client_dir = osp.join(test_dir, 'git_client') + os.makedirs(git_client_dir) + repo = GitWrapper.module.Repo.clone_from(git_repo_dir, git_client_dir) + source_dataset = Dataset.from_iterable([ + DatasetItem(1, image=np.ones((2, 4, 3)), annotations=[Label(1)]) + ], categories=['a', 'b']) + source_dataset.save(git_client_dir, save_images=True) + repo.git.add(all=True) + repo.index.commit("Initial commit") + repo.remote().push() + + project = Project.generate(save_dir=osp.join(test_dir, 'proj')) + project.vcs.remotes.add('r1', { + 'url': git_repo_dir, + 'type': 'git', + }) + project.sources.add('s1', { + 'url': 'remote://r1', + 'format': 'datumaro', + }) + project.save() - def test_project_compound_child_can_be_modified_recursively(self): + compare_datasets(self, source_dataset, + Dataset.load(project.sources.work_dir('s1'))) + + def test_can_add_dvc_source(self): with TestDir() as test_dir: - child1 = Project({ - 'project_dir': osp.join(test_dir, 'child1'), + dvc_repo_dir = osp.join(test_dir, 'dvc_repo') + os.makedirs(dvc_repo_dir, exist_ok=True) + git = GitWrapper(dvc_repo_dir) + git.init() + git.commit("Initial commit") + dvc = DvcWrapper(dvc_repo_dir) + dvc.init() + source_dataset = Dataset.from_iterable([ + DatasetItem(1, image=np.ones((2, 4, 3)), annotations=[Label(1)]) + ], categories=['a']) + source_dataset.save(osp.join(dvc_repo_dir, 'ds'), save_images=True) + dvc.add(osp.join(dvc_repo_dir, 'ds')) + dvc.commit(osp.join(dvc_repo_dir, 'ds')) + git.add([], all=True) + git.commit("First") + + project = Project.generate(save_dir=osp.join(test_dir, 'proj')) + project.vcs.remotes.add('r1', { + 'url': dvc_repo_dir, + 'type': 'dvc', + }) + project.sources.add('s1', { + 'url': 'remote://r1', + 'format': 'datumaro', }) - child1.save() + project.save() + + compare_datasets(self, source_dataset, + Dataset.load(project.sources.work_dir('s1'))) + + def test_cant_add_source_with_wrong_name(self): + with TestDir() as test_dir: + project = Project.generate(save_dir=test_dir) + + for name in ['dataset', 'project', 'build', '.any']: + with self.subTest(name=name), \ + self.assertRaisesRegex(ValueError, "Source name"): + project.sources.add(name, { 'format': 'fmt' }) + + def test_can_pull_dir_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x') + source_path = osp.join(source_url, 'y.txt') + os.makedirs(osp.dirname(source_path), exist_ok=True) + with open(source_path, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) + shutil.rmtree(project.sources.work_dir('s1')) + + project.sources.pull('s1') - child2 = Project({ - 'project_dir': osp.join(test_dir, 'child2'), + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) + + def test_can_pull_file_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) + shutil.rmtree(project.sources.work_dir('s1')) + + project.sources.pull('s1') + + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) + + def test_can_pull_source_with_existing_remote_rel_dir(self): + with TestDir() as test_dir: + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + source_file_path2 = osp.join(source_base_url, 'x', 'z.txt') + with open(source_file_path2, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.remotes.add('r1', { 'url': source_base_url }) + project.sources.add('s1', { + 'url': 'remote://r1/x/', + 'format': 'fmt' }) - child2.save() + shutil.rmtree(project.sources.work_dir('s1')) - parent = Project() - parent.add_source('child1', { - 'url': child1.config.project_dir + project.sources.pull('s1') + + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'z.txt'))) + + def test_can_pull_source_with_existing_remote_rel_file(self): + with TestDir() as test_dir: + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + # another file in the remote directory, should not be copied + source_file_path2 = osp.join(source_base_url, 'x', 'z.txt') + with open(source_file_path2, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.remotes.add('r1', { 'url': source_base_url }) + project.sources.add('s1', { + 'url': 'remote://r1/x/y.txt', + 'format': 'fmt' }) - parent.add_source('child2', { - 'url': child2.config.project_dir + shutil.rmtree(project.sources.work_dir('s1')) + + project.sources.pull('s1') + + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) + self.assertFalse(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'z.txt'))) + + def test_can_pull_source_with_existing_remote_root_file(self): + with TestDir() as test_dir: + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.remotes.add('r1', { 'url': source_file_path }) + project.sources.add('s1', { + 'url': 'remote://r1', + 'format': 'fmt' }) - dataset = parent.make_dataset() + shutil.rmtree(project.sources.work_dir('s1')) - item1 = DatasetItem(id='ch1', path=['child1']) - item2 = DatasetItem(id='ch2', path=['child2']) - dataset.put(item1) - dataset.put(item2) + project.sources.pull('s1') - self.assertEqual(2, len(dataset)) - self.assertEqual(1, len(dataset.sources['child1'])) - self.assertEqual(1, len(dataset.sources['child2'])) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) - def test_project_can_merge_item_annotations(self): - class TestExtractor1(Extractor): - def __iter__(self): - yield DatasetItem(id=1, subset='train', annotations=[ - Label(2, id=3), - Label(3, attributes={ 'x': 1 }), - ]) + def test_can_pull_source_with_existing_remote_root_dir(self): + with TestDir() as test_dir: + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + source_file_path2 = osp.join(source_base_url, 'z.txt') + with open(source_file_path2, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.remotes.add('r1', { 'url': source_base_url }) + project.sources.add('s1', { + 'url': 'remote://r1', + 'format': 'fmt' + }) + shutil.rmtree(project.sources.work_dir('s1')) - class TestExtractor2(Extractor): - def __iter__(self): - yield DatasetItem(id=1, subset='train', annotations=[ - Label(3, attributes={ 'x': 1 }), - Label(4, id=4), - ]) + project.sources.pull('s1') - project = Project() - project.env.extractors.register('t1', TestExtractor1) - project.env.extractors.register('t2', TestExtractor2) - project.add_source('source1', { 'format': 't1' }) - project.add_source('source2', { 'format': 't2' }) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'y.txt'))) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), 'z.txt'))) - merged = project.make_dataset() + def test_can_remove_source_and_keep_data(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - self.assertEqual(1, len(merged)) + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) - item = next(iter(merged)) - self.assertEqual(3, len(item.annotations)) + project.sources.remove('s1', keep_data=True) - def test_can_detect_and_import(self): - env = Environment() - env.importers.items = {DEFAULT_FORMAT: env.importers[DEFAULT_FORMAT]} - env.extractors.items = {DEFAULT_FORMAT: env.extractors[DEFAULT_FORMAT]} + self.assertFalse('s1' in project.sources) + self.assertTrue(osp.isfile(osp.join( + project.sources.work_dir('s1'), osp.basename(source_url)))) - source_dataset = Dataset.from_iterable([ - DatasetItem(id=1, annotations=[ Label(2) ]), - ], categories=['a', 'b', 'c']) + def test_can_remove_source_and_wipe_data(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) + project.sources.remove('s1', keep_data=False) + + self.assertFalse('s1' in project.sources) + self.assertFalse(osp.isfile(osp.join( + project.sources.work_dir('s1'), osp.basename(source_url)))) + + def test_can_checkout_source_rev_cached(self): with TestDir() as test_dir: - source_dataset.save(test_dir) + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) + local_source_path = osp.join( + project.sources.work_dir('s1'), osp.basename(source_url)) + project.save() + project.vcs.commit(None, message="First commit") + + with open(local_source_path, 'w') as f: + f.write('world') + project.vcs.commit(None, message="Second commit") + + project.vcs.checkout('HEAD~1', ['s1']) + + self.assertTrue(osp.isfile(local_source_path)) + with open(local_source_path) as f: + self.assertEqual('hello', f.readline().strip()) + + @skip('Source data status checks are not implemented yet') + def test_can_checkout_source_rev_noncached(self): + # Can't detect automatically if there is no cached source version + # in DVC cache, or if checkout produced a mismatching version of data. + # For example: + # a source was transformed without application + # - its stages changed, but files did not + # - it was committed, no changes in source data, + # so no updates in the DVC cache + # checkout produces an outdated version of the source. + # Resolution - source rebuilding + saving source hash in stage info. + raise NotImplementedError() + + def test_can_read_working_copy_of_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = Project.generate(save_dir=osp.join(test_dir, 'proj')) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.save() - project = Project.import_from(test_dir, env=env) - imported_dataset = project.make_dataset() + read_dataset = project.sources.make_dataset('s1') + + compare_datasets(self, source_dataset, read_dataset) + compare_dirs(self, source_url, project.sources.work_dir('s1')) + + def test_can_read_current_revision_of_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = Project.generate(save_dir=osp.join(test_dir, 'proj')) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.save() + + shutil.rmtree(project.sources.work_dir('s1')) + + read_dataset = project.sources.make_dataset('s1', rev='HEAD') + + compare_datasets(self, source_dataset, read_dataset) + self.assertFalse(osp.isdir(project.sources.work_dir('s1'))) + + def test_can_update_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { 'url': source_url }) + project.save() + project.vcs.commit(None, message="First commit") + + with open(source_url, 'w') as f: + f.write('world') + + project.sources.pull('s1') + + local_source_path = osp.join( + project.sources.work_dir('s1'), osp.basename(source_url)) + self.assertTrue(osp.isfile(local_source_path)) + with open(local_source_path) as f: + self.assertEqual('world', f.readline().strip()) + + def test_can_build_project(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.build_targets.add_filter_stage('s1', { 'expr': '/item' }) + + project.build() + + built_dataset = Dataset.load( + osp.join(test_dir, project.config.build_dir)) + compare_datasets(self, dataset, built_dataset) + + def test_cant_build_dirty_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.save() + project.vcs.commit(None, "Added a source") + + os.unlink(osp.join(project.sources.work_dir('s1'), + 'annotations', 'default.json')) - self.assertEqual(next(iter(project.config.sources.values())).format, + with self.assertRaisesRegex(VcsError, "uncommitted changes"): + project.build() + + def test_can_make_dataset_from_project(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + + built_dataset = project.make_dataset() + + compare_datasets(self, dataset, built_dataset) + + def test_can_build_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + + project.build('s1') + + built_dataset = Dataset.load(project.sources.work_dir('s1')) + compare_datasets(self, dataset, built_dataset) + + def test_can_make_dataset_from_source(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.build_targets.add_filter_stage('s1', { 'expr': '/item' }) + + built_dataset = project.make_dataset('s1') + + compare_datasets(self, dataset, built_dataset) + + def test_can_add_stage_directly(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + + project.build_targets.add_stage('s1', { + 'type': BuildStageType.filter.name, + 'params': {'expr': '/item/annotation[label="b"]'}, + }, name='f1') + project.save() + + self.assertTrue('s1.f1' in project.build_targets) + + def test_can_add_filter_stage(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + + _, stage = project.build_targets.add_filter_stage('s1', + params={'expr': '/item/annotation[label="b"]'} + ) + project.save() + + self.assertTrue(stage in project.build_targets) + + def test_can_add_convert_stage(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + + _, stage = project.build_targets.add_convert_stage('s1', DEFAULT_FORMAT) - compare_datasets(self, source_dataset, imported_dataset) + project.save() - def test_custom_extractor_can_be_created(self): - class CustomExtractor(Extractor): - def __iter__(self): - return iter([ - DatasetItem(id=0, subset='train'), - DatasetItem(id=1, subset='train'), - DatasetItem(id=2, subset='train'), + self.assertTrue(stage in project.build_targets) - DatasetItem(id=3, subset='test'), - DatasetItem(id=4, subset='test'), + def test_can_add_transform_stage(self): + class TestTransform(Transform): + def __init__(self, extractor, p1=None, p2=None): + super().__init__(extractor) + self.p1 = p1 + self.p2 = p2 - DatasetItem(id=1), - DatasetItem(id=2), - DatasetItem(id=3), - ]) + def transform_item(self, item): + return item + + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.env.transforms.register('tr', TestTransform) + + _, stage = project.build_targets.add_transform_stage('s1', + 'tr', params={'p1': 5, 'p2': ['1', 2, 3.5]} + ) + project.save() + + self.assertTrue(stage in project.build_targets) + + def test_can_build_stage(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.build_targets.add_stage('s1', { + 'type': BuildStageType.filter.name, + 'params': {'expr': '/item/annotation[label="b"]'}, + }, name='f1') + + project.build('s1.f1', out_dir=osp.join(test_dir, 'test_build')) + + built_dataset = Dataset.load(osp.join(test_dir, 'test_build')) + expected_dataset = Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + compare_datasets(self, expected_dataset, built_dataset) + + def test_can_make_dataset_from_stage(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = Project.generate(save_dir=test_dir) + project.sources.add('s1', { + 'url': source_url, + 'format': DEFAULT_FORMAT, + }) + project.build_targets.add_stage('s1', { + 'type': BuildStageType.filter.name, + 'params': {'expr': '/item/annotation[label="b"]'}, + }, name='f1') + + built_dataset = project.make_dataset('s1.f1') + + expected_dataset = Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + compare_datasets(self, expected_dataset, built_dataset) + + def test_can_commit_repo(self): + with TestDir() as test_dir: + project = Project.generate(save_dir=test_dir) + + project.vcs.commit(None, message="First commit") + + def test_can_checkout_repo(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo', 'x', 'y.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = Project.generate(save_dir=test_dir) + project.vcs.commit(None, message="First commit") + + project.sources.add('s1', { 'url': source_url }) + project.save() + project.vcs.commit(None, message="Second commit") + + project.vcs.checkout('HEAD~1') + + project = Project.load(test_dir) + self.assertFalse('s1' in project.sources) + + def test_can_push_repo(self): + with TestDir() as test_dir: + git_repo_dir = osp.join(test_dir, 'git_repo') + os.makedirs(git_repo_dir, exist_ok=True) + GitWrapper.module.Repo.init(git_repo_dir, bare=True) + + dvc_repo_dir = osp.join(test_dir, 'dvc_repo') + os.makedirs(dvc_repo_dir, exist_ok=True) + git = GitWrapper(dvc_repo_dir) + git.init() + dvc = DvcWrapper(dvc_repo_dir) + dvc.init() + + project = Project.generate(save_dir=osp.join(test_dir, 'proj')) + project.vcs.repositories.add('origin', git_repo_dir) + project.vcs.remotes.add('data', { + 'url': dvc_repo_dir, + 'type': 'dvc', + }) + project.vcs.remotes.set_default('data') + project.save() + project.vcs.commit(None, message="First commit") + + project.vcs.push() + + git = GitWrapper.module.Repo.init(git_repo_dir, bare=True) + self.assertEqual('First commit', next(git.iter_commits()).summary) + + def test_can_tag_repo(self): + with TestDir() as test_dir: + project = Project.generate(save_dir=test_dir) + + project.vcs.commit(None, message="First commit") + project.vcs.tag('r1') + + self.assertEqual(['r1'], project.vcs.tags) + + +class BackwardCompatibilityTests_v0_1(TestCase): + def test_can_load_old_project(self): + expected_dataset = Dataset.from_iterable([ + DatasetItem(0, subset='train', annotations=[Label(0)]), + DatasetItem(1, subset='test', annotations=[Label(1)]), + ], categories=['a', 'b']) + + project_dir = osp.join(osp.dirname(__file__), + 'assets', 'compat', 'v0.1', 'project') + + project = Project.load(project_dir) + loaded_dataset = project.make_dataset() + + compare_datasets(self, expected_dataset, loaded_dataset) + + @skip("Not actual") + def test_project_compound_child_can_be_modified_recursively(self): + with TestDir() as test_dir: + child1 = Project.generate(osp.join(test_dir, 'child1')) + child2 = Project.generate(osp.join(test_dir, 'child2')) + + parent = Project() + parent.sources.add('child1', { + 'url': child1.config.project_dir, + 'format': 'datumaro_project' + }) + parent.sources.add('child2', { + 'url': child2.config.project_dir, + 'format': 'datumaro_project' + }) + dataset = parent.make_dataset() + + item1 = DatasetItem(id='ch1', path=['child1']) + item2 = DatasetItem(id='ch2', path=['child2']) + dataset.put(item1) + dataset.put(item2) + + self.assertEqual(2, len(dataset)) + self.assertEqual(1, len(dataset.sources['child1'])) + self.assertEqual(1, len(dataset.sources['child2'])) + +class ModelsTest(TestCase): + def test_can_batch_launch_custom_model(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=i, subset='train', image=np.array([i])) + for i in range(5) + ], categories=['label']) + + class TestLauncher(Launcher): + def launch(self, inputs): + for i, inp in enumerate(inputs): + yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ] + + model_name = 'model' + launcher_name = 'custom_launcher' - extractor_name = 'ext1' project = Project() - project.env.extractors.register(extractor_name, CustomExtractor) - project.add_source('src1', { - 'url': 'path', - 'format': extractor_name, - }) + project.env.launchers.register(launcher_name, TestLauncher) + project.models.add(model_name, { 'launcher': launcher_name }) + model = project.models.make_executable_model(model_name) - dataset = project.make_dataset() + batch_size = 3 + executor = ModelTransform(dataset, model, batch_size=batch_size) - compare_datasets(self, CustomExtractor(), dataset) + for item in executor: + self.assertEqual(1, len(item.annotations)) + self.assertEqual(int(item.id) % batch_size, + item.annotations[0].attributes['idx']) + self.assertEqual(int(item.id), + item.annotations[0].attributes['data'])