This repository contains the implementation and comparison of different neural network architectures: a simple Multi-Layer Perceptron (MLP), a Convolutional Neural Network, a deeper Convolutional Neural Network and two ResNet architectures. The aim is to explore and analyze their performance on CIFAR10, providing insights into their strengths and weaknesses for image classification and trying to replicate the results highlighted from the paper Deep Residual Learning for Image Recognition. The following points will be explored:
- MLP vs CNN: Which architecture performs better in image recognition?
- CNN vs deeper CNN: Is it true that the deeper a model is, the better its performance?
- Residual Networks: The benefits of residual connections
- CNN vs ResNet: How residual connections can improve the accuracy and efficiency of the model
- GradCAM: Which are the image regions the model focuses on? Can we know why the model made that prediction?
The dataset used for training and evaluation is CIFAR-10.
Install the requirements:
pip install -r requirements.txt
Run the training script for each model:
python main.py --model resnet18 --epochs 50 --batch-size 128 --lr 0.1 --num-workers 2 --log
You can track experiments and compare results using Weights & Biases using --log argument. Then, the training logs, metrics, and gradients will be automatically uploaded to WandB.
As we can observe the CNN outperforms the MLP. The latter has the disadvantage of having too many parameters since it consists of fully connected layers where each node is connected to all others. Additionally, it takes flattened vectors as inputs, which disregards spatial information. In contrast, in a CNN, the layers are sparsely/partially connected, and thanks to convolutional operations, it can preserve spatial relationships in the data, capturing important features. This ability allows CNNs to handle image data more efficiently and achieve better performance in tasks like image recognition.
Observing the first epochs, the shallower CNN converges more quickly. As the epochs progress, CNN30 manages to achieve better accuracy, but this is also due to the phenomenon of overfitting. In general, as demonstrated in the paper (https://arxiv.org/abs/1512.03385), very deep architectures suffer from the problem of gradient degradation.
To address the problem of gradient vanishing, the authors of the paper (https://arxiv.org/abs/1512.03385) propose the use of residual connections, providing a 'shortcut' for gradients to flow backwards directly from later layers to earlier layers. As we can observe from the figures, thanks to residual connections, there is faster convergence and better performance.
This highlights how residual connections mitigate gradient vanishing and improve training efficiency and accuracy in deep networks.
Let's implement Grad-CAM following the paper Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization.
The idea is to analyze the gradients and the activations to gain insights into why the model made predictions.
Run the gradcam.py script:
!python gradcam.py --model cnn19 --label "horse" --img-size 32 --target-layer "model.conv_block3[9]"
!python gradcam.py --model cnn19 --label "frog" --img-size 32 --target-layer "model.conv_block3[9]"
!python gradcam.py --model cnn30 --label "horse" --img-size 32 --target-layer "model.conv_block3[9]"
!python gradcam.py --model cnn30 --label "frog" --img-size 32 --target-layer "model.conv_block3[9]"
!python gradcam.py --model resnet18 --label "horse" --img-size 32 --target-layer "model.layer4[-1]"
!python gradcam.py --model resnet18 --label "frog" --img-size 32 --target-layer "model.layer4[-1]"
-
The convolutional neural networks outperform MLPs for the task of image recognition. Convolutional layers naturally retain spatial information which is lost in fully-connected layers.
-
As the depth of a neural network increases beyond a certain point, its performance on the training and test sets starts to degrade due to the vanishing of the gradients.
-
Residual connections addresses the degradation problem providing a "shortcut" for gradients to flow backwards directly from later layers to earlier layers. This helps mitigate the vanishing gradient problem and enables the training of significantly deeper networks.
-
ResNets generally have smaller magnitudes of responses compared to plain networks. This means that residual functions are closer to zero than non-residual functions, making them easier to optimize. In other words, the model has to learn just small corrections.
-
Residual connections are great but be carefull with too deep architectures, they can lead to overfitting!
-
Analyzing the gradients and the activations we can gain insights into why the model made predictions (GradCAM).