diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index e8c814e432..229bba5771 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -50,7 +50,8 @@ public class MLTask implements ToXContentObject, Writeable { @Setter private String modelId; private final MLTaskType taskType; - private final FunctionName functionName; + @Setter + private FunctionName functionName; @Setter private MLTaskState state; private final MLInputDataType inputType; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 75012b56ea..77001b92e7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -94,7 +94,9 @@ public void downloadPrebuiltModelConfig( .url(modelZipFileUrl) .deployModel(deployModel) .modelNodeIds(modelNodeIds) - .modelGroupId(modelGroupId); + .modelGroupId(modelGroupId) + .functionName(FunctionName.from((String) config.get("model_task_type"))); + config.entrySet().forEach(entry -> { switch (entry.getKey().toString()) { case MLRegisterModelInput.MODEL_FORMAT_FIELD: diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 759b0cec9f..1f766ae13a 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -15,6 +15,7 @@ import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLTask.ERROR_FIELD; +import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.COMPLETED; @@ -756,6 +757,14 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa throw new IllegalArgumentException("This model is not in the pre-trained model list, please check your parameters."); } modelHelper.downloadPrebuiltModelConfig(taskId, registerModelInput, ActionListener.wrap(mlRegisterModelInput -> { + mlTask.setFunctionName(mlRegisterModelInput.getFunctionName()); + mlTaskManager + .updateMLTask( + taskId, + ImmutableMap.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()), + TIMEOUT_IN_MILLIS, + false + ); registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion); }, e -> { log.error("Failed to register prebuilt model", e); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 525bc95eda..14034aa93e 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -21,9 +21,11 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH; import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; +import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; @@ -49,8 +51,10 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.security.PrivilegedActionException; import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -168,6 +172,9 @@ public class MLModelManagerTests extends OpenSearchTestCase { @Mock private ScriptService scriptService; + @Mock + private MLTask pretrainedMLTask; + @Before public void setup() throws URISyntaxException { String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="; @@ -373,6 +380,35 @@ public void testRegisterMLModel_DownloadModelFileFailure() { verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any()); } + public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException { + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo")); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); + MLRegisterModelInput pretrainedInput = mockPretrainedInput(); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(pretrainedInput); + return null; + }).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any()); + MLTask pretrainedTask = MLTask + .builder() + .taskId("pretrained") + .modelId("pretrained") + .functionName(FunctionName.TEXT_EMBEDDING) + .build(); + modelManager.registerMLModel(pretrainedInput, pretrainedTask); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING); + verify(mlTaskManager) + .updateMLTask( + eq("pretrained"), + eq(ImmutableMap.of(FUNCTION_NAME_FIELD, FunctionName.SPARSE_ENCODING)), + eq((long) TIMEOUT_IN_MILLIS), + eq(false) + ); + } + @Ignore public void testRegisterMLModel_DownloadModelFile() throws IOException { doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); @@ -916,4 +952,15 @@ private MLRegisterModelMetaInput prepareRequest() { .build(); return input; } + + private MLRegisterModelInput mockPretrainedInput() { + return MLRegisterModelInput + .builder() + .modelName(modelName) + .version(version) + .modelGroupId("modelGroupId") + .modelFormat(modelFormat) + .functionName(FunctionName.SPARSE_ENCODING) + .build(); + } }