-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rounding to nearest pixel value breaks almost all attacks #44
Comments
Hi, the first things that I'd try would be 1) to run the attack without rounding and then just round the final output and 2) to exclude the rounding in the backward pass so that the gradients are computed normally (in a PGD-like attack you could round the current iterate after the projection step to ensure that it belongs to the desired image domain, in the end rounding is just a particular projection). Let me know if this helps! |
Hi thank you for your inputs. I've implemented a custom round function. import torch
class CustomRound(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, g):
# send the gradient g straight-through on the backward pass.
return g, None
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model.eval()
self.mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1).to(device)
self.std = torch.tensor([0.2470, 0.2435, 0.2616]).view(1, 3, 1, 1).to(device)
self.round = CustomRound.apply
def forward(self, x):
x = x.clamp(0, 1)
x = x * 255
x = self.round(x)
x = x / 255
x = (x - self.mean) / self.std
x = self.model(x)
return x
model = ModelWrapper(resnet18(pretrained=True))
model = model.to(device).eval() |
Thanks for sharing. Did you see any difference in the robustness of the model with rounding in this way? |
Usually images are stored in uint8 format, in range [0, 255]
Hence when I try to round the values of an image to its nearest interger values, all attacks fail to achieve desired accuracy.
I know that torch.round() doesn't give useful gradients to the adversary, hence the drop the attack accuracy.
So how to make sure the inputs to the model correspond to valid integer value of [0, 255], but still achieve high attack accuracy?
The text was updated successfully, but these errors were encountered: