Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new models to ModelHub: PoseNet and EfficientDet family #317

Merged
merged 20 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public object TFModels {
): ImageRecognitionModel {
modelHub as TFModelHub
val model = modelHub.loadModel(modelType)

// TODO: this part is not needed for inference (if we could add manually Softmax at the end of the graph)
model.compile(
optimizer = Adam(),
loss = Losses.MAE,
Expand All @@ -221,6 +221,7 @@ public object TFModels {
}

/** Basic interface for models loaded from S3. */
// TODO: add information about T and U types
public interface ModelType<T : InferenceModel, U : InferenceModel> {
/** Relative path to model for local and S3 buckets storages. */
public val modelRelativePath: String
Expand Down Expand Up @@ -256,4 +257,8 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
public fun model(modelHub: ModelHub): T {
return modelHub.loadModel(this)
}

public fun preInit(): InferenceModel {
TODO()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.jetbrains.kotlinx.dl.api.inference.posedetection

public data class DetectedPose (
val poseLandmarks: List<PoseLandmark>,
val edges: List<PoseEdge>
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.jetbrains.kotlinx.dl.api.inference.posedetection

import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject

public data class MultiPoseDetectionResult (
val multiplePoses: MutableList<Pair<DetectedObject, DetectedPose>>
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.jetbrains.kotlinx.dl.api.inference.posedetection

/**
* @property [poseEdgeLabel] The predicted pose edge label.
* @property [probability] The probability of the predicted class.
*/
public data class PoseEdge(
val poseEdgeLabel: String,
val probability: Float,
val start: PoseLandmark,
val end: PoseLandmark,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.jetbrains.kotlinx.dl.api.inference.posedetection

/**
* @property [poseLandmarkLabel] The predicted pose landmark label.
* @property [probability] The probability of the predicted class.
*/
public data class PoseLandmark (
val poseLandmarkLabel: String,
val probability: Float,
val x: Float,
val y: Float,
)
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,93 @@ public val cocoCategories: Map<Int, String> = mapOf(
79 to "hair drier",
80 to "toothbrush"
)


/**
* 80 object categories in COCO dataset.
*
* @see <a href="https://cocodataset.org/#home">
* COCO dataset</a>
*/
public val cocoCategoriesForEfficientDet: Map<Int, String> = mapOf(
1 to "person",
2 to "bicycle",
3 to "car",
4 to "motorcycle",
5 to "airplane",
6 to "bus",
7 to "train",
8 to "truck",
9 to "boat",
10 to "traffic light",
11 to "fire hydrant",
13 to "stop sign",
14 to "parking meter",
15 to "bench",
16 to "bird",
17 to "cat",
18 to "dog",
19 to "horse",
20 to "sheep",
21 to "cow",
22 to "elephant",
23 to "bear",
24 to "zebra",
25 to "giraffe",
27 to "backpack",
28 to "umbrella",
31 to "handbag",
32 to "tie",
33 to "suitcase",
34 to "frisbee",
35 to "skis",
36 to "snowboard",
37 to "sports ball",
38 to "kite",
39 to "baseball bat",
40 to "baseball glove",
41 to "skateboard",
42 to "surfboard",
43 to "tennis racket",
44 to "bottle",
46 to "wine glass",
47 to "cup",
48 to "fork",
49 to "knife",
50 to "spoon",
51 to "bowl",
52 to "banana",
53 to "apple",
54 to "sandwich",
55 to "orange",
56 to "broccoli",
57 to "carrot",
58 to "hot dog",
59 to "pizza",
60 to "donut",
61 to "cake",
62 to "chair",
63 to "couch",
64 to "potted plant",
65 to "bed",
67 to "dining table",
70 to "toilet",
72 to "tv",
73 to "laptop",
74 to "mouse",
75 to "remote",
76 to "keyboard",
77 to "cell phone",
78 to "microwave",
79 to "oven",
80 to "toaster",
81 to "sink",
82 to "refrigerator",
84 to "book",
85 to "clock",
86 to "vase",
87 to "scissors",
88 to "teddy bear",
89 to "hair drier",
90 to "toothbrush"
)
2 changes: 1 addition & 1 deletion examples/src/main/kotlin/examples/cnn/cifar10/VGG.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert
import java.io.File

private const val PATH_TO_MODEL = "savedmodels/vgg11"
private const val EPOCHS = 3
private const val EPOCHS = 1 // 10, at least, is recommended
private const val TRAINING_BATCH_SIZE = 128
private const val TEST_BATCH_SIZE = 1000
private const val NUM_LABELS = 10
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/kotlin/examples/cnn/fashionmnist/VGG.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.dataset.fashionMnist

private const val EPOCHS = 2
private const val EPOCHS = 1
private const val TRAINING_BATCH_SIZE = 100
private const val TEST_BATCH_SIZE = 1000
private const val NUM_LABELS = 10
Expand Down
5 changes: 3 additions & 2 deletions examples/src/main/kotlin/examples/cnn/fsdd/SoundNet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.jetbrains.kotlinx.dl.dataset.FSDD_SOUND_DATA_SIZE
import org.jetbrains.kotlinx.dl.dataset.freeSpokenDigits
import org.jetbrains.kotlinx.dl.dataset.handler.NUMBER_OF_CLASSES

private const val EPOCHS = 10
private const val TRAINING_BATCH_SIZE = 500
private const val EPOCHS = 20 // 20, at least, is recommended
private const val TRAINING_BATCH_SIZE = 64
private const val TEST_BATCH_SIZE = 500
private const val NUM_CHANNELS = 1L
private const val SEED = 12L
Expand Down Expand Up @@ -122,6 +122,7 @@ private val soundNet = Sequential.of(
*/
fun soundNet() {
val (train, test) = freeSpokenDigits()
train.shuffle()

soundNet.use {
it.compile(
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/kotlin/examples/cnn/mnist/VGG.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.core.summary.logSummary
import org.jetbrains.kotlinx.dl.dataset.mnist

private const val EPOCHS = 5
private const val EPOCHS = 1
private const val TRAINING_BATCH_SIZE = 200
private const val TEST_BATCH_SIZE = 1000
private const val NUM_LABELS = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ fun lenetOnMnistExportImportToJson() {
it.fit(
dataset = train,
validationRate = 0.1,
epochs = 5,
epochs = 2,
trainBatchSize = 1000,
validationBatchSize = 100
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv.efficicentnet

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.core.util.loadImageNetClassLabels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/
package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv.efficicentnet.lightAPI

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/
package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv.efficicentnet.lightAPI

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
Expand All @@ -15,7 +15,7 @@ import java.io.File
* - Model is obtained from [ONNXModelHub].
* - Model predicts on a few images located in resources.
*/
fun efficientNetB7EasyPrediction() {
fun efficientNetB7LightAPIPrediction() {
val modelHub =
ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

Expand All @@ -35,4 +35,4 @@ fun efficientNetB7EasyPrediction() {
}

/** */
fun main(): Unit = efficientNetB7EasyPrediction()
fun main(): Unit = efficientNetB7LightAPIPrediction()
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv.efficicentnet.notop

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.mnist
package examples.onnx.cv.lenet

import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom.efficientnet
package examples.onnx.cv

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.core.util.loadImageNetClassLabels
Expand All @@ -16,8 +16,6 @@ import org.jetbrains.kotlinx.dl.dataset.preprocessor.*
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize
import java.io.File
import java.net.URISyntaxException
import java.net.URL

fun runONNXImageRecognitionPrediction(
modelType: ONNXModels.CV<out OnnxInferenceModel>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package examples.onnx.cv.resnet

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import java.io.File
Expand All @@ -16,7 +15,7 @@ import java.io.File
* - Model is obtained from [ONNXModelHub].
* - Model predicts on a few images located in resources.
*/
fun resnet18easyPrediction() {
fun resnet18LightAPIPrediction() {
val modelHub =
ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

Expand All @@ -36,4 +35,4 @@ fun resnet18easyPrediction() {
}

/** */
fun main(): Unit = resnet18easyPrediction()
fun main(): Unit = resnet18LightAPIPrediction()
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom
package examples.onnx.cv.resnet.notop

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.core.util.loadImageNetClassLabels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.cv.custom
package examples.onnx.cv.resnet.notop

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
Expand Down
4 changes: 2 additions & 2 deletions examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fun main() {
val inputData = modelType.preprocessInput(preprocessing)

val yhat = it.predictRaw(inputData)
println(yhat.toTypedArray().contentDeepToString())
println(yhat.values.toTypedArray().contentDeepToString())

visualiseLandMarks(imageFile, yhat)
}
Expand All @@ -58,7 +58,7 @@ fun main() {

fun visualiseLandMarks(
imageFile: File,
landmarks: List<Array<*>>
landmarks: Map<String, Any>
) {
val preprocessing: Preprocessing = preprocess {
load {
Expand Down
Loading