-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutil.py
36 lines (32 loc) · 1.05 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
def get_dataloaders(data_dir, imsize, batch_size, eval_size, num_workers=1):
r"""
Creates a dataloader from a directory containing image data.
"""
dataset = datasets.ImageFolder(
root=data_dir,
transform=transforms.Compose(
[
transforms.Resize(imsize),
transforms.CenterCrop(imsize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
eval_dataset, train_dataset = torch.utils.data.random_split(
dataset,
[eval_size, len(dataset) - eval_size],
)
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset, batch_size=batch_size, num_workers=num_workers
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
)
return train_dataloader, eval_dataloader