Skip to content

Commit

Permalink
Merge pull request #26 from Pinjuf/main
Browse files Browse the repository at this point in the history
Better code for image_classification.py
  • Loading branch information
geekjr authored May 6, 2021
2 parents 41bbe9f + af15920 commit 8a179af
Showing 1 changed file with 35 additions and 239 deletions.
274 changes: 35 additions & 239 deletions quickai/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
data_augmentation=False,
epochs=20,
graph=True):
self.model = model
self.model = model.lower()
self.save = save
self.path = path
self.batch_size = batch_size
Expand All @@ -51,10 +51,7 @@ def load_img_data(
:param batch_size is batch size
"""
data_dir = pathlib.Path(path)
if grayscale:
color_mode = "grayscale"
else:
color_mode = "rgb"
color_mode = "grayscale" if grayscale else "rgb"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
Expand Down Expand Up @@ -85,241 +82,40 @@ def use(self):
"""
self.use()
"""
if self.model == "eb0":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB0(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb1":
img_size = 240
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB1(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb2":
img_size = 260
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB2(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb3":
img_size = 300
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB3(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb4":
img_size = 380
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB4(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb5":
img_size = 456
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB5(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb6":
img_size = 528
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB6(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "eb7":
img_size = 600
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.EfficientNetB7(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "vgg16":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.VGG16(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "vgg19":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.VGG19(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "dn121":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.DenseNet121(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "dn169":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.DenseNet169(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "dn201":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.DenseNet201(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "irnv2":
img_size = 299
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.InceptionResNetV2(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "iv3":
img_size = 299
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.InceptionV3(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "mn":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.MobileNet(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "mnv2":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.MobileNetV2(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "mnv3l":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.MobileNetV3Large(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "mnv3s":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.MobileNetV3Small(input_shape=(
img_size, img_size, 3), include_top=False, weights='imagenet')
elif self.model == "rn101":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet101(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "rn101v2":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet101V2(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')

elif self.model == "rn152":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet152(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "rn152v2":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet152V2(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')

elif self.model == "rn50":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet50(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "rn50v2":
img_size = 224
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.ResNet50V2(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
elif self.model == "xception":
img_size = 299
train, val, class_num = self.load_img_data(
self.path, img_size, img_size, self.batch_size)
body = tf.keras.applications.Xception(
input_shape=(
img_size,
img_size,
3),
include_top=False,
weights='imagenet')
else:
print("Model not found")
modeldata = {"eb0": [tf.keras.applications.EfficientNetB0, 224],
"eb1": [tf.keras.applications.EfficientNetB1, 240],
"eb2": [tf.keras.applications.EfficientNetB2, 260],
"eb3": [tf.keras.applications.EfficientNetB3, 300],
"eb4": [tf.keras.applications.EfficientNetB4, 340],
"eb5": [tf.keras.applications.EfficientNetB5, 456],
"eb6": [tf.keras.applications.EfficientNetB6, 528],
"eb7": [tf.keras.applications.EfficientNetB7, 600],
"vgg16": [tf.keras.applications.VGG16, 224],
"vgg19": [tf.keras.applications.VGG19, 224],
"dn121": [tf.keras.applications.DenseNet121, 224],
"dn169": [tf.keras.applications.DenseNet169, 224],
"dn201": [tf.keras.applications.DenseNet201, 224],
"irnv2": [tf.keras.applications.InceptionResNetV2, 299],
"iv3": [tf.keras.applications.InceptionV3, 299],
"mn": [tf.keras.applications.MobileNet, 224],
"mnv2": [tf.keras.applications.MobileNetV2, 224],
"mnv3l": [tf.keras.applications.MobileNetV3Large, 224],
"mnv3s": [tf.keras.applications.MobileNetV3Small, 224],
"rn101": [tf.keras.applications.ResNet101, 224],
"rn101v2": [tf.keras.applications.ResNet101V2, 224],
"rn152": [tf.keras.applications.ResNet152, 224],
"rn152v2": [tf.keras.applications.ResNet152V2, 224],
"rn50": [tf.keras.applications.ResNet50, 224],
"rn50v2": [tf.keras.applications.ResNet50V2, 224],
"xception": [tf.keras.applications.Xception, 299]}

img_size = modeldata[self.model][1]
train, val, class_num = self.load_img_data(self.path, img_size, img_size, self.batch_size)
body = modeldata[self.model][0](input_shape=
(img_size, img_size, 3),
include_top=False,
weights="imagenet")

body.trainable = False
average_layer = GlobalAveragePooling2D()
Expand Down

0 comments on commit 8a179af

Please sign in to comment.