Skip to content

Commit

Permalink
Read function Name from pretrained model (#1529)
Browse files Browse the repository at this point in the history
* read Function name from pretrained config

Signed-off-by: xinyual <[email protected]>

* rewrite mltask

Signed-off-by: xinyual <[email protected]>

* optimize import

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* add test for function name

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* maintain single import

Signed-off-by: xinyual <[email protected]>

* add more test

Signed-off-by: xinyual <[email protected]>

* apply spot less

Signed-off-by: xinyual <[email protected]>

* apply spot less

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Nov 15, 2023
1 parent 9342781 commit 4d53db5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
3 changes: 2 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public class MLTask implements ToXContentObject, Writeable {
@Setter
private String modelId;
private final MLTaskType taskType;
private final FunctionName functionName;
@Setter
private FunctionName functionName;
@Setter
private MLTaskState state;
private final MLInputDataType inputType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ public void downloadPrebuiltModelConfig(
.url(modelZipFileUrl)
.deployModel(deployModel)
.modelNodeIds(modelNodeIds)
.modelGroupId(modelGroupId);
.modelGroupId(modelGroupId)
.functionName(FunctionName.from((String) config.get("model_task_type")));

config.entrySet().forEach(entry -> {
switch (entry.getKey().toString()) {
case MLRegisterModelInput.MODEL_FORMAT_FIELD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.ml.common.CommonValue.UNDEPLOYED;
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
import static org.opensearch.ml.common.MLTask.ERROR_FIELD;
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
Expand Down Expand Up @@ -756,6 +757,14 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa
throw new IllegalArgumentException("This model is not in the pre-trained model list, please check your parameters.");
}
modelHelper.downloadPrebuiltModelConfig(taskId, registerModelInput, ActionListener.wrap(mlRegisterModelInput -> {
mlTask.setFunctionName(mlRegisterModelInput.getFunctionName());
mlTaskManager
.updateMLTask(
taskId,
ImmutableMap.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()),
TIMEOUT_IN_MILLIS,
false
);
registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion);
}, e -> {
log.error("Failed to register prebuilt model", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD;
import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES;
import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH;
import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES;
import static org.opensearch.ml.model.MLModelManager.TIMEOUT_IN_MILLIS;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
Expand All @@ -49,8 +51,10 @@
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.security.PrivilegedActionException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -168,6 +172,9 @@ public class MLModelManagerTests extends OpenSearchTestCase {
@Mock
private ScriptService scriptService;

@Mock
private MLTask pretrainedMLTask;

@Before
public void setup() throws URISyntaxException {
String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
Expand Down Expand Up @@ -373,6 +380,35 @@ public void testRegisterMLModel_DownloadModelFileFailure() {
verify(modelHelper).downloadAndSplit(eq(modelFormat), eq(modelId), eq(modelName), eq(version), eq(url), any(), any(), any());
}

public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionException {
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
MLRegisterModelInput pretrainedInput = mockPretrainedInput();
doAnswer(invocation -> {
ActionListener<MLRegisterModelInput> listener = (ActionListener<MLRegisterModelInput>) invocation.getArguments()[2];
listener.onResponse(pretrainedInput);
return null;
}).when(modelHelper).downloadPrebuiltModelConfig(any(), any(), any());
MLTask pretrainedTask = MLTask
.builder()
.taskId("pretrained")
.modelId("pretrained")
.functionName(FunctionName.TEXT_EMBEDDING)
.build();
modelManager.registerMLModel(pretrainedInput, pretrainedTask);
assertEquals(pretrainedTask.getFunctionName(), FunctionName.SPARSE_ENCODING);
verify(mlTaskManager)
.updateMLTask(
eq("pretrained"),
eq(ImmutableMap.of(FUNCTION_NAME_FIELD, FunctionName.SPARSE_ENCODING)),
eq((long) TIMEOUT_IN_MILLIS),
eq(false)
);
}

@Ignore
public void testRegisterMLModel_DownloadModelFile() throws IOException {
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
Expand Down Expand Up @@ -916,4 +952,15 @@ private MLRegisterModelMetaInput prepareRequest() {
.build();
return input;
}

private MLRegisterModelInput mockPretrainedInput() {
return MLRegisterModelInput
.builder()
.modelName(modelName)
.version(version)
.modelGroupId("modelGroupId")
.modelFormat(modelFormat)
.functionName(FunctionName.SPARSE_ENCODING)
.build();
}
}

0 comments on commit 4d53db5

Please sign in to comment.