Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datumaro] Add masks to tfrecord format #1156

Merged
merged 26 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bf6580b
Employ transforms and item wrapper
zhiltsov-max Feb 7, 2020
baeaf86
Add image class and tests
zhiltsov-max Feb 12, 2020
f79768c
Add image info support to formats
zhiltsov-max Feb 12, 2020
fc94473
Fix cli
zhiltsov-max Feb 12, 2020
c3c1602
Fix merge and voc converte
zhiltsov-max Feb 12, 2020
ff57202
Update remote images extractor
zhiltsov-max Feb 12, 2020
82c3a56
Codacy
zhiltsov-max Feb 13, 2020
986ce38
Remove item name, require path in Image
zhiltsov-max Feb 17, 2020
a7824de
Merge images of dataset items
zhiltsov-max Feb 17, 2020
5a8677f
Update tests
zhiltsov-max Feb 17, 2020
d198fe9
Add image dir converter
zhiltsov-max Feb 17, 2020
a7f6198
Update Datumaro format
zhiltsov-max Feb 17, 2020
8ada47e
Update COCO format with image info
zhiltsov-max Feb 17, 2020
7d47ac9
Update CVAT format with image info
zhiltsov-max Feb 17, 2020
c11ee19
Update TFrecord format with image info
zhiltsov-max Feb 17, 2020
0a93b80
Update VOC formar with image info
zhiltsov-max Feb 17, 2020
f9e5c8c
Update YOLO format with image info
zhiltsov-max Feb 17, 2020
bef057b
Update dataset manager bindings with image info
zhiltsov-max Feb 17, 2020
e00e5a9
Add image name to id transform
zhiltsov-max Feb 17, 2020
addd22e
Fix coco export
zhiltsov-max Feb 18, 2020
81e52d3
Add masks support for tfrecord
zhiltsov-max Feb 18, 2020
2cbd714
Refactor coco
zhiltsov-max Feb 18, 2020
7d58f4d
Fix comparison
zhiltsov-max Feb 20, 2020
3b6fbf8
Remove dead code
zhiltsov-max Feb 20, 2020
5e0dc37
Extract common code for instances
zhiltsov-max Feb 20, 2020
f0322c8
Merge branch 'develop' into zm/dm-add-masks-to-tfrecord
zhiltsov-max Feb 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 10 additions & 48 deletions datumaro/datumaro/plugins/coco_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from datumaro.components.converter import Converter
from datumaro.components.extractor import (DEFAULT_SUBSET_NAME,
AnnotationType, Points, Mask
AnnotationType, Points
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.util import find
from datumaro.util.image import save_image
import datumaro.util.mask_tools as mask_tools
import datumaro.util.annotation_tools as anno_tools

from .format import CocoTask, CocoPath

Expand Down Expand Up @@ -194,7 +195,7 @@ def crop_segments(cls, instances, img_width, img_height):
if inst[1]:
inst[1] = sum(new_segments, [])
else:
mask = cls.merge_masks(new_segments)
mask = mask_tools.merge_masks(new_segments)
inst[2] = mask_tools.mask_to_rle(mask)

return instances
Expand All @@ -205,8 +206,8 @@ def find_instance_parts(self, group, img_width, img_height):
masks = [a for a in group if a.type == AnnotationType.mask]

anns = boxes + polygons + masks
leader = self.find_group_leader(anns)
bbox = self.compute_bbox(anns)
leader = anno_tools.find_group_leader(anns)
bbox = anno_tools.compute_bbox(anns)
mask = None
polygons = [p.points for p in polygons]

Expand All @@ -228,68 +229,29 @@ def find_instance_parts(self, group, img_width, img_height):
if masks:
if mask is not None:
masks += [mask]
mask = self.merge_masks(masks)
mask = mask_tools.merge_masks([m.image for m in masks])

if mask is not None:
mask = mask_tools.mask_to_rle(mask)
polygons = []
else:
if masks:
mask = self.merge_masks(masks)
mask = mask_tools.merge_masks([m.image for m in masks])
polygons += mask_tools.mask_to_polygons(mask)
mask = None

return [leader, polygons, mask, bbox]

@staticmethod
def find_group_leader(group):
return max(group, key=lambda x: x.get_area())

@staticmethod
def merge_masks(masks):
if not masks:
return None

def get_mask(m):
if isinstance(m, Mask):
return m.image
else:
return m

binary_mask = get_mask(masks[0])
for m in masks[1:]:
binary_mask |= get_mask(m)

return binary_mask

@staticmethod
def compute_bbox(annotations):
boxes = [ann.get_bbox() for ann in annotations]
x0 = min((b[0] for b in boxes), default=0)
y0 = min((b[1] for b in boxes), default=0)
x1 = max((b[0] + b[2] for b in boxes), default=0)
y1 = max((b[1] + b[3] for b in boxes), default=0)
return [x0, y0, x1 - x0, y1 - y0]

@staticmethod
def find_instance_anns(annotations):
return [a for a in annotations
if a.type in { AnnotationType.bbox, AnnotationType.polygon } or \
a.type == AnnotationType.mask and a.label is not None
if a.type in { AnnotationType.bbox,
AnnotationType.polygon, AnnotationType.mask }
]

@classmethod
def find_instances(cls, annotations):
instance_anns = cls.find_instance_anns(annotations)

ann_groups = []
for g_id, group in groupby(instance_anns, lambda a: a.group):
if not g_id:
ann_groups.extend(([a] for a in group))
else:
ann_groups.append(list(group))

return ann_groups
return anno_tools.find_instances(cls.find_instance_anns(annotations))

def save_annotations(self, item):
instances = self.find_instances(item.annotations)
Expand Down
227 changes: 132 additions & 95 deletions datumaro/datumaro/plugins/tf_detection_api_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,115 +16,64 @@
from datumaro.components.converter import Converter
from datumaro.components.cli_plugin import CliPlugin
from datumaro.util.image import encode_image
from datumaro.util.mask_tools import merge_masks
from datumaro.util.annotation_tools import (compute_bbox,
find_group_leader, find_instances)
from datumaro.util.tf_util import import_tf as _import_tf

from .format import DetectionApiPath
tf = _import_tf()


# we need it to filter out non-ASCII characters, otherwise training will crash
# filter out non-ASCII characters, otherwise training will crash
_printable = set(string.printable)
def _make_printable(s):
return ''.join(filter(lambda x: x in _printable, s))

def _make_tf_example(item, get_label_id, get_label, save_images=False):
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


features = {
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
'image/filename': bytes_feature(
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
}

if not item.has_image:
raise Exception("Failed to export dataset item '%s': "
"item has no image info" % item.id)
height, width = item.image.size

features.update({
'image/height': int64_feature(height),
'image/width': int64_feature(width),
})

features.update({
'image/encoded': bytes_feature(b''),
'image/format': bytes_feature(b'')
})
if save_images:
if item.has_image and item.image.has_data:
fmt = DetectionApiPath.IMAGE_FORMAT
buffer = encode_image(item.image.data, DetectionApiPath.IMAGE_EXT)

features.update({
'image/encoded': bytes_feature(buffer),
'image/format': bytes_feature(fmt.encode('utf-8')),
})
else:
log.warning("Item '%s' has no image" % item.id)

xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
classes_text = [] # List of string class name of bounding box (1 per box)
classes = [] # List of integer class id of bounding box (1 per box)

boxes = [ann for ann in item.annotations if ann.type is AnnotationType.bbox]
for box in boxes:
box_label = _make_printable(get_label(box.label))

xmins.append(box.points[0] / width)
xmaxs.append(box.points[2] / width)
ymins.append(box.points[1] / height)
ymaxs.append(box.points[3] / height)
classes_text.append(box_label.encode('utf-8'))
classes.append(get_label_id(box.label))

if boxes:
features.update({
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
})
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

tf_example = tf.train.Example(
features=tf.train.Features(feature=features))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

return tf_example
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))

class TfDetectionApiConverter(Converter, CliPlugin):
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('--save-images', action='store_true',
help="Save images (default: %(default)s)")
parser.add_argument('--save-masks', action='store_true',
help="Include instance masks (default: %(default)s)")
return parser

def __init__(self, save_images=False):
def __init__(self, save_images=False, save_masks=False):
super().__init__()

self._save_images = save_images
self._save_masks = save_masks

def __call__(self, extractor, save_dir):
os.makedirs(save_dir, exist_ok=True)

label_categories = extractor.categories().get(AnnotationType.label,
LabelCategories())
get_label = lambda label_id: label_categories.items[label_id].name \
if label_id is not None else ''
label_ids = OrderedDict((label.name, 1 + idx)
for idx, label in enumerate(label_categories.items))
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)
self._get_label = get_label
self._get_label_id = map_label_id

subsets = extractor.subsets()
if len(subsets) == 0:
subsets = [ None ]
Expand All @@ -136,14 +85,6 @@ def __call__(self, extractor, save_dir):
subset_name = DEFAULT_SUBSET_NAME
subset = extractor

label_categories = subset.categories().get(AnnotationType.label,
LabelCategories())
get_label = lambda label_id: label_categories.items[label_id].name \
if label_id is not None else ''
label_ids = OrderedDict((label.name, 1 + idx)
for idx, label in enumerate(label_categories.items))
map_label_id = lambda label_id: label_ids.get(get_label(label_id), 0)

labelmap_path = osp.join(save_dir, DetectionApiPath.LABELMAP_FILE)
with codecs.open(labelmap_path, 'w', encoding='utf8') as f:
for label, idx in label_ids.items():
Expand All @@ -157,10 +98,106 @@ def __call__(self, extractor, save_dir):
anno_path = osp.join(save_dir, '%s.tfrecord' % (subset_name))
with tf.io.TFRecordWriter(anno_path) as writer:
for item in subset:
tf_example = _make_tf_example(
item,
get_label=get_label,
get_label_id=map_label_id,
save_images=self._save_images,
)
tf_example = self._make_tf_example(item)
writer.write(tf_example.SerializeToString())

@staticmethod
def _find_instances(annotations):
return find_instances(a for a in annotations
if a.type in { AnnotationType.bbox, AnnotationType.mask })

def _find_instance_parts(self, group, img_width, img_height):
boxes = [a for a in group if a.type == AnnotationType.bbox]
masks = [a for a in group if a.type == AnnotationType.mask]

anns = boxes + masks
leader = find_group_leader(anns)
bbox = compute_bbox(anns)

mask = None
if self._save_masks:
mask = merge_masks([m.image for m in masks])

return [leader, mask, bbox]

def _export_instances(self, instances, width, height):
xmins = [] # List of normalized left x coordinates of bounding boxes (1 per box)
xmaxs = [] # List of normalized right x coordinates of bounding boxes (1 per box)
ymins = [] # List of normalized top y coordinates of bounding boxes (1 per box)
ymaxs = [] # List of normalized bottom y coordinates of bounding boxes (1 per box)
classes_text = [] # List of class names of bounding boxes (1 per box)
classes = [] # List of class ids of bounding boxes (1 per box)
masks = [] # List of PNG-encoded instance masks (1 per box)

for leader, mask, box in instances:
label = _make_printable(self._get_label(leader.label))
classes_text.append(label.encode('utf-8'))
classes.append(self._get_label_id(leader.label))

xmins.append(box[0] / width)
xmaxs.append((box[0] + box[2]) / width)
ymins.append(box[1] / height)
ymaxs.append((box[1] + box[3]) / height)

if self._save_masks:
if mask is not None:
mask = encode_image(mask, '.png')
else:
mask = b''
masks.append(mask)

result = {}
if classes:
result = {
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
}
if masks:
result['image/object/mask'] = bytes_list_feature(masks)
return result

def _make_tf_example(self, item):
features = {
'image/source_id': bytes_feature(str(item.id).encode('utf-8')),
'image/filename': bytes_feature(
('%s%s' % (item.id, DetectionApiPath.IMAGE_EXT)).encode('utf-8')),
}

if not item.has_image:
raise Exception("Failed to export dataset item '%s': "
"item has no image info" % item.id)
height, width = item.image.size

features.update({
'image/height': int64_feature(height),
'image/width': int64_feature(width),
})

features.update({
'image/encoded': bytes_feature(b''),
'image/format': bytes_feature(b'')
})
if self._save_images:
if item.has_image and item.image.has_data:
fmt = DetectionApiPath.IMAGE_FORMAT
buffer = encode_image(item.image.data, DetectionApiPath.IMAGE_EXT)

features.update({
'image/encoded': bytes_feature(buffer),
'image/format': bytes_feature(fmt.encode('utf-8')),
})
else:
log.warning("Item '%s' has no image" % item.id)

instances = self._find_instances(item.annotations)
instances = [self._find_instance_parts(i, width, height) for i in instances]
features.update(self._export_instances(instances, width, height))

tf_example = tf.train.Example(
features=tf.train.Features(feature=features))

return tf_example
Loading