diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java index c00e8196057f..559950135b5a 100644 --- a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java +++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java @@ -17,6 +17,7 @@ package org.apache.mxnet.jna; +import com.sun.jna.Library; import com.sun.jna.Native; import java.io.File; import java.io.IOException; @@ -29,6 +30,8 @@ import java.util.Collections; import java.util.Enumeration; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.zip.GZIPInputStream; @@ -69,13 +72,13 @@ public static MxnetLibrary loadLibrary() { String libName = getLibName(); logger.debug("Loading mxnet library from: {}", libName); - // TODO: consider Linux platform - // if (System.getProperty("os.name").startsWith("Linux")) { - // Map options = new ConcurrentHashMap<>(); - // int rtld = 1; // Linux RTLD lazy + local - // options.put(Library.OPTION_OPEN_FLAGS, rtld); - // return Native.load(libName, MxnetLibrary.class, options); - // } + if (System.getProperty("os.name").startsWith("Linux")) { + logger.info("Loading on Linux platform"); + Map options = new ConcurrentHashMap<>(); + int rtld = 1; // Linux RTLD lazy + local + options.put(Library.OPTION_OPEN_FLAGS, rtld); + return Native.load(libName, MxnetLibrary.class, options); + } return Native.load(libName, MxnetLibrary.class); }