Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datumaro] Dataset format auto detection #1242

Merged
merged 3 commits into from
Mar 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions datumaro/datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser):
help="Overwrite existing files in the save directory")
parser.add_argument('-i', '--input-path', required=True, dest='source',
help="Path to import project from")
parser.add_argument('-f', '--format', required=True,
help="Source project format")
parser.add_argument('-f', '--format',
help="Source project format. Will try to detect, if not specified.")
parser.add_argument('extra_args', nargs=argparse.REMAINDER,
help="Additional arguments for importer (pass '-- -h' for help)")
parser.set_defaults(command=import_command)
Expand Down Expand Up @@ -164,22 +164,53 @@ def import_command(args):
if project_name is None:
project_name = osp.basename(project_dir)

try:
env = Environment()
importer = env.make_importer(args.format)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

extra_args = {}
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
env = Environment()
log.info("Importing project from '%s'" % args.source)

if not args.format:
if args.extra_args:
raise CliException("Extra args can not be used without format")

log.info("Trying to detect dataset format...")

matches = []
for format_name in env.importers.items:
log.debug("Checking '%s' format...", format_name)
importer = env.make_importer(format_name)
try:
match = importer.detect(args.source)
if match:
log.debug("format matched")
matches.append((format_name, importer))
except NotImplementedError:
log.debug("Format '%s' does not support auto detection.",
format_name)

if len(matches) == 0:
log.error("Failed to detect dataset format automatically. "
"Try to specify format with '-f/--format' parameter.")
return 1
elif len(matches) != 1:
log.error("Multiple formats match the dataset: %s. "
"Try to specify format with '-f/--format' parameter.",
', '.join(m[0] for m in matches))
return 2

format_name, importer = matches[0]
args.format = format_name
else:
try:
importer = env.make_importer(args.format)
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

log.info("Importing project from '%s' as '%s'" % \
(args.source, args.format))
log.info("Importing project as '%s'" % args.format)

source = osp.abspath(args.source)
project = importer(source, **extra_args)
project = importer(source, **locals().get('extra_args', {}))
project.config.project_name = project_name
project.config.project_dir = project_dir

Expand Down
4 changes: 4 additions & 0 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,10 @@ class SourceExtractor(Extractor):
pass

class Importer:
@classmethod
def detect(cls, path):
raise NotImplementedError()

def __call__(self, path, **extra_params):
raise NotImplementedError()

Expand Down
8 changes: 7 additions & 1 deletion datumaro/datumaro/plugins/coco_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os.path as osp

from datumaro.components.extractor import Importer
from datumaro.util.log_utils import logging_disabled

from .format import CocoTask, CocoPath

Expand All @@ -22,6 +23,11 @@ class CocoImporter(Importer):
CocoTask.image_info: 'coco_image_info',
}

@classmethod
def detect(cls, path):
with logging_disabled(log.WARN):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()
Expand Down Expand Up @@ -53,7 +59,7 @@ def find_subsets(path):

if osp.basename(osp.normpath(path)) != CocoPath.ANNOTATIONS_DIR:
path = osp.join(path, CocoPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*_*.json'))
subset_paths += glob(osp.join(path, '*_*.json'))

subsets = defaultdict(dict)
for subset_path in subset_paths:
Expand Down
25 changes: 17 additions & 8 deletions datumaro/datumaro/plugins/cvat_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@
class CvatImporter(Importer):
EXTRACTOR_NAME = 'cvat'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.xml') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.xml'))

if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.xml'))
subset_paths = self.find_subsets(path)

if len(subset_paths) == 0:
raise Exception("Failed to find 'cvat' dataset at '%s'" % path)
Expand All @@ -46,3 +43,15 @@ def __call__(self, path, **extra_params):
})

return project

@staticmethod
def find_subsets(path):
if path.endswith('.xml') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.xml'))

if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.xml'))
return subset_paths
26 changes: 17 additions & 9 deletions datumaro/datumaro/plugins/datumaro_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
class DatumaroImporter(Importer):
EXTRACTOR_NAME = 'datumaro'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.json'))

if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.json'))

subset_paths = self.find_subsets(path)
if len(subset_paths) == 0:
raise Exception("Failed to find 'datumaro' dataset at '%s'" % path)

Expand All @@ -46,3 +42,15 @@ def __call__(self, path, **extra_params):
})

return project

@staticmethod
def find_subsets(path):
if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.json'))

if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.json'))
return subset_paths
17 changes: 12 additions & 5 deletions datumaro/datumaro/plugins/tf_detection_api_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
class TfDetectionApiImporter(Importer):
EXTRACTOR_NAME = 'tf_detection_api'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.tfrecord') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.tfrecord'))

subset_paths = self.find_subsets(path)
if len(subset_paths) == 0:
raise Exception(
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
Expand All @@ -42,3 +42,10 @@ def __call__(self, path, **extra_params):

return project

@staticmethod
def find_subsets(path):
if path.endswith('.tfrecord') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.tfrecord'))
return subset_paths
18 changes: 15 additions & 3 deletions datumaro/datumaro/plugins/voc_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def save_subsets(self):
self.save_segm_lists(subset_name, segm_list)

def save_action_lists(self, subset_name, action_list):
if not action_list:
return

os.makedirs(self._action_subsets_dir, exist_ok=True)

ann_file = osp.join(self._action_subsets_dir, subset_name + '.txt')
Expand All @@ -343,11 +346,11 @@ def save_action_lists(self, subset_name, action_list):
(item, 1 + obj_id, 1 if presented else -1))

def save_class_lists(self, subset_name, class_lists):
os.makedirs(self._cls_subsets_dir, exist_ok=True)

if len(class_lists) == 0:
if not class_lists:
return

os.makedirs(self._cls_subsets_dir, exist_ok=True)

for label in self._label_map:
ann_file = osp.join(self._cls_subsets_dir,
'%s_%s.txt' % (label, subset_name))
Expand All @@ -361,6 +364,9 @@ def save_class_lists(self, subset_name, class_lists):
f.write('%s % d\n' % (item, 1 if presented else -1))

def save_clsdet_lists(self, subset_name, clsdet_list):
if not clsdet_list:
return

os.makedirs(self._cls_subsets_dir, exist_ok=True)

ann_file = osp.join(self._cls_subsets_dir, subset_name + '.txt')
Expand All @@ -369,6 +375,9 @@ def save_clsdet_lists(self, subset_name, clsdet_list):
f.write('%s\n' % item)

def save_segm_lists(self, subset_name, segm_list):
if not segm_list:
return

os.makedirs(self._segm_subsets_dir, exist_ok=True)

ann_file = osp.join(self._segm_subsets_dir, subset_name + '.txt')
Expand All @@ -377,6 +386,9 @@ def save_segm_lists(self, subset_name, segm_list):
f.write('%s\n' % item)

def save_layout_lists(self, subset_name, layout_list):
if not layout_list:
return

os.makedirs(self._layout_subsets_dir, exist_ok=True)

ann_file = osp.join(self._layout_subsets_dir, subset_name + '.txt')
Expand Down
Loading