This implements training of popular model architectures, such as AlexNet, SqueezeNet, ResNet, DenseNet and VGG on the ImageNet dataset(Now we supported alexnet, vgg, resnet, squeezenet, densenet).
- PyTorch 0.4.0
- cuda && cudnn
- Download the ImageNet dataset and move validation images to labeled subfolders
- To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
To train a model, run main.py with the desired model architecture and the path to the ImageNet dataset:
python main.py [imagenet-folder with train and val folders] -a alexnet --lr 0.01
The default learning rate schedule starts at 0.01 and decays by a factor of 10 every 30 epochs.
usage: main.py [-h] [-a ARCH] [--epochs N] [--start-epoch N] [-b N] [--lr LR]
[--momentum M] [--weight-decay W] [-j N] [-m] [-p]
[--print-freq N] [--resume PATH] [-e]
DIR
PyTorch ImageNet Training
positional arguments:
DIR path to dataset
optional arguments:
-h, --help show this help message and exit
-a ARCH, --arch ARCH model architecture: alexnet | squeezenet1_0 |
squeezenet1_1 | densenet121 | densenet169 |
densenet201 | densenet201 | densenet161 | vgg11 |
vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19
| vgg19_bn | resnet18 | resnet34 | resnet50 |
resnet101 | resnet152 (default: alexnet)
--epochs N numer of total epochs to run
--start-epoch N manual epoch number (useful to restarts)
-b N, --batch-size N mini-batch size (default: 256)
--lr LR, --learning-rate LR
initial learning rate
--momentum M momentum
--weight-decay W, --wd W
Weight decay (default: 1e-4)
-j N, --workers N number of data loading workers (default: 4)
-m, --pin-memory use pin memory
-p, --pretrained use pre-trained model
--print-freq N, -f N print frequency (default: 10)
--resume PATH path to latest checkpoitn, (default: None)
-e, --evaluate evaluate model on validation set
The results of a single model on ILSVRC-2012 validation set.
Model | top1@prec (val) | top5@prec (val) | Parameters | ModelSize(MB) |
---|---|---|---|---|
AlexNet | 56.522% | 79.066% | 244 | |
SqueezeNet1_0 | 58.092% | 80.420% | 5 | |
SqueezeNet1_1 | 58.178% | 80.624% | 5 | |
DenseNet121 | 74.434% | 91.972% | 32 | |
DenseNet169 | 75.600% | 92.806% | 57 | |
DenseNet201 | 76.896% | 93.370% | 81 | |
DenseNet161 | 77.138% | 93.560% | 116 | |
Vgg11 | 69.020% | 88.628% | 532 | |
Vgg13 | 69.928% | 89.246% | 532 | |
Vgg16 | 71.592% | 90.382% | 554 | |
Vgg19 | 72.376% | 90.876% | 574 | |
Vgg11_bn | 70.370% | 89.810% | 532 | |
Vgg13_bn | 71.586% | 90.374% | 532 | |
Vgg16_bn | 73.360% | 91.516% | 554 | |
Vgg19_bn | 74.218% | 91.842% | 574 | |
ResNet18 | 69.758% | 89.078% | 47 | |
ResNet34 | 73.314% | 91.420% | 87 | |
ResNet50 | 76.130% | 92.862% | 103 | |
ResNet101 | 77.374% | 93.546% | 179 | |
ResNet152 | 78.312% | 94.046% | 242 |