Skip to content

Commit

Permalink
[Datumaro] Dataset format auto detection (#1242)
Browse files Browse the repository at this point in the history
* Add dataset format detection

* Add auto format detection for import

* Split VOC extractor
  • Loading branch information
zhiltsov-max authored Mar 7, 2020
1 parent 24130cd commit 2ebca5b
Show file tree
Hide file tree
Showing 17 changed files with 572 additions and 857 deletions.
61 changes: 46 additions & 15 deletions datumaro/datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def build_import_parser(parser_ctor=argparse.ArgumentParser):
help="Overwrite existing files in the save directory")
parser.add_argument('-i', '--input-path', required=True, dest='source',
help="Path to import project from")
parser.add_argument('-f', '--format', required=True,
help="Source project format")
parser.add_argument('-f', '--format',
help="Source project format. Will try to detect, if not specified.")
parser.add_argument('extra_args', nargs=argparse.REMAINDER,
help="Additional arguments for importer (pass '-- -h' for help)")
parser.set_defaults(command=import_command)
Expand Down Expand Up @@ -164,22 +164,53 @@ def import_command(args):
if project_name is None:
project_name = osp.basename(project_dir)

try:
env = Environment()
importer = env.make_importer(args.format)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

extra_args = {}
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
env = Environment()
log.info("Importing project from '%s'" % args.source)

if not args.format:
if args.extra_args:
raise CliException("Extra args can not be used without format")

log.info("Trying to detect dataset format...")

matches = []
for format_name in env.importers.items:
log.debug("Checking '%s' format...", format_name)
importer = env.make_importer(format_name)
try:
match = importer.detect(args.source)
if match:
log.debug("format matched")
matches.append((format_name, importer))
except NotImplementedError:
log.debug("Format '%s' does not support auto detection.",
format_name)

if len(matches) == 0:
log.error("Failed to detect dataset format automatically. "
"Try to specify format with '-f/--format' parameter.")
return 1
elif len(matches) != 1:
log.error("Multiple formats match the dataset: %s. "
"Try to specify format with '-f/--format' parameter.",
', '.join(m[0] for m in matches))
return 2

format_name, importer = matches[0]
args.format = format_name
else:
try:
importer = env.make_importer(args.format)
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

log.info("Importing project from '%s' as '%s'" % \
(args.source, args.format))
log.info("Importing project as '%s'" % args.format)

source = osp.abspath(args.source)
project = importer(source, **extra_args)
project = importer(source, **locals().get('extra_args', {}))
project.config.project_name = project_name
project.config.project_dir = project_dir

Expand Down
4 changes: 4 additions & 0 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,10 @@ class SourceExtractor(Extractor):
pass

class Importer:
@classmethod
def detect(cls, path):
raise NotImplementedError()

def __call__(self, path, **extra_params):
raise NotImplementedError()

Expand Down
8 changes: 7 additions & 1 deletion datumaro/datumaro/plugins/coco_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os.path as osp

from datumaro.components.extractor import Importer
from datumaro.util.log_utils import logging_disabled

from .format import CocoTask, CocoPath

Expand All @@ -22,6 +23,11 @@ class CocoImporter(Importer):
CocoTask.image_info: 'coco_image_info',
}

@classmethod
def detect(cls, path):
with logging_disabled(log.WARN):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()
Expand Down Expand Up @@ -53,7 +59,7 @@ def find_subsets(path):

if osp.basename(osp.normpath(path)) != CocoPath.ANNOTATIONS_DIR:
path = osp.join(path, CocoPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*_*.json'))
subset_paths += glob(osp.join(path, '*_*.json'))

subsets = defaultdict(dict)
for subset_path in subset_paths:
Expand Down
25 changes: 17 additions & 8 deletions datumaro/datumaro/plugins/cvat_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@
class CvatImporter(Importer):
EXTRACTOR_NAME = 'cvat'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.xml') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.xml'))

if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.xml'))
subset_paths = self.find_subsets(path)

if len(subset_paths) == 0:
raise Exception("Failed to find 'cvat' dataset at '%s'" % path)
Expand All @@ -46,3 +43,15 @@ def __call__(self, path, **extra_params):
})

return project

@staticmethod
def find_subsets(path):
if path.endswith('.xml') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.xml'))

if osp.basename(osp.normpath(path)) != CvatPath.ANNOTATIONS_DIR:
path = osp.join(path, CvatPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.xml'))
return subset_paths
26 changes: 17 additions & 9 deletions datumaro/datumaro/plugins/datumaro_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
class DatumaroImporter(Importer):
EXTRACTOR_NAME = 'datumaro'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.json'))

if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.json'))

subset_paths = self.find_subsets(path)
if len(subset_paths) == 0:
raise Exception("Failed to find 'datumaro' dataset at '%s'" % path)

Expand All @@ -46,3 +42,15 @@ def __call__(self, path, **extra_params):
})

return project

@staticmethod
def find_subsets(path):
if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.json'))

if osp.basename(osp.normpath(path)) != DatumaroPath.ANNOTATIONS_DIR:
path = osp.join(path, DatumaroPath.ANNOTATIONS_DIR)
subset_paths += glob(osp.join(path, '*.json'))
return subset_paths
17 changes: 12 additions & 5 deletions datumaro/datumaro/plugins/tf_detection_api_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
class TfDetectionApiImporter(Importer):
EXTRACTOR_NAME = 'tf_detection_api'

@classmethod
def detect(cls, path):
return len(cls.find_subsets(path)) != 0

def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()

if path.endswith('.tfrecord') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.tfrecord'))

subset_paths = self.find_subsets(path)
if len(subset_paths) == 0:
raise Exception(
"Failed to find 'tf_detection_api' dataset at '%s'" % path)
Expand All @@ -42,3 +42,10 @@ def __call__(self, path, **extra_params):

return project

@staticmethod
def find_subsets(path):
if path.endswith('.tfrecord') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '*.tfrecord'))
return subset_paths
18 changes: 15 additions & 3 deletions datumaro/datumaro/plugins/voc_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ def save_subsets(self):
self.save_segm_lists(subset_name, segm_list)

def save_action_lists(self, subset_name, action_list):
if not action_list:
return

os.makedirs(self._action_subsets_dir, exist_ok=True)

ann_file = osp.join(self._action_subsets_dir, subset_name + '.txt')
Expand All @@ -342,11 +345,11 @@ def save_action_lists(self, subset_name, action_list):
(item, 1 + obj_id, 1 if presented else -1))

def save_class_lists(self, subset_name, class_lists):
os.makedirs(self._cls_subsets_dir, exist_ok=True)

if len(class_lists) == 0:
if not class_lists:
return

os.makedirs(self._cls_subsets_dir, exist_ok=True)

for label in self._label_map:
ann_file = osp.join(self._cls_subsets_dir,
'%s_%s.txt' % (label, subset_name))
Expand All @@ -360,6 +363,9 @@ def save_class_lists(self, subset_name, class_lists):
f.write('%s % d\n' % (item, 1 if presented else -1))

def save_clsdet_lists(self, subset_name, clsdet_list):
if not clsdet_list:
return

os.makedirs(self._cls_subsets_dir, exist_ok=True)

ann_file = osp.join(self._cls_subsets_dir, subset_name + '.txt')
Expand All @@ -368,6 +374,9 @@ def save_clsdet_lists(self, subset_name, clsdet_list):
f.write('%s\n' % item)

def save_segm_lists(self, subset_name, segm_list):
if not segm_list:
return

os.makedirs(self._segm_subsets_dir, exist_ok=True)

ann_file = osp.join(self._segm_subsets_dir, subset_name + '.txt')
Expand All @@ -376,6 +385,9 @@ def save_segm_lists(self, subset_name, segm_list):
f.write('%s\n' % item)

def save_layout_lists(self, subset_name, layout_list):
if not layout_list:
return

os.makedirs(self._layout_subsets_dir, exist_ok=True)

ann_file = osp.join(self._layout_subsets_dir, subset_name + '.txt')
Expand Down
Loading

0 comments on commit 2ebca5b

Please sign in to comment.