-
Notifications
You must be signed in to change notification settings - Fork 137
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add KITTI detection and segmentation formats (#282)
* Add KITTI detection and segmentation formats * Remove unused import * Add KITTI user manual Co-authored-by: Maxim Zhiltsov <[email protected]>
- Loading branch information
Zoya Maslova
and
Maxim Zhiltsov
authored
Jun 9, 2021
1 parent
54e21bf
commit c536b07
Showing
21 changed files
with
1,241 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
|
||
# Copyright (C) 2021 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import logging as log | ||
import os | ||
import os.path as osp | ||
from collections import OrderedDict | ||
from enum import Enum | ||
|
||
import numpy as np | ||
|
||
from datumaro.components.converter import Converter | ||
from datumaro.components.extractor import (AnnotationType, | ||
CompiledMask, LabelCategories) | ||
from datumaro.util import find, parse_str_enum_value, str_to_bool, cast | ||
from datumaro.util.annotation_util import make_label_id_mapping | ||
from datumaro.util.image import save_image | ||
from datumaro.util.mask_tools import paint_mask | ||
|
||
from .format import (KittiTask, KittiPath, KittiLabelMap, | ||
make_kitti_categories, make_kitti_detection_categories, | ||
parse_label_map, write_label_map, | ||
) | ||
|
||
LabelmapType = Enum('LabelmapType', ['kitti', 'source']) | ||
|
||
class KittiConverter(Converter): | ||
DEFAULT_IMAGE_EXT = KittiPath.IMAGE_EXT | ||
|
||
@staticmethod | ||
def _split_tasks_string(s): | ||
return [KittiTask[i.strip().lower()] for i in s.split(',')] | ||
|
||
@staticmethod | ||
def _get_labelmap(s): | ||
if osp.isfile(s): | ||
return s | ||
try: | ||
return LabelmapType[s.lower()].name | ||
except KeyError: | ||
import argparse | ||
raise argparse.ArgumentTypeError() | ||
|
||
@classmethod | ||
def build_cmdline_parser(cls, **kwargs): | ||
parser = super().build_cmdline_parser(**kwargs) | ||
|
||
parser.add_argument('--apply-colormap', type=str_to_bool, default=True, | ||
help="Use colormap for class masks (default: %(default)s)") | ||
parser.add_argument('--label-map', type=cls._get_labelmap, default=None, | ||
help="Labelmap file path or one of %s" % \ | ||
', '.join(t.name for t in LabelmapType)) | ||
parser.add_argument('--tasks', type=cls._split_tasks_string, | ||
help="KITTI task filter, comma-separated list of {%s} " | ||
"(default: all)" % ', '.join(t.name for t in KittiTask)) | ||
return parser | ||
|
||
def __init__(self, extractor, save_dir, | ||
tasks=None, apply_colormap=True, allow_attributes=True, | ||
label_map=None, **kwargs): | ||
super().__init__(extractor, save_dir, **kwargs) | ||
|
||
assert tasks is None or isinstance(tasks, (KittiTask, list, set)) | ||
if tasks is None: | ||
tasks = set(KittiTask) | ||
elif isinstance(tasks, KittiTask): | ||
tasks = {tasks} | ||
else: | ||
tasks = set(parse_str_enum_value(t, KittiTask) for t in tasks) | ||
self._tasks = tasks | ||
|
||
self._apply_colormap = apply_colormap | ||
|
||
if label_map is None: | ||
label_map = LabelmapType.source.name | ||
if KittiTask.segmentation in self._tasks: | ||
self._load_categories(label_map) | ||
elif KittiTask.detection in self._tasks: | ||
self._load_detection_categories() | ||
|
||
def apply(self): | ||
os.makedirs(self._save_dir, exist_ok=True) | ||
|
||
for subset_name, subset in self._extractor.subsets().items(): | ||
if KittiTask.segmentation in self._tasks: | ||
os.makedirs(osp.join(self._save_dir, subset_name, | ||
KittiPath.INSTANCES_DIR), exist_ok=True) | ||
|
||
for item in subset: | ||
if self._save_images: | ||
self._save_image(item, | ||
subdir=osp.join(subset_name, KittiPath.IMAGES_DIR)) | ||
|
||
masks = [a for a in item.annotations | ||
if a.type == AnnotationType.mask] | ||
if masks and KittiTask.segmentation in self._tasks: | ||
compiled_class_mask = CompiledMask.from_instance_masks(masks, | ||
instance_labels=[self._label_id_mapping(m.label) | ||
for m in masks]) | ||
color_mask_path = osp.join(subset_name, | ||
KittiPath.SEMANTIC_RGB_DIR, item.id + KittiPath.MASK_EXT) | ||
self.save_mask(osp.join(self._save_dir, color_mask_path), | ||
compiled_class_mask.class_mask) | ||
|
||
labelids_mask_path = osp.join(subset_name, | ||
KittiPath.SEMANTIC_DIR, item.id + KittiPath.MASK_EXT) | ||
self.save_mask(osp.join(self._save_dir, labelids_mask_path), | ||
compiled_class_mask.class_mask, apply_colormap=False, | ||
dtype=np.int32) | ||
|
||
# TODO: optimize second merging | ||
compiled_instance_mask = CompiledMask.from_instance_masks(masks, | ||
instance_labels=[(m.label << 8) + m.id for m in masks]) | ||
inst_path = osp.join(subset_name, | ||
KittiPath.INSTANCES_DIR, item.id + KittiPath.MASK_EXT) | ||
self.save_mask(osp.join(self._save_dir, inst_path), | ||
compiled_instance_mask.class_mask, apply_colormap=False, | ||
dtype=np.int32) | ||
|
||
bboxes = [a for a in item.annotations | ||
if a.type == AnnotationType.bbox] | ||
if bboxes and KittiTask.detection in self._tasks: | ||
labels_file = osp.join(self._save_dir, subset_name, | ||
KittiPath.LABELS_DIR, '%s.txt' % item.id) | ||
os.makedirs(osp.dirname(labels_file), exist_ok=True) | ||
with open(labels_file, 'w', encoding='utf-8') as f: | ||
for bbox in bboxes: | ||
label_line = [-1] * 15 | ||
label_line[0] = self.get_label(bbox.label) | ||
label_line[1] = cast(bbox.attributes.get('truncated'), | ||
float, KittiPath.DEFAULT_TRUNCATED) | ||
label_line[2] = cast(bbox.attributes.get('occluded'), | ||
int, KittiPath.DEFAULT_OCCLUDED) | ||
x, y, h, w = bbox.get_bbox() | ||
label_line[4:8] = x, y, x + h, y + w | ||
|
||
label_line = ' '.join(str(v) for v in label_line) | ||
f.write('%s\n' % label_line) | ||
|
||
if KittiTask.segmentation in self._tasks: | ||
self.save_label_map() | ||
|
||
def get_label(self, label_id): | ||
return self._extractor. \ | ||
categories()[AnnotationType.label].items[label_id].name | ||
|
||
def save_label_map(self): | ||
path = osp.join(self._save_dir, KittiPath.LABELMAP_FILE) | ||
write_label_map(path, self._label_map) | ||
|
||
def _load_categories(self, label_map_source): | ||
if label_map_source == LabelmapType.kitti.name: | ||
# use the default KITTI colormap | ||
label_map = KittiLabelMap | ||
|
||
elif label_map_source == LabelmapType.source.name and \ | ||
AnnotationType.mask not in self._extractor.categories(): | ||
# generate colormap for input labels | ||
labels = self._extractor.categories() \ | ||
.get(AnnotationType.label, LabelCategories()) | ||
label_map = OrderedDict((item.name, None) | ||
for item in labels.items) | ||
|
||
elif label_map_source == LabelmapType.source.name and \ | ||
AnnotationType.mask in self._extractor.categories(): | ||
# use source colormap | ||
labels = self._extractor.categories()[AnnotationType.label] | ||
colors = self._extractor.categories()[AnnotationType.mask] | ||
label_map = OrderedDict() | ||
for idx, item in enumerate(labels.items): | ||
color = colors.colormap.get(idx) | ||
if color is not None: | ||
label_map[item.name] = color | ||
|
||
elif isinstance(label_map_source, dict): | ||
label_map = OrderedDict( | ||
sorted(label_map_source.items(), key=lambda e: e[0])) | ||
|
||
elif isinstance(label_map_source, str) and osp.isfile(label_map_source): | ||
label_map = parse_label_map(label_map_source) | ||
|
||
else: | ||
raise Exception("Wrong labelmap specified, " | ||
"expected one of %s or a file path" % \ | ||
', '.join(t.name for t in LabelmapType)) | ||
|
||
self._categories = make_kitti_categories(label_map) | ||
self._label_map = label_map | ||
self._label_id_mapping = self._make_label_id_map() | ||
|
||
def _load_detection_categories(self): | ||
self._categories = make_kitti_detection_categories() | ||
|
||
def _make_label_id_map(self): | ||
map_id, id_mapping, src_labels, dst_labels = make_label_id_mapping( | ||
self._extractor.categories().get(AnnotationType.label), | ||
self._categories[AnnotationType.label]) | ||
|
||
void_labels = [src_label for src_id, src_label in src_labels.items() | ||
if src_label not in dst_labels] | ||
if void_labels: | ||
log.warning("The following labels are remapped to background: %s" % | ||
', '.join(void_labels)) | ||
log.debug("Saving segmentations with the following label mapping: \n%s" % | ||
'\n'.join(["#%s '%s' -> #%s '%s'" % | ||
( | ||
src_id, src_label, id_mapping[src_id], | ||
self._categories[AnnotationType.label] \ | ||
.items[id_mapping[src_id]].name | ||
) | ||
for src_id, src_label in src_labels.items() | ||
]) | ||
) | ||
|
||
return map_id | ||
|
||
def save_mask(self, path, mask, colormap=None, apply_colormap=True, | ||
dtype=np.uint8): | ||
if self._apply_colormap and apply_colormap: | ||
if colormap is None: | ||
colormap = self._categories[AnnotationType.mask].colormap | ||
mask = paint_mask(mask, colormap) | ||
save_image(path, mask, create_dir=True, dtype=dtype) | ||
|
||
class KittiSegmentationConverter(KittiConverter): | ||
def __init__(self, *args, **kwargs): | ||
kwargs['tasks'] = KittiTask.segmentation | ||
super().__init__(*args, **kwargs) | ||
|
||
class KittiDetectionConverter(KittiConverter): | ||
def __init__(self, *args, **kwargs): | ||
kwargs['tasks'] = KittiTask.detection | ||
super().__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (C) 2021 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import os.path as osp | ||
|
||
import numpy as np | ||
|
||
from datumaro.components.extractor import (SourceExtractor, | ||
AnnotationType, DatasetItem, Mask, Bbox) | ||
from datumaro.util.image import load_image, find_images | ||
|
||
from .format import ( | ||
KittiTask, KittiPath, KittiLabelMap, parse_label_map, | ||
make_kitti_categories, make_kitti_detection_categories | ||
) | ||
|
||
class _KittiExtractor(SourceExtractor): | ||
def __init__(self, path, task, subset=None): | ||
assert osp.isdir(path), path | ||
self._path = path | ||
self._task = task | ||
|
||
if not subset: | ||
subset = osp.splitext(osp.basename(path))[0] | ||
self._subset = subset | ||
super().__init__(subset=subset) | ||
|
||
self._categories = self._load_categories(osp.dirname(self._path)) | ||
self._items = list(self._load_items().values()) | ||
|
||
def _load_categories(self, path): | ||
if self._task == KittiTask.segmentation: | ||
return self._load_categories_segmentation(path) | ||
elif self._task == KittiTask.detection: | ||
return make_kitti_detection_categories() | ||
|
||
def _load_categories_segmentation(self, path): | ||
label_map = None | ||
label_map_path = osp.join(path, KittiPath.LABELMAP_FILE) | ||
if osp.isfile(label_map_path): | ||
label_map = parse_label_map(label_map_path) | ||
else: | ||
label_map = KittiLabelMap | ||
self._labels = [label for label in label_map] | ||
return make_kitti_categories(label_map) | ||
|
||
def _load_items(self): | ||
items = {} | ||
|
||
image_dir = osp.join(self._path, KittiPath.IMAGES_DIR) | ||
for image_path in find_images(image_dir, recursive=True): | ||
image_name = osp.relpath(image_path, image_dir) | ||
sample_id = osp.splitext(image_name)[0] | ||
anns = [] | ||
|
||
instances_path = osp.join(self._path, KittiPath.INSTANCES_DIR, | ||
sample_id + KittiPath.MASK_EXT) | ||
if self._task == KittiTask.segmentation and \ | ||
osp.isfile(instances_path): | ||
instances_mask = load_image(instances_path, dtype=np.int32) | ||
segm_ids = np.unique(instances_mask) | ||
for segm_id in segm_ids: | ||
semantic_id = segm_id >> 8 | ||
ann_id = int(segm_id % 256) | ||
isCrowd = (ann_id == 0) | ||
anns.append(Mask( | ||
image=self._lazy_extract_mask(instances_mask, segm_id), | ||
label=semantic_id, id=ann_id, | ||
attributes={ 'is_crowd': isCrowd })) | ||
|
||
labels_path = osp.join(self._path, KittiPath.LABELS_DIR, | ||
sample_id+'.txt') | ||
if self._task == KittiTask.detection and osp.isfile(labels_path): | ||
with open(labels_path, 'r', encoding='utf-8') as f: | ||
lines = f.readlines() | ||
|
||
for line_idx, line in enumerate(lines): | ||
line = line.split() | ||
assert len(line) == 15 | ||
|
||
x1, y1 = float(line[4]), float(line[5]) | ||
x2, y2 = float(line[6]), float(line[7]) | ||
|
||
attributes = {} | ||
attributes['truncated'] = float(line[1]) != 0 | ||
attributes['occluded'] = int(line[2]) != 0 | ||
|
||
label_id = self.categories()[ | ||
AnnotationType.label].find(line[0])[0] | ||
if label_id is None: | ||
raise Exception("Item %s: unknown label '%s'" % \ | ||
(sample_id, line[0])) | ||
|
||
anns.append( | ||
Bbox(x=x1, y=y1, w=x2-x1, h=y2-y1, id=line_idx, | ||
attributes=attributes, label=label_id, | ||
)) | ||
items[sample_id] = DatasetItem(id=sample_id, subset=self._subset, | ||
image=image_path, annotations=anns) | ||
return items | ||
|
||
@staticmethod | ||
def _lazy_extract_mask(mask, c): | ||
return lambda: mask == c | ||
|
||
class KittiSegmentationExtractor(_KittiExtractor): | ||
def __init__(self, path): | ||
super().__init__(path, task=KittiTask.segmentation) | ||
|
||
class KittiDetectionExtractor(_KittiExtractor): | ||
def __init__(self, path): | ||
super().__init__(path, task=KittiTask.detection) |
Oops, something went wrong.