-
Notifications
You must be signed in to change notification settings - Fork 0
/
cotrain.py
executable file
·140 lines (127 loc) · 6.29 KB
/
cotrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import git
import torch as ch
import torch
from tensorboardX import SummaryWriter
import argparse
from datasets.co_training import CoTrainDatasetTrainer
import json
from datasets.datasets import COPRIOR_DATASETS
from models.model_utils import make_and_restore_model, load_model_from_checkpoint_dir
import os
from utils.logging import log_images_hook
from cox.store import Store
def main(args):
store = Store(args.out_dir)
metadata_keys = {'dataset': str,
'use_val': bool,
'input_dirs': str,
'fraction': float,
'epochs_per_era': int,
'eras': int,
'strategy': str,
'pure': bool,
'arch': str,
'epochs': int,
'lr': float,
'step_lr': int,
'step_lr_gamma': float,
'additional_transform': str,
'spurious': str,
'use_gt': bool}
store.add_table('metadata', metadata_keys)
args_dict = args.__dict__
store['metadata'].append_row({k: args_dict[k] for k in metadata_keys.keys()})
for mode in ['train', 'val']:
store.add_table(mode,
{'loss': float,
'acc': float,
'epoch': int})
for i in range(args.eras): # hacky but backward compatible
store.add_table(f'{i}_cotrain_{mode}',
{'loss': float,
'acc': float,
'epoch': int})
store.add_table('unlabeled',
{'model': int,
'era': int,
'acc': float,
'selected_acc': float,
'num_points': int})
# MAKE DATASET AND LOADERS
dataset_name = args.dataset
data_path = args.data_path
ds_class = COPRIOR_DATASETS[dataset_name](data_path)
classes = ds_class.CLASS_NAMES
train_ds = ds_class.get_dataset('train')
val_ds = ds_class.get_dataset('val' if args.use_val else 'test')
unlabelled_ds = ds_class.get_dataset('unlabeled')
input_models = []
input_model_args = []
for input_dir in args.input_dirs:
model, model_args, _ = load_model_from_checkpoint_dir(input_dir, ds_class, args.out_dir)
input_models.append(model)
input_model_args.append(model_args)
model, model_args, checkpoint = make_and_restore_model(arch_name=args.arch, ds_class=ds_class,
resume_path=args.resume_path, train_args=args,
additional_transform=args.additional_transform,
out_dir=args.out_dir)
writer = SummaryWriter(args.out_dir)
cotrainer = CoTrainDatasetTrainer(train_dataset=train_ds,
unlabelled_dataset=unlabelled_ds,
val_dataset=val_ds, writer=writer,
store=store, log_hook=log_images_hook,
start_fraction=args.fraction,
out_dir=args.out_dir,
strategy=args.strategy,
pure=args.pure,
spurious=args.spurious,
use_gt=args.use_gt,
num_classes=len(ds_class.CLASS_NAMES))
cotrainer.co_train_models(input_model_args=input_model_args,
input_models=input_models,
co_training_eras=args.eras,
co_training_epochs_per_era=args.epochs_per_era,
final_model=model,
final_model_args=model_args,
val_iters=25, checkpoint_iters=50)
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, help='name of dataset')
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--out-dir', type=str, help='path to dump output')
parser.add_argument('--use_val', action='store_true', default=False,
help='use a training fold for validation')
parser.add_argument('--input-dirs', type=str, action='append')
parser.add_argument('--fraction', type=float, default=0.1)
parser.add_argument('--epochs_per_era', type=int, default=300)
parser.add_argument('--eras', type=int, default=5)
parser.add_argument('--strategy', type=str, default='STANDARD_CONSTANT')
parser.add_argument('--pure', action='store_true', help='use pure cotraining')
# final model arguments
parser.add_argument('--arch', type=str, help='name of model architecture')
parser.add_argument('--resume_path', type=str, default=None, help='path to load a previous checkpoint')
parser.add_argument('--epochs', type=int, default=500, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
parser.add_argument('--step_lr', type=int, default=50, help='epochs before LR drop')
parser.add_argument('--step_lr_gamma', type=float, default=0.1, help='LR drop multiplier')
parser.add_argument('--additional-transform', type=str, default='NONE', help='type of additional transform')
parser.add_argument('--spurious', type=str, default=None,
help='add a spurious correlation to the'
'training dataset')
parser.add_argument('--use-gt', action='store_true', help='use gt labels')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if os.path.exists('job_parameters.json'):
with open('job_parameters.json') as f:
job_params = json.load(f)
for k, v in job_params.items():
assert args.__contains__(k) # catch typos)
args.__setattr__(k, v)
os.makedirs(args.out_dir, exist_ok=True)
print(args.__dict__)
with open(os.path.join(args.out_dir, 'env.json'), 'w') as f:
json.dump(args.__dict__, f)
main(args)