-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfruit_sense_model.py
105 lines (84 loc) · 2.89 KB
/
fruit_sense_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.optimizers.experimental import RMSprop
from keras.layers import Dropout, Flatten, Dense
# Define data generators
train = ImageDataGenerator(rescale=1/255)
validation = ImageDataGenerator(rescale=1/255)
img_height, img_width = 180, 180
batch_size = 6
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
r"C:\Users\Parth\Desktop\project\Banana Ripeness Classification.v2-original-images_modifiedclasses.folder\train",
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
r"C:\Users\Parth\Desktop\project\Banana Ripeness Classification.v2-original-images_modifiedclasses.folder\valid",
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds.class_names
class_names = train_ds.class_names
# Define and compile the model
resnet_model = Sequential()
pretrained_model = tf.keras.applications.ResNet50(
include_top=False,
input_shape=(180, 180, 3),
pooling='max',
classes=4,
weights='imagenet'
)
for layer in pretrained_model.layers:
layer.trainable = False
resnet_model.add(pretrained_model)
resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(4, activation='softmax'))
resnet_model.compile(optimizer=Adam(learning_rate=0.001), loss="sparse_categorical_crossentropy", metrics=['accuracy'])
# Train the model
epochs = 12
history = resnet_model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
# Plot accuracy
plt.figure(figsize=(8, 6))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()
plt.show()
# Plot loss
plt.figure(figsize=(8, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()
# Evaluate on test data
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
r"C:\Users\Parth\Desktop\project\Banana Ripeness Classification.v2-original-images_modifiedclasses.folder\test",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
test_loss, test_accuracy = resnet_model.evaluate(test_ds, verbose=1)
print("Loss : ", test_loss)
print("Accuracy :", test_accuracy)