-
Notifications
You must be signed in to change notification settings - Fork 3
/
mobile_student_training.py
146 lines (121 loc) · 5.56 KB
/
mobile_student_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Import necessary libraries
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import ConcatDataset
import os
from torchvision.models import resnet50
from torchvision.models import mobilenet_v2
class CustomCIFAR10CDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img, label = self.data[index], self.labels[index]
img = Image.fromarray(img)
if self.transform:
img = self.transform(img)
return img, label
# torch.manual_seed(42)
if __name__ == '__main__':
# Hyper-parameters
num_epochs = 100
Inference_configuration_list = np.asarray([20,24,28,32])
# CIFAR10C_list = ["original","contrast","defocus_blur","gaussian_noise","jpeg_compression","motion_blur","shot_noise"]
CIFAR10C_list = ["original", "brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise", "snow", "spatter", "speckle_noise","zoom_blur"]
# Define the device, model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = mobilenet_v2().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
training_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
combined_train_dataset = torchvision.datasets.CIFAR10(root='./', train=True, transform=training_transform, download=False)
train_loader = torch.utils.data.DataLoader(combined_train_dataset, batch_size=200, shuffle=True, num_workers=2)
# Create validation dataset
validation_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
validation_dataset = torchvision.datasets.CIFAR10(root='./', train=False, transform=validation_transform, download=False)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=100, shuffle=False, num_workers=2)
# Training process
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
# Validation process
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in validation_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
print(f'Validation Accuracy: {100 * correct / total:.2f}%')
for CIFAR10C in CIFAR10C_list:
for image_size in Inference_configuration_list:
# Load CIFAR-10-C dataset
test_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if CIFAR10C == "original":
test_dataset = torchvision.datasets.CIFAR10(root='./', train=False, transform=test_transform, download=False)
else:
test_data = np.load('CIFAR-10-C/'+CIFAR10C+'.npy')
test_labels = np.load('CIFAR-10-C/labels.npy')
test_dataset = CustomCIFAR10CDataset(test_data, test_labels, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
print(CIFAR10C)
print(f'Test Accuracy: {100 * correct / total:.2f}%')
# Check if the models folder exists, if not, create one
if not os.path.exists("models"):
os.makedirs("models")
# Save the model to the models folder
torch.save(model.state_dict(), "models/mobile_student_model1"+".pth")