-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
75 lines (60 loc) · 2.44 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import os
class ImageDataset():
def __init__(self, path_file, n_samples=None, random_seed=None, transform=None):
"""
Args:
path_file : (string)
file path to a csv file listing all images
n_samples : (int), optional (default=None)
number of samples to take out of file path, randomly sampled if specified
If None, uses all samples (in order from path file)
random_seed (int), optional
Used to set random state for reproducable subsampling
If None, no random state set
transform (list, optional)
Optional transform to be applied on a sample, should be a list of torchvision,.
"""
if n_samples is None:
self.data = pd.read_csv(path_file)
else:
full_data = pd.read_csv(path_file)
if random_seed is None:
self.data = full_data.sample(n_samples)
else:
self.data = full_data.sample(n_samples, random_state=random_seed)
self.transform = transform
def __getitem__(self, idx):
"""
Args:
idx : (int)
the idx of self.data to grab
"""
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
rgb_init = Image.open(img_name)
gray_init = Image.open(img_name)
rgb_image = rgb_init.copy()
gray_image = gray_init.copy()
if self.transform is not None:
rgb_transforms = transforms.Compose(self.transform)
rgb_trans_image = rgb_transforms(rgb_image)
gray_transform = self.transform
gray_transform.append(transforms.Grayscale(num_output_channels=1))
gray_transforms = transforms.Compose(gray_transform)
gray_trans_image = gray_transforms(gray_image)
else:
gray_transforms = transforms.Grayscale(num_output_channels=1)
gray_trans_image = gray_transforms(gray_image)
pil2tensor = transforms.ToTensor()
rgb_tensor = pil2tensor(rgb_trans_image)
gray_tensor = pil2tensor(gray_trans_image)
return gray_tensor, rgb_tensor
def __len__(self):
return len(self.data)