From e98873997b6d77c29877e43d26944f85413afd18 Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Thu, 29 Dec 2022 21:58:06 +0100 Subject: [PATCH] Move no-top models to the separate model types Since no-top models could not be wrapped into ImageRecognitionModelBase they should not be loaded by CV model types. --- .../onnx/cv/efficicentnet/EfficientNetB0.kt | 2 +- .../lightAPI/EfficientNetB0LightAPI.kt | 2 +- .../lightAPI/EfficientNetB7LightAPI.kt | 2 +- ...ntNetB0additionalTrainingWithTensorFlow.kt | 4 +- .../onnx/cv/onnxAdditionalTraining.kt | 4 +- ...esNet50additionalTrainingWithTensorFlow.kt | 4 +- .../modelhub/nasnet/NasNetMobile.kt | 2 +- ...50_prediction_additional_training_noTop.kt | 4 +- ...n_additional_training_noTop_with_helper.kt | 4 +- ...ample_5_VGG16_additional_training_noTop.kt | 4 +- .../transferLearningRunner.kt | 3 +- .../onnx/cv/OnnxEfficientNetTestSuite.kt | 28 +- .../modelhub/inception/InceptionTestSuite.kt | 6 +- .../modelhub/mobilenet/MobileNetTestSuite.kt | 6 +- .../modelhub/nasnet/NasNetTestSuite.kt | 5 +- .../modelhub/resnet/ResNetTestSuite.kt | 15 +- .../kotlinx/dl/onnx/inference/ONNXModels.kt | 259 +++++++++++------- .../dl/onnx/summary/OnnxModelsSummaryTests.kt | 2 +- .../dl/api/inference/loaders/TFModels.kt | 226 ++++++++++++--- 19 files changed, 396 insertions(+), 186 deletions(-) diff --git a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNetB0.kt b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNetB0.kt index 5517e1344..5434bb049 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNetB0.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNetB0.kt @@ -23,7 +23,7 @@ import java.io.File */ fun efficientNetB0Prediction() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = ONNXModels.CV.EfficientNetB0() + val modelType = ONNXModels.CV.EfficientNetB0 val model = modelHub.loadModel(modelType) model.printSummary() diff --git a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB0LightAPI.kt b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB0LightAPI.kt index dcfe1257d..5da17e69e 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB0LightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB0LightAPI.kt @@ -20,7 +20,7 @@ fun efficientNetB0EasyPrediction() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val model = ONNXModels.CV.EfficientNetB0().pretrainedModel(modelHub) + val model = ONNXModels.CV.EfficientNetB0.pretrainedModel(modelHub) model.printSummary() diff --git a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB7LightAPI.kt b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB7LightAPI.kt index b537a66f8..1558f6833 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB7LightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/lightAPI/EfficientNetB7LightAPI.kt @@ -20,7 +20,7 @@ fun efficientNetB7LightAPIPrediction() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val model = ONNXModels.CV.EfficientNetB7().pretrainedModel(modelHub) + val model = ONNXModels.CV.EfficientNetB7.pretrainedModel(modelHub) model.printSummary() model.use { diff --git a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/notop/EfficientNetB0additionalTrainingWithTensorFlow.kt b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/notop/EfficientNetB0additionalTrainingWithTensorFlow.kt index d1507286a..793ed2fa6 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/notop/EfficientNetB0additionalTrainingWithTensorFlow.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/notop/EfficientNetB0additionalTrainingWithTensorFlow.kt @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels -import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing import java.io.File private const val EPOCHS = 1 @@ -54,7 +54,7 @@ private val topModel = Sequential.of( */ fun efficientNetB0AdditionalTraining() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = ONNXModels.CV.EfficientNetB0(noTop = true) + val modelType = ONNXModels.CVnoTop.EfficientNetB0 modelHub.loadModel(modelType).use { model -> println(model) diff --git a/examples/src/main/kotlin/examples/onnx/cv/onnxAdditionalTraining.kt b/examples/src/main/kotlin/examples/onnx/cv/onnxAdditionalTraining.kt index dd47cc51a..b738ca2f5 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/onnxAdditionalTraining.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/onnxAdditionalTraining.kt @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels -import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing import java.io.File private const val EPOCHS = 1 @@ -43,7 +43,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.8 * We use the preprocessing DSL to describe the dataset generation pipeline. * We demonstrate the workflow on the subset of Kaggle Cats vs Dogs binary classification dataset. */ -fun runONNXAdditionalTraining(modelType: ONNXModels.CV) { +fun runONNXAdditionalTraining(modelType: ONNXModels.CVnoTop) { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath() diff --git a/examples/src/main/kotlin/examples/onnx/cv/resnet/notop/ResNet50additionalTrainingWithTensorFlow.kt b/examples/src/main/kotlin/examples/onnx/cv/resnet/notop/ResNet50additionalTrainingWithTensorFlow.kt index 16656f980..efbf264dd 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/resnet/notop/ResNet50additionalTrainingWithTensorFlow.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/resnet/notop/ResNet50additionalTrainingWithTensorFlow.kt @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels -import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing import java.io.File private const val EPOCHS = 3 @@ -56,7 +56,7 @@ fun resnet50additionalTraining() { val modelHub = ONNXModelHub( cacheDirectory = File("cache/pretrainedModels") ) - val modelType = ONNXModels.CV.ResNet50noTopCustom + val modelType = ONNXModels.CVnoTop.ResNet50Custom modelHub.loadModel(modelType).use { model -> println(model) diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/nasnet/NasNetMobile.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/nasnet/NasNetMobile.kt index b6fae0a5f..59af0d6e9 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/nasnet/NasNetMobile.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/nasnet/NasNetMobile.kt @@ -20,7 +20,7 @@ import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels * NOTE: Input resolution is 224*224 */ fun nasNetMobilePrediction() { - runImageRecognitionPrediction(modelType = TFModels.CV.NASNetMobile()) + runImageRecognitionPrediction(modelType = TFModels.CV.NASNetMobile) } /** */ diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_5_ResNet50_prediction_additional_training_noTop.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_5_ResNet50_prediction_additional_training_noTop.kt index b2466c6fc..e2203cc47 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_5_ResNet50_prediction_additional_training_noTop.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_5_ResNet50_prediction_additional_training_noTop.kt @@ -17,7 +17,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels -import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders @@ -43,7 +43,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.7 */ fun resnet50noTopAdditionalTraining() { val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3)) + val modelType = TFModels.CVnoTop.ResNet50(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3)) val model = modelHub.loadModel(modelType) val hdfFile = modelHub.loadWeights(modelType) diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_6_ResNet50_prediction_additional_training_noTop_with_helper.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_6_ResNet50_prediction_additional_training_noTop_with_helper.kt index fc2f0fbc6..41398923a 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_6_ResNet50_prediction_additional_training_noTop_with_helper.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/resnet/Example_6_ResNet50_prediction_additional_training_noTop_with_helper.kt @@ -17,7 +17,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels -import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders @@ -44,7 +44,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.7 */ fun resnet50additionalTrainingNoTopWithHelper() { val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) + val modelType = TFModels.CVnoTop.ResNet50(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) val noTopModel = modelHub.loadModel(modelType) val hdfFile = modelHub.loadWeights(modelType) diff --git a/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg16/Example_5_VGG16_additional_training_noTop.kt b/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg16/Example_5_VGG16_additional_training_noTop.kt index 591d792ab..797397fc9 100644 --- a/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg16/Example_5_VGG16_additional_training_noTop.kt +++ b/examples/src/main/kotlin/examples/transferlearning/modelhub/vgg16/Example_5_VGG16_additional_training_noTop.kt @@ -18,7 +18,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels -import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing import org.jetbrains.kotlinx.dl.api.summary.printSummary import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath @@ -49,7 +49,7 @@ private const val EPOCHS = 2 */ fun vgg16noTopAdditionalTraining() { val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = TFModels.CV.VGG16(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3)) + val modelType = TFModels.CVnoTop.VGG16(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3)) val model = modelHub.loadModel(modelType) val layers = mutableListOf() diff --git a/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt b/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt index 0e12d0bd7..5d18af177 100644 --- a/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt +++ b/examples/src/main/kotlin/examples/transferlearning/transferLearningRunner.kt @@ -20,6 +20,7 @@ import org.jetbrains.kotlinx.dl.api.inference.keras.* import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing +import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing import org.jetbrains.kotlinx.dl.api.summary.printSummary import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath @@ -33,7 +34,7 @@ private const val NUM_CLASSES = 2 private const val TRAIN_TEST_SPLIT_RATIO = 0.7 fun runImageRecognitionTransferLearning( - modelType: TFModels.CV, + modelType: TFModels.CVnoTop, epochs: Int = 2 ) { val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels")) diff --git a/examples/src/test/kotlin/examples/onnx/cv/OnnxEfficientNetTestSuite.kt b/examples/src/test/kotlin/examples/onnx/cv/OnnxEfficientNetTestSuite.kt index 1e4cdcea2..005f0dd29 100644 --- a/examples/src/test/kotlin/examples/onnx/cv/OnnxEfficientNetTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/cv/OnnxEfficientNetTestSuite.kt @@ -30,72 +30,72 @@ class OnnxEfficientNetTestSuite { @Test fun efficientNetB1PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB1()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB1) } @Test fun efficientNetB1AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB1(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB1) } @Test fun efficientNetB2PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB2()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB2) } @Test fun efficientNetB2AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB2(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB2) } @Test fun efficientNetB3PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB3()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB3) } @Test fun efficientNetB3AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB3(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB3) } @Test fun efficientNetB4PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB4()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB4) } @Test fun efficientNetB4AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB4(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB4) } @Test fun efficientNetB5PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB5()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB5) } @Test fun efficientNetB5AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB5(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB5) } @Test fun efficientNetB6PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB6()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB6) } @Test fun efficientNetB6AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB6(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB6) } @Test fun efficientNetB7PredictionTest() { - runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB7()) + runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB7) } @Test fun efficientNetB7AdditionalTrainingTest() { - runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB7(noTop = true)) + runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB7) } @Test diff --git a/examples/src/test/kotlin/examples/transferlearning/modelhub/inception/InceptionTestSuite.kt b/examples/src/test/kotlin/examples/transferlearning/modelhub/inception/InceptionTestSuite.kt index cf7cff7e2..c1a954e58 100644 --- a/examples/src/test/kotlin/examples/transferlearning/modelhub/inception/InceptionTestSuite.kt +++ b/examples/src/test/kotlin/examples/transferlearning/modelhub/inception/InceptionTestSuite.kt @@ -18,8 +18,7 @@ class InceptionTestSuite { @Test fun inceptionV3dditionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.Inception( - noTop = true, + modelType = TFModels.CVnoTop.Inception( inputShape = intArrayOf(99, 99, 3) ) ) @@ -33,8 +32,7 @@ class InceptionTestSuite { @Test fun xceptionAdditionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.Xception( - noTop = true, + modelType = TFModels.CVnoTop.Xception( inputShape = intArrayOf(90, 90, 3) ) ) diff --git a/examples/src/test/kotlin/examples/transferlearning/modelhub/mobilenet/MobileNetTestSuite.kt b/examples/src/test/kotlin/examples/transferlearning/modelhub/mobilenet/MobileNetTestSuite.kt index 96740f0ff..f0a5035eb 100644 --- a/examples/src/test/kotlin/examples/transferlearning/modelhub/mobilenet/MobileNetTestSuite.kt +++ b/examples/src/test/kotlin/examples/transferlearning/modelhub/mobilenet/MobileNetTestSuite.kt @@ -23,8 +23,7 @@ class MobileNetTestSuite { @Test fun mobilenetAdditionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.MobileNet( - noTop = true, + modelType = TFModels.CVnoTop.MobileNet( inputShape = intArrayOf(100, 100, 3) ) ) @@ -38,8 +37,7 @@ class MobileNetTestSuite { @Test fun mobilenetv2AdditionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.MobileNetV2( - noTop = true, + modelType = TFModels.CVnoTop.MobileNetV2( inputShape = intArrayOf(120, 120, 3) ) ) diff --git a/examples/src/test/kotlin/examples/transferlearning/modelhub/nasnet/NasNetTestSuite.kt b/examples/src/test/kotlin/examples/transferlearning/modelhub/nasnet/NasNetTestSuite.kt index 8146844ff..61e5b4c79 100644 --- a/examples/src/test/kotlin/examples/transferlearning/modelhub/nasnet/NasNetTestSuite.kt +++ b/examples/src/test/kotlin/examples/transferlearning/modelhub/nasnet/NasNetTestSuite.kt @@ -17,7 +17,7 @@ class NasNetTestSuite { @Test fun nasNetMobileAdditionalTrainingNoTopTest() { - runImageRecognitionTransferLearning(modelType = TFModels.CV.NASNetMobile(noTop = true)) + runImageRecognitionTransferLearning(modelType = TFModels.CVnoTop.NASNetMobile) } @Test @@ -28,8 +28,7 @@ class NasNetTestSuite { @Test fun nasNetLargeAdditionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.NASNetLarge( - noTop = true, + modelType = TFModels.CVnoTop.NASNetLarge( inputShape = intArrayOf(370, 370, 3) ) ) diff --git a/examples/src/test/kotlin/examples/transferlearning/modelhub/resnet/ResNetTestSuite.kt b/examples/src/test/kotlin/examples/transferlearning/modelhub/resnet/ResNetTestSuite.kt index 615285b0a..b2cc7dc1b 100644 --- a/examples/src/test/kotlin/examples/transferlearning/modelhub/resnet/ResNetTestSuite.kt +++ b/examples/src/test/kotlin/examples/transferlearning/modelhub/resnet/ResNetTestSuite.kt @@ -59,8 +59,7 @@ class ResNetTestSuite { @Test fun resnet50v2additionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.ResNet50v2( - noTop = true, + modelType = TFModels.CVnoTop.ResNet50v2( inputShape = intArrayOf(150, 150, 3) ) ) @@ -74,8 +73,7 @@ class ResNetTestSuite { @Test fun resnet101additionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.ResNet101( - noTop = true, + modelType = TFModels.CVnoTop.ResNet101( inputShape = intArrayOf(100, 100, 3) ) ) @@ -89,8 +87,7 @@ class ResNetTestSuite { @Test fun resnet101v2additionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.ResNet101v2( - noTop = true, + modelType = TFModels.CVnoTop.ResNet101v2( inputShape = intArrayOf(120, 120, 3) ) ) @@ -104,8 +101,7 @@ class ResNetTestSuite { @Test fun resnet152additionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.ResNet152( - noTop = true, + modelType = TFModels.CVnoTop.ResNet152( inputShape = intArrayOf(152, 152, 3) ) ) @@ -119,8 +115,7 @@ class ResNetTestSuite { @Test fun resnet152v2additionalTrainingNoTopTest() { runImageRecognitionTransferLearning( - modelType = TFModels.CV.ResNet152v2( - noTop = true, + modelType = TFModels.CVnoTop.ResNet152v2( inputShape = intArrayOf(90, 90, 3) ) ) diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt index 2cd14086a..27cb536a8 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt @@ -38,13 +38,10 @@ public object ONNXModels { * Note: the wrong choice of this parameter can significantly impact the model's performance. * */ public sealed class CV( - relativePath: String, - protected val channelsFirst: Boolean, - private val inputColorMode: ColorMode = ColorMode.RGB, - /** If true, model is shipped without last few layers and could be used for transfer learning and fine-tuning with TF Runtime. */ - noTop: Boolean = false + override val modelRelativePath: String, + internal val channelsFirst: Boolean, + internal val inputColorMode: ColorMode = ColorMode.RGB ) : OnnxModelType { - override val modelRelativePath: String = if (noTop) "$relativePath-notop" else relativePath override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel { return ImageRecognitionModel( @@ -317,7 +314,7 @@ public object ONNXModels { * - an input with the shape (1x224x224x3) * - an output with the shape (1x1000) * - * NOTE: This model is converted from Keras.applications and could be used to be compared with the [ResNet50noTopCustom] model. + * No-top version is available as [CVnoTop.ResNet50Custom]. * * @see * Deep Residual Learning for Image Recognition (CVPR 2015) @@ -334,29 +331,6 @@ public object ONNXModels { get() = InputType.CAFFE.preprocessing() } - /** - * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes - * (labels are available via [org.jetbrains.kotlinx.dl.impl.dataset.Imagenet.V1k.labels] method). - * - * This model has 50 layers with ResNetv1 architecture. - * - * The model have - * - an input with the shape (1x224x224x3) - * - an output with the shape (N,M3,M4,2048) - * - * NOTE: This model is converted from Keras.applications, the last two layers in the model have been removed so that the user can fine-tune the model for his specific task. - * - * @see - * Deep Residual Learning for Image Recognition (CVPR 2015) - * @see - * Official ResNet model from Keras.applications. - */ - public object ResNet50noTopCustom : - CV("models/onnx/cv/custom/resnet50notop", channelsFirst = false) { - override val preprocessor: Operation - get() = InputType.CAFFE.preprocessing() - } - /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes * (labels are available via [org.jetbrains.kotlinx.dl.impl.dataset.Imagenet.V1k.labels] method). @@ -366,21 +340,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x224x224x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x7x7x1280) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB0]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB0 model from Keras.applications. */ - public class EfficientNetB0(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b0", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB0 : CV("models/onnx/cv/efficientnet/efficientnet-b0", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -391,21 +359,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x240x240x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x7x7x1280) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB1]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB1 model from Keras.applications. */ - public class EfficientNetB1(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b1", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB1 : CV("models/onnx/cv/efficientnet/efficientnet-b1", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -416,21 +378,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x260x260x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x8x8x1408) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB2]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB2 model from Keras.applications. */ - public class EfficientNetB2(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b2", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB2 : CV("models/onnx/cv/efficientnet/efficientnet-b2", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -441,21 +397,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x300x300x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x9x9x1536) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB3]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB3 model from Keras.applications. */ - public class EfficientNetB3(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b3", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB3 : CV("models/onnx/cv/efficientnet/efficientnet-b3", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -466,21 +416,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x380x380x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x11x11x1792) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB4]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB4 model from Keras.applications. */ - public class EfficientNetB4(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b4", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB4 : CV("models/onnx/cv/efficientnet/efficientnet-b4", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -491,21 +435,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x456x456x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x14x14x2048) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB5]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB5 model from Keras.applications. */ - public class EfficientNetB5(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b5", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB5 : CV("models/onnx/cv/efficientnet/efficientnet-b5", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -516,21 +454,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x528x528x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x16x16x2304) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB6]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB6 model from Keras.applications. */ - public class EfficientNetB6(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b6", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB6 : CV("models/onnx/cv/efficientnet/efficientnet-b6", channelsFirst = false) /** * This model is a neural network for image classification that take images as input and classify the major object in the image into a set of 1000 different classes @@ -541,21 +473,15 @@ public object ONNXModels { * The model have * - an input with the shape (1x600x600x3) * - an output with the shape (1x1000) - * - an output for noTop model with the shape (1x18x18x2560) * - * NOTE: This model is converted from Keras.applications, the last two layers in the noTop model have been removed so that the user can fine-tune the model for his specific task. + * No-top version is available as [CVnoTop.EfficientNetB7]. * * @see * EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks * @see * Official EfficientNetB7 model from Keras.applications. */ - public class EfficientNetB7(noTop: Boolean = false) : - CV( - "models/onnx/cv/efficientnet/efficientnet-b7", - channelsFirst = false, - noTop = noTop - ) + public object EfficientNetB7 : CV("models/onnx/cv/efficientnet/efficientnet-b7", channelsFirst = false) /** * This model is a neural network for digit classification that take grey-scale images of digits as input and classify the major object in the image into a set of 10 different classes. @@ -578,6 +504,151 @@ public object ONNXModels { } } + /** + * Image classification models without top layers. + */ + public sealed class CVnoTop( + override val modelRelativePath: String, + protected val channelsFirst: Boolean, + private val inputColorMode: ColorMode, + override val preprocessor: Operation + ) : OnnxModelType { + + protected constructor(baseModelType: CV) : this( + "${baseModelType.modelRelativePath}-notop", + baseModelType.channelsFirst, + baseModelType.inputColorMode, + baseModelType.preprocessor + ) + + override fun pretrainedModel(modelHub: ModelHub): OnnxInferenceModel { + return modelHub.loadModel(this) + } + + /** + * This model is a no-top version of the [CV.ResNet50custom] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x224x224x3) + * - an output with the shape (N,M3,M4,2048) + * + * @see CV.ResNet50custom + */ + public object ResNet50Custom : + CVnoTop( + "models/onnx/cv/custom/resnet50notop", + inputColorMode = ColorMode.RGB, + channelsFirst = false, + preprocessor = InputType.CAFFE.preprocessing() + ) + + /** + * This model is a no-top version of the [CV.EfficientNetB0] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x224x224x3) + * - an output with the shape (1x7x7x1280) + * + * @see CV.EfficientNetB0 + */ + public object EfficientNetB0 : CVnoTop(CV.EfficientNetB0) + + /** + * This model is a no-top version of the [CV.EfficientNetB1] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x240x240x3) + * - an output with the shape (1x7x7x1280) + * + * @see CV.EfficientNetB1 + */ + public object EfficientNetB1 : CVnoTop(CV.EfficientNetB1) + + /** + * This model is a no-top version of the [CV.EfficientNetB2] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x260x260x3) + * - an output with the shape (1x8x8x1408) + * + * @see CV.EfficientNetB2 + */ + public object EfficientNetB2 : CVnoTop(CV.EfficientNetB3) + + /** + * This model is a no-top version of the [CV.EfficientNetB3] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x300x300x3) + * - an output with the shape (1x9x9x1536) + * + * @see CV.EfficientNetB3 + */ + public object EfficientNetB3 : CVnoTop(CV.EfficientNetB3) + + /** + * This model is a no-top version of the [CV.EfficientNetB4] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x380x380x3) + * - an output with the shape (1x11x11x1792) + * + * @see CV.EfficientNetB4 + */ + public object EfficientNetB4 : CVnoTop(CV.EfficientNetB4) + + /** + * This model is a no-top version of the [CV.EfficientNetB5] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x456x456x3) + * - an output with the shape (1x14x14x2048) + * + * @see CV.EfficientNetB5 + */ + public object EfficientNetB5 : CVnoTop(CV.EfficientNetB5) + + /** + * This model is a no-top version of the [CV.EfficientNetB6] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x528x528x3) + * - an output with the shape (1x16x16x2304) + * + * @see CV.EfficientNetB6 + */ + public object EfficientNetB6 : CVnoTop(CV.EfficientNetB6) + + /** + * This model is a no-top version of the [CV.EfficientNetB7] model. + * The last two layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * The model has + * - an input with the shape (1x600x600x3) + * - an output with the shape (1x18x18x2560) + * + * @see CV.EfficientNetB7 + */ + public object EfficientNetB7 : CVnoTop(CV.EfficientNetB7) + + public companion object { + /** + * Creates a preprocessing [Operation] which converts given [BufferedImage] to [FloatData] suitable for this [model]. + */ + public fun CVnoTop.createPreprocessing(model: InferenceModel): Operation { + return createPreprocessing(model, channelsFirst, inputColorMode, preprocessor) + } + } + } + /** Object detection models and preprocessing. */ public sealed class ObjectDetection(override val modelRelativePath: String) : OnnxModelType { diff --git a/onnx/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/onnx/summary/OnnxModelsSummaryTests.kt b/onnx/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/onnx/summary/OnnxModelsSummaryTests.kt index de9aca24d..0cda66f79 100644 --- a/onnx/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/onnx/summary/OnnxModelsSummaryTests.kt +++ b/onnx/src/jvmTest/kotlin/org/jetbrains/kotlinx/dl/onnx/summary/OnnxModelsSummaryTests.kt @@ -53,7 +53,7 @@ class OnnxModelsSummaryTests { private val efficientNetB0Summary = ModelHubModelSummary( internalSummary = efficientNetB0InternalModelSummary, - modelKindDescription = ONNXModels.CV.EfficientNetB0()::class.simpleName + modelKindDescription = ONNXModels.CV.EfficientNetB0::class.simpleName ) private val internalModelExpectedFormat = listOf( diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/loaders/TFModels.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/loaders/TFModels.kt index 8134a0823..4bc26809e 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/loaders/TFModels.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/loaders/TFModels.kt @@ -37,15 +37,12 @@ public object TFModels { * Note: the wrong choice of this parameter can significantly impact the model's performance. * */ public sealed class CV( - relativePath: String, - private val channelsFirst: Boolean = false, - private val inputColorMode: ColorMode = ColorMode.RGB, - public var inputShape: IntArray? = null, - noTop: Boolean = false + override val modelRelativePath: String, + internal val channelsFirst: Boolean = false, + internal val inputColorMode: ColorMode = ColorMode.RGB, + public var inputShape: IntArray? = null ) : TFModelType { - override val modelRelativePath: String = if (noTop) "$relativePath/notop" else relativePath - init { if (inputShape != null) { require(inputShape!!.size == 3) { "Input shape for the model ${this.javaClass.kotlin.simpleName} should contain 3 number: height, weight and number of channels." } @@ -80,11 +77,10 @@ public object TFModels { * @see * Official VGG16 model from Keras.applications. */ - public class VGG16(noTop: Boolean = false, inputShape: IntArray? = null) : + public class VGG16(inputShape: IntArray? = null) : CV( "models/tensorflow/cv/vgg16", inputShape = inputShape, - noTop = noTop, inputColorMode = ColorMode.BGR ) { override val preprocessor: Operation @@ -112,11 +108,10 @@ public object TFModels { * @see * Official VGG19 model from Keras.applications. */ - public class VGG19(noTop: Boolean = false, inputShape: IntArray? = null) : + public class VGG19(inputShape: IntArray? = null) : CV( "models/tensorflow/cv/vgg19", inputShape = inputShape, - noTop = noTop, inputColorMode = ColorMode.BGR ) { override val preprocessor: Operation @@ -188,11 +183,10 @@ public object TFModels { * @see * Official ResNet50 model from Keras.applications. */ - public class ResNet50(noTop: Boolean = false, inputShape: IntArray? = null) : + public class ResNet50(inputShape: IntArray? = null) : CV( "models/tensorflow/cv/resnet50", inputShape = inputShape, - noTop = noTop, inputColorMode = ColorMode.BGR ) { override val preprocessor: Operation @@ -218,11 +212,10 @@ public object TFModels { * @see * Official ResNet101 model from Keras.applications. */ - public class ResNet101(noTop: Boolean = false, inputShape: IntArray? = null) : + public class ResNet101(inputShape: IntArray? = null) : CV( "models/tensorflow/cv/resnet101", inputShape = inputShape, - noTop = noTop, inputColorMode = ColorMode.BGR ) { override val preprocessor: Operation @@ -248,11 +241,10 @@ public object TFModels { * @see * Official ResNet152 model from Keras.applications. */ - public class ResNet152(noTop: Boolean = false, inputShape: IntArray? = null) : + public class ResNet152(inputShape: IntArray? = null) : CV( "models/tensorflow/cv/resnet152", inputShape = inputShape, - noTop = noTop, inputColorMode = ColorMode.BGR ) { override val preprocessor: Operation @@ -278,8 +270,8 @@ public object TFModels { * @see * Official ResNet50v2 model from Keras.applications. */ - public class ResNet50v2(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/resnet50v2", inputShape = inputShape, noTop = noTop) { + public class ResNet50v2(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/resnet50v2", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -303,8 +295,8 @@ public object TFModels { * @see * Official ResNet101v2 model from Keras.applications. */ - public class ResNet101v2(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/resnet101v2", inputShape = inputShape, noTop = noTop) { + public class ResNet101v2(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/resnet101v2", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -328,8 +320,8 @@ public object TFModels { * @see * Official ResNet152v2 model from Keras.applications. */ - public class ResNet152v2(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/resnet152v2", inputShape = inputShape, noTop = noTop) { + public class ResNet152v2(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/resnet152v2", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -351,8 +343,8 @@ public object TFModels { * @see * Official MobileNet model from Keras.applications. */ - public class MobileNet(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/mobilenet", inputShape = inputShape, noTop = noTop) { + public class MobileNet(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/mobilenet", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -374,8 +366,8 @@ public object TFModels { * @see * Official MobileNetV2 model from Keras.applications. */ - public class MobileNetV2(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/mobilenetv2", inputShape = inputShape, noTop = noTop) { + public class MobileNetV2(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/mobilenetv2", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -397,8 +389,8 @@ public object TFModels { * @see * Official InceptionV3 model from Keras.applications. */ - public class Inception(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/inception", inputShape = inputShape, noTop = noTop) { + public class Inception(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/inception", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -420,8 +412,8 @@ public object TFModels { * @see * Official Xception model from Keras.applications. */ - public class Xception(noTop: Boolean = false, inputShape: IntArray? = null) : - CV("models/tensorflow/cv/xception", inputShape = inputShape, noTop = noTop) { + public class Xception(inputShape: IntArray? = null) : + CV("models/tensorflow/cv/xception", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -444,7 +436,7 @@ public object TFModels { * Official DenseNet121 model from Keras.applications. */ public class DenseNet121(inputShape: IntArray? = null) : - CV("models/tensorflow/cv/densenet121", inputShape = inputShape, noTop = false) { + CV("models/tensorflow/cv/densenet121", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TORCH.preprocessing() } @@ -467,7 +459,7 @@ public object TFModels { * Official DenseNet169 model from Keras.applications. */ public class DenseNet169(inputShape: IntArray? = null) : - CV("models/tensorflow/cv/densenet169", inputShape = inputShape, noTop = false) { + CV("models/tensorflow/cv/densenet169", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TORCH.preprocessing() } @@ -490,7 +482,7 @@ public object TFModels { * Official DenseNet201 model from Keras.applications. */ public class DenseNet201(inputShape: IntArray? = null) : - CV("models/tensorflow/cv/densenet201", inputShape = inputShape, noTop = false) { + CV("models/tensorflow/cv/densenet201", inputShape = inputShape) { override val preprocessor: Operation get() = InputType.TORCH.preprocessing() } @@ -512,8 +504,8 @@ public object TFModels { * @see * Official NASNetMobile model from Keras.applications. */ - public class NASNetMobile(noTop: Boolean = false) : - CV("models/tensorflow/cv/nasnetmobile", inputShape = intArrayOf(224, 224, 3), noTop = noTop) { + public object NASNetMobile : + CV("models/tensorflow/cv/nasnetmobile", inputShape = intArrayOf(224, 224, 3)) { override val preprocessor: Operation get() = InputType.TF.preprocessing() } @@ -535,8 +527,8 @@ public object TFModels { * @see * Official NASNetLarge model from Keras.applications. */ - public class NASNetLarge(noTop: Boolean = false, inputShape: IntArray? = intArrayOf(331, 331, 3)) : - CV("models/tensorflow/cv/nasnetlarge", inputShape = inputShape, noTop = noTop) { + public class NASNetLarge(inputShape: IntArray? = intArrayOf(331, 331, 3)) : + CV("models/tensorflow/cv/nasnetlarge", inputShape = inputShape) { init { require(inputShape!![0] >= 331 && inputShape[1] >= 331) { "Width and height should be no smaller than 331 for the model ${this.javaClass.kotlin.simpleName}." } } @@ -555,7 +547,163 @@ public object TFModels { } } - private fun loadModel(modelHub: ModelHub, modelType: CV): GraphTrainableModel { + /** + * Image classification models without top layers. + */ + public sealed class CVnoTop(private val baseModelType: CV) : TFModelType { + override val modelRelativePath: String = "${baseModelType.modelRelativePath}/notop" + + override val preprocessor: Operation + get() = baseModelType.preprocessor + + override fun pretrainedModel(modelHub: ModelHub): T { + return loadModel(modelHub, this) + } + + override fun loadModelConfiguration(jsonFile: File): T { + return baseModelType.loadModelConfiguration(jsonFile) + } + + /** + * This model is a no-top version of the [CV.VGG16] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.VGG16 + */ + public class VGG16(inputShape: IntArray? = null) : CVnoTop(CV.VGG16(inputShape)) + + /** + * This model is a no-top version of the [CV.VGG19] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.VGG19 + */ + public class VGG19(inputShape: IntArray? = null) : CVnoTop(CV.VGG19(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet18] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet18 + */ + public class ResNet18(inputShape: IntArray? = null) : CVnoTop(CV.ResNet18(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet34] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet34 + */ + public class ResNet34(inputShape: IntArray? = null) : CVnoTop(CV.ResNet34(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet50] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet50 + */ + public class ResNet50(inputShape: IntArray? = null) : CVnoTop(CV.ResNet50(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet101] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet101 + */ + public class ResNet101(inputShape: IntArray? = null) : CVnoTop(CV.ResNet101(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet152] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet152 + */ + public class ResNet152(inputShape: IntArray? = null) : CVnoTop(CV.ResNet152(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet50v2] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet50v2 + */ + public class ResNet50v2(inputShape: IntArray? = null) : CVnoTop(CV.ResNet50v2(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet101v2] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet101v2 + */ + public class ResNet101v2(inputShape: IntArray? = null) : CVnoTop(CV.ResNet101v2(inputShape)) + + /** + * This model is a no-top version of the [CV.ResNet152v2] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.ResNet152v2 + */ + public class ResNet152v2(inputShape: IntArray? = null) : CVnoTop(CV.ResNet152v2(inputShape)) + + /** + * This model is a no-top version of the [CV.MobileNet] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.MobileNet + */ + public class MobileNet(inputShape: IntArray? = null) : CVnoTop(CV.MobileNet(inputShape)) + + /** + * This model is a no-top version of the [CV.MobileNetV2] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.MobileNetV2 + */ + public class MobileNetV2(inputShape: IntArray? = null) : CVnoTop(CV.MobileNetV2(inputShape)) + + /** + * This model is a no-top version of the [CV.Inception] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.Inception + */ + public class Inception(inputShape: IntArray? = null) : CVnoTop(CV.Inception(inputShape)) + + /** + * This model is a no-top version of the [CV.Xception] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.Xception + */ + public class Xception(inputShape: IntArray? = null) : CVnoTop(CV.Xception(inputShape)) + + /** + * This model is a no-top version of the [CV.NASNetMobile] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.NASNetMobile + */ + public object NASNetMobile : CVnoTop(CV.NASNetMobile) + + /** + * This model is a no-top version of the [CV.NASNetLarge] model. + * The last few layers of the base model have been removed so that the model could be fine-tuned for users specific task. + * + * @see CV.NASNetLarge + */ + public class NASNetLarge(inputShape: IntArray? = intArrayOf(331, 331, 3)) : + CVnoTop(CV.NASNetLarge(inputShape)) + + public companion object { + /** + * Creates a preprocessing [Operation] which converts given [BufferedImage] to [FloatData] suitable for this [model]. + */ + public fun CVnoTop<*>.createPreprocessing(model: InferenceModel): Operation { + return createPreprocessing(model, baseModelType.channelsFirst, baseModelType.inputColorMode, preprocessor) + } + } + } + + private fun loadModel(modelHub: ModelHub, modelType: TFModelType): T { modelHub as TFModelHub val model = modelHub.loadModel(modelType) // TODO: this part is not needed for inference (if we could add manually Softmax at the end of the graph)