From e0b2487333bf3e26fd64d10836dc1e6b9bd70e7e Mon Sep 17 00:00:00 2001 From: xinyual Date: Wed, 18 Oct 2023 12:27:56 +0800 Subject: [PATCH] read function Name from pretrained model Signed-off-by: xinyual --- .../src/main/java/org/opensearch/ml/engine/ModelHelper.java | 3 ++- .../src/main/java/org/opensearch/ml/model/MLModelManager.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 1af6f90990..2c249c4088 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -87,7 +87,8 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi .url(modelZipFileUrl) .deployModel(deployModel) .modelNodeIds(modelNodeIds) - .modelGroupId(modelGroupId); + .modelGroupId(modelGroupId) + .functionName(FunctionName.from((String) config.get("model_task_type")));; config.entrySet().forEach(entry -> { switch (entry.getKey().toString()) { case MLRegisterModelInput.MODEL_FORMAT_FIELD: diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 759b0cec9f..dac1a50029 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -609,7 +609,7 @@ private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { String taskId = mlTask.getTaskId(); - FunctionName functionName = mlTask.getFunctionName(); + FunctionName functionName = registerModelInput.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { String modelName = registerModelInput.getModelName(); String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;