diff --git a/digits/extensions/data/__init__.py b/digits/extensions/data/__init__.py
index b8a56e911..e1af0c476 100644
--- a/digits/extensions/data/__init__.py
+++ b/digits/extensions/data/__init__.py
@@ -2,10 +2,12 @@
from __future__ import absolute_import
from . import imageGradients
+from . import objectDetection
data_extensions = [
# set show=True if extension should be listed in known extensions
{'class': imageGradients.DataIngestion, 'show': False},
+ {'class': objectDetection.DataIngestion, 'show': True},
]
diff --git a/digits/extensions/data/objectDetection/__init__.py b/digits/extensions/data/objectDetection/__init__.py
new file mode 100644
index 000000000..2ed387aad
--- /dev/null
+++ b/digits/extensions/data/objectDetection/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+from .data import DataIngestion
diff --git a/digits/extensions/data/objectDetection/data.py b/digits/extensions/data/objectDetection/data.py
new file mode 100644
index 000000000..ab26b24d8
--- /dev/null
+++ b/digits/extensions/data/objectDetection/data.py
@@ -0,0 +1,241 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+import numpy as np
+import operator
+import os
+import PIL.Image
+import random
+
+import digits
+from digits.utils import subclass, override, constants
+from ..interface import DataIngestionInterface
+from .forms import DatasetForm
+from .utils import GroundTruth, GroundTruthObj
+from .utils import bbox_to_array, resize_bbox_list
+
+TEMPLATE = "template.html"
+
+
+@subclass
+class DataIngestion(DataIngestionInterface):
+ """
+ A data ingestion extension for an object detection dataset
+ """
+
+ def __init__(self, **kwargs):
+ super(DataIngestion, self).__init__(**kwargs)
+
+ # this instance is automatically populated with form field
+ # attributes by superclass constructor
+
+ if ((self.val_image_folder == '') ^ (self.val_label_folder == '')):
+ raise ValueError("You must specify either both val_image_folder and val_label_folder or none")
+
+ if ((self.resize_image_width is None) ^ (self.resize_image_height is None)):
+ raise ValueError("You must specify either both resize_image_width and resize_image_height or none")
+
+ # this will be set when we know the phase we are encoding
+ self.ground_truth = None
+
+ @override
+ def encode_entry(self, entry):
+ """
+ Return numpy.ndarray
+ """
+ image_filename = entry
+
+ # (1) image part
+
+ # load from file
+ img = digits.utils.image.load_image(image_filename)
+ if self.channel_conversion != 'none':
+ if img.mode != self.channel_conversion:
+ # convert to different image mode if necessary
+ img = img.convert(self.channel_conversion)
+
+ # pad
+ img = self.pad_image(img)
+
+ if self.resize_image_width is not None:
+ # resize
+ resize_ratio_x = self.resize_image_width / self.padding_image_width
+ resize_ratio_y = self.resize_image_height / self.padding_image_height
+ img = digits.utils.image.resize_image(
+ img,
+ self.resize_image_height,
+ self.resize_image_width)
+ else:
+ resize_ratio_x = 1
+ resize_ratio_y = 1
+ # convert to numpy array
+ img = np.array(img)
+
+ if img.ndim == 2:
+ # grayscale
+ img = img[np.newaxis, :, :]
+ if img.dtype == 'uint16':
+ img = img.astype(float)
+ else:
+ if img.ndim != 3 or img.shape[2] != 3:
+ raise ValueError("Unsupported image shape: %s" % repr(img.shape))
+ # HWC -> CHW
+ img = img.transpose(2, 0, 1)
+
+ # (2) label part
+
+ # make sure label exists
+ try:
+ label_id = int(os.path.splitext(os.path.basename(entry))[0])
+ except:
+ raise ValueError("Unable to extract numerical id from file name %s" % entry)
+
+ if not label_id in self.datasrc_annotation_dict:
+ raise ValueError("Label key %s not found in label folder" % label_id)
+ annotations = self.datasrc_annotation_dict[label_id]
+
+ # collect bbox list into bboxList
+ bboxList = []
+
+ for bbox in annotations:
+ # retrieve all vars defining groundtruth, and interpret all
+ # serialized values as float32s:
+ np_bbox = np.array(bbox.gt_to_lmdb_format())
+ bboxList.append(np_bbox)
+
+ bboxList = sorted(
+ bboxList,
+ key=operator.itemgetter(GroundTruthObj.lmdb_format_length()-1)
+ )
+
+ bboxList.reverse()
+
+ # adjust bboxes according to image cropping
+ bboxList = resize_bbox_list(bboxList, resize_ratio_x, resize_ratio_y)
+
+ # return data
+ feature = img
+ label = np.asarray(bboxList)
+
+ # LMDB compaction: now label (aka bbox) is the joint array
+ label = bbox_to_array(
+ label,
+ 0,
+ max_bboxes=self.max_bboxes,
+ bbox_width=GroundTruthObj.lmdb_format_length())
+
+
+ return feature, label
+
+ @staticmethod
+ @override
+ def get_category():
+ return "Images"
+
+ @staticmethod
+ @override
+ def get_id():
+ return "image-object-detection"
+
+ @staticmethod
+ @override
+ def get_dataset_form():
+ return DatasetForm()
+
+ @staticmethod
+ @override
+ def get_dataset_template(form):
+ """
+ parameters:
+ - form: form returned by get_dataset_form(). This may be populated with values if the job was cloned
+ return:
+ - (template, context) tuple
+ template is a Jinja template to use for rendering dataset creation options
+ context is a dictionary of context variables to use for rendering the form
+ """
+ extension_dir = os.path.dirname(os.path.abspath(__file__))
+ template = open(os.path.join(extension_dir, TEMPLATE), "r").read()
+ context = {'form': form}
+ return (template, context)
+
+ @staticmethod
+ @override
+ def get_title():
+ return "Object Detection"
+
+ @override
+ def itemize_entries(self, stage):
+ """
+ return list of image file names to encode for specified stage
+ """
+ if stage == constants.TEST_DB:
+ # don't retun anything for the test stage
+ return []
+ elif stage == constants.TRAIN_DB:
+ # load ground truth
+ self.load_ground_truth(self.train_label_folder)
+ # get training image file names
+ return self.make_image_list(self.train_image_folder)
+ elif stage == constants.VAL_DB:
+ if self.val_image_folder != '':
+ # load ground truth
+ self.load_ground_truth(self.val_label_folder)
+ # get validation image file names
+ return self.make_image_list(self.val_image_folder)
+ else:
+ # no validation folder was specified
+ return []
+ else:
+ raise ValueError("Unknown stage: %s" % stage)
+
+ def load_ground_truth(self, folder):
+ """
+ load ground truth from specified folder
+ """
+ datasrc = GroundTruth(folder)
+ datasrc.load_gt_obj()
+ self.datasrc_annotation_dict = datasrc.objects_all
+
+ scene_files = []
+ for key in self.datasrc_annotation_dict:
+ scene_files.append(key)
+
+ # determine largest label height:
+ self.max_bboxes = max([len(annotation) for annotation in self.datasrc_annotation_dict.values()])
+
+ def make_image_list(self, folder):
+ """
+ find all supported images within specified folder and return list of file names
+ """
+ image_files = []
+ for dirpath, dirnames, filenames in os.walk(folder, followlinks=True):
+ for filename in filenames:
+ if filename.lower().endswith(digits.utils.image.SUPPORTED_EXTENSIONS):
+ image_files.append('%s' % os.path.join(folder, filename))
+ if len(image_files) == 0:
+ raise ValueError("Unable to find supported images in %s" % folder)
+ # shuffle
+ random.shuffle(image_files)
+ return image_files
+
+ def pad_image(self, img):
+ """
+ pad a single image to the dimensions specified in form
+ """
+ src_width = img.size[0]
+ src_height = img.size[1]
+
+ if self.padding_image_width < src_width:
+ raise ValueError("Source image width %d is greater than padding width %d" % (src_width, self.padding_image_width))
+
+ if self.padding_image_height < src_height:
+ raise ValueError("Source image height %d is greater than padding height %d" % (src_height, self.padding_image_height))
+
+ padded_img = PIL.Image.new(
+ img.mode,
+ (self.padding_image_width, self.padding_image_height),
+ "black")
+
+ padded_img.paste(img, (0, 0)) # copy to top-left corner
+
+ return padded_img
diff --git a/digits/extensions/data/objectDetection/forms.py b/digits/extensions/data/objectDetection/forms.py
new file mode 100644
index 000000000..35b089209
--- /dev/null
+++ b/digits/extensions/data/objectDetection/forms.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+from flask.ext.wtf import Form
+import os
+from wtforms import validators
+
+from digits import utils
+from digits.utils import subclass
+
+
+@subclass
+class DatasetForm(Form):
+ """
+ A form used to create an image processing dataset
+ """
+
+ def validate_folder_path(form, field):
+ if not field.data:
+ pass
+ else:
+ # make sure the filesystem path exists
+ if not os.path.exists(field.data) or not os.path.isdir(field.data):
+ raise validators.ValidationError('Folder does not exist or is not reachable')
+ else:
+ return True
+
+ train_image_folder = utils.forms.StringField(
+ u'Training image folder',
+ validators=[
+ validators.DataRequired(),
+ validate_folder_path,
+ ],
+ tooltip="Indicate a folder of images to use for training"
+ )
+
+ train_label_folder = utils.forms.StringField(
+ u'Training label folder',
+ validators=[
+ validators.DataRequired(),
+ validate_folder_path,
+ ],
+ tooltip="Indicate a folder of training labels"
+ )
+
+ val_image_folder = utils.forms.StringField(
+ u'Validation image folder',
+ validators=[
+ validators.Optional(),
+ validate_folder_path,
+ ],
+ tooltip="Indicate a folder of images to use for training"
+ )
+
+ val_label_folder = utils.forms.StringField(
+ u'Validation label folder',
+ validators=[
+ validators.Optional(),
+ validate_folder_path,
+ ],
+ tooltip="Indicate a folder of validation labels"
+ )
+
+ resize_image_width = utils.forms.IntegerField(
+ u'Resize Image Width',
+ validators=[
+ validators.Optional(),
+ validators.NumberRange(min=1),
+ ],
+ tooltip="If specified, images will be resized to that dimension after padding"
+ )
+
+ resize_image_height = utils.forms.IntegerField(
+ u'Resize Image Height',
+ validators=[
+ validators.Optional(),
+ validators.NumberRange(min=1),
+ ],
+ tooltip="If specified, images will be resized to that dimension after padding"
+ )
+
+ padding_image_width = utils.forms.IntegerField(
+ u'Padding Image Width',
+ default=1248,
+ validators=[
+ validators.DataRequired(),
+ validators.NumberRange(min=1),
+ ],
+ tooltip="Images will be padded to that dimension"
+ )
+
+ padding_image_height = utils.forms.IntegerField(
+ u'Padding Image Height',
+ default=384,
+ validators=[
+ validators.DataRequired(),
+ validators.NumberRange(min=1),
+ ],
+ tooltip="Images will be padded to that dimension"
+ )
+
+ channel_conversion = utils.forms.SelectField(
+ 'Channel conversion',
+ choices=[
+ ('RGB', 'RGB'),
+ ('L', 'Grayscale'),
+ ('none', 'None'),
+ ],
+ default='RGB',
+ tooltip="Perform selected channel conversion."
+ )
diff --git a/digits/extensions/data/objectDetection/template.html b/digits/extensions/data/objectDetection/template.html
new file mode 100644
index 000000000..1cd8ef288
--- /dev/null
+++ b/digits/extensions/data/objectDetection/template.html
@@ -0,0 +1,70 @@
+{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}
+
+{% from "helper.html" import print_flashes %}
+{% from "helper.html" import print_errors %}
+{% from "helper.html" import mark_errors %}
+
+
Object Detection Dataset Options
+
+
+ Images can be stored in any of the supported file formats ('.png','.jpg','.jpeg','.bmp','.ppm').
+
+
+
+ {{ form.train_image_folder.label }}
+ {{ form.train_image_folder.tooltip }}
+ {{ form.train_image_folder(class='form-control autocomplete_path', placeholder='folder')}}
+
+
+
+ Label files are expected to have the .txt extension.
+ For example if an image file is named foo.png the corresponding label file should be foo.txt.
+
+
+
+ {{ form.train_label_folder.label }}
+ {{ form.train_label_folder.tooltip }}
+ {{ form.train_label_folder(class='form-control autocomplete_path', placeholder='folder')}}
+
+
+
+ {{ form.val_image_folder.label }}
+ {{ form.val_image_folder.tooltip }}
+ {{ form.val_image_folder(class='form-control autocomplete_path', placeholder='folder')}}
+
+
+
+ {{ form.val_label_folder.label }}
+ {{ form.val_label_folder.tooltip }}
+ {{ form.val_label_folder(class='form-control autocomplete_path', placeholder='folder')}}
+
+
+
+ {{ form.padding_image_width.label }}
+ {{ form.padding_image_width.tooltip }}
+ {{ form.padding_image_width(class='form-control')}}
+
+
+
+ {{ form.padding_image_height.label }}
+ {{ form.padding_image_height.tooltip }}
+ {{ form.padding_image_height(class='form-control')}}
+
+
+
+ {{ form.resize_image_width.label }}
+ {{ form.resize_image_width.tooltip }}
+ {{ form.resize_image_width(class='form-control')}}
+
+
+
+ {{ form.resize_image_height.label }}
+ {{ form.resize_image_height.tooltip }}
+ {{ form.resize_image_height(class='form-control')}}
+
+
+
+ {{ form.channel_conversion.label }}
+ {{ form.channel_conversion.tooltip }}
+ {{ form.channel_conversion(class='form-control')}}
+
diff --git a/digits/extensions/data/objectDetection/utils.py b/digits/extensions/data/objectDetection/utils.py
new file mode 100644
index 000000000..45802048e
--- /dev/null
+++ b/digits/extensions/data/objectDetection/utils.py
@@ -0,0 +1,270 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+
+import csv
+import numpy as np
+import os
+
+class ObjectType:
+
+ Dontcare, Car, Van, Truck, Bus, Pickup, VehicleWithTrailer, SpecialVehicle,\
+ Person, Person_fa, Person_unsure, People, Cyclist, Tram, Person_Sitting,\
+ Misc = range(16)
+
+ def __init__(self):
+ pass
+
+class Bbox:
+
+ def __init__(self, x_left=0, y_top=0, x_right=0, y_bottom=0):
+ self.xl = x_left
+ self.yt = y_top
+ self.xr = x_right
+ self.yb = y_bottom
+
+ def area(self):
+ return (self.xr - self.xl) * (self.yb - self.yt)
+
+ def width(self):
+ return self.xr - self.xl
+
+ def height(self):
+ return self.yb - self.yt
+
+ def get_array(self):
+ return [self.xl, self.yt, self.xr, self.yb]
+
+class GroundTruthObj:
+
+ """ This class is the data ground-truth
+
+ #Values Name Description
+ ----------------------------------------------------------------------------
+ 1 type Describes the type of object: 'Car', 'Van', 'Truck',
+ 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
+ 'Misc' or 'DontCare'
+ 1 truncated Float from 0 (non-truncated) to 1 (truncated), where
+ truncated refers to the object leaving image boundaries.
+ -1 corresponds to a don't care region.
+ 1 occluded Integer (-1,0,1,2) indicating occlusion state:
+ -1 = unkown, 0 = fully visible,
+ 1 = partly occluded, 2 = largely occluded
+ 1 alpha Observation angle of object, ranging [-pi..pi]
+ 4 bbox 2D bounding box of object in the image (0-based index):
+ contains left, top, right, bottom pixel coordinates
+ 3 dimensions 3D object dimensions: height, width, length (in meters)
+ 3 location 3D object location x,y,z in camera coordinates (in meters)
+ 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi]
+ 1 score Only for results: Float, indicating confidence in
+ detection, needed for p/r curves, higher is better.
+
+ Here, 'DontCare' labels denote regions in which objects have not been labeled,
+ for example because they have been too far away from the laser scanner.
+ """
+
+
+ def __init__(self):
+ self.stype = ''
+ self.truncated = 0
+ self.occlusion = 0
+ self.angle = 0
+ self.height = 0
+ self.width = 0
+ self.length = 0
+ self.locx = 0
+ self.locy = 0
+ self.locz = 0
+ self.roty = 0
+ self.bbox = Bbox()
+ self.object = ObjectType.Dontcare
+
+ @classmethod
+ def lmdb_format_length(cls):
+ """
+ width of an LMDB datafield returned by the gt_to_lmdb_format function.
+ :return:
+ """
+ return 16
+
+ def gt_to_lmdb_format(self):
+ """
+ For storage of a bbox ground truth object into a float32 LMDB.
+ Sort-by attribute is always the last value in the array.
+ """
+ result = [
+ # bbox in x,y,w,h format:
+ self.bbox.xl,
+ self.bbox.yt,
+ self.bbox.xr - self.bbox.xl,
+ self.bbox.yb - self.bbox.yt,
+ # alpha angle:
+ self.angle,
+ # class number:
+ self.object,
+ 0,
+ # Y axis rotation:
+ self.roty,
+ # bounding box attributes:
+ self.truncated,
+ self.occlusion,
+ # object dimensions:
+ self.length,
+ self.width,
+ self.height,
+ self.locx,
+ self.locy,
+ # depth (sort-by attribute):
+ self.locz,
+ ]
+ assert(len(result) is self.lmdb_format_length())
+ return result
+
+ def set_type(self):
+ object_types = {
+ 'bus': ObjectType.Bus,
+ 'car': ObjectType.Car,
+ 'cyclist': ObjectType.Cyclist,
+ 'pedestrian': ObjectType.Person,
+ 'people': ObjectType.People,
+ 'person': ObjectType.Person,
+ 'person_sitting': ObjectType.Person_Sitting,
+ 'person-fa': ObjectType.Person_fa,
+ 'person?': ObjectType.Person_unsure,
+ 'pickup': ObjectType.Pickup,
+ 'misc': ObjectType.Misc,
+ 'special-vehicle': ObjectType.SpecialVehicle,
+ 'tram': ObjectType.Tram,
+ 'truck': ObjectType.Truck,
+ 'van': ObjectType.Van,
+ 'vehicle-with-trailer': ObjectType.VehicleWithTrailer
+ }
+ self.object = object_types.get(self.stype, ObjectType.Dontcare)
+
+
+class GroundTruth:
+
+ """ this class load the ground truth
+ """
+
+ def __init__(self, label_dir, label_ext='.txt', label_delimiter=' '):
+ self.label_dir = label_dir
+ self.label_ext = label_ext # extension of label files
+ self.label_delimiter = label_delimiter # space is used as delimiter in label files
+ self._objects_all = dict() # positive bboxes across images
+
+ def update_objects_all(self, _key, _bboxes):
+ if _bboxes:
+ self._objects_all[_key] = _bboxes
+ else:
+ self._objects_all[_key] = none
+
+ def load_gt_obj(self):
+
+ """ load bbox ground truth from files either via the provided label directory or list of label files"""
+ files = os.listdir(self.label_dir)
+ files = filter(lambda x: x.endswith(self.label_ext), files)
+ if len(files) == 0:
+ raise exception('error: no label files found in', self.label_dir)
+ for label_file in files:
+ objects_per_image = list()
+ with open( os.path.join(self.label_dir, label_file), 'rb') as flabel:
+ for row in csv.reader(flabel, delimiter=self.label_delimiter):
+
+ # load data
+ gt = GroundTruthObj()
+ gt.stype = row[0].lower()
+ gt.truncated = float(row[1])
+ gt.occlusion = int(row[2])
+ gt.angle = float(row[3])
+ gt.bbox.xl = float(row[4])
+ gt.bbox.yt = float(row[5])
+ gt.bbox.xr = float(row[6])
+ gt.bbox.yb = float(row[7])
+ gt.height = float(row[8])
+ gt.width = float(row[9])
+ gt.length = float(row[10])
+ gt.locx = float(row[11])
+ gt.locy = float(row[12])
+ gt.locz = float(row[13])
+ gt.roty = float(row[14])
+ gt.set_type()
+ objects_per_image.append(gt)
+ key = int(os.path.splitext(label_file)[0])
+ self.update_objects_all(key, objects_per_image)
+
+ @property
+ def objects_all(self):
+ return self._objects_all
+
+# return the # of pixels remaining in a
+
+def pad_bbox(arr, max_bboxes=64, bbox_width=16):
+ if arr.shape[0] > max_bboxes:
+ raise ValueError(
+ 'Too many bounding boxes (%d > %d)' % arr.shape[0], max_bboxes
+ )
+ # fill remainder with zeroes:
+ data = np.zeros((max_bboxes+1, bbox_width), dtype='float')
+ # number of bounding boxes:
+ data[0][0] = arr.shape[0]
+ # width of a bounding box:
+ data[0][1] = bbox_width
+ # bounding box data. Merge nothing if no bounding boxes exist.
+ if arr.shape[0] > 0:
+ data[1:1 + arr.shape[0]] = arr
+
+ return data
+
+
+def bbox_to_array(arr, label=0, max_bboxes=64, bbox_width=16):
+ """
+ Converts a 1-dimensional bbox array to an image-like
+ 3-dimensional array CHW array
+ """
+ arr = pad_bbox(arr, max_bboxes, bbox_width)
+ return arr[np.newaxis, :, :]
+
+
+def bbox_overlap(abox, bbox):
+ # the abox box
+ x11 = abox[0]
+ y11 = abox[1]
+ x12 = abox[0] + abox[2] - 1
+ y12 = abox[1] + abox[3] - 1
+
+ # the closer box
+ x21 = bbox[0]
+ y21 = bbox[1]
+ x22 = bbox[0] + bbox[2] - 1
+ y22 = bbox[1] + bbox[3] - 1
+
+ overlap_box_x2 = min(x12, x22)
+ overlap_box_x1 = max(x11, x21)
+ overlap_box_y2 = min(y12, y22)
+ overlap_box_y1 = max(y11, y21)
+
+ # make sure we preserve any non-bbox components
+ overlap_box = list(bbox)
+ overlap_box[0] = overlap_box_x1
+ overlap_box[1] = overlap_box_y1
+ overlap_box[2] = overlap_box_x2-overlap_box_x1+1
+ overlap_box[3] = overlap_box_y2-overlap_box_y1+1
+
+ xoverlap = max(0, overlap_box_x2 - overlap_box_x1)
+ yoverlap = max(0, overlap_box_y2 - overlap_box_y1)
+ overlap_pix = xoverlap * yoverlap
+
+ return overlap_pix, overlap_box
+
+def resize_bbox_list(bboxlist, rescale_x=1, rescale_y=1):
+ # this is expecting x1,y1,w,h:
+ bboxListNew = []
+ for bbox in bboxlist:
+ abox = bbox
+ abox[0] *= rescale_x
+ abox[1] *= rescale_y
+ abox[2] *= rescale_x
+ abox[3] *= rescale_y
+ bboxListNew.append(abox)
+ return bboxListNew
+
+
diff --git a/digits/extensions/view/__init__.py b/digits/extensions/view/__init__.py
index a5ebb3573..82aa38b2b 100644
--- a/digits/extensions/view/__init__.py
+++ b/digits/extensions/view/__init__.py
@@ -1,10 +1,12 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
+from . import boundingBox
from . import rawData
view_extensions = [
# set show=True if extension should be listed in known extensions
+ {'class': boundingBox.Visualization, 'show': True},
{'class': rawData.Visualization, 'show': True},
]
diff --git a/digits/extensions/view/boundingBox/__init__.py b/digits/extensions/view/boundingBox/__init__.py
new file mode 100644
index 000000000..c0b2b8cf3
--- /dev/null
+++ b/digits/extensions/view/boundingBox/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+from .view import Visualization
diff --git a/digits/extensions/view/boundingBox/config_template.html b/digits/extensions/view/boundingBox/config_template.html
new file mode 100644
index 000000000..1712fb114
--- /dev/null
+++ b/digits/extensions/view/boundingBox/config_template.html
@@ -0,0 +1,22 @@
+{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}
+
+{% from "helper.html" import print_flashes %}
+{% from "helper.html" import print_errors %}
+{% from "helper.html" import mark_errors %}
+
+Draw a bounding box around a detected object. This expected network output is a nested list of list of box coordinates
+
+
+ {{ form.box_color.label }}
+ {{ form.box_color.tooltip }}
+ {{ form.box_color(class='form-control') }}
+
+
+
+ {{ form.line_width.label }}
+ {{ form.line_width.tooltip }}
+ {{ form.line_width(class='form-control') }}
+
+
+
+
diff --git a/digits/extensions/view/boundingBox/forms.py b/digits/extensions/view/boundingBox/forms.py
new file mode 100644
index 000000000..69314dbfb
--- /dev/null
+++ b/digits/extensions/view/boundingBox/forms.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+from digits import utils
+from digits.utils import subclass
+from flask.ext.wtf import Form
+import wtforms
+from wtforms import validators
+
+
+@subclass
+class ConfigForm(Form):
+ """
+ A form used to configure the drawing of bounding boxes
+ """
+
+ box_color = wtforms.SelectField(
+ 'Box color',
+ choices=[
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+ ],
+ default='red',
+ )
+
+ line_width = utils.forms.IntegerField(
+ 'Line width',
+ validators=[
+ validators.DataRequired(),
+ validators.NumberRange(min=1),
+ ],
+ default=2,
+ )
diff --git a/digits/extensions/view/boundingBox/summary_template.html b/digits/extensions/view/boundingBox/summary_template.html
new file mode 100644
index 000000000..2d88d9cb1
--- /dev/null
+++ b/digits/extensions/view/boundingBox/summary_template.html
@@ -0,0 +1,3 @@
+{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}
+
+Found {{ bbox_count }} bounding box(es) in {{ image_count }} image(s).
diff --git a/digits/extensions/view/boundingBox/view.py b/digits/extensions/view/boundingBox/view.py
new file mode 100644
index 000000000..c4521a311
--- /dev/null
+++ b/digits/extensions/view/boundingBox/view.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
+from __future__ import absolute_import
+
+import os
+import PIL.Image
+
+import digits
+from digits.utils import subclass, override
+from .forms import ConfigForm
+from ..interface import VisualizationInterface
+
+CONFIG_TEMPLATE = "config_template.html"
+SUMMARY_TEMPLATE = "summary_template.html"
+VIEW_TEMPLATE = "view_template.html"
+
+
+@subclass
+class Visualization(VisualizationInterface):
+ """
+ A visualization extension to display bounding boxes
+ """
+
+ def __init__(self, dataset, **kwargs):
+ # bounding box options
+ color = kwargs['box_color']
+ if color == "red":
+ self.color = (255, 0, 0)
+ elif color == "green":
+ self.color = (0, 255, 0)
+ elif color == "blue":
+ self.color = (0, 0, 255)
+ else:
+ raise ValueError("unknown color: %s" % color)
+ self.line_width = int(kwargs['line_width'])
+
+ # memorize view template for later use
+ extension_dir = os.path.dirname(os.path.abspath(__file__))
+ self.view_template = open(
+ os.path.join(extension_dir, VIEW_TEMPLATE), "r").read()
+
+ # stats
+ self.image_count = 0
+ self.bbox_count = 0
+
+ @staticmethod
+ def get_config_form():
+ return ConfigForm()
+
+ @staticmethod
+ def get_config_template(form):
+ """
+ parameters:
+ - form: form returned by get_config_form(). This may be populated
+ with values if the job was cloned
+ returns:
+ - (template, context) tuple
+ - template is a Jinja template to use for rendering config options
+ - context is a dictionary of context variables to use for rendering
+ the form
+ """
+ extension_dir = os.path.dirname(os.path.abspath(__file__))
+ template = open(
+ os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read()
+ context = {'form': form}
+ return (template, context)
+
+ @staticmethod
+ def get_id():
+ return 'image-bounding-boxes'
+
+ @override
+ def get_summary_template(self):
+ """
+ This returns a summary of the job. This method is called after all
+ entries have been processed.
+ returns:
+ - (template, context) tuple
+ - template is a Jinja template to use for rendering the summary, or
+ None if there is no summary to display
+ - context is a dictionary of context variables to use for rendering
+ the form
+ """
+ extension_dir = os.path.dirname(os.path.abspath(__file__))
+ template = open(
+ os.path.join(extension_dir, SUMMARY_TEMPLATE), "r").read()
+ return template, {'image_count': self.image_count, 'bbox_count': self.bbox_count}
+
+ @staticmethod
+ def get_title():
+ return 'Bounding boxes'
+
+ @override
+ def get_view_template(self, data):
+ """
+ return:
+ - (template, context) tuple
+ - template is a Jinja template to use for rendering config options
+ - context is a dictionary of context variables to use for rendering
+ the form
+ """
+ return self.view_template, {'image': data['image']}
+
+ @override
+ def process_data(
+ self,
+ dataset,
+ input_data,
+ inference_data,
+ ground_truth=None):
+ """
+ Process one inference output
+ Parameters:
+ - dataset: dataset used during training
+ - input_data: input to the network
+ - inference_data: network output
+ - ground_truth: Ground truth. Format is application specific.
+ None if absent.
+ Returns:
+ - an object reprensenting the processed data
+ """
+ # get source image
+ image = PIL.Image.fromarray(input_data).convert('RGB')
+
+ self.image_count += 1
+
+ # create arrays in expected format
+ bboxes = []
+ outputs = inference_data[inference_data.keys()[0]]
+ for output in outputs:
+ # last number is confidence
+ if output[-1] > 0:
+ box = ((output[0], output[1]), (output[2], output[3]))
+ bboxes.append(box)
+ self.bbox_count += 1
+ digits.utils.image.add_bboxes_to_image(
+ image,
+ bboxes,
+ self.color,
+ self.line_width)
+ image_html = digits.utils.image.embed_image_html(image)
+ return {'image': image_html}
diff --git a/digits/extensions/view/boundingBox/view_template.html b/digits/extensions/view/boundingBox/view_template.html
new file mode 100644
index 000000000..483fe8bb0
--- /dev/null
+++ b/digits/extensions/view/boundingBox/view_template.html
@@ -0,0 +1,3 @@
+{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}
+
+
diff --git a/digits/utils/image.py b/digits/utils/image.py
index 38d0afa17..7e4b97d50 100644
--- a/digits/utils/image.py
+++ b/digits/utils/image.py
@@ -289,6 +289,42 @@ def embed_image_html(image):
data = string_buf.getvalue().encode('base64').replace('\n', '')
return 'data:image/%s;base64,%s' % (fmt, data)
+def add_bboxes_to_image(image, bboxes, color='red', width=1):
+ """
+ Draw rectangles on the image for the bounding boxes
+ Returns a PIL.Image
+
+ Arguments:
+ image -- input image
+ bboxes -- bounding boxes in the [((l, t), (r, b)), ...] format
+
+ Keyword arguments:
+ color -- color to draw the rectangles
+ width -- line width of the rectangles
+
+ Example:
+ image = Image.open(filename)
+ add_bboxes_to_image(image, bboxes[filename], width=2, color='#FF7700')
+ image.show()
+ """
+ def expanded_bbox(bbox, n):
+ """
+ Grow the bounding box by n pixels
+ """
+ l = min(bbox[0][0], bbox[1][0])
+ r = max(bbox[0][0], bbox[1][0])
+ t = min(bbox[0][1], bbox[1][1])
+ b = max(bbox[0][1], bbox[1][1])
+ return ((l - n, t - n), (r + n, b + n))
+
+ from PIL import Image, ImageDraw
+ draw = ImageDraw.Draw(image)
+ for bbox in bboxes:
+ for n in xrange(width):
+ draw.rectangle(expanded_bbox(bbox, n), outline = color)
+
+ return image
+
def get_layer_vis_square(data,
allow_heatmap = True,
normalize = True,
diff --git a/digits/utils/test_image.py b/digits/utils/test_image.py
index dcff77c7b..46a103b8d 100644
--- a/digits/utils/test_image.py
+++ b/digits/utils/test_image.py
@@ -189,3 +189,12 @@ def args_to_str(self, args):
shape=%s""" % args
+class TestBBoxes():
+
+ def test_add_bboxes(self):
+ np_color = np.random.randint(0, 100, (10,10,3)).astype('uint8')
+ pil_color = PIL.Image.fromarray(np_color)
+ pil_color = image_utils.add_bboxes_to_image(pil_color, [((4, 4), (7, 7))], color='red')
+ pixelMap = pil_color.load()
+ assert pixelMap[4, 4] == (255, 0, 0)
+ assert pixelMap[7, 7] == (255, 0, 0)