-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathattack_experiment.py
76 lines (58 loc) · 2.61 KB
/
attack_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from collections import OrderedDict
from pprint import pformat
from typing import Optional, Union
import torch
from PIL import Image
from adv_lib.distances.lp_norms import l1_distances, l2_distances, linf_distances
from sacred import Experiment
from sacred.observers import FileStorageObserver
from torchvision.transforms.functional import pil_to_tensor
from attack_ingredient import attack_ingredient, get_attack
from attack_utils import run_attack
from dataset_ingredient import dataset_ingredient, get_dataset
from model_ingredient import get_model, model_ingredient
ex = Experiment('segmentation_attack', ingredients=[dataset_ingredient, model_ingredient, attack_ingredient])
@ex.config
def config():
cpu = False # force experiment to run on CPU
save_adv = False # save the adversarial images produced by the attack
target = None # specify a target as a png image or int
cudnn_flag = 'benchmark'
@ex.named_config
def save_adv():
save_adv = True
metrics = OrderedDict([
('linf', linf_distances),
('l1', l1_distances),
('l2', l2_distances),
])
@ex.automain
def main(cpu: bool,
cudnn_flag: str,
save_adv: bool,
target: Optional[Union[int, str]],
_config, _run, _log):
device = torch.device('cuda' if torch.cuda.is_available() and not cpu else 'cpu')
setattr(torch.backends.cudnn, cudnn_flag, True)
loader, label_func = get_dataset()
model = get_model(dataset=_config['dataset']['name'], device=device)
attack, attack_name = get_attack()
file_observers = [obs for obs in _run.observers if isinstance(obs, FileStorageObserver)]
save_dir = file_observers[0].dir if len(file_observers) else None
if isinstance(target, str):
target_size = _config['dataset']['size']
if not isinstance(target_size, (list, tuple)):
target_size = (target_size, target_size)
img_target = Image.open(target).resize(size=target_size[::-1], resample=Image.NEAREST)
target = pil_to_tensor(label_func(img_target)).long().to(device)
attack_data = run_attack(model=model, loader=loader, attack=attack, target=target, metrics=metrics,
return_adv=save_adv and save_dir is not None)
if save_adv and save_dir is not None:
dataset_name = _config['dataset']['name']
model_name = _config['model']['name']
torch.save(attack_data, os.path.join(save_dir, f'attack_data_{dataset_name}_{model_name}_{attack_name}.pt'))
if 'images' in attack_data.keys():
del attack_data['images'], attack_data['adv_images']
_run.info = attack_data
_log.info(pformat(attack_data))