From 19e4302241f57531621325aba8f96be24c5d6d92 Mon Sep 17 00:00:00 2001 From: onaple Date: Tue, 27 Feb 2024 05:55:09 +0800 Subject: [PATCH] Fixes cases where the getEngine method in the EngineProvider class returns null when called concurrently. (#3005) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixes cases where the getEngine method in the EngineProvider class returns null when called concurrently. * Revert "Creates DJL manual engine initialization (#2885)" This reverts commit 6141c480e7047b17b21e59eccfa7c58b0b916646. --------- Co-authored-by: 王旭 Co-authored-by: Frank Liu --- api/src/main/java/ai/djl/engine/Engine.java | 57 +++++-------------- docs/development/troubleshooting.md | 5 -- .../djl/ml/lightgbm/LgbmEngineProvider.java | 17 ++---- .../ai/djl/ml/xgboost/XgbEngineProvider.java | 17 ++---- .../ai/djl/mxnet/engine/MxEngineProvider.java | 17 ++---- .../onnxruntime/engine/OrtEngineProvider.java | 17 ++---- .../paddlepaddle/engine/PpEngineProvider.java | 17 ++---- .../djl/pytorch/engine/PtEngineProvider.java | 8 +-- .../tensorflow/engine/TfEngineProvider.java | 8 +-- .../tensorrt/engine/TrtEngineProvider.java | 17 ++---- .../ai/djl/tensorrt/engine/TrtEngineTest.java | 2 +- .../djl/tensorrt/engine/TrtNDManagerTest.java | 2 +- .../ai/djl/tensorrt/integration/TrtTest.java | 6 +- .../tflite/engine/TfLiteEngineProvider.java | 17 ++---- 14 files changed, 59 insertions(+), 148 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index a799c70f600..8a1fc8871ac 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -59,7 +59,7 @@ public abstract class Engine { private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); - private static String defaultEngine = initEngine(); + private static final String DEFAULT_ENGINE = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); @@ -69,10 +69,6 @@ public abstract class Engine { private Integer seed; private static synchronized String initEngine() { - if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) { - return null; - } - ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); @@ -84,21 +80,21 @@ private static synchronized String initEngine() { } String def = System.getProperty("ai.djl.default_engine"); - String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); - if (newDefaultEngine == null || newDefaultEngine.isEmpty()) { + String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); + if (defaultEngine == null || defaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { - newDefaultEngine = provider.getEngineName(); + defaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } - } else if (!ALL_ENGINES.containsKey(newDefaultEngine)) { - throw new EngineException("Unknown default engine: " + newDefaultEngine); + } else if (!ALL_ENGINES.containsKey(defaultEngine)) { + throw new EngineException("Unknown default engine: " + defaultEngine); } - logger.debug("Found default engine: {}", newDefaultEngine); - Ec2Utils.callHome(newDefaultEngine); - return newDefaultEngine; + logger.debug("Found default engine: {}", defaultEngine); + Ec2Utils.callHome(defaultEngine); + return defaultEngine; } /** @@ -128,7 +124,7 @@ private static synchronized String initEngine() { * @return the default Engine name */ public static String getDefaultEngineName() { - return System.getProperty("ai.djl.default_engine", defaultEngine); + return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); } /** @@ -138,7 +134,7 @@ public static String getDefaultEngineName() { * @see EngineProvider */ public static Engine getInstance() { - if (defaultEngine == null) { + if (DEFAULT_ENGINE == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() @@ -167,29 +163,7 @@ public static boolean hasEngine(String engineName) { */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); - ALL_ENGINES.put(provider.getEngineName(), provider); - } - - /** - * Returns the default engine. - * - * @return the default engine - */ - public static String getDefaultEngine() { - return defaultEngine; - } - - /** - * Sets the default engine returned by {@link #getInstance()}. - * - * @param engineName the new default engine's name - */ - public static void setDefaultEngine(String engineName) { - // Requires an engine to be loaded (without exception) before being the default - getEngine(engineName); - - logger.debug("Setting new default engine: {}", engineName); - defaultEngine = engineName; + ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); } /** @@ -213,12 +187,7 @@ public static Engine getEngine(String engineName) { if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } - Engine engine = provider.getEngine(); - if (engine == null) { - throw new IllegalStateException( - "The engine " + engineName + " was not able to initialize"); - } - return engine; + return provider.getEngine(); } /** diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index 1a04592dc12..ff03d32648e 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -105,11 +105,6 @@ For more information, please refer to [DJL Cache Management](cache_management.md It happened when you had a wrong version with DJL and Deep Engines. You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue. -### 1.6 Manual initialization - -If you are using manual engine initialization, you must both register an engine and set it as the default. -This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`. - ## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception. The following exception may appear after running the `./gradlew clean` command: diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index 583cd8132b2..f8c84c753ef 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,9 +18,6 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (LgbmEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = LgbmEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = LgbmEngine.newInstance(); } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 8b534d5196c..5859f3f344d 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,9 +18,6 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (XgbEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = XgbEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = XgbEngine.newInstance(); } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 2a5ab970560..5f45116f615 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,9 +18,6 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (MxEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = MxEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = MxEngine.newInstance(); } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 5616eb80edb..005c0fa25f1 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,9 +18,6 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (OrtEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = OrtEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = OrtEngine.newInstance(); } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index e2fb86974f5..59e5cd90724 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,9 +18,6 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (PpEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = PpEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = PpEngine.newInstance(); } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 24be3e91d7a..42ca3c5b8a5 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,8 +18,7 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD + private static volatile Engine engine; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (PtEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = PtEngine.newInstance(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index fa7813a49fb..ad440a47951 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,8 +18,7 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD + private static volatile Engine engine; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (TfEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = TfEngine.newInstance(); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index 8c90859c6c6..d92ed9e449d 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,9 +18,6 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (TrtEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = TrtEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = TrtEngine.newInstance(); } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index 96066b380e1..efd9d89e509 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,7 +26,7 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Assert.assertEquals(version, "8.4.1"); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 24d734af54c..09001f0e2da 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 105e057ba0a..99cbc6f763e 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Exception ignore) { + } catch (Throwable ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index b46cad53b99..fb61551a3bf 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,9 +18,6 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { - private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -36,14 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { - synchronized (TfLiteEngineProvider.class) { - if (!initialized) { - initialized = true; - engine = TfLiteEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = TfLiteEngine.newInstance(); } }