Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move no-top models to the separate model types #511

Merged
merged 1 commit into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.io.File
*/
fun efficientNetB0Prediction() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = ONNXModels.CV.EfficientNetB0()
val modelType = ONNXModels.CV.EfficientNetB0
val model = modelHub.loadModel(modelType)
model.printSummary()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fun efficientNetB0EasyPrediction() {
val modelHub =
ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

val model = ONNXModels.CV.EfficientNetB0().pretrainedModel(modelHub)
val model = ONNXModels.CV.EfficientNetB0.pretrainedModel(modelHub)

model.printSummary()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fun efficientNetB7LightAPIPrediction() {
val modelHub =
ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

val model = ONNXModels.CV.EfficientNetB7().pretrainedModel(modelHub)
val model = ONNXModels.CV.EfficientNetB7.pretrainedModel(modelHub)
model.printSummary()

model.use {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing
import java.io.File

private const val EPOCHS = 1
Expand Down Expand Up @@ -54,7 +54,7 @@ private val topModel = Sequential.of(
*/
fun efficientNetB0AdditionalTraining() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = ONNXModels.CV.EfficientNetB0(noTop = true)
val modelType = ONNXModels.CVnoTop.EfficientNetB0

modelHub.loadModel(modelType).use { model ->
println(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing
import java.io.File

private const val EPOCHS = 1
Expand All @@ -43,7 +43,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.8
* We use the preprocessing DSL to describe the dataset generation pipeline.
* We demonstrate the workflow on the subset of Kaggle Cats vs Dogs binary classification dataset.
*/
fun runONNXAdditionalTraining(modelType: ONNXModels.CV) {
fun runONNXAdditionalTraining(modelType: ONNXModels.CVnoTop) {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.onnx.dataset.preprocessor.onnx
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModelHub
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels.CVnoTop.Companion.createPreprocessing
import java.io.File

private const val EPOCHS = 3
Expand Down Expand Up @@ -56,7 +56,7 @@ fun resnet50additionalTraining() {
val modelHub = ONNXModelHub(
cacheDirectory = File("cache/pretrainedModels")
)
val modelType = ONNXModels.CV.ResNet50noTopCustom
val modelType = ONNXModels.CVnoTop.ResNet50Custom

modelHub.loadModel(modelType).use { model ->
println(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
* NOTE: Input resolution is 224*224
*/
fun nasNetMobilePrediction() {
runImageRecognitionPrediction(modelType = TFModels.CV.NASNetMobile())
runImageRecognitionPrediction(modelType = TFModels.CV.NASNetMobile)
}

/** */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
Expand All @@ -43,7 +43,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.7
*/
fun resnet50noTopAdditionalTraining() {
val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3))
val modelType = TFModels.CVnoTop.ResNet50(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3))
val model = modelHub.loadModel(modelType)

val hdfFile = modelHub.loadWeights(modelType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
Expand All @@ -44,7 +44,7 @@ private const val TRAIN_TEST_SPLIT_RATIO = 0.7
*/
fun resnet50additionalTrainingNoTopWithHelper() {
val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
val modelType = TFModels.CVnoTop.ResNet50(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
val noTopModel = modelHub.loadModel(modelType)

val hdfFile = modelHub.loadWeights(modelType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.summary.printSummary
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
Expand Down Expand Up @@ -49,7 +49,7 @@ private const val EPOCHS = 2
*/
fun vgg16noTopAdditionalTraining() {
val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
val modelType = TFModels.CV.VGG16(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3))
val modelType = TFModels.CVnoTop.VGG16(inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, 3))
val model = modelHub.loadModel(modelType)

val layers = mutableListOf<Layer>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.jetbrains.kotlinx.dl.api.inference.keras.*
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CV.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels.CVnoTop.Companion.createPreprocessing
import org.jetbrains.kotlinx.dl.api.summary.printSummary
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
Expand All @@ -33,7 +34,7 @@ private const val NUM_CLASSES = 2
private const val TRAIN_TEST_SPLIT_RATIO = 0.7

fun runImageRecognitionTransferLearning(
modelType: TFModels.CV<out GraphTrainableModel>,
modelType: TFModels.CVnoTop<out GraphTrainableModel>,
epochs: Int = 2
) {
val modelHub = TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,72 +30,72 @@ class OnnxEfficientNetTestSuite {

@Test
fun efficientNetB1PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB1())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB1)
}

@Test
fun efficientNetB1AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB1(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB1)
}

@Test
fun efficientNetB2PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB2())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB2)
}

@Test
fun efficientNetB2AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB2(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB2)
}

@Test
fun efficientNetB3PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB3())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB3)
}

@Test
fun efficientNetB3AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB3(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB3)
}

@Test
fun efficientNetB4PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB4())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB4)
}

@Test
fun efficientNetB4AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB4(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB4)
}

@Test
fun efficientNetB5PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB5())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB5)
}

@Test
fun efficientNetB5AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB5(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB5)
}

@Test
fun efficientNetB6PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB6())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB6)
}

@Test
fun efficientNetB6AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB6(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB6)
}

@Test
fun efficientNetB7PredictionTest() {
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB7())
runONNXImageRecognitionPrediction(ONNXModels.CV.EfficientNetB7)
}

@Test
fun efficientNetB7AdditionalTrainingTest() {
runONNXAdditionalTraining(ONNXModels.CV.EfficientNetB7(noTop = true))
runONNXAdditionalTraining(ONNXModels.CVnoTop.EfficientNetB7)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class InceptionTestSuite {
@Test
fun inceptionV3dditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.Inception(
noTop = true,
modelType = TFModels.CVnoTop.Inception(
inputShape = intArrayOf(99, 99, 3)
)
)
Expand All @@ -33,8 +32,7 @@ class InceptionTestSuite {
@Test
fun xceptionAdditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.Xception(
noTop = true,
modelType = TFModels.CVnoTop.Xception(
inputShape = intArrayOf(90, 90, 3)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class MobileNetTestSuite {
@Test
fun mobilenetAdditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.MobileNet(
noTop = true,
modelType = TFModels.CVnoTop.MobileNet(
inputShape = intArrayOf(100, 100, 3)
)
)
Expand All @@ -38,8 +37,7 @@ class MobileNetTestSuite {
@Test
fun mobilenetv2AdditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.MobileNetV2(
noTop = true,
modelType = TFModels.CVnoTop.MobileNetV2(
inputShape = intArrayOf(120, 120, 3)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NasNetTestSuite {

@Test
fun nasNetMobileAdditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(modelType = TFModels.CV.NASNetMobile(noTop = true))
runImageRecognitionTransferLearning(modelType = TFModels.CVnoTop.NASNetMobile)
}

@Test
Expand All @@ -28,8 +28,7 @@ class NasNetTestSuite {
@Test
fun nasNetLargeAdditionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.NASNetLarge(
noTop = true,
modelType = TFModels.CVnoTop.NASNetLarge(
inputShape = intArrayOf(370, 370, 3)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class ResNetTestSuite {
@Test
fun resnet50v2additionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.ResNet50v2(
noTop = true,
modelType = TFModels.CVnoTop.ResNet50v2(
inputShape = intArrayOf(150, 150, 3)
)
)
Expand All @@ -74,8 +73,7 @@ class ResNetTestSuite {
@Test
fun resnet101additionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.ResNet101(
noTop = true,
modelType = TFModels.CVnoTop.ResNet101(
inputShape = intArrayOf(100, 100, 3)
)
)
Expand All @@ -89,8 +87,7 @@ class ResNetTestSuite {
@Test
fun resnet101v2additionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.ResNet101v2(
noTop = true,
modelType = TFModels.CVnoTop.ResNet101v2(
inputShape = intArrayOf(120, 120, 3)
)
)
Expand All @@ -104,8 +101,7 @@ class ResNetTestSuite {
@Test
fun resnet152additionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.ResNet152(
noTop = true,
modelType = TFModels.CVnoTop.ResNet152(
inputShape = intArrayOf(152, 152, 3)
)
)
Expand All @@ -119,8 +115,7 @@ class ResNetTestSuite {
@Test
fun resnet152v2additionalTrainingNoTopTest() {
runImageRecognitionTransferLearning(
modelType = TFModels.CV.ResNet152v2(
noTop = true,
modelType = TFModels.CVnoTop.ResNet152v2(
inputShape = intArrayOf(90, 90, 3)
)
)
Expand Down
Loading