-
Notifications
You must be signed in to change notification settings - Fork 0
/
attacker.py
217 lines (185 loc) · 7.01 KB
/
attacker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import abc
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
IMAGE_SCALE = 2.0 / 255
def get_kernel(size, nsig, mode="gaussian", device="cuda:0"):
if mode == "gaussian":
# since we have to normlize all the numbers
# there is no need to calculate the const number like \pi and \sigma.
vec = torch.linspace(-nsig, nsig, steps=size).to(device)
vec = torch.exp(-vec * vec / 2)
res = vec.view(-1, 1) @ vec.view(1, -1)
res = res / torch.sum(res)
elif mode == "linear":
# originally, res[i][j] = (1-|i|/(k+1)) * (1-|j|/(k+1))
# since we have to normalize it
# calculate res[i][j] = (k+1-|i|)*(k+1-|j|)
vec = (size + 1) / 2 - torch.abs(
torch.arange(-(size + 1) / 2, (size + 1) / 2 + 1, step=1)
).to(device)
res = vec.view(-1, 1) @ vec.view(1, -1)
res = res / torch.sum(res)
else:
raise ValueError("no such mode in get_kernel.")
return res
class AttackerBase(abc.ABC):
def __init__(
self,
num_iter: int = 0,
num_classes: typing.Optional[int] = None,
original_attack: bool = False,
):
self.num_iter = num_iter
self.num_classes = num_classes
self.original_attack = original_attack
def is_grad_reusable(self) -> bool:
"""Whether gradients are reuabled."""
return self.original_attack
@abc.abstractmethod
def attack_init(
self, image_clean: torch.Tensor, label: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Init the perturbation."""
raise NotImplementedError()
@abc.abstractmethod
def attack_update(
self,
image_clean: torch.Tensor,
adv: torch.Tensor,
adv_grad: torch.Tensor,
) -> torch.Tensor:
"""Update the perturbation."""
raise NotImplementedError()
def _create_random_target(self, label: torch.Tensor) -> torch.Tensor:
if self.num_classes is not None:
label_offset = torch.randint_like(
label, low=0, high=self.num_classes
)
return (label + label_offset) % self.num_classes
raise RuntimeError("num_classes must be provided.")
class NoOpAttacker(AttackerBase):
"""Not impl."""
def __init__(self):
super().__init__(num_iter=0, num_classes=None, original_attack=True)
def attack_init(
self, image_clean: torch.Tensor, label: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
return super().attack_init(image_clean, label)
def attack_update(
self,
image_clean: torch.Tensor,
adv: torch.Tensor,
adv_grad: torch.Tensor,
) -> torch.Tensor:
return super().attack_update(image_clean, adv, adv_grad)
class CopyAttacker(AttackerBase):
"""Always return the clean image itself.
Mimick the AdvProp's behavior.
"""
def __init__(self):
super().__init__(num_iter=0, num_classes=None, original_attack=True)
def attack_init(
self, image_clean: torch.Tensor, label: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
return image_clean, label
def attack_update(self, image_clean, adv, adv_grad):
_, _ = adv, adv_grad
return image_clean
class PGDAttacker(AttackerBase):
"""PGD Attacker.
https://arxiv.org/pdf/1706.06083.pdf
"""
def __init__(
self,
num_iter,
epsilon,
step_size,
start_from_clean=False,
kernel_size=15,
prob_start_from_clean=0.0,
translation=False,
device="cuda:0",
num_classes=1000,
original_attack=False,
):
super().__init__(num_iter, num_classes, original_attack)
step_size = max(step_size, epsilon / num_iter)
self.num_iter = num_iter
self.epsilon = epsilon * IMAGE_SCALE
self.step_size = step_size * IMAGE_SCALE
self.start_from_clean = start_from_clean
self.prob_start_from_clean = prob_start_from_clean
self.device = device
self.translation = translation
if translation:
# this is equivalent to deepth wise convolution
# details can be found in the docs of Conv2d.
# "When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also termed in literature as depthwise convolution."
self.conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=kernel_size,
stride=(kernel_size - 1) // 2,
bias=False,
groups=3,
).to(self.device)
self.gkernel = get_kernel(
kernel_size, nsig=3, device=self.device
).to(self.device)
self.conv.weight = self.gkernel
def attack_init(self, image_clean, label):
if self.original_attack:
target_label = label
else:
target_label = self._create_random_target(label)
if not self.start_from_clean:
init_start = torch.empty_like(image_clean).uniform_(
-self.epsilon, self.epsilon
)
# NOTE(meijieru): the prob_start_from_clean use normal distribution,
# so prob_start_from_clean==1 doesn't mean clean images
start_from_noise_index = (
torch.randn([]) > self.prob_start_from_clean
).float()
start_adv = image_clean + start_from_noise_index * init_start
else:
start_adv = image_clean.clone().detach()
return start_adv, target_label
def attack_update(
self,
image_clean: torch.Tensor,
adv: torch.Tensor,
adv_grad: torch.Tensor,
) -> torch.Tensor:
lower_bound = torch.clamp(image_clean - self.epsilon, min=-1.0, max=1.0)
upper_bound = torch.clamp(image_clean + self.epsilon, min=-1.0, max=1.0)
if self.translation:
adv_grad = self.conv(adv_grad)
if self.original_attack:
adv = adv + torch.sign(adv_grad) * self.step_size
else:
adv = adv - torch.sign(adv_grad) * self.step_size
adv = torch.where(adv > lower_bound, adv, lower_bound)
adv = torch.where(adv < upper_bound, adv, upper_bound).detach()
return adv
def attack(
self,
image_clean: torch.Tensor,
label: torch.Tensor,
model: typing.Callable,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Full attack process.
It use standard cross entropy loss, so we won't use it.
"""
adv, target_label = self.attack_init(image_clean, label)
for _ in range(self.num_iter):
adv.requires_grad_(True)
logits = model(adv)
losses = F.cross_entropy(logits, target_label)
g = torch.autograd.grad(
losses, adv, retain_graph=False, create_graph=False
)[0]
adv = self.attack_update(image_clean, adv, g)
return adv, target_label