-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
39 lines (33 loc) · 1.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
from torchvision import datasets
def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=False):
dset = datasets.ImageFolder(path, transform)
ind = dset.class_to_idx[subfolder]
n = 0
for i in range(dset.__len__()):
if ind != dset.imgs[n][1]:
del dset.imgs[n]
n -= 1
n += 1
return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()