forked from xuebinqin/U-2-Net
-
Notifications
You must be signed in to change notification settings - Fork 16
/
data_loader.py
286 lines (217 loc) · 8.09 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Helper code for loading and altering images in dataset
"""
import gc
import random
import imageio.v3 as iio
import numpy as np
import torchvision.transforms.functional as tf
from PIL import Image
from torch.utils.data import Dataset
class RandomCrop:
"""
A class for performing random cropping of given image.
Treats white pixels as empty space, and strives to return as few of them as possible.
"""
THRESHOLDS = [
0.5,
0.8,
0.9,
0.95,
0.98,
0.99,
] # allowed percentages of white pixels
start_threshold_index = 0 # chosen percentage that we aim for
def __init__(self, output_size, index=0):
"""
Initialize the RandomCrop transformer.
Parameters:
- output_size (int or tuple): The desired size of the cropped image.
- index (int): The starting threshold index.
"""
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
self.start_threshold_index = index
@staticmethod
def _calculate_white_percentage(img):
"""Calculate the percentage of white pixels in the image."""
white_pixels = np.sum(img == 255)
total_pixels = img.size
return white_pixels / total_pixels
def __call__(self, sample):
"""
Apply the random crop to the input image + mask.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the random crop of image and mask.
"""
image, label = sample["image"], sample["label"]
w, h = image.size
grid_size = self.output_size[0]
# splitting image into grid for faster search
cells = [(i, j) for i in range(0, w, grid_size) for j in range(0, h, grid_size)]
random.shuffle(cells)
# looking for as non-empty cell as possible
# lowering threshold of whiteness if none are found
threshold_sequence = RandomCrop.THRESHOLDS[RandomCrop.start_threshold_index :]
for threshold in threshold_sequence:
for i, j in cells:
if i + self.output_size[0] <= w and j + self.output_size[1] <= h:
cropped_image = tf.crop(image, i, j, *self.output_size)
cropped_label = tf.crop(label, i, j, *self.output_size)
if (
self._calculate_white_percentage(np.array(cropped_image))
<= threshold
):
return {"image": cropped_image, "label": cropped_label}
raise ValueError("Fully white image is given :(")
class HorizontalFlip:
"""
A class to perform horizontal flipping of images.
"""
def __call__(self, sample):
"""
Flip the image and mask horizontally.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the horizontally flipped image and mask.
"""
image, label = sample["image"], sample["label"]
# Apply horizontal flip
image = tf.hflip(image)
label = tf.hflip(label)
return {"image": image, "label": label}
class VerticalFlip:
"""
A class to perform vertical flipping of images.
"""
def __call__(self, sample):
"""
Flip the image and mask vertically.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the vertically flipped image and mask.
"""
image, label = sample["image"], sample["label"]
# Apply vertical flip
image = tf.vflip(image)
label = tf.vflip(label)
return {"image": image, "label": label}
class Rotation:
"""
A class to rotate images by a given angle.
"""
def __init__(self, degrees):
"""
Initialize the Rotation transformer.
Parameters:
- degrees (float): The angle by which the image should be rotated.
"""
self.degrees = degrees
def __call__(self, sample):
"""
Rotate the image and mask by a specified angle.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the rotated image and mask.
"""
image, label = sample["image"], sample["label"]
# Apply rotation
image = tf.rotate(image, self.degrees)
label = tf.rotate(label, self.degrees)
return {"image": image, "label": label}
class Resize:
"""
A class to resize images to a specified size.
"""
def __init__(self, size=1024):
"""
Initialize the Resize transformer.
Parameters:
- size (int): The desired size of the image after resizing.
"""
self.size = size
def __call__(self, sample):
"""
Resize the image and mask to the specified size.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the resized image and mask.
"""
image, label = sample["image"], sample["label"]
# Resize both the image and label
image = tf.resize(image, [self.size, self.size])
label = tf.resize(label, [self.size, self.size])
return {"image": image, "label": label}
class ToTensorLab:
"""
A class to convert images from PIL format to PyTorch tensors.
"""
def __call__(self, sample):
"""
Convert the image and label from PIL format to PyTorch tensors.
Parameters:
- sample (dict): Dictionary containing an image and its mask.
Returns:
- Dictionary containing the image and mask as tensors.
"""
from u2net_train import HALF_PRECISION
image, label = sample["image"], sample["label"]
# Convert to tensor
image = tf.to_tensor(image)
label = tf.to_tensor(label)
if HALF_PRECISION:
image, label = image.half(), label.half()
return {"image": image, "label": label}
class SalObjDataset(Dataset):
"""
Custom dataset class for salient object detection. This class helps in
loading images, their corresponding masks, and applying the desired
transformations before feeding them to the network.
"""
def __init__(self, img_name_list, lbl_name_list, transform=None):
"""
Initialize the SalObjDataset.
Parameters:
- img_name_list (list): List of paths to the images.
- lbl_name_list (list): List of paths to the corresponding masks.
- transform (callable, optional): Optional transform to be applied to both image & mask.
"""
self.img_name_list = img_name_list
self.lbl_name_list = lbl_name_list
self.transform = transform
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.img_name_list)
def __getitem__(self, idx):
"""
Fetch an image and its corresponding label, apply any transformations if needed,
and return them as a dictionary.
Parameters:
- idx (int): Index of the desired sample.
Returns:
- Dictionary containing an image and its label.
"""
# Read the images
image_array = iio.imread(self.img_name_list[idx])
label_array = iio.imread(self.lbl_name_list[idx])
# Convert arrays to PIL images for compatibility with existing transforms
image = Image.fromarray(image_array)
label = Image.fromarray(label_array).convert("L") # Convert RGB to grayscale
# TODO add a check here if it even needs to be converted, to increase perf
# Clean up memory
del image_array, label_array
gc.collect()
sample = {"image": image, "label": label}
# Apply the transformations
if self.transform:
sample = self.transform(sample)
return sample