-
Notifications
You must be signed in to change notification settings - Fork 94
/
Main.py
86 lines (57 loc) · 2.66 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import numpy as np
import time
import sys
from ChexnetTrainer import ChexnetTrainer
#--------------------------------------------------------------------------------
def main ():
runTest()
#runTrain()
#--------------------------------------------------------------------------------
def runTrain():
DENSENET121 = 'DENSE-NET-121'
DENSENET169 = 'DENSE-NET-169'
DENSENET201 = 'DENSE-NET-201'
timestampTime = time.strftime("%H%M%S")
timestampDate = time.strftime("%d%m%Y")
timestampLaunch = timestampDate + '-' + timestampTime
#---- Path to the directory with images
pathDirData = './database'
#---- Paths to the files with training, validation and testing sets.
#---- Each file should contains pairs [path to image, output vector]
#---- Example: images_011/00027736_001.png 0 0 0 0 0 0 0 0 0 0 0 0 0 0
pathFileTrain = './dataset/train_1.txt'
pathFileVal = './dataset/val_1.txt'
pathFileTest = './dataset/test_1.txt'
#---- Neural network parameters: type of the network, is it pre-trained
#---- on imagenet, number of classes
nnArchitecture = DENSENET121
nnIsTrained = True
nnClassCount = 14
#---- Training settings: batch size, maximum number of epochs
trBatchSize = 16
trMaxEpoch = 100
#---- Parameters related to image transforms: size of the down-scaled image, cropped image
imgtransResize = 256
imgtransCrop = 224
pathModel = 'm-' + timestampLaunch + '.pth.tar'
print ('Training NN architecture = ', nnArchitecture)
ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, imgtransResize, imgtransCrop, timestampLaunch, None)
print ('Testing the trained model')
ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, imgtransResize, imgtransCrop, timestampLaunch)
#--------------------------------------------------------------------------------
def runTest():
pathDirData = './database'
pathFileTest = './dataset/test_1.txt'
nnArchitecture = 'DENSE-NET-121'
nnIsTrained = True
nnClassCount = 14
trBatchSize = 16
imgtransResize = 256
imgtransCrop = 224
pathModel = './models/m-25012018-123527.pth.tar'
timestampLaunch = ''
ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, imgtransResize, imgtransCrop, timestampLaunch)
#--------------------------------------------------------------------------------
if __name__ == '__main__':
main()