-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backport 2.x] Separates RegisterModelStep into RegisterLocalModelSte…
…p and RegisterRemoteModelStep. Adds GetMLTaskStep. Handles optional params (#161) Separates RegisterModelStep into RegisterLocalModelStep and RegisterRemoteModelStep. Adds GetMLTaskStep. Handles optional params (#155) * added RegisterRemoteModelStep and tests * Adding RegisterLocalModelStep, fixing tests, adding input/ouput definitions to workflow step json * Fixing javadoc warnings, fixing log message * Addressing PR comments,making description field optional for RegisterRemoteModelStep and RegisterLocalModelStep * moving modelConfig builder before adding allConfig * handling optional description field for remote/local model * Removing poolingMode, modelMaxLenth, normalizeResult * adding modelType to required fields check * Fixing RegisterLocalModelStep to output a task ID instead of a model id * Adding GetMLTaskStep and tests * Adding todo for GetMLTask retry capability --------- (cherry picked from commit 2142874) Signed-off-by: Joshua Palis <[email protected]> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Joshua Palis <[email protected]>
- Loading branch information
1 parent
4b9bea4
commit 4813fc0
Showing
12 changed files
with
634 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.MLTask; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; | ||
import static org.opensearch.flowframework.common.CommonValue.TASK_ID; | ||
|
||
/** | ||
* Step to retrieve an ML Task | ||
*/ | ||
public class GetMLTaskStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class); | ||
private MachineLearningNodeClient mlClient; | ||
static final String NAME = "get_ml_task"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param mlClient client to instantiate MLClient | ||
*/ | ||
public GetMLTaskStep(MachineLearningNodeClient mlClient) { | ||
this.mlClient = mlClient; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) { | ||
|
||
CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>(); | ||
|
||
ActionListener<MLTask> actionListener = ActionListener.wrap(response -> { | ||
|
||
// TODO : Add retry capability if response status is not COMPLETED : | ||
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/158 | ||
|
||
logger.info("ML Task retrieval successful"); | ||
getMLTaskFuture.complete( | ||
new WorkflowData( | ||
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name())) | ||
) | ||
); | ||
}, exception -> { | ||
logger.error("Failed to retrieve ML Task"); | ||
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); | ||
}); | ||
|
||
String taskId = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case TASK_ID: | ||
taskId = (String) content.get(TASK_ID); | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
} | ||
|
||
if (taskId == null) { | ||
logger.error("Failed to retrieve ML Task"); | ||
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); | ||
} else { | ||
mlClient.getTask(taskId, actionListener); | ||
} | ||
|
||
return getMLTaskFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
200 changes: 200 additions & 0 deletions
200
src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; | ||
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; | ||
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; | ||
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; | ||
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; | ||
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; | ||
import static org.opensearch.flowframework.common.CommonValue.TASK_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.URL; | ||
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; | ||
|
||
/** | ||
* Step to register a local model | ||
*/ | ||
public class RegisterLocalModelStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class); | ||
|
||
private MachineLearningNodeClient mlClient; | ||
|
||
static final String NAME = "register_local_model"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param mlClient client to instantiate MLClient | ||
*/ | ||
public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { | ||
this.mlClient = mlClient; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) { | ||
|
||
CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>(); | ||
|
||
ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() { | ||
@Override | ||
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { | ||
logger.info("Local Model registration task creation successful"); | ||
registerLocalModelFuture.complete( | ||
new WorkflowData( | ||
Map.ofEntries( | ||
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), | ||
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) | ||
) | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
logger.error("Failed to register local model"); | ||
registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); | ||
} | ||
}; | ||
|
||
String modelName = null; | ||
String modelVersion = null; | ||
String description = null; | ||
MLModelFormat modelFormat = null; | ||
String modelGroupId = null; | ||
String modelContentHashValue = null; | ||
String modelType = null; | ||
String embeddingDimension = null; | ||
FrameworkType frameworkType = null; | ||
String allConfig = null; | ||
String url = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case NAME_FIELD: | ||
modelName = (String) content.get(NAME_FIELD); | ||
break; | ||
case VERSION_FIELD: | ||
modelVersion = (String) content.get(VERSION_FIELD); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
description = (String) content.get(DESCRIPTION_FIELD); | ||
break; | ||
case MODEL_FORMAT: | ||
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); | ||
break; | ||
case MODEL_GROUP_ID: | ||
modelGroupId = (String) content.get(MODEL_GROUP_ID); | ||
break; | ||
case MODEL_TYPE: | ||
modelType = (String) content.get(MODEL_TYPE); | ||
break; | ||
case EMBEDDING_DIMENSION: | ||
embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); | ||
break; | ||
case FRAMEWORK_TYPE: | ||
frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); | ||
break; | ||
case ALL_CONFIG: | ||
allConfig = (String) content.get(ALL_CONFIG); | ||
break; | ||
case MODEL_CONTENT_HASH_VALUE: | ||
modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); | ||
break; | ||
case URL: | ||
url = (String) content.get(URL); | ||
break; | ||
default: | ||
break; | ||
|
||
} | ||
} | ||
} | ||
|
||
if (Stream.of( | ||
modelName, | ||
modelVersion, | ||
modelFormat, | ||
modelGroupId, | ||
modelType, | ||
embeddingDimension, | ||
frameworkType, | ||
modelContentHashValue, | ||
url | ||
).allMatch(x -> x != null)) { | ||
|
||
// Create Model configudation | ||
TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() | ||
.modelType(modelType) | ||
.embeddingDimension(Integer.valueOf(embeddingDimension)) | ||
.frameworkType(frameworkType); | ||
if (allConfig != null) { | ||
modelConfigBuilder.allConfig(allConfig); | ||
} | ||
MLModelConfig modelConfig = modelConfigBuilder.build(); | ||
|
||
// Create register local model input | ||
MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder() | ||
.modelName(modelName) | ||
.version(modelVersion) | ||
.modelFormat(modelFormat) | ||
.modelGroupId(modelGroupId) | ||
.hashValue(modelContentHashValue) | ||
.modelConfig(modelConfig) | ||
.url(url); | ||
if (description != null) { | ||
mlInputBuilder.description(description); | ||
} | ||
|
||
MLRegisterModelInput mlInput = mlInputBuilder.build(); | ||
|
||
mlClient.register(mlInput, actionListener); | ||
} else { | ||
registerLocalModelFuture.completeExceptionally( | ||
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) | ||
); | ||
} | ||
|
||
return registerLocalModelFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
} |
Oops, something went wrong.