diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index 4d2e8505f02f..e4e903c8672b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -231,10 +231,21 @@ private static Path findJniLibrary(LibTorch libTorch) { String djlVersion = libTorch.apiVersion; String flavor = libTorch.flavor; + // Looking for JNI in libTorch.dir first + Path libDir = libTorch.dir.toAbsolutePath(); + Path path = libDir.resolve(djlVersion + '-' + JNI_LIB_NAME); + if (Files.exists(path)) { + return path; + } + Path path = libDir.resolve(JNI_LIB_NAME); + if (Files.exists(path)) { + return path; + } + // always use cache dir, cache dir might be different from libTorch.dir Path cacheDir = Utils.getEngineCacheDir("pytorch"); Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier); - Path path = dir.resolve(djlVersion + '-' + JNI_LIB_NAME); + path = dir.resolve(djlVersion + '-' + JNI_LIB_NAME); if (Files.exists(path)) { return path; } @@ -554,8 +565,10 @@ private static final class LibTorch { if (flavor == null || flavor.isEmpty()) { if (CudaUtils.getGpuCount() > 0) { flavor = "cu" + CudaUtils.getCudaVersionString() + "-precxx11"; - } else { + } else if ("linux".equals(platform.getOsPrefix())) { flavor = "cpu-precxx11"; + } else { + flavor = "cpu"; } } }