diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index 51bc98fe1b..c7baa9b3a6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -18,15 +18,18 @@ @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { public static final String TASK_ID_FIELD = "task_id"; + public static final String MODEL_ID_FIELD = "model_id"; public static final String STATUS_FIELD = "status"; private String taskId; private String status; + private String modelId; public MLRegisterModelResponse(StreamInput in) throws IOException { super(in); this.taskId = in.readString(); this.status = in.readString(); + this.modelId = in.readOptionalString(); } public MLRegisterModelResponse(String taskId, String status) { @@ -34,10 +37,17 @@ public MLRegisterModelResponse(String taskId, String status) { this.status= status; } + public MLRegisterModelResponse(String taskId, String status, String modelId) { + this.taskId = taskId; + this.status= status; + this.modelId = modelId; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(taskId); out.writeString(status); + out.writeOptionalString(modelId); } @Override @@ -45,6 +55,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.startObject(); builder.field(TASK_ID_FIELD, taskId); builder.field(STATUS_FIELD, status); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } builder.endObject(); return builder; } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java index 555c9163e5..f222f40200 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java @@ -16,24 +16,27 @@ public class MLRegisterModelResponseTest { private String taskId; private String status; + private String modelId; @Before public void setUp() throws Exception { taskId = "test_id"; status = "test"; + modelId = "model_id"; } @Test public void writeTo_Success() throws IOException { // Setup BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status); + MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId); // Run the test response.writeTo(bytesStreamOutput); MLRegisterModelResponse parsedResponse = new MLRegisterModelResponse(bytesStreamOutput.bytes().streamInput()); // Verify the results assertEquals(response.getTaskId(), parsedResponse.getTaskId()); assertEquals(response.getStatus(), parsedResponse.getStatus()); + assertEquals(response.getModelId(), parsedResponse.getModelId()); } @Test @@ -49,4 +52,18 @@ public void testToXContent() throws IOException { assertEquals("{\"task_id\":\"test_id\"," + "\"status\":\"test\"}", jsonStr); } + + @Test + public void testToXContent_withModelId() throws IOException { + // Setup + MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId); + // Run the test + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = builder.toString(); + // Verify the results + assertEquals("{\"task_id\":\"test_id\"," + + "\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 7c0e2d6b9d..6fcb190dfa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -233,6 +233,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen throw new IllegalArgumentException("URL can't match trusted url regex"); } } + System.out.println("registering the model"); boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE; MLTask mlTask = MLTask .builder() @@ -249,8 +250,8 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - mlModelManager.registerMLModel(registerModelInput, mlTask); - listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name())); + System.out.println("mlModelManager calls registerMLRemoteModel"); + mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener); }, e -> { logException("Failed to register model", e, log); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 56b0a743f9..92c9590e15 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -93,6 +93,7 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; @@ -103,6 +104,7 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.MLExecutable; @@ -222,30 +224,22 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { - Map source = modelGroup.getSourceAsMap(); - int latestVersion = (int) source.get(MLModelGroup.LATEST_VERSION_FIELD); - int newVersion = latestVersion + 1; - source.put(MLModelGroup.LATEST_VERSION_FIELD, newVersion); - source.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateRequest updateModelGroupRequest = new UpdateRequest(); - long seqNo = modelGroup.getSeqNo(); - long primaryTerm = modelGroup.getPrimaryTerm(); - updateModelGroupRequest - .index(ML_MODEL_GROUP_INDEX) - .id(modelGroupId) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .doc(source); - client - .update( - updateModelGroupRequest, - ActionListener - .wrap(r -> { uploadMLModelMeta(mlRegisterModelMetaInput, newVersion + "", listener); }, e -> { - log.error("Failed to update model group", e); - listener.onFailure(e); - }) - ); + Map modelGroupSource = modelGroup.getSourceAsMap(); + int updatedVersion = incrementLatestVersion(modelGroupSource); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + modelGroupSource, + modelGroupId, + modelGroup.getSeqNo(), + modelGroup.getPrimaryTerm(), + updatedVersion + ); + + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + uploadMLModelMeta(mlRegisterModelMetaInput, updatedVersion + "", listener); + }, e -> { + log.error("Failed to update model group", e); + listener.onFailure(e); + })); } else { log.error("Model group not found"); listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); @@ -312,6 +306,80 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput } } + /** + * + * @param mlRegisterModelInput register model input for remote models + * @param mlTask ML task + * @param listener action listener + */ + public void registerMLRemoteModel( + MLRegisterModelInput mlRegisterModelInput, + MLTask mlTask, + ActionListener listener + ) { + checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); + mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + + String modelGroupId = mlRegisterModelInput.getModelGroupId(); + GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); + if (Strings.isBlank(modelGroupId)) { + indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener); + } + + client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { + if (getModelGroupResponse.isExists()) { + Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); + int updatedVersion = incrementLatestVersion(modelGroupSourceMap); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + modelGroupSourceMap, + modelGroupId, + getModelGroupResponse.getSeqNo(), + getModelGroupResponse.getPrimaryTerm(), + updatedVersion + ); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + indexRemoteModel(mlRegisterModelInput, mlTask, updatedVersion + "", listener); + }, e -> { + log.error("Failed to update model group " + modelGroupId, e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + listener.onFailure(e); + })); + } else { + log.error("Model group response is empty"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + new MLValidationException("Model group not found") + ); + listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); + } + }, error -> { + if (error instanceof IndexNotFoundException) { + log.error("Model group Index is missing"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + new MLResourceNotFoundException("Failed to get model group due to index missing") + ); + listener.onFailure(error); + } else { + log.error("Failed to get model group", error); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), error); + listener.onFailure(error); + } + })); + } catch (Exception e) { + log.error("Failed to register remote model", e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + listener.onFailure(e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); + } + } + /** * Register model. Basically download model file, split into chunks and save into model index. * @@ -334,25 +402,19 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { - Map source = modelGroup.getSourceAsMap(); - int latestVersion = (int) source.get(MLModelGroup.LATEST_VERSION_FIELD); - int newVersion = latestVersion + 1; - source.put(MLModelGroup.LATEST_VERSION_FIELD, newVersion); - source.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateRequest updateModelGroupRequest = new UpdateRequest(); - long seqNo = modelGroup.getSeqNo(); - long primaryTerm = modelGroup.getPrimaryTerm(); - updateModelGroupRequest - .index(ML_MODEL_GROUP_INDEX) - .id(modelGroupId) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .doc(source); + Map modelGroupSourceMap = modelGroup.getSourceAsMap(); + int updatedVersion = incrementLatestVersion(modelGroupSourceMap); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + modelGroupSourceMap, + modelGroupId, + modelGroup.getSeqNo(), + modelGroup.getPrimaryTerm(), + updatedVersion + ); client .update( updateModelGroupRequest, - ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, newVersion + ""); }, e -> { + ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, updatedVersion + ""); }, e -> { log.error("Failed to update model group", e); handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); }) @@ -388,6 +450,95 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } } + private UpdateRequest createUpdateModelGroupRequest( + Map modelGroupSourceMap, + String modelGroupId, + long seqNo, + long primaryTerm, + int updatedVersion + ) { + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateModelGroupRequest = new UpdateRequest(); + + updateModelGroupRequest + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .doc(modelGroupSourceMap); + + return updateModelGroupRequest; + } + + private int incrementLatestVersion(Map modelGroupSourceMap) { + return (int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1; + } + + private void indexRemoteModel( + MLRegisterModelInput registerModelInput, + MLTask mlTask, + String modelVersion, + ActionListener listener + ) { + String taskId = mlTask.getTaskId(); + FunctionName functionName = mlTask.getFunctionName(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String modelName = registerModelInput.getModelName(); + String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; + Instant now = Instant.now(); + if (registerModelInput.getConnector() != null) { + registerModelInput.getConnector().encrypt(mlEngine::encrypt); + } + + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(boolResponse -> { + MLModel mlModelMeta = MLModel + .builder() + .name(modelName) + .algorithm(functionName) + .modelGroupId(registerModelInput.getModelGroupId()) + .version(version) + .description(registerModelInput.getDescription()) + .modelFormat(registerModelInput.getModelFormat()) + .modelState(MLModelState.REGISTERED) + .connector(registerModelInput.getConnector()) + .connectorId(registerModelInput.getConnectorId()) + .modelConfig(registerModelInput.getModelConfig()) + .createdTime(now) + .lastUpdateTime(now) + .build(); + + IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); + indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); + indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // index remote model doc + ActionListener indexListener = ActionListener.wrap(modelMetaRes -> { + String modelId = modelMetaRes.getId(); + mlTask.setModelId(modelId); + log.info("create new model meta doc {} for upload task {}", modelId, taskId); + mlTaskManager.updateMLTask(taskId, ImmutableMap.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true); + // if (registerModelInput.isDeployModel()) { + // deployModelAfterRegistering(registerModelInput, modelId); + // } + listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId)); + }, e -> { + log.error("Failed to index model meta doc", e); + handleException(functionName, taskId, e); + listener.onFailure(e); + }); + + client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); + }, error -> { + // failed to initialize the model index + log.error("Failed to init model index", error); + handleException(functionName, taskId, error); + listener.onFailure(error); + })); + } + } + private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); @@ -431,7 +582,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml log.error("Failed to index model meta doc", e); handleException(functionName, taskId, e); }); - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); }, e -> { log.error("Failed to init model index", e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 767d47784f..1a8384f45f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -40,6 +40,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -202,6 +203,7 @@ public void setup() { when(node2.getId()).thenReturn("node2Id"); doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLModel(any(), any()); + doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLRemoteModel(any(), any(), any()); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -358,7 +360,7 @@ public void test_execute_registerRemoteModel_withConnectorId_success() { MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(actionListener).onResponse(argumentCaptor.capture()); + verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); } public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConnectorId() { @@ -424,7 +426,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(actionListener).onResponse(argumentCaptor.capture()); + verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); } public void test_execute_registerRemoteModel_withInternalConnector_connectorIsNull() {