From 1e1829fe98b8ee67597155bdd27a0514d85fdc6f Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 2 Jan 2024 11:15:43 -0800 Subject: [PATCH] [onnx] Adds yolov8n to model zoo (#2909) --- .../ai/djl/onnxruntime/zoo/OrtModelZoo.java | 1 + .../ai/djl/onnxruntime/yolov8n/metadata.json | 40 +++++++++++++++++++ .../examples/inference/Yolov8Detection.java | 4 +- 3 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 9d8037cfa8b..d61cb81f1ee 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -31,6 +31,7 @@ public class OrtModelZoo extends ModelZoo { OrtModelZoo() { addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1")); } diff --git a/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json new file mode 100644 index 00000000000..1e0169a2561 --- /dev/null +++ b/engines/onnxruntime/onnxruntime-engine/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.onnxruntime", + "artifactId": "yolov8n", + "name": "yolov8n", + "description": "YoloV8 Model", + "website": "http://www.djl.ai/engines/onnxruntime/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n", + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "rescale": true, + "optApplyRatio": true, + "threshold": 0.6, + "translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n.zip", + "name": "", + "sha1Hash": "9fbad7f706713843cbb8c8d6a56c81a640ec6fa2", + "size": 11053839 + } + } + } + ] +} diff --git a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java index 27ee1211d05..3d2cfb26409 100644 --- a/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java +++ b/examples/src/main/java/ai/djl/examples/inference/Yolov8Detection.java @@ -45,14 +45,13 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static DetectedObjects predict() throws IOException, ModelException, TranslateException { - Path modelPath = Paths.get("src/test/resources/yolov8n.onnx"); Path imgPath = Paths.get("src/test/resources/yolov8_test.jpg"); Image img = ImageFactory.getInstance().fromFile(imgPath); Criteria criteria = Criteria.builder() .setTypes(Image.class, DetectedObjects.class) - .optModelPath(modelPath) + .optModelUrls("djl://ai.djl.onnxruntime/yolov8n") .optEngine("OnnxRuntime") .optArgument("width", 640) .optArgument("height", 640) @@ -63,7 +62,6 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran // for performance optimization maxBox parameter can reduce number of // considered boxes from 8400 .optArgument("maxBox", 1000) - .optArgument("synsetFileName", "yolov8_synset.txt") .optTranslatorFactory(new YoloV8TranslatorFactory()) .optProgress(new ProgressBar()) .build();