-
Notifications
You must be signed in to change notification settings - Fork 74
/
train.py
324 lines (294 loc) · 14.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
A simple script to train certified defense LSTM or Transformer using the auto_LiRPA library.
"""
import argparse
import pickle
import os
import logging
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput
from auto_LiRPA.utils import MultiAverageMeter, AverageMeter, logger, scale_gradients
from auto_LiRPA.eps_scheduler import *
from Transformer.Transformer import Transformer
from lstm import LSTM
from data_utils import load_data, clean_data, get_batches
from oracle import oracle
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true')
parser.add_argument('--robust', action='store_true')
parser.add_argument('--oracle', action='store_true')
parser.add_argument('--dir', type=str, default='model')
parser.add_argument('--checkpoint', type=int, default=None)
parser.add_argument('--data', type=str, default='sst', choices=['sst'])
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'])
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--legacy_loading', action='store_true', help='use a deprecated way of loading checkpoints for previously saved models')
parser.add_argument('--auto_test', action='store_true')
parser.add_argument('--eps', type=float, default=1.0)
parser.add_argument('--budget', type=int, default=6)
parser.add_argument('--method', type=str, default=None,
choices=['IBP', 'IBP+backward', 'IBP+backward_train', 'forward', 'forward+backward'])
parser.add_argument('--model', type=str, default='transformer',
choices=['transformer', 'lstm'])
parser.add_argument('--num_epochs', type=int, default=25)
parser.add_argument('--num_epochs_all_nodes', type=int, default=20)
parser.add_argument('--eps_start', type=int, default=1)
parser.add_argument('--eps_length', type=int, default=10)
parser.add_argument('--log_interval', type=int, default=100)
parser.add_argument('--min_word_freq', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--oracle_batch_size', type=int, default=1024)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
parser.add_argument('--max_sent_length', type=int, default=32)
parser.add_argument('--vocab_size', type=int, default=50000)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=1)
parser.add_argument('--grad_clip', type=float, default=10.0)
parser.add_argument('--num_classes', type=int, default=2)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--num_attention_heads', type=int, default=4)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument('--embedding_size', type=int, default=64)
parser.add_argument('--intermediate_size', type=int, default=128)
parser.add_argument('--drop_unk', action='store_true')
parser.add_argument('--hidden_act', type=str, default='relu')
parser.add_argument('--layer_norm', type=str, default='no_var',
choices=['standard', 'no', 'no_var'])
parser.add_argument('--loss_fusion', action='store_true')
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--bound_opts_relu', type=str, default='zero-lb')
args = parser.parse_args()
# log writer in Tensorboard
writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10)
file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log'))
file_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s'))
logger.addHandler(file_handler)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
## Step 1: Prepare dataset and Initial original model as usual
data_train_all_nodes, data_train, data_dev, data_test = load_data(args.data)
if args.robust:
data_dev, data_test = clean_data(data_dev), clean_data(data_test)
if args.auto_test:
random.seed(args.seed)
random.shuffle(data_test)
data_test = data_test[:10]
assert args.batch_size >= 10
logger.info('Dataset sizes: {}/{}/{}/{}'.format(
len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test)))
dummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device)
dummy_labels = torch.zeros(1, dtype=torch.long, device=args.device)
if args.model == 'transformer':
dummy_mask = torch.zeros(1, 1, 1, args.max_sent_length, device=args.device)
model = Transformer(args, data_train)
elif args.model == 'lstm':
dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device)
model = LSTM(args, data_train)
dev_batches = get_batches(data_dev, args.batch_size)
test_batches = get_batches(data_test, args.batch_size)
## Step 3: Define perturbation range, here we use synonym replacement perturbation constarint by args.budget
ptb = PerturbationSynonym(budget=args.budget)
dummy_embeddings = BoundedTensor(dummy_embeddings, ptb)
## Step 4: wrap model with auto_LiRPA
model_ori = model.model_from_embeddings
bound_opts = {'relu': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True}
if isinstance(model_ori, BoundedModule):
model_bound = model_ori
else:
model_bound = BoundedModule(
model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device)
model.model_from_embeddings = model_bound
if args.loss_fusion:
bound_opts['loss_fusion'] = True
model_loss = BoundedModule(
CrossEntropyWrapperMultiInput(model_ori),
(torch.zeros(1, dtype=torch.long), dummy_embeddings, dummy_mask),
bound_opts=bound_opts, device=args.device)
ptb.model = model
optimizer = model.build_optimizer()
if args.lr_decay < 1:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=args.lr_decay)
else:
lr_scheduler = None
if args.robust:
eps_scheduler = LinearScheduler(args.eps, 'start={},length={}'.format(args.eps_start, args.eps_length))
for i in range(model.checkpoint):
eps_scheduler.step_epoch(verbose=False)
else:
eps_scheduler = None
logger.info('Model converted to support bounds')
def step(model, ptb, batch, eps=1.0, train=False):
model_bound = model.model_from_embeddings
if train:
model.train()
model_bound.train()
grad = torch.enable_grad()
if args.loss_fusion:
model_loss.train()
else:
model.eval()
model_bound.eval()
grad = torch.no_grad()
if args.auto_test:
grad = torch.enable_grad()
with grad:
ptb.set_eps(eps)
ptb.set_train(train)
embeddings_unbounded, mask, tokens, labels = model.get_input(batch)
aux = (tokens, batch)
if args.robust and eps > 1e-9:
embeddings = BoundedTensor(embeddings_unbounded, ptb)
else:
embeddings = embeddings_unbounded.detach().requires_grad_(True)
robust = args.robust and eps > 1e-6
if train and robust and args.loss_fusion:
# loss_fusion loss
if args.method == 'IBP+backward_train':
lb, ub = model_loss.compute_bounds(
x=(labels, embeddings, mask), aux=aux,
C=None, method='IBP+backward', bound_lower=False)
else:
raise NotImplementedError
loss_robust = torch.log(ub).mean()
loss = acc = acc_robust = -1 # unknown
else:
# regular loss
logits = model_bound(embeddings, mask)
loss = CrossEntropyLoss()(logits, labels)
acc = (torch.argmax(logits, dim=1) == labels).float().mean()
if robust:
# generate specifications
num_class = args.num_classes
c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \
torch.eye(num_class).type_as(embeddings).unsqueeze(0)
# remove specifications to self
I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
c = (c[I].view(embeddings.size(0), num_class - 1, num_class))
if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method=args.method, bound_upper=False)
elif args.method == 'IBP+backward_train':
# CROWN-IBP
if 1 - eps > 1e-4:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False)
ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True)
# we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
lb = eps * ilb + (1 - eps) * lb
else:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP')
else:
raise NotImplementedError
# Pad zero at the beginning for each example, and use fake label "0" for all examples because the margins
# have already been calculated by specifications
lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)
fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float())
else:
acc_robust, loss_robust = acc, loss
if train or args.auto_test:
loss_robust.backward()
grad_embed = torch.autograd.grad(
embeddings_unbounded, model.word_embeddings.weight,
grad_outputs=embeddings.grad)[0]
if model.word_embeddings.weight.grad is None:
model.word_embeddings.weight.grad = grad_embed
else:
model.word_embeddings.weight.grad += grad_embed
if args.auto_test:
with open('res_test.pkl', 'wb') as file:
pickle.dump((
float(acc), float(loss), float(acc_robust), float(loss_robust),
grad_embed.detach().numpy()), file)
return acc, loss, acc_robust, loss_robust
def train(epoch, batches, type):
meter = MultiAverageMeter()
assert(optimizer is not None)
train = type == 'train'
if args.robust:
# epsilon dynamically growth
eps_scheduler.set_epoch_length(len(batches))
if train:
eps_scheduler.train()
eps_scheduler.step_epoch()
else:
eps_scheduler.eval()
for i, batch in enumerate(batches):
if args.robust:
eps_scheduler.step_batch()
eps = eps_scheduler.get_eps()
else:
eps = 0
acc, loss, acc_robust, loss_robust = \
step(model, ptb, batch, eps=eps, train=train)
meter.update('acc', acc, len(batch))
meter.update('loss', loss, len(batch))
meter.update('acc_rob', acc_robust, len(batch))
meter.update('loss_rob', loss_robust, len(batch))
if train:
if (i + 1) % args.gradient_accumulation_steps == 0 or (i + 1) == len(batches):
scale_gradients(optimizer, i % args.gradient_accumulation_steps + 1, args.grad_clip)
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'), i + 1)
writer.add_scalar('loss_robust_train_{}'.format(epoch), meter.avg('loss_rob'), i + 1)
writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'), i + 1)
writer.add_scalar('acc_robust_train_{}'.format(epoch), meter.avg('acc_rob'), i + 1)
if (i + 1) % args.log_interval == 0 or (i + 1) == len(batches):
logger.info('Epoch {}, {} step {}/{}: eps {:.5f}, {}'.format(
epoch, type, i + 1, len(batches), eps, meter))
if lr_scheduler is not None:
logger.info('lr {}'.format(lr_scheduler.get_lr()))
writer.add_scalar('loss/{}'.format(type), meter.avg('loss'), epoch)
writer.add_scalar('loss_robust/{}'.format(type), meter.avg('loss_rob'), epoch)
writer.add_scalar('acc/{}'.format(type), meter.avg('acc'), epoch)
writer.add_scalar('acc_robust/{}'.format(type), meter.avg('acc_rob'), epoch)
if train:
if args.loss_fusion:
state_dict_loss = model_loss.state_dict()
state_dict = {}
for name in state_dict_loss:
assert(name.startswith('model.'))
state_dict[name[6:]] = state_dict_loss[name]
model_ori.load_state_dict(state_dict)
model_bound = BoundedModule(
model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device)
model.model_from_embeddings = model_bound
model.save(epoch)
return meter.avg('acc_rob')
def main():
if args.train:
for t in range(model.checkpoint, args.num_epochs):
if t + 1 <= args.num_epochs_all_nodes:
train(t + 1, get_batches(data_train_all_nodes, args.batch_size), 'train')
else:
train(t + 1, get_batches(data_train, args.batch_size), 'train')
train(t + 1, dev_batches, 'dev')
train(t + 1, test_batches, 'test')
elif args.oracle:
oracle(args, model, ptb, data_test, 'test')
else:
if args.robust:
for i in range(args.num_epochs):
eps_scheduler.step_epoch(verbose=False)
res = []
for i in range(1, args.budget + 1):
logger.info('budget {}'.format(i))
ptb.budget = i
acc_rob = train(None, test_batches, 'test')
res.append(acc_rob)
logger.info('Verification results:')
for i in range(len(res)):
logger.info('budget {} acc_rob {:.3f}'.format(i + 1, res[i]))
logger.info(res)
else:
train(None, test_batches, 'test')
if __name__ == '__main__':
main()