diff --git a/quickai/image_classification.py b/quickai/image_classification.py index 1d885d0..a70e4f0 100644 --- a/quickai/image_classification.py +++ b/quickai/image_classification.py @@ -27,7 +27,7 @@ def __init__( data_augmentation=False, epochs=20, graph=True): - self.model = model + self.model = model.lower() self.save = save self.path = path self.batch_size = batch_size @@ -51,10 +51,7 @@ def load_img_data( :param batch_size is batch size """ data_dir = pathlib.Path(path) - if grayscale: - color_mode = "grayscale" - else: - color_mode = "rgb" + color_mode = "grayscale" if grayscale else "rgb" train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, @@ -85,241 +82,40 @@ def use(self): """ self.use() """ - if self.model == "eb0": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB0(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb1": - img_size = 240 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB1(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb2": - img_size = 260 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB2(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb3": - img_size = 300 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB3(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb4": - img_size = 380 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB4(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb5": - img_size = 456 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB5(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb6": - img_size = 528 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB6(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "eb7": - img_size = 600 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.EfficientNetB7(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "vgg16": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.VGG16( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "vgg19": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.VGG19( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "dn121": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.DenseNet121( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "dn169": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.DenseNet169( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "dn201": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.DenseNet201( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "irnv2": - img_size = 299 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.InceptionResNetV2(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "iv3": - img_size = 299 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.InceptionV3( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "mn": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.MobileNet( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "mnv2": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.MobileNetV2( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "mnv3l": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.MobileNetV3Large(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "mnv3s": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.MobileNetV3Small(input_shape=( - img_size, img_size, 3), include_top=False, weights='imagenet') - elif self.model == "rn101": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet101( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "rn101v2": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet101V2( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - - elif self.model == "rn152": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet152( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "rn152v2": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet152V2( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "rn50": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet50( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "rn50v2": - img_size = 224 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.ResNet50V2( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - elif self.model == "xception": - img_size = 299 - train, val, class_num = self.load_img_data( - self.path, img_size, img_size, self.batch_size) - body = tf.keras.applications.Xception( - input_shape=( - img_size, - img_size, - 3), - include_top=False, - weights='imagenet') - else: - print("Model not found") + modeldata = {"eb0": [tf.keras.applications.EfficientNetB0, 224], + "eb1": [tf.keras.applications.EfficientNetB1, 240], + "eb2": [tf.keras.applications.EfficientNetB2, 260], + "eb3": [tf.keras.applications.EfficientNetB3, 300], + "eb4": [tf.keras.applications.EfficientNetB4, 340], + "eb5": [tf.keras.applications.EfficientNetB5, 456], + "eb6": [tf.keras.applications.EfficientNetB6, 528], + "eb7": [tf.keras.applications.EfficientNetB7, 600], + "vgg16": [tf.keras.applications.VGG16, 224], + "vgg19": [tf.keras.applications.VGG19, 224], + "dn121": [tf.keras.applications.DenseNet121, 224], + "dn169": [tf.keras.applications.DenseNet169, 224], + "dn201": [tf.keras.applications.DenseNet201, 224], + "irnv2": [tf.keras.applications.InceptionResNetV2, 299], + "iv3": [tf.keras.applications.InceptionV3, 299], + "mn": [tf.keras.applications.MobileNet, 224], + "mnv2": [tf.keras.applications.MobileNetV2, 224], + "mnv3l": [tf.keras.applications.MobileNetV3Large, 224], + "mnv3s": [tf.keras.applications.MobileNetV3Small, 224], + "rn101": [tf.keras.applications.ResNet101, 224], + "rn101v2": [tf.keras.applications.ResNet101V2, 224], + "rn152": [tf.keras.applications.ResNet152, 224], + "rn152v2": [tf.keras.applications.ResNet152V2, 224], + "rn50": [tf.keras.applications.ResNet50, 224], + "rn50v2": [tf.keras.applications.ResNet50V2, 224], + "xception": [tf.keras.applications.Xception, 299]} + + img_size = modeldata[self.model][1] + train, val, class_num = self.load_img_data(self.path, img_size, img_size, self.batch_size) + body = modeldata[self.model][0](input_shape= + (img_size, img_size, 3), + include_top=False, + weights="imagenet") body.trainable = False average_layer = GlobalAveragePooling2D()