diff --git a/digits/extensions/data/__init__.py b/digits/extensions/data/__init__.py index b8a56e911..e1af0c476 100644 --- a/digits/extensions/data/__init__.py +++ b/digits/extensions/data/__init__.py @@ -2,10 +2,12 @@ from __future__ import absolute_import from . import imageGradients +from . import objectDetection data_extensions = [ # set show=True if extension should be listed in known extensions {'class': imageGradients.DataIngestion, 'show': False}, + {'class': objectDetection.DataIngestion, 'show': True}, ] diff --git a/digits/extensions/data/objectDetection/__init__.py b/digits/extensions/data/objectDetection/__init__.py new file mode 100644 index 000000000..2ed387aad --- /dev/null +++ b/digits/extensions/data/objectDetection/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from .data import DataIngestion diff --git a/digits/extensions/data/objectDetection/data.py b/digits/extensions/data/objectDetection/data.py new file mode 100644 index 000000000..ab26b24d8 --- /dev/null +++ b/digits/extensions/data/objectDetection/data.py @@ -0,0 +1,241 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +import numpy as np +import operator +import os +import PIL.Image +import random + +import digits +from digits.utils import subclass, override, constants +from ..interface import DataIngestionInterface +from .forms import DatasetForm +from .utils import GroundTruth, GroundTruthObj +from .utils import bbox_to_array, resize_bbox_list + +TEMPLATE = "template.html" + + +@subclass +class DataIngestion(DataIngestionInterface): + """ + A data ingestion extension for an object detection dataset + """ + + def __init__(self, **kwargs): + super(DataIngestion, self).__init__(**kwargs) + + # this instance is automatically populated with form field + # attributes by superclass constructor + + if ((self.val_image_folder == '') ^ (self.val_label_folder == '')): + raise ValueError("You must specify either both val_image_folder and val_label_folder or none") + + if ((self.resize_image_width is None) ^ (self.resize_image_height is None)): + raise ValueError("You must specify either both resize_image_width and resize_image_height or none") + + # this will be set when we know the phase we are encoding + self.ground_truth = None + + @override + def encode_entry(self, entry): + """ + Return numpy.ndarray + """ + image_filename = entry + + # (1) image part + + # load from file + img = digits.utils.image.load_image(image_filename) + if self.channel_conversion != 'none': + if img.mode != self.channel_conversion: + # convert to different image mode if necessary + img = img.convert(self.channel_conversion) + + # pad + img = self.pad_image(img) + + if self.resize_image_width is not None: + # resize + resize_ratio_x = self.resize_image_width / self.padding_image_width + resize_ratio_y = self.resize_image_height / self.padding_image_height + img = digits.utils.image.resize_image( + img, + self.resize_image_height, + self.resize_image_width) + else: + resize_ratio_x = 1 + resize_ratio_y = 1 + # convert to numpy array + img = np.array(img) + + if img.ndim == 2: + # grayscale + img = img[np.newaxis, :, :] + if img.dtype == 'uint16': + img = img.astype(float) + else: + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError("Unsupported image shape: %s" % repr(img.shape)) + # HWC -> CHW + img = img.transpose(2, 0, 1) + + # (2) label part + + # make sure label exists + try: + label_id = int(os.path.splitext(os.path.basename(entry))[0]) + except: + raise ValueError("Unable to extract numerical id from file name %s" % entry) + + if not label_id in self.datasrc_annotation_dict: + raise ValueError("Label key %s not found in label folder" % label_id) + annotations = self.datasrc_annotation_dict[label_id] + + # collect bbox list into bboxList + bboxList = [] + + for bbox in annotations: + # retrieve all vars defining groundtruth, and interpret all + # serialized values as float32s: + np_bbox = np.array(bbox.gt_to_lmdb_format()) + bboxList.append(np_bbox) + + bboxList = sorted( + bboxList, + key=operator.itemgetter(GroundTruthObj.lmdb_format_length()-1) + ) + + bboxList.reverse() + + # adjust bboxes according to image cropping + bboxList = resize_bbox_list(bboxList, resize_ratio_x, resize_ratio_y) + + # return data + feature = img + label = np.asarray(bboxList) + + # LMDB compaction: now label (aka bbox) is the joint array + label = bbox_to_array( + label, + 0, + max_bboxes=self.max_bboxes, + bbox_width=GroundTruthObj.lmdb_format_length()) + + + return feature, label + + @staticmethod + @override + def get_category(): + return "Images" + + @staticmethod + @override + def get_id(): + return "image-object-detection" + + @staticmethod + @override + def get_dataset_form(): + return DatasetForm() + + @staticmethod + @override + def get_dataset_template(form): + """ + parameters: + - form: form returned by get_dataset_form(). This may be populated with values if the job was cloned + return: + - (template, context) tuple + template is a Jinja template to use for rendering dataset creation options + context is a dictionary of context variables to use for rendering the form + """ + extension_dir = os.path.dirname(os.path.abspath(__file__)) + template = open(os.path.join(extension_dir, TEMPLATE), "r").read() + context = {'form': form} + return (template, context) + + @staticmethod + @override + def get_title(): + return "Object Detection" + + @override + def itemize_entries(self, stage): + """ + return list of image file names to encode for specified stage + """ + if stage == constants.TEST_DB: + # don't retun anything for the test stage + return [] + elif stage == constants.TRAIN_DB: + # load ground truth + self.load_ground_truth(self.train_label_folder) + # get training image file names + return self.make_image_list(self.train_image_folder) + elif stage == constants.VAL_DB: + if self.val_image_folder != '': + # load ground truth + self.load_ground_truth(self.val_label_folder) + # get validation image file names + return self.make_image_list(self.val_image_folder) + else: + # no validation folder was specified + return [] + else: + raise ValueError("Unknown stage: %s" % stage) + + def load_ground_truth(self, folder): + """ + load ground truth from specified folder + """ + datasrc = GroundTruth(folder) + datasrc.load_gt_obj() + self.datasrc_annotation_dict = datasrc.objects_all + + scene_files = [] + for key in self.datasrc_annotation_dict: + scene_files.append(key) + + # determine largest label height: + self.max_bboxes = max([len(annotation) for annotation in self.datasrc_annotation_dict.values()]) + + def make_image_list(self, folder): + """ + find all supported images within specified folder and return list of file names + """ + image_files = [] + for dirpath, dirnames, filenames in os.walk(folder, followlinks=True): + for filename in filenames: + if filename.lower().endswith(digits.utils.image.SUPPORTED_EXTENSIONS): + image_files.append('%s' % os.path.join(folder, filename)) + if len(image_files) == 0: + raise ValueError("Unable to find supported images in %s" % folder) + # shuffle + random.shuffle(image_files) + return image_files + + def pad_image(self, img): + """ + pad a single image to the dimensions specified in form + """ + src_width = img.size[0] + src_height = img.size[1] + + if self.padding_image_width < src_width: + raise ValueError("Source image width %d is greater than padding width %d" % (src_width, self.padding_image_width)) + + if self.padding_image_height < src_height: + raise ValueError("Source image height %d is greater than padding height %d" % (src_height, self.padding_image_height)) + + padded_img = PIL.Image.new( + img.mode, + (self.padding_image_width, self.padding_image_height), + "black") + + padded_img.paste(img, (0, 0)) # copy to top-left corner + + return padded_img diff --git a/digits/extensions/data/objectDetection/forms.py b/digits/extensions/data/objectDetection/forms.py new file mode 100644 index 000000000..35b089209 --- /dev/null +++ b/digits/extensions/data/objectDetection/forms.py @@ -0,0 +1,111 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from flask.ext.wtf import Form +import os +from wtforms import validators + +from digits import utils +from digits.utils import subclass + + +@subclass +class DatasetForm(Form): + """ + A form used to create an image processing dataset + """ + + def validate_folder_path(form, field): + if not field.data: + pass + else: + # make sure the filesystem path exists + if not os.path.exists(field.data) or not os.path.isdir(field.data): + raise validators.ValidationError('Folder does not exist or is not reachable') + else: + return True + + train_image_folder = utils.forms.StringField( + u'Training image folder', + validators=[ + validators.DataRequired(), + validate_folder_path, + ], + tooltip="Indicate a folder of images to use for training" + ) + + train_label_folder = utils.forms.StringField( + u'Training label folder', + validators=[ + validators.DataRequired(), + validate_folder_path, + ], + tooltip="Indicate a folder of training labels" + ) + + val_image_folder = utils.forms.StringField( + u'Validation image folder', + validators=[ + validators.Optional(), + validate_folder_path, + ], + tooltip="Indicate a folder of images to use for training" + ) + + val_label_folder = utils.forms.StringField( + u'Validation label folder', + validators=[ + validators.Optional(), + validate_folder_path, + ], + tooltip="Indicate a folder of validation labels" + ) + + resize_image_width = utils.forms.IntegerField( + u'Resize Image Width', + validators=[ + validators.Optional(), + validators.NumberRange(min=1), + ], + tooltip="If specified, images will be resized to that dimension after padding" + ) + + resize_image_height = utils.forms.IntegerField( + u'Resize Image Height', + validators=[ + validators.Optional(), + validators.NumberRange(min=1), + ], + tooltip="If specified, images will be resized to that dimension after padding" + ) + + padding_image_width = utils.forms.IntegerField( + u'Padding Image Width', + default=1248, + validators=[ + validators.DataRequired(), + validators.NumberRange(min=1), + ], + tooltip="Images will be padded to that dimension" + ) + + padding_image_height = utils.forms.IntegerField( + u'Padding Image Height', + default=384, + validators=[ + validators.DataRequired(), + validators.NumberRange(min=1), + ], + tooltip="Images will be padded to that dimension" + ) + + channel_conversion = utils.forms.SelectField( + 'Channel conversion', + choices=[ + ('RGB', 'RGB'), + ('L', 'Grayscale'), + ('none', 'None'), + ], + default='RGB', + tooltip="Perform selected channel conversion." + ) diff --git a/digits/extensions/data/objectDetection/template.html b/digits/extensions/data/objectDetection/template.html new file mode 100644 index 000000000..1cd8ef288 --- /dev/null +++ b/digits/extensions/data/objectDetection/template.html @@ -0,0 +1,70 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + +{% from "helper.html" import print_flashes %} +{% from "helper.html" import print_errors %} +{% from "helper.html" import mark_errors %} + +

Object Detection Dataset Options

+ + + Images can be stored in any of the supported file formats ('.png','.jpg','.jpeg','.bmp','.ppm'). + + +
+ {{ form.train_image_folder.label }} + {{ form.train_image_folder.tooltip }} + {{ form.train_image_folder(class='form-control autocomplete_path', placeholder='folder')}} +
+ + + Label files are expected to have the .txt extension. + For example if an image file is named foo.png the corresponding label file should be foo.txt. + + +
+ {{ form.train_label_folder.label }} + {{ form.train_label_folder.tooltip }} + {{ form.train_label_folder(class='form-control autocomplete_path', placeholder='folder')}} +
+ +
+ {{ form.val_image_folder.label }} + {{ form.val_image_folder.tooltip }} + {{ form.val_image_folder(class='form-control autocomplete_path', placeholder='folder')}} +
+ +
+ {{ form.val_label_folder.label }} + {{ form.val_label_folder.tooltip }} + {{ form.val_label_folder(class='form-control autocomplete_path', placeholder='folder')}} +
+ +
+ {{ form.padding_image_width.label }} + {{ form.padding_image_width.tooltip }} + {{ form.padding_image_width(class='form-control')}} +
+ +
+ {{ form.padding_image_height.label }} + {{ form.padding_image_height.tooltip }} + {{ form.padding_image_height(class='form-control')}} +
+ +
+ {{ form.resize_image_width.label }} + {{ form.resize_image_width.tooltip }} + {{ form.resize_image_width(class='form-control')}} +
+ +
+ {{ form.resize_image_height.label }} + {{ form.resize_image_height.tooltip }} + {{ form.resize_image_height(class='form-control')}} +
+ +
+ {{ form.channel_conversion.label }} + {{ form.channel_conversion.tooltip }} + {{ form.channel_conversion(class='form-control')}} +
diff --git a/digits/extensions/data/objectDetection/utils.py b/digits/extensions/data/objectDetection/utils.py new file mode 100644 index 000000000..45802048e --- /dev/null +++ b/digits/extensions/data/objectDetection/utils.py @@ -0,0 +1,270 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. + +import csv +import numpy as np +import os + +class ObjectType: + + Dontcare, Car, Van, Truck, Bus, Pickup, VehicleWithTrailer, SpecialVehicle,\ + Person, Person_fa, Person_unsure, People, Cyclist, Tram, Person_Sitting,\ + Misc = range(16) + + def __init__(self): + pass + +class Bbox: + + def __init__(self, x_left=0, y_top=0, x_right=0, y_bottom=0): + self.xl = x_left + self.yt = y_top + self.xr = x_right + self.yb = y_bottom + + def area(self): + return (self.xr - self.xl) * (self.yb - self.yt) + + def width(self): + return self.xr - self.xl + + def height(self): + return self.yb - self.yt + + def get_array(self): + return [self.xl, self.yt, self.xr, self.yb] + +class GroundTruthObj: + + """ This class is the data ground-truth + + #Values Name Description + ---------------------------------------------------------------------------- + 1 type Describes the type of object: 'Car', 'Van', 'Truck', + 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', + 'Misc' or 'DontCare' + 1 truncated Float from 0 (non-truncated) to 1 (truncated), where + truncated refers to the object leaving image boundaries. + -1 corresponds to a don't care region. + 1 occluded Integer (-1,0,1,2) indicating occlusion state: + -1 = unkown, 0 = fully visible, + 1 = partly occluded, 2 = largely occluded + 1 alpha Observation angle of object, ranging [-pi..pi] + 4 bbox 2D bounding box of object in the image (0-based index): + contains left, top, right, bottom pixel coordinates + 3 dimensions 3D object dimensions: height, width, length (in meters) + 3 location 3D object location x,y,z in camera coordinates (in meters) + 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] + 1 score Only for results: Float, indicating confidence in + detection, needed for p/r curves, higher is better. + + Here, 'DontCare' labels denote regions in which objects have not been labeled, + for example because they have been too far away from the laser scanner. + """ + + + def __init__(self): + self.stype = '' + self.truncated = 0 + self.occlusion = 0 + self.angle = 0 + self.height = 0 + self.width = 0 + self.length = 0 + self.locx = 0 + self.locy = 0 + self.locz = 0 + self.roty = 0 + self.bbox = Bbox() + self.object = ObjectType.Dontcare + + @classmethod + def lmdb_format_length(cls): + """ + width of an LMDB datafield returned by the gt_to_lmdb_format function. + :return: + """ + return 16 + + def gt_to_lmdb_format(self): + """ + For storage of a bbox ground truth object into a float32 LMDB. + Sort-by attribute is always the last value in the array. + """ + result = [ + # bbox in x,y,w,h format: + self.bbox.xl, + self.bbox.yt, + self.bbox.xr - self.bbox.xl, + self.bbox.yb - self.bbox.yt, + # alpha angle: + self.angle, + # class number: + self.object, + 0, + # Y axis rotation: + self.roty, + # bounding box attributes: + self.truncated, + self.occlusion, + # object dimensions: + self.length, + self.width, + self.height, + self.locx, + self.locy, + # depth (sort-by attribute): + self.locz, + ] + assert(len(result) is self.lmdb_format_length()) + return result + + def set_type(self): + object_types = { + 'bus': ObjectType.Bus, + 'car': ObjectType.Car, + 'cyclist': ObjectType.Cyclist, + 'pedestrian': ObjectType.Person, + 'people': ObjectType.People, + 'person': ObjectType.Person, + 'person_sitting': ObjectType.Person_Sitting, + 'person-fa': ObjectType.Person_fa, + 'person?': ObjectType.Person_unsure, + 'pickup': ObjectType.Pickup, + 'misc': ObjectType.Misc, + 'special-vehicle': ObjectType.SpecialVehicle, + 'tram': ObjectType.Tram, + 'truck': ObjectType.Truck, + 'van': ObjectType.Van, + 'vehicle-with-trailer': ObjectType.VehicleWithTrailer + } + self.object = object_types.get(self.stype, ObjectType.Dontcare) + + +class GroundTruth: + + """ this class load the ground truth + """ + + def __init__(self, label_dir, label_ext='.txt', label_delimiter=' '): + self.label_dir = label_dir + self.label_ext = label_ext # extension of label files + self.label_delimiter = label_delimiter # space is used as delimiter in label files + self._objects_all = dict() # positive bboxes across images + + def update_objects_all(self, _key, _bboxes): + if _bboxes: + self._objects_all[_key] = _bboxes + else: + self._objects_all[_key] = none + + def load_gt_obj(self): + + """ load bbox ground truth from files either via the provided label directory or list of label files""" + files = os.listdir(self.label_dir) + files = filter(lambda x: x.endswith(self.label_ext), files) + if len(files) == 0: + raise exception('error: no label files found in', self.label_dir) + for label_file in files: + objects_per_image = list() + with open( os.path.join(self.label_dir, label_file), 'rb') as flabel: + for row in csv.reader(flabel, delimiter=self.label_delimiter): + + # load data + gt = GroundTruthObj() + gt.stype = row[0].lower() + gt.truncated = float(row[1]) + gt.occlusion = int(row[2]) + gt.angle = float(row[3]) + gt.bbox.xl = float(row[4]) + gt.bbox.yt = float(row[5]) + gt.bbox.xr = float(row[6]) + gt.bbox.yb = float(row[7]) + gt.height = float(row[8]) + gt.width = float(row[9]) + gt.length = float(row[10]) + gt.locx = float(row[11]) + gt.locy = float(row[12]) + gt.locz = float(row[13]) + gt.roty = float(row[14]) + gt.set_type() + objects_per_image.append(gt) + key = int(os.path.splitext(label_file)[0]) + self.update_objects_all(key, objects_per_image) + + @property + def objects_all(self): + return self._objects_all + +# return the # of pixels remaining in a + +def pad_bbox(arr, max_bboxes=64, bbox_width=16): + if arr.shape[0] > max_bboxes: + raise ValueError( + 'Too many bounding boxes (%d > %d)' % arr.shape[0], max_bboxes + ) + # fill remainder with zeroes: + data = np.zeros((max_bboxes+1, bbox_width), dtype='float') + # number of bounding boxes: + data[0][0] = arr.shape[0] + # width of a bounding box: + data[0][1] = bbox_width + # bounding box data. Merge nothing if no bounding boxes exist. + if arr.shape[0] > 0: + data[1:1 + arr.shape[0]] = arr + + return data + + +def bbox_to_array(arr, label=0, max_bboxes=64, bbox_width=16): + """ + Converts a 1-dimensional bbox array to an image-like + 3-dimensional array CHW array + """ + arr = pad_bbox(arr, max_bboxes, bbox_width) + return arr[np.newaxis, :, :] + + +def bbox_overlap(abox, bbox): + # the abox box + x11 = abox[0] + y11 = abox[1] + x12 = abox[0] + abox[2] - 1 + y12 = abox[1] + abox[3] - 1 + + # the closer box + x21 = bbox[0] + y21 = bbox[1] + x22 = bbox[0] + bbox[2] - 1 + y22 = bbox[1] + bbox[3] - 1 + + overlap_box_x2 = min(x12, x22) + overlap_box_x1 = max(x11, x21) + overlap_box_y2 = min(y12, y22) + overlap_box_y1 = max(y11, y21) + + # make sure we preserve any non-bbox components + overlap_box = list(bbox) + overlap_box[0] = overlap_box_x1 + overlap_box[1] = overlap_box_y1 + overlap_box[2] = overlap_box_x2-overlap_box_x1+1 + overlap_box[3] = overlap_box_y2-overlap_box_y1+1 + + xoverlap = max(0, overlap_box_x2 - overlap_box_x1) + yoverlap = max(0, overlap_box_y2 - overlap_box_y1) + overlap_pix = xoverlap * yoverlap + + return overlap_pix, overlap_box + +def resize_bbox_list(bboxlist, rescale_x=1, rescale_y=1): + # this is expecting x1,y1,w,h: + bboxListNew = [] + for bbox in bboxlist: + abox = bbox + abox[0] *= rescale_x + abox[1] *= rescale_y + abox[2] *= rescale_x + abox[3] *= rescale_y + bboxListNew.append(abox) + return bboxListNew + + diff --git a/digits/extensions/view/__init__.py b/digits/extensions/view/__init__.py index a5ebb3573..82aa38b2b 100644 --- a/digits/extensions/view/__init__.py +++ b/digits/extensions/view/__init__.py @@ -1,10 +1,12 @@ # Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import +from . import boundingBox from . import rawData view_extensions = [ # set show=True if extension should be listed in known extensions + {'class': boundingBox.Visualization, 'show': True}, {'class': rawData.Visualization, 'show': True}, ] diff --git a/digits/extensions/view/boundingBox/__init__.py b/digits/extensions/view/boundingBox/__init__.py new file mode 100644 index 000000000..c0b2b8cf3 --- /dev/null +++ b/digits/extensions/view/boundingBox/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from .view import Visualization diff --git a/digits/extensions/view/boundingBox/config_template.html b/digits/extensions/view/boundingBox/config_template.html new file mode 100644 index 000000000..1712fb114 --- /dev/null +++ b/digits/extensions/view/boundingBox/config_template.html @@ -0,0 +1,22 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + +{% from "helper.html" import print_flashes %} +{% from "helper.html" import print_errors %} +{% from "helper.html" import mark_errors %} + +Draw a bounding box around a detected object. This expected network output is a nested list of list of box coordinates + +
+ {{ form.box_color.label }} + {{ form.box_color.tooltip }} + {{ form.box_color(class='form-control') }} +
+ +
+ {{ form.line_width.label }} + {{ form.line_width.tooltip }} + {{ form.line_width(class='form-control') }} +
+ + + diff --git a/digits/extensions/view/boundingBox/forms.py b/digits/extensions/view/boundingBox/forms.py new file mode 100644 index 000000000..69314dbfb --- /dev/null +++ b/digits/extensions/view/boundingBox/forms.py @@ -0,0 +1,34 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +from digits import utils +from digits.utils import subclass +from flask.ext.wtf import Form +import wtforms +from wtforms import validators + + +@subclass +class ConfigForm(Form): + """ + A form used to configure the drawing of bounding boxes + """ + + box_color = wtforms.SelectField( + 'Box color', + choices=[ + ('red', 'Red'), + ('green', 'Green'), + ('blue', 'Blue'), + ], + default='red', + ) + + line_width = utils.forms.IntegerField( + 'Line width', + validators=[ + validators.DataRequired(), + validators.NumberRange(min=1), + ], + default=2, + ) diff --git a/digits/extensions/view/boundingBox/summary_template.html b/digits/extensions/view/boundingBox/summary_template.html new file mode 100644 index 000000000..2d88d9cb1 --- /dev/null +++ b/digits/extensions/view/boundingBox/summary_template.html @@ -0,0 +1,3 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + +Found {{ bbox_count }} bounding box(es) in {{ image_count }} image(s). diff --git a/digits/extensions/view/boundingBox/view.py b/digits/extensions/view/boundingBox/view.py new file mode 100644 index 000000000..c4521a311 --- /dev/null +++ b/digits/extensions/view/boundingBox/view.py @@ -0,0 +1,141 @@ +# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +import os +import PIL.Image + +import digits +from digits.utils import subclass, override +from .forms import ConfigForm +from ..interface import VisualizationInterface + +CONFIG_TEMPLATE = "config_template.html" +SUMMARY_TEMPLATE = "summary_template.html" +VIEW_TEMPLATE = "view_template.html" + + +@subclass +class Visualization(VisualizationInterface): + """ + A visualization extension to display bounding boxes + """ + + def __init__(self, dataset, **kwargs): + # bounding box options + color = kwargs['box_color'] + if color == "red": + self.color = (255, 0, 0) + elif color == "green": + self.color = (0, 255, 0) + elif color == "blue": + self.color = (0, 0, 255) + else: + raise ValueError("unknown color: %s" % color) + self.line_width = int(kwargs['line_width']) + + # memorize view template for later use + extension_dir = os.path.dirname(os.path.abspath(__file__)) + self.view_template = open( + os.path.join(extension_dir, VIEW_TEMPLATE), "r").read() + + # stats + self.image_count = 0 + self.bbox_count = 0 + + @staticmethod + def get_config_form(): + return ConfigForm() + + @staticmethod + def get_config_template(form): + """ + parameters: + - form: form returned by get_config_form(). This may be populated + with values if the job was cloned + returns: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + extension_dir = os.path.dirname(os.path.abspath(__file__)) + template = open( + os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read() + context = {'form': form} + return (template, context) + + @staticmethod + def get_id(): + return 'image-bounding-boxes' + + @override + def get_summary_template(self): + """ + This returns a summary of the job. This method is called after all + entries have been processed. + returns: + - (template, context) tuple + - template is a Jinja template to use for rendering the summary, or + None if there is no summary to display + - context is a dictionary of context variables to use for rendering + the form + """ + extension_dir = os.path.dirname(os.path.abspath(__file__)) + template = open( + os.path.join(extension_dir, SUMMARY_TEMPLATE), "r").read() + return template, {'image_count': self.image_count, 'bbox_count': self.bbox_count} + + @staticmethod + def get_title(): + return 'Bounding boxes' + + @override + def get_view_template(self, data): + """ + return: + - (template, context) tuple + - template is a Jinja template to use for rendering config options + - context is a dictionary of context variables to use for rendering + the form + """ + return self.view_template, {'image': data['image']} + + @override + def process_data( + self, + dataset, + input_data, + inference_data, + ground_truth=None): + """ + Process one inference output + Parameters: + - dataset: dataset used during training + - input_data: input to the network + - inference_data: network output + - ground_truth: Ground truth. Format is application specific. + None if absent. + Returns: + - an object reprensenting the processed data + """ + # get source image + image = PIL.Image.fromarray(input_data).convert('RGB') + + self.image_count += 1 + + # create arrays in expected format + bboxes = [] + outputs = inference_data[inference_data.keys()[0]] + for output in outputs: + # last number is confidence + if output[-1] > 0: + box = ((output[0], output[1]), (output[2], output[3])) + bboxes.append(box) + self.bbox_count += 1 + digits.utils.image.add_bboxes_to_image( + image, + bboxes, + self.color, + self.line_width) + image_html = digits.utils.image.embed_image_html(image) + return {'image': image_html} diff --git a/digits/extensions/view/boundingBox/view_template.html b/digits/extensions/view/boundingBox/view_template.html new file mode 100644 index 000000000..483fe8bb0 --- /dev/null +++ b/digits/extensions/view/boundingBox/view_template.html @@ -0,0 +1,3 @@ +{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} + + diff --git a/digits/utils/image.py b/digits/utils/image.py index 38d0afa17..7e4b97d50 100644 --- a/digits/utils/image.py +++ b/digits/utils/image.py @@ -289,6 +289,42 @@ def embed_image_html(image): data = string_buf.getvalue().encode('base64').replace('\n', '') return 'data:image/%s;base64,%s' % (fmt, data) +def add_bboxes_to_image(image, bboxes, color='red', width=1): + """ + Draw rectangles on the image for the bounding boxes + Returns a PIL.Image + + Arguments: + image -- input image + bboxes -- bounding boxes in the [((l, t), (r, b)), ...] format + + Keyword arguments: + color -- color to draw the rectangles + width -- line width of the rectangles + + Example: + image = Image.open(filename) + add_bboxes_to_image(image, bboxes[filename], width=2, color='#FF7700') + image.show() + """ + def expanded_bbox(bbox, n): + """ + Grow the bounding box by n pixels + """ + l = min(bbox[0][0], bbox[1][0]) + r = max(bbox[0][0], bbox[1][0]) + t = min(bbox[0][1], bbox[1][1]) + b = max(bbox[0][1], bbox[1][1]) + return ((l - n, t - n), (r + n, b + n)) + + from PIL import Image, ImageDraw + draw = ImageDraw.Draw(image) + for bbox in bboxes: + for n in xrange(width): + draw.rectangle(expanded_bbox(bbox, n), outline = color) + + return image + def get_layer_vis_square(data, allow_heatmap = True, normalize = True, diff --git a/digits/utils/test_image.py b/digits/utils/test_image.py index dcff77c7b..46a103b8d 100644 --- a/digits/utils/test_image.py +++ b/digits/utils/test_image.py @@ -189,3 +189,12 @@ def args_to_str(self, args): shape=%s""" % args +class TestBBoxes(): + + def test_add_bboxes(self): + np_color = np.random.randint(0, 100, (10,10,3)).astype('uint8') + pil_color = PIL.Image.fromarray(np_color) + pil_color = image_utils.add_bboxes_to_image(pil_color, [((4, 4), (7, 7))], color='red') + pixelMap = pil_color.load() + assert pixelMap[4, 4] == (255, 0, 0) + assert pixelMap[7, 7] == (255, 0, 0)