-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
179 lines (149 loc) · 5.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn
import time
import os
import tqdm as tqdm
import segmentation_models_pytorch as smp
import pickle
# This place contains a lot of stolen code.
def get_img(fname, folder="input/train_images_525/train_images_525", npy=False):
if npy:
return np.load(os.path.join(folder, fname+'.npy'))
img = cv2.imread(os.path.join(folder, fname))
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def get_mask(fname, mask_dir="input/mask"):
return np.load(os.path.join(mask_dir, fname+'.npy'))
def mask_to_np(df, out):
ids = df.im_id.values
for id in tqdm(ids):
mask = make_mask(df, image_name)
np.save(mask, os.path.join(id, out))
def seed_everything(seed=42):
"""
42 is the answer to everything.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_ids(train_ids_file='train_ids.pkl', valid_ids_file='valid_ids.pkl'):
with open(train_ids_file, 'rb') as handle:
train_ids = pickle.load(handle)
with open(valid_ids_file, 'rb') as handle:
valid_ids = pickle.load(handle)
return train_ids, valid_ids
def post_process(probability, threshold, min_size, size=(350, 525)):
"""
Post processing of each predicted mask, components with lesser number of pixels
than `min_size` are ignored
"""
# don't remember where I saw it
mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
predictions = np.zeros(size, np.float32) #Output size needed for this comp
num = 0
for c in range(1, num_component):
p = (component == c)
if p.sum() > min_size:
predictions[p] = 1
num += 1
return predictions, num
def rle_decode(mask_rle: str = "", shape: tuple = (1400, 2100)):
"""Source: https://www.kaggle.com/artgor/segmentation-in-pytorch-using-convenient-tools"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape, order="F")
def make_mask(df: pd.DataFrame, image_name: str='img.jpg', shape: tuple = (1400, 2100)):
"""Source: https://www.kaggle.com/artgor/segmentation-in-pytorch-using-convenient-tools"""
encoded_masks = df.loc[df['im_id'] == image_name, 'EncodedPixels']
masks = np.zeros((shape[0], shape[1], 4), dtype=np.float32)
for idx, label in enumerate(encoded_masks.values):
if label is not np.nan:
mask = rle_decode(label)
masks[:, :, idx] = mask
return masks
def mask2rle(img):
'''
Convert mask to rle.
img: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels= img.T.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
def get_model(encoder='resnet18', type='unet',
encoder_weights = 'imagenet', classes=4):
# My own simple wrapper around smp
if type == 'unet':
model = smp.Unet(
encoder_name=encoder,
encoder_weights=encoder_weights,
classes=classes,
activation=None,
)
elif type == 'fpn':
model = smp.FPN(
encoder_name=encoder,
encoder_weights=encoder_weights,
classes=classes,
activation=None,
)
elif type == 'pspnet':
model = smp.PSPNet(
encoder_name=encoder,
encoder_weights=encoder_weights,
classes=classes,
activation=None,
)
elif type == 'linknet':
model = smp.Linknet(
encoder_name=encoder,
encoder_weights=encoder_weights,
classes=classes,
activation=None,
)
else:
raise "weird architecture"
print(f"Training on {type} architecture with {encoder} encoder")
preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights)
return model, preprocessing_fn
def visualize_with_raw(image, mask, original_image=None, original_mask=None,
raw_image=None, raw_mask=None):
"""
Plot image and masks.
If two pairs of images and masks are passes, show both.
Source: https://www.kaggle.com/artgor/segmentation-in-pytorch-using-convenient-tools
"""
fontsize = 14
class_dict = {0: 'Fish', 1: 'Flower', 2: 'Gravel', 3: 'Sugar'}
f, ax = plt.subplots(3, 5, figsize=(24, 12))
ax[0, 0].imshow(original_image)
ax[0, 0].set_title('Original image', fontsize=fontsize)
for i in range(4):
ax[0, i + 1].imshow(original_mask[:, :, i])
ax[0, i + 1].set_title(f'Original mask {class_dict[i]}',
fontsize=fontsize)
ax[1, 0].imshow(raw_image)
ax[1, 0].set_title('Original image', fontsize=fontsize)
for i in range(4):
ax[1, i + 1].imshow(raw_mask[:, :, i])
ax[1, i + 1].set_title(f'Raw predicted mask {class_dict[i]}',
fontsize=fontsize)
ax[2, 0].imshow(image)
ax[2, 0].set_title('Transformed image', fontsize=fontsize)
for i in range(4):
ax[2, i + 1].imshow(mask[:, :, i])
ax[2, i + 1].set_title(f'Predicted mask with processing {class_dict[i]}',
fontsize=fontsize)