Skip to content

Commit

Permalink
upgrade djl version to 0.28.0 (opensearch-project#2578) (opensearch-p…
Browse files Browse the repository at this point in the history
…roject#2580)

* upgrade djl version to latest 0.28.0

Signed-off-by: Bhavana Ramaram <[email protected]>

* force onnxruntime_gpu to 1.16.3

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Bhavana Ramaram <[email protected]>
Signed-off-by: Yaliang Wu <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
(cherry picked from commit 01c85cb)

Co-authored-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and rbhavna authored Jun 20, 2024
1 parent 0e27181 commit 63aeaab
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
implementation group: 'com.google.guava', name: 'guava', version: '32.1.2-jre'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation platform("ai.djl:bom:0.21.0")
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo', version: '0.21.0'
implementation platform("ai.djl:bom:0.28.0")
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
implementation group: 'ai.djl', name: 'api'
implementation group: 'ai.djl.huggingface', name: 'tokenizers'
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.21.0") {
implementation("ai.djl.onnxruntime:onnxruntime-engine") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
def os = DefaultNativePlatform.currentOperatingSystem
//arm/macos doesn't support GPU
if (os.macOsX || System.getProperty("os.arch") == "aarch64") {
dependencies {
implementation "com.microsoft.onnxruntime:onnxruntime:1.14.0"
implementation "com.microsoft.onnxruntime:onnxruntime:1.16.3!!"
}
} else {
dependencies {
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0"
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.16.3!!"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ protected void loadModel(
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
try {
System.setProperty("PYTORCH_PRECXX11", "true");
System.setProperty("PYTORCH_VERSION", "1.13.1");
System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString());
// DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw
// access denied exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ private void loadModel(File modelZipFile, String modelId, String modelName, Stri
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
try {
System.setProperty("PYTORCH_PRECXX11", "true");
System.setProperty("PYTORCH_VERSION", "1.13.1");
System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString());
// DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw
// access denied exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ public void initModel_predict_ONNX_QuestionAnswering() throws URISyntaxException
.modelFormat(MLModelFormat.ONNX)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TEXT_SIMILARITY)
.algorithm(FunctionName.QUESTION_ANSWERING)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
modelZipFile = new File(getClass().getResource("question_answering_onnx.zip").toURI());
params.put(MODEL_ZIP_FILE, modelZipFile);

questionAnsweringModel.initModel(model, params, encryptor);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build();
ModelTensorOutput output = (ModelTensorOutput) questionAnsweringModel.predict(mlInput);
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
assertEquals(1, mlModelOutputs.size());
Expand Down

0 comments on commit 63aeaab

Please sign in to comment.