Skip to content

Commit

Permalink
Merge pull request NVIDIA#735 from gheinrich/object-detection-workflow
Browse files Browse the repository at this point in the history
Object detection workflow
  • Loading branch information
lukeyeager committed May 25, 2016
2 parents d5483f0 + 2613082 commit 234f5c4
Show file tree
Hide file tree
Showing 15 changed files with 952 additions and 0 deletions.
2 changes: 2 additions & 0 deletions digits/extensions/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import absolute_import

from . import imageGradients
from . import objectDetection

data_extensions = [
# set show=True if extension should be listed in known extensions
{'class': imageGradients.DataIngestion, 'show': False},
{'class': objectDetection.DataIngestion, 'show': True},
]


Expand Down
4 changes: 4 additions & 0 deletions digits/extensions/data/objectDetection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .data import DataIngestion
241 changes: 241 additions & 0 deletions digits/extensions/data/objectDetection/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import numpy as np
import operator
import os
import PIL.Image
import random

import digits
from digits.utils import subclass, override, constants
from ..interface import DataIngestionInterface
from .forms import DatasetForm
from .utils import GroundTruth, GroundTruthObj
from .utils import bbox_to_array, resize_bbox_list

TEMPLATE = "template.html"


@subclass
class DataIngestion(DataIngestionInterface):
"""
A data ingestion extension for an object detection dataset
"""

def __init__(self, **kwargs):
super(DataIngestion, self).__init__(**kwargs)

# this instance is automatically populated with form field
# attributes by superclass constructor

if ((self.val_image_folder == '') ^ (self.val_label_folder == '')):
raise ValueError("You must specify either both val_image_folder and val_label_folder or none")

if ((self.resize_image_width is None) ^ (self.resize_image_height is None)):
raise ValueError("You must specify either both resize_image_width and resize_image_height or none")

# this will be set when we know the phase we are encoding
self.ground_truth = None

@override
def encode_entry(self, entry):
"""
Return numpy.ndarray
"""
image_filename = entry

# (1) image part

# load from file
img = digits.utils.image.load_image(image_filename)
if self.channel_conversion != 'none':
if img.mode != self.channel_conversion:
# convert to different image mode if necessary
img = img.convert(self.channel_conversion)

# pad
img = self.pad_image(img)

if self.resize_image_width is not None:
# resize
resize_ratio_x = self.resize_image_width / self.padding_image_width
resize_ratio_y = self.resize_image_height / self.padding_image_height
img = digits.utils.image.resize_image(
img,
self.resize_image_height,
self.resize_image_width)
else:
resize_ratio_x = 1
resize_ratio_y = 1
# convert to numpy array
img = np.array(img)

if img.ndim == 2:
# grayscale
img = img[np.newaxis, :, :]
if img.dtype == 'uint16':
img = img.astype(float)
else:
if img.ndim != 3 or img.shape[2] != 3:
raise ValueError("Unsupported image shape: %s" % repr(img.shape))
# HWC -> CHW
img = img.transpose(2, 0, 1)

# (2) label part

# make sure label exists
try:
label_id = int(os.path.splitext(os.path.basename(entry))[0])
except:
raise ValueError("Unable to extract numerical id from file name %s" % entry)

if not label_id in self.datasrc_annotation_dict:
raise ValueError("Label key %s not found in label folder" % label_id)
annotations = self.datasrc_annotation_dict[label_id]

# collect bbox list into bboxList
bboxList = []

for bbox in annotations:
# retrieve all vars defining groundtruth, and interpret all
# serialized values as float32s:
np_bbox = np.array(bbox.gt_to_lmdb_format())
bboxList.append(np_bbox)

bboxList = sorted(
bboxList,
key=operator.itemgetter(GroundTruthObj.lmdb_format_length()-1)
)

bboxList.reverse()

# adjust bboxes according to image cropping
bboxList = resize_bbox_list(bboxList, resize_ratio_x, resize_ratio_y)

# return data
feature = img
label = np.asarray(bboxList)

# LMDB compaction: now label (aka bbox) is the joint array
label = bbox_to_array(
label,
0,
max_bboxes=self.max_bboxes,
bbox_width=GroundTruthObj.lmdb_format_length())


return feature, label

@staticmethod
@override
def get_category():
return "Images"

@staticmethod
@override
def get_id():
return "image-object-detection"

@staticmethod
@override
def get_dataset_form():
return DatasetForm()

@staticmethod
@override
def get_dataset_template(form):
"""
parameters:
- form: form returned by get_dataset_form(). This may be populated with values if the job was cloned
return:
- (template, context) tuple
template is a Jinja template to use for rendering dataset creation options
context is a dictionary of context variables to use for rendering the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@staticmethod
@override
def get_title():
return "Object Detection"

@override
def itemize_entries(self, stage):
"""
return list of image file names to encode for specified stage
"""
if stage == constants.TEST_DB:
# don't retun anything for the test stage
return []
elif stage == constants.TRAIN_DB:
# load ground truth
self.load_ground_truth(self.train_label_folder)
# get training image file names
return self.make_image_list(self.train_image_folder)
elif stage == constants.VAL_DB:
if self.val_image_folder != '':
# load ground truth
self.load_ground_truth(self.val_label_folder)
# get validation image file names
return self.make_image_list(self.val_image_folder)
else:
# no validation folder was specified
return []
else:
raise ValueError("Unknown stage: %s" % stage)

def load_ground_truth(self, folder):
"""
load ground truth from specified folder
"""
datasrc = GroundTruth(folder)
datasrc.load_gt_obj()
self.datasrc_annotation_dict = datasrc.objects_all

scene_files = []
for key in self.datasrc_annotation_dict:
scene_files.append(key)

# determine largest label height:
self.max_bboxes = max([len(annotation) for annotation in self.datasrc_annotation_dict.values()])

def make_image_list(self, folder):
"""
find all supported images within specified folder and return list of file names
"""
image_files = []
for dirpath, dirnames, filenames in os.walk(folder, followlinks=True):
for filename in filenames:
if filename.lower().endswith(digits.utils.image.SUPPORTED_EXTENSIONS):
image_files.append('%s' % os.path.join(folder, filename))
if len(image_files) == 0:
raise ValueError("Unable to find supported images in %s" % folder)
# shuffle
random.shuffle(image_files)
return image_files

def pad_image(self, img):
"""
pad a single image to the dimensions specified in form
"""
src_width = img.size[0]
src_height = img.size[1]

if self.padding_image_width < src_width:
raise ValueError("Source image width %d is greater than padding width %d" % (src_width, self.padding_image_width))

if self.padding_image_height < src_height:
raise ValueError("Source image height %d is greater than padding height %d" % (src_height, self.padding_image_height))

padded_img = PIL.Image.new(
img.mode,
(self.padding_image_width, self.padding_image_height),
"black")

padded_img.paste(img, (0, 0)) # copy to top-left corner

return padded_img
111 changes: 111 additions & 0 deletions digits/extensions/data/objectDetection/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from flask.ext.wtf import Form
import os
from wtforms import validators

from digits import utils
from digits.utils import subclass


@subclass
class DatasetForm(Form):
"""
A form used to create an image processing dataset
"""

def validate_folder_path(form, field):
if not field.data:
pass
else:
# make sure the filesystem path exists
if not os.path.exists(field.data) or not os.path.isdir(field.data):
raise validators.ValidationError('Folder does not exist or is not reachable')
else:
return True

train_image_folder = utils.forms.StringField(
u'Training image folder',
validators=[
validators.DataRequired(),
validate_folder_path,
],
tooltip="Indicate a folder of images to use for training"
)

train_label_folder = utils.forms.StringField(
u'Training label folder',
validators=[
validators.DataRequired(),
validate_folder_path,
],
tooltip="Indicate a folder of training labels"
)

val_image_folder = utils.forms.StringField(
u'Validation image folder',
validators=[
validators.Optional(),
validate_folder_path,
],
tooltip="Indicate a folder of images to use for training"
)

val_label_folder = utils.forms.StringField(
u'Validation label folder',
validators=[
validators.Optional(),
validate_folder_path,
],
tooltip="Indicate a folder of validation labels"
)

resize_image_width = utils.forms.IntegerField(
u'Resize Image Width',
validators=[
validators.Optional(),
validators.NumberRange(min=1),
],
tooltip="If specified, images will be resized to that dimension after padding"
)

resize_image_height = utils.forms.IntegerField(
u'Resize Image Height',
validators=[
validators.Optional(),
validators.NumberRange(min=1),
],
tooltip="If specified, images will be resized to that dimension after padding"
)

padding_image_width = utils.forms.IntegerField(
u'Padding Image Width',
default=1248,
validators=[
validators.DataRequired(),
validators.NumberRange(min=1),
],
tooltip="Images will be padded to that dimension"
)

padding_image_height = utils.forms.IntegerField(
u'Padding Image Height',
default=384,
validators=[
validators.DataRequired(),
validators.NumberRange(min=1),
],
tooltip="Images will be padded to that dimension"
)

channel_conversion = utils.forms.SelectField(
'Channel conversion',
choices=[
('RGB', 'RGB'),
('L', 'Grayscale'),
('none', 'None'),
],
default='RGB',
tooltip="Perform selected channel conversion."
)
Loading

0 comments on commit 234f5c4

Please sign in to comment.