-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
69 lines (51 loc) · 1.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Andrew Player
September 2022
Script for training and testing a network for MSTAR SAR Target Detection.
"""
import os
import time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from mstar_io import test_with_val_data, make_dataset
from training import train_model
DATA_DIR = "./data"
MODEL_NAME = f"Testing_{time.time()}"
EPOCHS = 10
BATCH_SIZE = 1
VALIDATION_SPLIT = 0.1
print("CREATING DATASET:")
print("-----------------")
"""
Right now, the data dir should just contain the target folders:
./data
|_ BRDM_2/
|_ .../
"""
save_directory, count, label_dict, target_types = make_dataset(
DATA_DIR,
validation_split=VALIDATION_SPLIT
)
training_set_dir = "./" + str(save_directory)
print(f"Dataset saved to \"{training_set_dir}\"")
print(f"Dataset contains {count} samples.")
print(f"Labels to array index: {label_dict}\n")
print("TRAINING MODEL:")
print("--------------")
"""
Basic conv to dense model for now.
"""
history = train_model(
MODEL_NAME,
training_set_dir,
128,
EPOCHS,
BATCH_SIZE
)
print("\nMODEL HISTORY:")
print("--------------")
print(history.history)
print("")
print("TEST RESULT:")
print("------------")
test_with_val_data(training_set_dir, "./models/checkpoints/"+MODEL_NAME)
print("")