This repository contains the implementation of the paper Deep Frank-Wolfe For Neural Network Optimization in pytorch. If you use this work for your research, please cite the paper:
@Article{berrada2019deep,
author = {Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan},
title = {Deep Frank-Wolfe For Neural Network Optimization},
journal = {International Conference on Learning Representations},
year = {2019},
}
Note: you might be interested in our follow-up algorithm ALI-G, which has explicit convergence guarantees and outperforms DFW in our experiments.
This code should work for pytorch >= 1.0 in python3. Detailed requirements are available in requirements.txt
.
- Clone this repository:
git clone --recursive https://github.com/oval-group/dfw
(note that the optionrecursive
is necessary to have clone the submodules, these are needed to reproduce the experiments but not for the DFW implementation itself). - Go to directory and install the requirements:
cd dfw && pip install -r requirements.txt
- Install the DFW package
python setup.py install
- Simple usage example:
from dfw import DFW
from dfw.losses import MultiClassHingeLoss
# boilerplate code:
# `model` is a nn.Module
# `x` is an input sample, `y` is a label
# create loss function
svm = MultiClassHingeLoss()
# create DFW optimizer with learning rate of 0.1
optimizer = DFW(model.parameters(), eta=0.1)
# DFW can be used with standard pytorch syntax
optimizer.zero_grad()
loss = svm(model(x), y)
loss.backward()
# NB: DFW needs to have access to the current loss value,
# (this syntax is compatible with standard pytorch optimizers too)
optimizer.step(lambda: float(loss))
-
Technical requirement: the DFW uses a custom step-size at each step. For this update to make sense, the loss function must be piecewise linear convex. For instance, one can use a multi-class SVM loss or an l1 regression.
-
Smoothing: sometimes the multi-class SVM loss does not fare well with a large number of classes. This issue can be alleviated by using dual smoothing, which is easy to plug in the code:
from dfw.losses import set_smoothing_enabled
...
with set_smoothing_enabled(True):
loss = svm(model(x), y)
- To reproduce the CIFAR experiments:
VISION_DATA=[path/to/your/cifar/data] python reproduce/cifar.py
- To reproduce the SNLI experiments: follow the preparation instructions and run
python reproduce/snli.py
DFW largely outperforms all baselines that do not use a manual schedule for the learning rate. The tables below show the performance on the CIFAR data sets when using data augmentation (AMSGrad, a variant of Adam, is the strongest baseline in our experiments), and on the SNLI data set.
Wide Residual Networks | Densely Connected Networks | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
Wide Residual Networks | Densely Connected Networks | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
Optimizer | Test Accuracy (%) |
---|---|
Adagrad | 84.6 |
Adam | 85.0 |
AMSGrad | 85.1 |
BPGrad | 84.2 |
DFW | 85.2 |
SGD (with schedule) | 85.2 |
We use the following third-part implementations: