-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
38 lines (33 loc) · 1.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# correct label 1, fake label 0
# one-hot vector concat dataで
import os
import random
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import model
import hoge
from config import *
interrupt_flag = False # restart interrupted training
# create required directory
required_dirs = ["param", "result", "mnist"]
hoge.make_dir(required_dirs)
mnist_data = MNIST('./mnist/', train=True, download=True,
transform = transforms.ToTensor())
dataloader = DataLoader(mnist_data, batch_size=mini_batch_num, shuffle=True)
print("\n")
if interrupt_flag: # restart training
f = open("./param/tmp.pickle", mode="rb")
init_epoch = pickle.load(f)
model = model.GAN(dataloader, interrupting=True)
else:
init_epoch = 1
model = model.GAN(dataloader)
# iteration
for epoch in range(init_epoch, epochs+1):
print("Epoch[%d/%d]:"%(epoch, epochs))
model.study(epoch)
model.evaluate()
model.save_tmp_weight(epoch)
model.eval_pic(epoch)
model.output()