Skip to content

Commit

Permalink
Merge pull request #1 from Austriker/python-3-support
Browse files Browse the repository at this point in the history
Python 3 support
  • Loading branch information
Austriker committed May 18, 2016
2 parents 96dc9f1 + ff4172b commit 6b3c6a3
Show file tree
Hide file tree
Showing 37 changed files with 1,377 additions and 704 deletions.
7 changes: 3 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
[submodule "caffe-fast-rcnn"]
path = caffe-fast-rcnn
url = https://github.com/rbgirshick/caffe-fast-rcnn.git
branch = fast-rcnn
[submodule "caffe"]
path = caffe
url = https://github.com/Austriker/caffe.git
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# *Faster* R-CNN Fork

#### Warning
This Fork is still a work in progress

This fork :
- Merged the [caffe-fast-rcnn](https://github.com/rbgirshick/caffe-fast-rcnn/tree/0dcd397b29507b8314e252e850518c5695efbb83) fork to the current Caffe version. The [PR](https://github.com/BVLC/caffe/pull/4163) is waiting to be Merged
- Add support to Python 3.4

### Disclaimer

The official Faster R-CNN code (written in MATLAB) is available [here](https://github.com/ShaoqingRen/faster_rcnn).
Expand Down Expand Up @@ -74,7 +83,7 @@ If you find Faster R-CNN useful in your research, please consider citing:
1. Clone the Faster R-CNN repository
```Shell
# Make sure to clone with --recursive
git clone --recursive https://github.com/rbgirshick/py-faster-rcnn.git
git clone --recursive https://github.com/Austriker/py-faster-rcnn.git
```

2. We'll call the directory that you cloned Faster R-CNN into `FRCN_ROOT`
Expand All @@ -85,7 +94,7 @@ If you find Faster R-CNN useful in your research, please consider citing:
```Shell
git submodule update --init --recursive
```
**Note 2:** The `caffe-fast-rcnn` submodule needs to be on the `faster-rcnn` branch (or equivalent detached state). This will happen automatically *if you followed step 1 instructions*.
**Note 2:** The `caffe` submodule needs to be on the `fast-rcnn` branch (or equivalent detached state). This will happen automatically *if you followed step 1 instructions*.

3. Build the Cython modules
```Shell
Expand All @@ -95,7 +104,7 @@ If you find Faster R-CNN useful in your research, please consider citing:

4. Build Caffe and pycaffe
```Shell
cd $FRCN_ROOT/caffe-fast-rcnn
cd $FRCN_ROOT/caffe
# Now follow the Caffe installation instructions here:
# http://caffe.berkeleyvision.org/installation.html
Expand Down
1 change: 1 addition & 0 deletions caffe
Submodule caffe added at c5f996
1 change: 0 additions & 1 deletion caffe-fast-rcnn
Submodule caffe-fast-rcnn deleted from 0dcd39
142 changes: 99 additions & 43 deletions lib/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as COCOmask


def _filter_crowd_proposals(roidb, crowd_thresh):
"""
Finds proposals that are inside crowd regions and marks them with
Expand All @@ -31,26 +32,31 @@ def _filter_crowd_proposals(roidb, crowd_thresh):
overlaps = entry['gt_overlaps'].toarray()
crowd_inds = np.where(overlaps.max(axis=1) == -1)[0]
non_gt_inds = np.where(entry['gt_classes'] == 0)[0]

if len(crowd_inds) == 0 or len(non_gt_inds) == 0:
continue

iscrowd = [int(True) for _ in xrange(len(crowd_inds))]
crowd_boxes = ds_utils.xyxy_to_xywh(entry['boxes'][crowd_inds, :])
non_gt_boxes = ds_utils.xyxy_to_xywh(entry['boxes'][non_gt_inds, :])
ious = COCOmask.iou(non_gt_boxes, crowd_boxes, iscrowd)
bad_inds = np.where(ious.max(axis=1) > crowd_thresh)[0]
overlaps[non_gt_inds[bad_inds], :] = -1
roidb[ix]['gt_overlaps'] = scipy.sparse.csr_matrix(overlaps)

return roidb


class coco(imdb):

def __init__(self, image_set, year):
imdb.__init__(self, 'coco_' + year + '_' + image_set)
# COCO specific config options
self.config = {'top_k' : 2000,
'use_salt' : True,
'cleanup' : True,
'crowd_thresh' : 0.7,
'min_size' : 2}
self.config = {'top_k': 2000,
'use_salt': True,
'cleanup': True,
'crowd_thresh': 0.7,
'min_size': 2}
# name, paths
self._year = year
self._image_set = image_set
Expand All @@ -71,20 +77,23 @@ def __init__(self, image_set, year):
# For example, minival2014 is a random 5000 image subset of val2014.
# This mapping tells us where the view's images and proposals come from.
self._view_map = {
'minival2014' : 'val2014', # 5k val2014 subset
'valminusminival2014' : 'val2014', # val2014 \setminus minival2014
'minival2014': 'val2014', # 5k val2014 subset
'valminusminival2014': 'val2014', # val2014 \setminus minival2014
}
coco_name = image_set + year # e.g., "val2014"
self._data_name = (self._view_map[coco_name]
if self._view_map.has_key(coco_name)
if coco_name in self._view_map
else coco_name)

# Dataset splits that have ground-truth annotations (test splits
# do not have gt annotations)
self._gt_splits = ('train', 'val', 'minival')

def _get_ann_file(self):
prefix = 'instances' if self._image_set.find('test') == -1 \
else 'image_info'
prefix = 'instances'
if self._image_set.find('test') == -1:
prefix = 'image_info'

return osp.join(self._data_path, 'annotations',
prefix + '_' + self._image_set + self._year + '.json')

Expand All @@ -98,6 +107,7 @@ def _load_image_set_index(self):
def _get_widths(self):
anns = self._COCO.loadImgs(self._image_index)
widths = [ann['width'] for ann in anns]

return widths

def image_path_at(self, i):
Expand All @@ -116,8 +126,11 @@ def image_path_from_index(self, index):
str(index).zfill(12) + '.jpg')
image_path = osp.join(self._data_path, 'images',
self._data_name, file_name)
assert osp.exists(image_path), \
'Path does not exist: {}'.format(image_path)
assert(
osp.exists(image_path),
'Path does not exist: {}'.format(image_path)
)

return image_path

def selective_search_roidb(self):
Expand All @@ -141,21 +154,30 @@ def _roidb_from_proposals(self, method):
if osp.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{:s} {:s} roidb loaded from {:s}'.format(self.name, method,
cache_file)

print('{:s} {:s} roidb loaded from {:s}'.format(
self.name,
method,
cache_file
)
)

return roidb

if self._image_set in self._gt_splits:
gt_roidb = self.gt_roidb()
method_roidb = self._load_proposals(method, gt_roidb)
roidb = imdb.merge_roidbs(gt_roidb, method_roidb)

# Make sure we don't use proposals that are contained in crowds
roidb = _filter_crowd_proposals(roidb, self.config['crowd_thresh'])
else:
roidb = self._load_proposals(method, None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote {:s} roidb to {:s}'.format(method, cache_file)

print('wrote {:s} roidb to {:s}'.format(method, cache_file))

return roidb

def _load_proposals(self, method, gt_roidb):
Expand All @@ -175,22 +197,25 @@ def _load_proposals(self, method, gt_roidb):
'selective_search',
'edge_boxes_AR',
'edge_boxes_70']
assert method in valid_methods
assert(method in valid_methods)

print('Loading {} boxes'.format(method))

print 'Loading {} boxes'.format(method)
for i, index in enumerate(self._image_index):
if i % 1000 == 0:
print '{:d} / {:d}'.format(i + 1, len(self._image_index))
print('{:d} / {:d}'.format(i + 1, len(self._image_index)))

box_file = osp.join(
cfg.DATA_DIR, 'coco_proposals', method, 'mat',
self._get_box_file(index))

raw_data = sio.loadmat(box_file)['boxes']
boxes = np.maximum(raw_data - 1, 0).astype(np.uint16)

if method == 'MCG':
# Boxes from the MCG website are in (y1, x1, y2, x2) order
boxes = boxes[:, (1, 0, 3, 2)]

# Remove duplicate boxes and very small boxes and then take top k
keep = ds_utils.unique_boxes(boxes)
boxes = boxes[keep, :]
Expand All @@ -203,6 +228,7 @@ def _load_proposals(self, method, gt_roidb):
width = im_ann['width']
height = im_ann['height']
ds_utils.validate_boxes(boxes, width=width, height=height)

return self.create_roidb_from_box_list(box_list, gt_roidb)

def gt_roidb(self):
Expand All @@ -214,15 +240,19 @@ def gt_roidb(self):
if osp.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)

print('{} gt roidb loaded from {}'.format(self.name, cache_file))

return roidb

gt_roidb = [self._load_coco_annotation(index)
for index in self._image_index]

with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)

print('wrote gt roidb to {}'.format(cache_file))

return gt_roidb

def _load_coco_annotation(self, index):
Expand Down Expand Up @@ -266,6 +296,7 @@ def _load_coco_annotation(self, index):
boxes[ix, :] = obj['clean_bbox']
gt_classes[ix] = cls
seg_areas[ix] = obj['area']

if obj['iscrowd']:
# Set overlap to -1 for all classes for crowd objects
# so they will be excluded during training
Expand All @@ -275,49 +306,60 @@ def _load_coco_annotation(self, index):

ds_utils.validate_boxes(boxes, width=width, height=height)
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes' : boxes,

return {'boxes': boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False,
'seg_areas' : seg_areas}
'gt_overlaps': overlaps,
'flipped': False,
'seg_areas': seg_areas}

def _get_box_file(self, index):
# first 14 chars / first 22 chars / all chars + .mat
# COCO_val2014_0/COCO_val2014_000000447/COCO_val2014_000000447991.mat
file_name = ('COCO_' + self._data_name +
'_' + str(index).zfill(12) + '.mat')

return osp.join(file_name[:14], file_name[:22], file_name)

def _print_detection_eval_metrics(self, coco_eval):
IoU_lo_thresh = 0.5
IoU_hi_thresh = 0.95

def _get_thr_ind(coco_eval, thr):
ind = np.where((coco_eval.params.iouThrs > thr - 1e-5) &
(coco_eval.params.iouThrs < thr + 1e-5))[0][0]
iou_thr = coco_eval.params.iouThrs[ind]
assert np.isclose(iou_thr, thr)
assert(np.isclose(iou_thr, thr))

return ind

ind_lo = _get_thr_ind(coco_eval, IoU_lo_thresh)
ind_hi = _get_thr_ind(coco_eval, IoU_hi_thresh)
# precision has dims (iou, recall, cls, area range, max dets)
# area range index 0: all area ranges
# max dets index 2: 100 per image
precision = \
coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2]
precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2]
ap_default = np.mean(precision[precision > -1])
print ('~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] '
'~~~~').format(IoU_lo_thresh, IoU_hi_thresh)
print '{:.1f}'.format(100 * ap_default)

print(('~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] '
'~~~~').format(IoU_lo_thresh, IoU_hi_thresh))
print('{:.1f}'.format(100 * ap_default))

for cls_ind, cls in enumerate(self.classes):

if cls == '__background__':
continue
# minus 1 because of __background__
precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2]

precision = coco_eval.eval['precision'][
ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2
]

ap = np.mean(precision[precision > -1])
print '{:.1f}'.format(100 * ap)

print '~~~~ Summary metrics ~~~~'
print('{:.1f}'.format(100 * ap))

print('~~~~ Summary metrics ~~~~')
coco_eval.summarize()

def _do_detection_eval(self, res_file, output_dir):
Expand All @@ -329,9 +371,11 @@ def _do_detection_eval(self, res_file, output_dir):
coco_eval.accumulate()
self._print_detection_eval_metrics(coco_eval)
eval_file = osp.join(output_dir, 'detection_results.pkl')

with open(eval_file, 'wb') as fid:
cPickle.dump(coco_eval, fid, cPickle.HIGHEST_PROTOCOL)
print 'Wrote COCO eval results to: {}'.format(eval_file)

print('Wrote COCO eval results to: {}'.format(eval_file))

def _coco_results_one_category(self, boxes, cat_id):
results = []
Expand All @@ -345,10 +389,11 @@ def _coco_results_one_category(self, boxes, cat_id):
ws = dets[:, 2] - xs + 1
hs = dets[:, 3] - ys + 1
results.extend(
[{'image_id' : index,
'category_id' : cat_id,
'bbox' : [xs[k], ys[k], ws[k], hs[k]],
'score' : scores[k]} for k in xrange(dets.shape[0])])
[{'image_id': index,
'category_id': cat_id,
'bbox': [xs[k], ys[k], ws[k], hs[k]],
'score': scores[k]} for k in xrange(dets.shape[0])])

return results

def _write_coco_results_file(self, all_boxes, res_file):
Expand All @@ -358,14 +403,21 @@ def _write_coco_results_file(self, all_boxes, res_file):
# "score": 0.236}, ...]
results = []
for cls_ind, cls in enumerate(self.classes):

if cls == '__background__':
continue
print 'Collecting {} results ({:d}/{:d})'.format(cls, cls_ind,
self.num_classes - 1)

print('Collecting {} results ({:d}/{:d})'.format(
cls, cls_ind, self.num_classes - 1)
)

coco_cat_id = self._class_to_coco_cat_id[cls]
results.extend(self._coco_results_one_category(all_boxes[cls_ind],
coco_cat_id))
print 'Writing results json to {}'.format(res_file)
results.extend(self._coco_results_one_category(
all_boxes[cls_ind], coco_cat_id)
)

print('Writing results json to {}'.format(res_file))

with open(res_file, 'w') as fid:
json.dump(results, fid)

Expand All @@ -374,13 +426,17 @@ def evaluate_detections(self, all_boxes, output_dir):
self._image_set +
self._year +
'_results'))

if self.config['use_salt']:
res_file += '_{}'.format(str(uuid.uuid4()))

res_file += '.json'
self._write_coco_results_file(all_boxes, res_file)

# Only do evaluation on non-test sets
if self._image_set.find('test') == -1:
self._do_detection_eval(res_file, output_dir)

# Optionally cleanup results json file
if self.config['cleanup']:
os.remove(res_file)
Expand Down
Loading

0 comments on commit 6b3c6a3

Please sign in to comment.