-
Notifications
You must be signed in to change notification settings - Fork 0
/
kangaroo_dataset.py
132 lines (117 loc) · 3.67 KB
/
kangaroo_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# split into train and test set
from os import listdir
from xml.etree import ElementTree
from numpy import zeros
from numpy import asarray
from mrcnn.utils import Dataset
from matplotlib import pyplot
from mrcnn.visualize import display_instances
from mrcnn.utils import extract_bboxes
# class that defines and loads the kangaroo dataset
class KangarooDataset(Dataset):
# load the dataset definitions
def load_dataset(self, dataset_dir, is_train=True):
# define one class
self.add_class("dataset", 1, "kangaroo")
# define data locations
images_dir = dataset_dir + '/images/'
annotations_dir = dataset_dir + '/annots/'
# find all images
for filename in listdir(images_dir):
# extract image id
image_id = filename[:-4]
# skip bad images
if image_id in ['00090']:
continue
# skip all images after 150 if we are building the train set
if is_train and int(image_id) >= 150:
continue
# skip all images before 150 if we are building the test/val set
if not is_train and int(image_id) < 150:
continue
img_path = images_dir + filename
ann_path = annotations_dir + image_id + '.xml'
# add to dataset
self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)
# extract bounding boxes from an annotation file
def extract_boxes(self, filename):
# load and parse the file
tree = ElementTree.parse(filename)
# get the root of the document
root = tree.getroot()
# extract each bounding box
boxes = list()
for box in root.findall('.//bndbox'):
xmin = int(box.find('xmin').text)
ymin = int(box.find('ymin').text)
xmax = int(box.find('xmax').text)
ymax = int(box.find('ymax').text)
coors = [xmin, ymin, xmax, ymax]
boxes.append(coors)
# extract image dimensions
width = int(root.find('.//size/width').text)
height = int(root.find('.//size/height').text)
return boxes, width, height
# load the masks for an image
def load_mask(self, image_id):
# get details of image
info = self.image_info[image_id]
# define box file location
path = info['annotation']
# load XML
boxes, w, h = self.extract_boxes(path)
# create one array for all masks, each on a different channel
masks = zeros([h, w, len(boxes)], dtype='uint8')
# create masks
class_ids = list()
for i in range(len(boxes)):
box = boxes[i]
row_s, row_e = box[1], box[3]
col_s, col_e = box[0], box[2]
masks[row_s:row_e, col_s:col_e, i] = 1
class_ids.append(self.class_names.index('kangaroo'))
return masks, asarray(class_ids, dtype='int32')
# load an image reference
def image_reference(self, image_id):
info = self.image_info[image_id]
return info['path']
# train set
train_set = KangarooDataset()
train_set.load_dataset('kangaroo', is_train=True)
train_set.prepare()
print('Train: %d' % len(train_set.image_ids))
# test/val set
test_set = KangarooDataset()
test_set.load_dataset('kangaroo', is_train=False)
test_set.prepare()
print('Test: %d' % len(test_set.image_ids))
# train set
train_set = KangarooDataset()
train_set.load_dataset('kangaroo', is_train=True)
train_set.prepare()
# load an image
image_id = 0
image = train_set.load_image(image_id)
print(image.shape)
# load image mask
mask, class_ids = train_set.load_mask(image_id)
print(mask.shape)
# plot image
pyplot.imshow(image)
# plot mask
pyplot.imshow(mask[:, :, 0], cmap='gray', alpha=0.5)
pyplot.show()
def show_many():
# plot first few images
for i in range(9):
# define subplot
pyplot.subplot(330 + 1 + i)
# plot raw pixel data
image = train_set.load_image(i)
pyplot.imshow(image)
# plot all masks
mask, _ = train_set.load_mask(i)
for j in range(mask.shape[2]):
pyplot.imshow(mask[:, :, j], cmap='gray', alpha=0.3)
# show the figure
pyplot.show()