forked from tristandeleu/pytorch-maml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
165 lines (133 loc) · 6.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import math
import os
import time
import json
import logging
from torchmeta.utils.data import BatchMetaDataLoader
from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
import argparse
parser = argparse.ArgumentParser('MAML')
# General
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--dataset', type=str,
choices=['sinusoid', 'omniglot', 'miniimagenet','doublenmnist','doublenmnistsequence'], default='doublenmnist',
help='Name of the dataset (default: omniglot).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder to save the model.')
parser.add_argument('--num-ways', type=int, default=5, help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--num-shots', type=int, default=1, help='Number of training example per class (k in "k-shot", default: 5).')
parser.add_argument('--num-shots-test', type=int, default=10, help='Number of test example per class. If negative, same as the number of training examples `--num-shots` (default: 15).')
# Model
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels in each convolution layer of the VGG network '
'(default: 64).')
# Optimization
parser.add_argument('--batch-size', type=int, default=25,
help='Number of tasks in a batch of tasks (default: 25).')
parser.add_argument('--num-steps', type=int, default=1,
help='Number of fast adaptation steps, ie. gradient descent '
'updates (default: 1).')
parser.add_argument('--num-epochs', type=int, default=50,
help='Number of epochs of meta-training (default: 50).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batch of tasks per epoch (default: 100).')
parser.add_argument('--step-size', type=float, default=0.1,
help='Size of the fast adaptation step, ie. learning rate in the '
'gradient descent update (default: 0.1).')
parser.add_argument('--first-order', action='store_true',
help='Use the first order approximation, do not use higher-order '
'derivatives during meta-optimization.')
parser.add_argument('--meta-lr', type=float, default=0.001,
help='Learning rate for the meta-optimizer (optimization of the outer '
'loss). The default optimizer is Adam (default: 1e-3).')
# Misc
parser.add_argument('--num-workers', type=int, default=4,
help='Number of workers to use for data-loading (default: 6).')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--use-cuda', action='store_true')
args = parser.parse_args()
if args.num_shots_test <= 0:
args.num_shots_test = args.num_shots
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
if (args.output_folder is not None):
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
logging.debug('Creating folder `{0}`'.format(args.output_folder))
folder = os.path.join(args.output_folder,
time.strftime('%Y-%m-%d_%H%M%S'))
os.makedirs(folder)
logging.debug('Creating folder `{0}`'.format(folder))
args.folder = os.path.abspath(args.folder)
args.model_path = os.path.abspath(os.path.join(folder, 'model.th'))
# Save the configuration in a config.json file
with open(os.path.join(folder, 'config.json'), 'w') as f:
json.dump(vars(args), f, indent=2)
logging.info('Saving configuration file in `{0}`'.format(
os.path.abspath(os.path.join(folder, 'config.json'))))
benchmark = get_benchmark_by_name(args.dataset,
args.folder,
args.num_ways,
args.num_shots,
args.num_shots_test,
hidden_size=args.hidden_size)
meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
if hasattr(benchmark.model, 'get_trainable_parameters'):
print('Using get_trainable_parameters instead of parameters for optimization parameters')
meta_optimizer = torch.optim.Adamax(benchmark.model.get_trainable_parameters(), lr=args.meta_lr, betas=(0.,0.95))
else:
meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=args.meta_lr)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
meta_optimizer,
first_order=args.first_order,
num_adaptation_steps=args.num_steps,
step_size=args.step_size,
loss_function=benchmark.loss_function,
device=device)
best_value = None
from maml.utils import tensors_to_device, compute_accuracy
out = next(iter(meta_train_dataloader))
out_c = tensors_to_device(out, device='cuda')
benchmark.model.load_state_dict(torch.load('model.th'))
# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(args.num_epochs)))
for epoch in range(args.num_epochs):
print(epoch)
metalearner.train(meta_train_dataloader,
max_batches=args.num_batches,
verbose=args.verbose,
desc='Training',
leave=False)
results = metalearner.evaluate(meta_val_dataloader,
max_batches=args.num_batches,
verbose=args.verbose,
desc=epoch_desc.format(epoch + 1))
# Save best model
if 'accuracies_after' in results:
if (best_value is None) or (best_value < results['accuracies_after']):
best_value = results['accuracies_after']
save_model = True
elif (best_value is None) or (best_value > results['mean_outer_loss']):
best_value = results['mean_outer_loss']
save_model = True
else:
save_model = False
if save_model and (args.output_folder is not None):
with open(args.model_path, 'wb') as f:
torch.save(benchmark.model.state_dict(), f)
if hasattr(benchmark.meta_train_dataset, 'close'):
benchmark.meta_train_dataset.close()
benchmark.meta_val_dataset.close()