Using VGG16 to extract features from image to train ML model.
Learn to use keras pretrained model to extract features from images and train Machine Learning model.
Classify Fish and People from cifar-100 dataset.
Cifar-100: This dataset has 100 classes containing 600 images each. There are 500 training images and 100 testing images per class. The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the class to which it belongs) and a "coarse" label (the superclass to which it belongs).
Link: https://www.cs.toronto.edu/~kriz/cifar.html
- Load the dataset.
- Extract People and Fish data from the dataset.
- Reshape and Preprocess the images.
- Load VGG16 model from keras using imagenet weights.
- Extract Features from VGG16.
- Train a ML model (we are using LogisticRegression)
- Extract Features for test data
- Test the model.
-
Tried using complete VGG16 for feautre extraction and used Logistic Regression for classification.
Network Summary :-Model: "vgg16" _________________________________________________________________ Layer (type) Output Shape Param ================================================================= input_1 (InputLayer) (None, 32, 32, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 32, 32, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 16, 16, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 16, 16, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 16, 16, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 8, 8, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 8, 8, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 8, 8, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 8, 8, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 4, 4, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 4, 4, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 4, 4, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 4, 4, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 2, 2, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 2, 2, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 2, 2, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 2, 2, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 1, 1, 512) 0 ================================================================= Total params: 14,714,688 Trainable params: 14,714,688 Non-trainable params: 0 _________________________________________________________________
Accuracy Score : 0.871
Classification Report:
Label | f1-score | precision | recall | support | |
---|---|---|---|---|---|
0 | Fish | 0.871129 | 0.872000 | 0.870259 | 501.000 |
1 | People | 0.870871 | 0.870000 | 0.871743 | 499.000 |
accuracy | 0.871000 | 0.871000 | 0.871000 | 0.871 | |
macro avg | 0.871000 | 0.871000 | 0.871001 | 1000.000 | |
weighted avg | 0.871000 | 0.871002 | 0.871000 | 1000.000 |
- Removed Last 8 layers of VGG16 and extracted features to train Logistic Regression.
Network Summary:-
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param ================================================================= block1_conv1 (Conv2D) (None, 32, 32, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 16, 16, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 16, 16, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 16, 16, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 8, 8, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 8, 8, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 8, 8, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 8, 8, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 4, 4, 256) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 4096) 0 ================================================================= Total params: 1,735,488 Trainable params: 1,735,488 Non-trainable params: 0 _________________________________________________________________
Accuracy Score : 0.931
Classification Report:
Label | f1-score | precision | recall | support | |
---|---|---|---|---|---|
0 | Fish | 0.930931 | 0.930000 | 0.931864 | 499.000 |
1 | People | 0.931069 | 0.932000 | 0.930140 | 501.000 |
accuracy | 0.931000 | 0.931000 | 0.931000 | 0.931 | |
macro avg | 0.931000 | 0.931000 | 0.931002 | 1000.000 | |
weighted avg | 0.931000 | 0.931002 | 0.931000 | 1000.000 |
The pretrained VGG16 model is trained on imagenet datasets which contains more than 20,000 categories. So the last convolution layers will capture complex high level feature for those categories. which is not required by our model. Hence,the reduced model performed better for our dataset.