The implementation of our Pattern Recognition 2022 paper: "FocusNet: Classifying better by focusing on confusing classes"
Paper: https://www.sciencedirect.com/science/article/abs/pii/S003132032200190X?via%3Dihub
- This repository mainly relies on "ImageNet training in PyTorch". Therefore, it is helpful for you to refer to its document.
- The first version of our architecture was named ClonalNet, and after the second revision we changed its name to FocusNet. Therefore, the following clonalnet is just focusnet.
This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset.
- Install PyTorch (pytorch.org)
pip install -r requirements.txt
- Note: the
requirements.txt
in this repository is not the same as the official requirements. If something goes wrong, please use the official requirements. - Download the ImageNet dataset from http://www.image-net.org/
- Then, move validation images to labeled subfolders, using the following shell script
To train our network, run clonalnet_main.py
with the desired model architecture and the path to the ImageNet dataset:
python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc
resnet34
mobilenet_v2
The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs.
To evaluate our network, run clonalnet_main.py
with the desired model architecture and the path to the ImageNet dataset:
python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc -e --resume clonalnet_resnet18_model_best.pth.tar
resnet34 clonalnet_resnet34_model_best.pth.tar
mobilenet_v2 clonalnet_mobilenet_v2_model_best.pth.tar
The clonal_resnet18_from_scratch.log
and the clonal_resnet34_from_scratch.log
are the training logs of the clonalnet_resnet18 and the clonalnet_resnet34.
To validate the baseline results, please run:
# resnet18 / resnet34
python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a resnet18 --seed 10 -e --pretrained --gpu 0
resnet34
# mobilenet_V2
python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a mobilenet_v2 --seed 10 -e --pretrained --gpu 0 --resume models/_pytorch_pretrained_checkpoints/baseline_mobilenet_v2_model_best.pth.tar
Models | Acc@1 | Acc@5 | Checkpoint |
---|---|---|---|
ResNet18 | 69.760 | 89.082 | PyTorch Pre-trained |
ClonalNet (r18) | 70.422 | 89.562 | Baidu, code:1234; Google Driver |
ResNet34 | 73.310 | 91.420 | PyTorch Pre-trained |
ClonalNet (r34) | 74.366 | 91.884 | Baidu, code:1234; Google Driver |
MobileNet_v2 | 65.558 | 86.744 | Baidu, code:1234; Google Driver |
ClonalNet (MobileNet_v2) | 66.300 | 87.232 | Baidu; Google Driver |
you can also download more checkpoints at here: Baidu, code: 1234; Google Driver.
If you find our work is helpful to you, please cite it:
@article{zhang2022focusnet,
title={FocusNet: Classifying better by focusing on confusing classes},
author={Zhang, Xue and Sheng, Zehua and Shen, Hui-Liang},
journal={Pattern Recognition},
pages={108709},
year={2022},
publisher={Elsevier}
}