diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 251ea79a1e..6e44a8e8a9 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -42,22 +42,22 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0' implementation group: 'com.google.guava', name: 'guava', version: '32.1.2-jre' implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' - implementation platform("ai.djl:bom:0.21.0") - implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo', version: '0.21.0' + implementation platform("ai.djl:bom:0.28.0") + implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo' implementation group: 'ai.djl', name: 'api' implementation group: 'ai.djl.huggingface', name: 'tokenizers' - implementation("ai.djl.onnxruntime:onnxruntime-engine:0.21.0") { + implementation("ai.djl.onnxruntime:onnxruntime-engine") { exclude group: "com.microsoft.onnxruntime", module: "onnxruntime" } def os = DefaultNativePlatform.currentOperatingSystem //arm/macos doesn't support GPU if (os.macOsX || System.getProperty("os.arch") == "aarch64") { dependencies { - implementation "com.microsoft.onnxruntime:onnxruntime:1.14.0" + implementation "com.microsoft.onnxruntime:onnxruntime:1.16.3!!" } } else { dependencies { - implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0" + implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.16.3!!" } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index 546a01b386..073ae9c87c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -253,6 +253,7 @@ protected void loadModel( ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); try { System.setProperty("PYTORCH_PRECXX11", "true"); + System.setProperty("PYTORCH_VERSION", "1.13.1"); System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString()); // DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw // access denied exception diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index d2d0824a4d..68ec1ca39c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -131,6 +131,7 @@ private void loadModel(File modelZipFile, String modelId, String modelName, Stri ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); try { System.setProperty("PYTORCH_PRECXX11", "true"); + System.setProperty("PYTORCH_VERSION", "1.13.1"); System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString()); // DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw // access denied exception diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java index 0999ef9a9e..9869b42def 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java @@ -163,7 +163,7 @@ public void initModel_predict_ONNX_QuestionAnswering() throws URISyntaxException .modelFormat(MLModelFormat.ONNX) .name("test_model_name") .modelId("test_model_id") - .algorithm(FunctionName.TEXT_SIMILARITY) + .algorithm(FunctionName.QUESTION_ANSWERING) .version("1.0.0") .modelState(MLModelState.TRAINED) .build(); @@ -171,7 +171,7 @@ public void initModel_predict_ONNX_QuestionAnswering() throws URISyntaxException params.put(MODEL_ZIP_FILE, modelZipFile); questionAnsweringModel.initModel(model, params, encryptor); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build(); ModelTensorOutput output = (ModelTensorOutput) questionAnsweringModel.predict(mlInput); List mlModelOutputs = output.getMlModelOutputs(); assertEquals(1, mlModelOutputs.size());