Skip to content

Commit

Permalink
Refactoring of extension functions with ImageProxy input (#454)
Browse files Browse the repository at this point in the history
* Introduce doWithRotation extension function to reduce code duplication

* Move extension functions to the base classes (e.g. SinglePoseDetectionModelBase instead SinglePoseDetectionModel)

Co-authored-by: Julia Beliaeva <[email protected]>
  • Loading branch information
ermolenkodev and juliabeliaeva committed Sep 29, 2022
1 parent 6bac806 commit 10be87e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,16 @@ public interface CameraXCompatibleModel {
*/
public var targetRotation: Int
}

/**
* Convenience function to execute arbitrary code with a preliminary updated target rotation.
* After the code is executed, the target rotation is restored to its original value.
*
* @param rotation target rotation to be set for the duration of the code execution
* @param function arbitrary code to be executed
*/
public fun <R> CameraXCompatibleModel.doWithRotation(rotation: Int, function: () -> R): R {
val currentRotation = targetRotation
targetRotation = rotation
return function().apply { targetRotation = currentRotation }
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionM
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.detectPose
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap
Expand Down Expand Up @@ -54,11 +56,13 @@ public open class ImageRecognitionModel(
*
* @return The label of the recognized object with the highest probability.
*/
public fun ImageRecognitionModel.predictObject(imageProxy: ImageProxy): String {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return predictObject(imageProxy.toBitmap()).also { targetRotation = currentRotation }
}
public fun ImageRecognitionModelBase<Bitmap>.predictObject(imageProxy: ImageProxy): String =
when (this) {
is CameraXCompatibleModel -> {
doWithRotation(imageProxy.imageInfo.rotationDegrees) { predictObject(imageProxy.toBitmap()) }
}
else -> predictObject(imageProxy.toBitmap(applyRotation = true))
}

/**
* Predicts [topK] objects for the given [imageProxy].
Expand All @@ -70,11 +74,13 @@ public fun ImageRecognitionModel.predictObject(imageProxy: ImageProxy): String {
*
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
*/
public fun ImageRecognitionModel.predictTopKObjects(
public fun ImageRecognitionModelBase<Bitmap>.predictTopKObjects(
imageProxy: ImageProxy,
topK: Int = 5
): List<Pair<String, Float>> {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return predictTopKObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation }
}
): List<Pair<String, Float>> =
when (this) {
is CameraXCompatibleModel -> {
doWithRotation(imageProxy.imageInfo.rotationDegrees) { predictTopKObjects(imageProxy.toBitmap(), topK) }
}
else -> predictTopKObjects(imageProxy.toBitmap(applyRotation = true), topK)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.detectPose
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap
Expand Down Expand Up @@ -52,8 +54,10 @@ public class SSDLikeModel(override val internalModel: OnnxInferenceModel, metada
* @param [topK] The number of the detected objects with the highest score to be returned.
* @return List of [DetectedObject] sorted by score.
*/
public fun SSDLikeModel.detectObjects(imageProxy: ImageProxy, topK: Int = 3): List<DetectedObject> {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return detectObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation }
}
public fun ObjectDetectionModelBase<Bitmap>.detectObjects(imageProxy: ImageProxy, topK: Int = 3): List<DetectedObject> =
when (this) {
is CameraXCompatibleModel -> {
doWithRotation(imageProxy.imageInfo.rotationDegrees) { detectObjects(imageProxy.toBitmap(), topK) }
}
else -> detectObjects(imageProxy.toBitmap(applyRotation = true), topK)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionP
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap

Expand Down Expand Up @@ -59,8 +60,10 @@ public class SinglePoseDetectionModel(override val internalModel: OnnxInferenceM
*
* @param [imageProxy] input image.
*/
public fun SinglePoseDetectionModel.detectPose(imageProxy: ImageProxy): DetectedPose {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return detectPose(imageProxy.toBitmap()).also { targetRotation = currentRotation }
}
public fun SinglePoseDetectionModelBase<Bitmap>.detectPose(imageProxy: ImageProxy): DetectedPose =
when (this) {
is CameraXCompatibleModel -> {
doWithRotation(imageProxy.imageInfo.rotationDegrees) { detectPose(imageProxy.toBitmap()) }
}
else -> detectPose(imageProxy.toBitmap(applyRotation = true))
}

0 comments on commit 10be87e

Please sign in to comment.