From 1948ef9a32a7dd85921183481de31012250327a4 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Fri, 5 Mar 2021 10:31:08 -0800 Subject: [PATCH] fix Change-Id: I7f44c815d0d9293c5493057c5255d247fbb98e18 --- .../src/main/java/ai/djl/dlr/engine/DlrEngine.java | 7 ++++++- .../src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java | 4 ++++ .../main/java/ai/djl/onnxruntime/engine/OrtNDManager.java | 2 +- .../src/test/java/ai/djl/onnxruntime/engine/OrtTest.java | 4 +++- .../src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java | 4 ++++ .../src/main/java/ai/djl/tflite/engine/TfLiteEngine.java | 4 ++++ 6 files changed, 22 insertions(+), 3 deletions(-) diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index e28426e048fb..90c401794dad 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -34,8 +34,12 @@ public final class DlrEngine extends Engine { public static final String ENGINE_NAME = "DLR"; private Engine alternativeEngine; + private boolean disableAlternative; - private DlrEngine() {} + private DlrEngine() { + disableAlternative = + Boolean.parseBoolean(System.getProperty("djl_dlr_disable_alternative", "false")); + } static Engine newInstance() { try { @@ -47,6 +51,7 @@ static Engine newInstance() { } private Engine getAlternativeEngine() { + if (disableAlternative) return null; if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 5396d0e4e2db..70e21a8fb4c0 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -34,10 +34,13 @@ public final class OrtEngine extends Engine { private Engine alternativeEngine; private OrtEnvironment env; + private boolean disableAlternative; private OrtEngine() { // init OrtRuntime this.env = OrtEnvironment.getEnvironment(); + disableAlternative = + Boolean.parseBoolean(System.getProperty("djl_onnx_disable_alternative", "false")); } static Engine newInstance() { @@ -57,6 +60,7 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (disableAlternative) return null; if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index 8d4ef42da98b..b0f4dd19e027 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -39,7 +39,7 @@ private OrtNDManager(NDManager parent, Device device, OrtEnvironment env) { this.env = env; } - public static OrtNDManager getSystemManager() { + static OrtNDManager getSystemManager() { return SYSTEM_MANAGER; } diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java index dbcbe7013cb4..bd9c48f6c3cf 100644 --- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java +++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException { public void testStringTensor() throws MalformedModelException, ModelNotFoundException, IOException, TranslateException { + System.setProperty("djl_onnx_disable_alternative", "true"); Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) @@ -82,12 +83,13 @@ public void testStringTensor() .build(); try (ZooModel model = ModelZoo.loadModel(criteria); Predictor predictor = model.newPredictor()) { - OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager(); + OrtNDManager manager = (OrtNDManager) model.getNDManager(); NDArray stringNd = manager.create( new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"}, new Shape(1, 2)); predictor.predict(new NDList(stringNd)); } + System.clearProperty("djl_onnx_disable_alternative"); } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index 20468243717d..c379aceec966 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -34,9 +34,12 @@ public final class PpEngine extends Engine { private Engine alternativeEngine; private String version; + private boolean disableAlternative; private PpEngine() { version = JniUtils.getVersion(); + disableAlternative = + Boolean.parseBoolean(System.getProperty("djl_paddle_disable_alternative", "false")); } static Engine newInstance() { @@ -57,6 +60,7 @@ public int getRank() { } Engine getAlternativeEngine() { + if (disableAlternative) return null; if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) { diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index a28ec7c9fe13..d935452d985a 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -32,9 +32,12 @@ public final class TfLiteEngine extends Engine { public static final String ENGINE_NAME = "TFLite"; private Engine alternativeEngine; + private boolean disableAlternative; private TfLiteEngine() { LibUtils.loadLibrary(); + disableAlternative = + Boolean.parseBoolean(System.getProperty("djl_tflite_disable_alternative", "false")); } static Engine newInstance() { @@ -54,6 +57,7 @@ public int getRank() { } private Engine getAlternativeEngine() { + if (disableAlternative) return null; if (alternativeEngine == null) { Engine engine = Engine.getInstance(); if (engine.getRank() < getRank()) {