Skip to content

Commit

Permalink
[Datumaro] Extract common extractor functionality (#1319)
Browse files Browse the repository at this point in the history
* Extract common extractor functionality

* Simplify coco extractor

* Fix tfrecord
  • Loading branch information
zhiltsov-max authored Mar 26, 2020
1 parent 3f4d6fc commit 6a4ccde
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 167 deletions.
17 changes: 16 additions & 1 deletion datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,22 @@ def select(self, pred):


class SourceExtractor(Extractor):
pass
def __init__(self, length=None, subset=None):
super().__init__(length=length)

if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset

def subsets(self):
if self._subset:
return [self._subset]
return None

def get_subset(self, name):
if name != self._subset:
return None
return self

class Importer:
@classmethod
Expand Down
65 changes: 25 additions & 40 deletions datumaro/datumaro/plugins/coco_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@

class _CocoExtractor(SourceExtractor):
def __init__(self, path, task, merge_instance_polygons=False):
super().__init__()

assert osp.isfile(path)
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
self._path = rootpath
assert osp.isfile(path), path

subset = osp.splitext(osp.basename(path))[0].rsplit('_', maxsplit=1)[1]
super().__init__(subset=subset)

rootpath = ''
if path.endswith(osp.join(CocoPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CocoPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, CocoPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, CocoPath.IMAGES_DIR)
if osp.isdir(osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)):
images_dir = osp.join(images_dir, subset or DEFAULT_SUBSET_NAME)
self._images_dir = images_dir
self._task = task

subset = osp.splitext(osp.basename(path))[0] \
.rsplit('_', maxsplit=1)[1]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset

self._merge_instance_polygons = merge_instance_polygons

loader = self._make_subset_loader(path)
Expand All @@ -51,16 +54,6 @@ def __iter__(self):
def __len__(self):
return len(self._items)

def subsets(self):
if self._subset:
return [self._subset]
return None

def get_subset(self, name):
if name != self._subset:
return None
return self

@staticmethod
def _make_subset_loader(path):
# COCO API has an 'unclosed file' warning
Expand Down Expand Up @@ -117,9 +110,7 @@ def _load_items(self, loader):

for img_id in loader.getImgIds():
image_info = loader.loadImgs(img_id)[0]
image_path = self._find_image(image_info['file_name'])
if not image_path:
image_path = image_info['file_name']
image_path = osp.join(self._images_dir, image_info['file_name'])
image_size = (image_info.get('height'), image_info.get('width'))
if all(image_size):
image_size = (int(image_size[0]), int(image_size[1]))
Expand Down Expand Up @@ -232,33 +223,27 @@ def _load_annotations(self, ann, image_info=None):

return parsed_annotations

def _find_image(self, file_name):
images_dir = osp.join(self._path, CocoPath.IMAGES_DIR)
search_paths = [
osp.join(images_dir, file_name),
osp.join(images_dir, self._subset or DEFAULT_SUBSET_NAME, file_name),
]
for image_path in search_paths:
if osp.exists(image_path):
return image_path
return None

class CocoImageInfoExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.image_info, **kwargs)
kwargs['task'] = CocoTask.image_info
super().__init__(path, **kwargs)

class CocoCaptionsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.captions, **kwargs)
kwargs['task'] = CocoTask.captions
super().__init__(path, **kwargs)

class CocoInstancesExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.instances, **kwargs)
kwargs['task'] = CocoTask.instances
super().__init__(path, **kwargs)

class CocoPersonKeypointsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.person_keypoints, **kwargs)
kwargs['task'] = CocoTask.person_keypoints
super().__init__(path, **kwargs)

class CocoLabelsExtractor(_CocoExtractor):
def __init__(self, path, **kwargs):
super().__init__(path, task=CocoTask.labels, **kwargs)
kwargs['task'] = CocoTask.labels
super().__init__(path, **kwargs)
32 changes: 5 additions & 27 deletions datumaro/datumaro/plugins/cvat_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import os.path as osp
from defusedxml import ElementTree

from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Points, Polygon, PolyLine, Bbox, Label,
LabelCategories
)
Expand All @@ -21,9 +20,7 @@ class CvatExtractor(SourceExtractor):
_SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points')

def __init__(self, path):
super().__init__()

assert osp.isfile(path)
assert osp.isfile(path), path
rootpath = ''
if path.endswith(osp.join(CvatPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(CvatPath.ANNOTATIONS_DIR, maxsplit=1)[0]
Expand All @@ -33,10 +30,7 @@ def __init__(self, path):
self._images_dir = images_dir
self._path = path

subset = osp.splitext(osp.basename(path))[0]
if subset == DEFAULT_SUBSET_NAME:
subset = None
self._subset = subset
super().__init__(subset=osp.splitext(osp.basename(path))[0])

items, categories = self._parse(path)
self._items = self._load_items(items)
Expand All @@ -52,16 +46,6 @@ def __iter__(self):
def __len__(self):
return len(self._items)

def subsets(self):
if self._subset:
return [self._subset]
return None

def get_subset(self, name):
if name != self._subset:
return None
return self

@classmethod
def _parse(cls, path):
context = ElementTree.iterparse(path, events=("start", "end"))
Expand Down Expand Up @@ -342,14 +326,8 @@ def _load_items(self, parsed):
def _find_image(self, file_name):
search_paths = []
if self._images_dir:
search_paths += [
osp.join(self._images_dir, file_name),
osp.join(self._images_dir, self._subset or DEFAULT_SUBSET_NAME,
file_name),
]
search_paths += [
osp.join(osp.dirname(self._path), file_name)
]
search_paths += [ osp.join(self._images_dir, file_name) ]
search_paths += [ osp.join(osp.dirname(self._path), file_name) ]
for image_path in search_paths:
if osp.isfile(image_path):
return image_path
Expand Down
35 changes: 12 additions & 23 deletions datumaro/datumaro/plugins/datumaro_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import json
import os.path as osp

from datumaro.components.extractor import (SourceExtractor,
DEFAULT_SUBSET_NAME, DatasetItem,
from datumaro.components.extractor import (SourceExtractor, DatasetItem,
AnnotationType, Label, RleMask, Points, Polygon, PolyLine, Bbox, Caption,
LabelCategories, MaskCategories, PointsCategories
)
Expand All @@ -18,16 +17,16 @@

class DatumaroExtractor(SourceExtractor):
def __init__(self, path):
super().__init__()
assert osp.isfile(path), path
rootpath = ''
if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))):
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
images_dir = ''
if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)):
images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR)
self._images_dir = images_dir

assert osp.isfile(path)
rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
self._path = rootpath

subset_name = osp.splitext(osp.basename(path))[0]
if subset_name == DEFAULT_SUBSET_NAME:
subset_name = None
self._subset_name = subset_name
super().__init__(subset=osp.splitext(osp.basename(path))[0])

with open(path, 'r') as f:
parsed_anns = json.load(f)
Expand All @@ -44,16 +43,6 @@ def __iter__(self):
def __len__(self):
return len(self._items)

def subsets(self):
if self._subset_name:
return [self._subset_name]
return None

def get_subset(self, name):
if name != self._subset_name:
return None
return self

@staticmethod
def _load_categories(parsed):
categories = {}
Expand Down Expand Up @@ -95,13 +84,13 @@ def _load_items(self, parsed):
image = None
image_info = item_desc.get('image', {})
if image_info:
image_path = osp.join(self._path, DatumaroPath.IMAGES_DIR,
image_path = osp.join(self._images_dir,
image_info.get('path', '')) # relative or absolute fits
image = Image(path=image_path, size=image_info.get('size'))

annotations = self._load_annotations(item_desc)

item = DatasetItem(id=item_id, subset=self._subset_name,
item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image)

items.append(item)
Expand Down
7 changes: 1 addition & 6 deletions datumaro/datumaro/plugins/image_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ImageDirExtractor(SourceExtractor):
def __init__(self, url):
super().__init__()

assert osp.isdir(url)
assert osp.isdir(url), url

items = []
for name in os.listdir(url):
Expand All @@ -52,18 +52,13 @@ def __init__(self, url):
items = OrderedDict(items)
self._items = items

self._subsets = None

def __iter__(self):
for item in self._items.values():
yield item

def __len__(self):
return len(self._items)

def subsets(self):
return self._subsets

def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()
Expand Down
18 changes: 2 additions & 16 deletions datumaro/datumaro/plugins/labelme_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ class LabelMePath:

class LabelMeExtractor(SourceExtractor):
def __init__(self, path, subset_name=None):
super().__init__()

assert osp.isdir(path)
self._rootdir = path

self._subset = subset_name
assert osp.isdir(path), path
super().__init__(subset=subset_name)

items, categories = self._parse(path)
self._categories = categories
Expand All @@ -47,16 +43,6 @@ def __iter__(self):
def __len__(self):
return len(self._items)

def subsets(self):
if self._subset:
return [self._subset]
return None

def get_subset(self, name):
if name != self._subset:
return None
return self

def _parse(self, path):
categories = {
AnnotationType.label: LabelCategories(attributes={
Expand Down
14 changes: 0 additions & 14 deletions datumaro/datumaro/plugins/mot_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def __init__(self, path, labels=None, occlusion_threshold=0, is_gt=None):
super().__init__()

assert osp.isfile(path)
self._path = path
seq_root = osp.dirname(osp.dirname(path))

self._image_dir = ''
if osp.isdir(osp.join(seq_root, MotPath.IMAGE_DIR)):
self._image_dir = osp.join(seq_root, MotPath.IMAGE_DIR)
Expand All @@ -91,8 +89,6 @@ def __init__(self, path, labels=None, occlusion_threshold=0, is_gt=None):
is_gt = True
self._is_gt = is_gt

self._subset = None

if labels is None:
if osp.isfile(osp.join(seq_root, MotPath.LABELS_FILE)):
labels = osp.join(seq_root, MotPath.LABELS_FILE)
Expand All @@ -117,16 +113,6 @@ def __iter__(self):
def __len__(self):
return len(self._items)

def subsets(self):
if self._subset:
return [self._subset]
return None

def get_subset(self, name):
if name != self._subset:
return None
return self

@staticmethod
def _parse_labels(path):
with open(path, encoding='utf-8') as labels_file:
Expand Down
Loading

0 comments on commit 6a4ccde

Please sign in to comment.