Skip to content

Commit

Permalink
v2.11.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry24k committed Dec 12, 2020
1 parent 3613a1a commit 82d5b82
Show file tree
Hide file tree
Showing 15 changed files with 985 additions and 190 deletions.
7 changes: 3 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
build/*
dist/*
robustbench/*
torchattacks.egg-info/*
data/*
models/*
demos/data/MNIST/*
demos/data/cifar*
*/.*
Expand All @@ -9,8 +12,4 @@ _build
MENIFEST.in
setup.cfg
setup.py
data/
.ipynb_checkpoints/Untitled-checkpoint.ipynb
Untitled.ipynb
_commit.bat
_PGD BIM.pptx
62 changes: 60 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,28 @@

```python
import torchattacks

# Untargeted (Default)
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
adversarial_images = atk(images, labels)

# Targeted (User Define)
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
target_map_function = lambda images, labels: labels.fill_(300)
atk.set_attack_mode("targeted", target_map_function=target_map_function)
adversarial_images = atk(images, labels)

# Targeted (Least Likely)
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
atk.set_attack_mode("least_likely")
adversarial_images = atk(images, labels)

# Type of Return
atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
atk.set_return_type('int')

# Save Adversarial Images and Show accuracy
atk.save(data_loader=test_loader, save_path="./data/cifar10_pgd.pt", verbose=True)
```


Expand Down Expand Up @@ -85,7 +105,7 @@ The distance measure in parentheses.
- PGD (Linf)

* **Boosting Adversarial Attacks with Momentum (Oct 2017)**: [Paper](https://arxiv.org/abs/1710.06081)
* MIFGSM (Linf) - :heart_eyes: Contributor [zhuangzi926](https://github.com/zhuangzi926)
* MIFGSM (Linf) - :heart_eyes: Contributor [zhuangzi926](https://github.com/zhuangzi926), [huitailangyz](https://github.com/huitailangyz)

* **Theoretically Principled Trade-off between Robustness and Accuracy (Jan 2019)**: [Paper](https://arxiv.org/abs/1901.08573)
- TPGD (Linf)
Expand All @@ -97,7 +117,45 @@ The distance measure in parentheses.
- FFGSM (Linf)


## Performance Comparison

All experiments were done on GeForce RTX 2080.

For a fair comparison, [Robustbench](https://github.com/RobustBench/robustbench) is used.

As for the comparison methods, currently updated and the most cited methods were selected:

* **Foolbox**: [178](https://scholar.google.com/scholar?q=Foolbox%3A%20A%20Python%20toolbox%20to%20benchmark%20the%20robustness%20of%20machine%20learning%20models.%20arXiv%202018)

* **ART**: [102](https://scholar.google.com/scholar?cluster=5391305326811305758&hl=ko&as_sdt=0,5&sciodt=0,5)

For other methods, please refer to each projects' github on [Recommended Sites and Packages](#Recommended-Sites-and-Packages).

The code is here ([code](https://github.com/Harry24k/adversarial-attacks-pytorch/blob/master/demos/Performance%20Comparison%20(CIFAR10).ipynb), [nbviewer](https://nbviewer.jupyter.org/github/Harry24k/adversarial-attacks-pytorch/blob/master/demos/Performance%20Comparison%20(CIFAR10).ipynb)).

Accuracy and elapsed time on the first 50 images of CIFAR10. For L2 attacks, the average L2 distances between adversarial images and the original images are recorded.

| Attack | Package | Wong2020Fast | Rice2020Overfitting | Carmon2019Unlabeled | Remark |
| ------------- | ------------ | ------------------------ | ------------------------ | ------------------------- | -------------------------------------- |
| FGSM (Linf) | torchattacks | 48% (15 ms) | 62% (88 ms) | 68% (11 ms) | |
| | foolbox | 48% (15 ms) | 62% (55 ms) | 68% (24 ms) | |
| | ART | 48% (64 ms) | 62% (750 ms) | 68% (223 ms) | |
| BIM (Linf) | torchattacks | 46% (83 ms) | 58% (671 ms) | 64% (119 ms) | |
| | foolbox | 46% (80 ms) | 58% (1169 ms) | 64% (256 ms) | |
| | ART | 46% (248 ms) | 58% (2571 ms) | 64% (760 ms) | |
| PGD (Linf) | torchattacks | 46% (64 ms) | 58% (593 ms) | 64% (95 ms) | |
| | foolbox | 46% (70 ms) | 58% (1177 ms) | 64% (264 ms) | |
| | ART | 46% (243 ms) | 58% (2569 ms) | 64% (759 ms) | |
| CW (L2) | torchattacks | 14% / 0.00016 (4361 ms) | 22% / 0.00013 (4361 ms) | 26% / 8.5e-05 (13052 ms) | Different Results |
| | foolbox | 32% / 0.00016 (4564 ms) | 34% / 0.00017 (4361 ms) | 32% / 0.00016 (13332 ms) | |
| | ART | 32% / 0.00016 (72684 ms) | 34% / 0.00017 (4361 ms) | 32% / 0.00016 (206290 ms) | Slower than others |
| DeepFool (L2) | torchattacks | 20% / 0.00063 (12942 ms) | 14% / 0.00094 (46856 ms) | 10% / 0.0021 (14232 ms) | Different Results / Slower than others |
| | foolbox | 40% / 0.00018 (1959 ms) | 36% / 0.00019 (20410 ms) | 46% / 0.00021 (5936 ms) | |
| | ART | 40% / 0.00018 (2193 ms) | 36% / 0.00019 (19941 ms) | 46% / 0.00021 (5905 ms) | |

* **Note**:
* In torchattacks, there is no binary search algorithms for const `c`. It will be added in the future. Recommanded to use MultiAttack.
* In torchattacks, DeepFool takes longer time than other methods. Altough it produces stronger adverarial examples, please use other packages untill fixed.

## Documentation

Expand All @@ -109,7 +167,7 @@ Here is a [documentation](https://adversarial-attacks-pytorch.readthedocs.io/en/

### :mag_right: Update Records

Here is [update records](Update%20Records.md) of this package.
Here is [update records](update_records.md) of this package.



Expand Down
168 changes: 168 additions & 0 deletions contributions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Guidelines for Contributors

First of all, thank you for participating in this project :heart_eyes:.
The goal of this project is to make _paper-oriented_ and _easy_ implementation that is easy to understand even for those who are new to adversarial attacks.
Thus, I note some guidelines so that contributors can easily create new attacks, so please refer to it.

## Add a New Attack

### Fork the repo

Please use git or github desktop to fork torchattacks.

### Class Inheritance

Torchattacks uses `torchattacks.Attack`.

`Attack` is the mother class of all attacks in torchattacks.

In this class, there are some functions to control the training mode of the model, the untargeted/targeted mode of the attack, the return type of adversarial images, and save the adversarial images as below:

```python
class Attack(object):
#~~~~~#
def __init__(self, name, model):
r"""
Initializes internal attack state.
Arguments:
name (str) : name of an attack.
model (torch.nn.Module): model to attack.
"""

self.attack = name
self.model = model
self.model_name = str(model).split("(")[0]

self.training = model.training
self.device = next(model.parameters()).device

self._targeted = 1
self._attack_mode = 'default'
self._return_type = 'float'
self._target_map_function = lambda images, labels:labels
#~~~~~#
def set_attack_mode(self, mode, target_map_function=None):
#~~~~~#
def set_return_type(self, type):
#~~~~~#
def save(self, data_loader, save_path=None, verbose=True):
```

The most important thing is that `Attack` only takes `model` when it is called.

Now, all methods are made on assumption that "Users feed the original images and labels".

However, it can be changed in the future if there is a new attack that uses other inputs.
**Any ideas to further improve `Attack` class are welcome!!!**

### Define a New Attack

#### 1. Name a new method.
It is free how to name it, but if there is a common name, please use that name.

#### 2. Make a file in torchattacks/attacks
Now, let's make a file with '[name].py' in './attacks/'.

Here, **the file name must be written in lowercase** following PEP8.

For example, fgsm.py or pgd.py.

#### 3. Fill the file.
There are two things to prepare:
1. The original paper.
2. Attack algorithm.

As an example, here is the code of `torchattacks.PGD`:

```python
import torch
import torch.nn as nn

from ..attack import Attack


class PGD(Attack):
r"""
PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
[https://arxiv.org/abs/1706.06083]
Distance Measure : Linf
Arguments:
model (nn.Module): model to attack.
eps (float): maximum perturbation. (DEFALUT: 0.3)
alpha (float): step size. (DEFALUT: 2/255)
steps (int): number of steps. (DEFALUT: 40)
random_start (bool): using random initialization of delta. (DEFAULT: False)
Shape:
- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1].
- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
- output: :math:`(N, C, H, W)`.
Examples::
>>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=40, random_start=False)
>>> adv_images = attack(images, labels)
"""
def __init__(self, model, eps=0.3, alpha=2/255, steps=40, random_start=False):
super(PGD, self).__init__("PGD", model)
self.eps = eps
self.alpha = alpha
self.steps = steps
self.random_start = random_start

def forward(self, images, labels):
r"""
Overridden.
"""
images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)
labels = self._transform_label(images, labels)

loss = nn.CrossEntropyLoss()

adv_images = images.clone().detach()

if self.random_start:
# Starting at a uniformly random point
adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
adv_images = torch.clamp(adv_images, min=0, max=1).detach()

for i in range(self.steps):
adv_images.requires_grad = True
outputs = self.model(adv_images)

cost = self._targeted*loss(outputs, labels)

grad = torch.autograd.grad(cost, adv_images,
retain_graph=False, create_graph=False)[0]

adv_images = adv_images.detach() + self.alpha*grad.sign()
delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
adv_images = torch.clamp(images + delta, min=0, max=1).detach()

return adv_images

```

As above, the paper information is noted in the first line after the class definition.
Likewise here, **the class name must be written in uppercase** following PEP8.
I think it will be easy to fix the rest part of code based on other attack implementations. (This is how I do it :blush:).

#### 4. Git Pull

Finally, **PULL** your new branch to Github!!!


## TODO List

Here, I note **TODO List** for those who are interested in this.

* **Fix DeepFool**
* Use BATCHES instead of Iterations.
* Memory Error when calculate jacobian gradients.

* **Add JSMA**
* https://github.com/Harry24k/adversarial-attacks-pytorch/issues/9
44 changes: 22 additions & 22 deletions demos/Adversairal Training (MNIST).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/5], lter [100/468], Loss: 2.1549\n",
"Epoch [1/5], lter [200/468], Loss: 1.6238\n",
"Epoch [1/5], lter [300/468], Loss: 1.6483\n",
"Epoch [1/5], lter [400/468], Loss: 1.3744\n",
"Epoch [2/5], lter [100/468], Loss: 1.1771\n",
"Epoch [2/5], lter [200/468], Loss: 0.9683\n",
"Epoch [2/5], lter [300/468], Loss: 1.1367\n",
"Epoch [2/5], lter [400/468], Loss: 0.8875\n",
"Epoch [3/5], lter [100/468], Loss: 0.8134\n",
"Epoch [3/5], lter [200/468], Loss: 0.6230\n",
"Epoch [3/5], lter [300/468], Loss: 0.8044\n",
"Epoch [3/5], lter [400/468], Loss: 0.5878\n",
"Epoch [4/5], lter [100/468], Loss: 0.5479\n",
"Epoch [4/5], lter [200/468], Loss: 0.4738\n",
"Epoch [4/5], lter [300/468], Loss: 0.5863\n",
"Epoch [4/5], lter [400/468], Loss: 0.4373\n",
"Epoch [5/5], lter [100/468], Loss: 0.4429\n",
"Epoch [5/5], lter [200/468], Loss: 0.4137\n",
"Epoch [5/5], lter [300/468], Loss: 0.4763\n",
"Epoch [5/5], lter [400/468], Loss: 0.3299\n"
"Epoch [1/5], lter [100/468], Loss: 2.2584\n",
"Epoch [1/5], lter [200/468], Loss: 1.9181\n",
"Epoch [1/5], lter [300/468], Loss: 1.9304\n",
"Epoch [1/5], lter [400/468], Loss: 1.6101\n",
"Epoch [2/5], lter [100/468], Loss: 1.4774\n",
"Epoch [2/5], lter [200/468], Loss: 1.2040\n",
"Epoch [2/5], lter [300/468], Loss: 1.4303\n",
"Epoch [2/5], lter [400/468], Loss: 1.2566\n",
"Epoch [3/5], lter [100/468], Loss: 1.1698\n",
"Epoch [3/5], lter [200/468], Loss: 0.8809\n",
"Epoch [3/5], lter [300/468], Loss: 1.0796\n",
"Epoch [3/5], lter [400/468], Loss: 0.9179\n",
"Epoch [4/5], lter [100/468], Loss: 0.8632\n",
"Epoch [4/5], lter [200/468], Loss: 0.6719\n",
"Epoch [4/5], lter [300/468], Loss: 0.8642\n",
"Epoch [4/5], lter [400/468], Loss: 0.6858\n",
"Epoch [5/5], lter [100/468], Loss: 0.7318\n",
"Epoch [5/5], lter [200/468], Loss: 0.5638\n",
"Epoch [5/5], lter [300/468], Loss: 0.7381\n",
"Epoch [5/5], lter [400/468], Loss: 0.5195\n"
]
}
],
Expand Down Expand Up @@ -212,7 +212,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Standard accuracy: 97.31 %\n"
"Standard accuracy: 95.87 %\n"
]
}
],
Expand Down Expand Up @@ -251,7 +251,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Robust accuracy: 89.67 %\n"
"Robust accuracy: 86.07 %\n"
]
}
],
Expand Down
Loading

0 comments on commit 82d5b82

Please sign in to comment.