-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
85 lines (59 loc) · 2.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dataset import DataSet
import argparse
from arch import custom_resnet18
from arch import custom_resnet34
from arch import custom_resnet50
from arch import simple
from torch.utils.tensorboard import SummaryWriter
import os
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, help='--lr : learning rate')
parser.add_argument('-e', type=int, help='-e : epochs')
parser.add_argument('-b', type=int, help='-b : batch size')
parser.add_argument('--ts', type=int, default=77, help='torch manual seed')
parser.add_argument('-n', type=str, help='model name')
parser.add_argument('--itv', type=int, default=0, help='intervention epochs (train from this epoch)')
args = parser.parse_args()
model_name = args.n
inter_epoch = args.itv
torch.manual_seed(args.ts)
data = pd.read_csv('train.csv', header=None)
data = data.to_numpy()
labels = data[:,0]
images = data[:,1:]
train_set = DataSet(images, labels)
train_loader = DataLoader(train_set, batch_size=args.b, shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
writer = SummaryWriter(model_name)
if not os.path.exists(model_name):
os.makedirs(model_name)
model = custom_resnet50().to(device) if inter_epoch == 0 else torch.load(f'{model_name}/trained_{inter_epoch}.pth')
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss().to(device)
for epoch in range(inter_epoch, args.e):
avg_loss = 0.0
for x, y in train_loader:
x = x.to(device)
y = y.to(device)
y = torch.flatten(y)
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
avg_loss += loss
avg_loss /= len(train_loader)
writer.add_scalar(f'loss/train{model_name}', avg_loss, epoch + 1)
torch.save(model, f'./{model_name}/trained_{epoch + 1}.pth')
print(f'Epoch[{epoch + 1} / {args.e}] train avg loss {model_name} : {avg_loss}')
writer.close()