Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change task worker node to list; add target worker node to cache #656

Merged
merged 3 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
Expand Down Expand Up @@ -54,7 +56,7 @@ public class MLTask implements ToXContentObject, Writeable {
private Float progress;
private final String outputIndex;
@Setter
private String workerNode;
private List<String> workerNodes;
private final Instant createTime;
private Instant lastUpdateTime;
@Setter
Expand All @@ -72,7 +74,7 @@ public MLTask(
MLInputDataType inputType,
Float progress,
String outputIndex,
String workerNode,
List<String> workerNodes,
Instant createTime,
Instant lastUpdateTime,
String error,
Expand All @@ -87,7 +89,7 @@ public MLTask(
this.inputType = inputType;
this.progress = progress;
this.outputIndex = outputIndex;
this.workerNode = workerNode;
this.workerNodes = workerNodes;
this.createTime = createTime;
this.lastUpdateTime = lastUpdateTime;
this.error = error;
Expand All @@ -108,7 +110,7 @@ public MLTask(StreamInput input) throws IOException {
}
this.progress = input.readOptionalFloat();
this.outputIndex = input.readOptionalString();
this.workerNode = input.readString();
this.workerNodes = input.readStringList();
this.createTime = input.readInstant();
this.lastUpdateTime = input.readInstant();
this.error = input.readOptionalString();
Expand All @@ -135,7 +137,7 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeOptionalFloat(progress);
out.writeOptionalString(outputIndex);
out.writeString(workerNode);
out.writeStringCollection(workerNodes);
out.writeInstant(createTime);
out.writeInstant(lastUpdateTime);
out.writeOptionalString(error);
Expand Down Expand Up @@ -174,8 +176,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (outputIndex != null) {
builder.field(OUTPUT_INDEX_FIELD, outputIndex);
}
if (workerNode != null) {
builder.field(WORKER_NODE_FIELD, workerNode);
if (workerNodes != null) {
builder.field(WORKER_NODE_FIELD, workerNodes);
}
if (createTime != null) {
builder.field(CREATE_TIME_FIELD, createTime.toEpochMilli());
Expand Down Expand Up @@ -207,7 +209,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
MLInputDataType inputType = null;
Float progress = null;
String outputIndex = null;
String workerNode = null;
List<String> workerNodes = null;
Instant createTime = null;
Instant lastUpdateTime = null;
String error = null;
Expand Down Expand Up @@ -245,7 +247,11 @@ public static MLTask parse(XContentParser parser) throws IOException {
outputIndex = parser.text();
break;
case WORKER_NODE_FIELD:
workerNode = parser.text();
workerNodes = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
workerNodes.add(parser.text());
}
break;
case CREATE_TIME_FIELD:
createTime = Instant.ofEpochMilli(parser.longValue());
Expand Down Expand Up @@ -276,7 +282,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.inputType(inputType)
.progress(progress)
.outputIndex(outputIndex)
.workerNode(workerNode)
.workerNodes(workerNodes)
.createTime(createTime)
.lastUpdateTime(lastUpdateTime)
.error(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -32,7 +33,7 @@ public void setup() {
.functionName(FunctionName.KMEANS)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand All @@ -57,7 +58,7 @@ public void toXContent() throws IOException {
Assert.assertEquals(
"{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\","
+ "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0,"
+ "\"output_index\":\"test_index\",\"worker_node\":\"node1\",\"create_time\":1641599940000,"
+ "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000,"
+ "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}",
taskContent
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.function.Consumer;
Expand All @@ -48,7 +49,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("mlTaskNode1")
.workerNodes(Arrays.asList("mlTaskNode1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.UncheckedIOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;

Expand All @@ -52,7 +53,7 @@ public void setUp() throws Exception {
.functionName(functionName)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("mlTaskNode1")
.workerNodes(Arrays.asList("mlTaskNode1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;

Expand All @@ -44,7 +45,7 @@ public void setUp() throws Exception {
.functionName(FunctionName.LINEAR_REGRESSION)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.InetAddress;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Collections;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -70,7 +71,7 @@ public void setUp() throws Exception {
.functionName(FunctionName.LINEAR_REGRESSION)
.state(MLTaskState.RUNNING)
.inputType(MLInputDataType.DATA_FRAME)
.workerNode("node1")
.workerNodes(Arrays.asList("node1"))
.progress(0.0f)
.outputIndex("test_index")
.error("test_error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;

import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
Expand All @@ -35,7 +36,7 @@ public void setUp() {
.inputType(MLInputDataType.DATA_FRAME)
.progress(1.3f)
.outputIndex("some index")
.workerNode("some node")
.workerNodes(Arrays.asList("some node"))
.createTime(Instant.ofEpochMilli(123))
.lastUpdateTime(Instant.ofEpochMilli(123))
.error("error")
Expand All @@ -59,7 +60,7 @@ public void writeTo_Success() throws IOException {
assertEquals(response.mlTask.getInputType(), parsedResponse.mlTask.getInputType());
assertEquals(response.mlTask.getProgress(), parsedResponse.mlTask.getProgress());
assertEquals(response.mlTask.getOutputIndex(), parsedResponse.mlTask.getOutputIndex());
assertEquals(response.mlTask.getWorkerNode(), parsedResponse.mlTask.getWorkerNode());
assertEquals(response.mlTask.getWorkerNodes(), parsedResponse.mlTask.getWorkerNodes());
assertEquals(response.mlTask.getCreateTime(), parsedResponse.mlTask.getCreateTime());
assertEquals(response.mlTask.getLastUpdateTime(), parsedResponse.mlTask.getLastUpdateTime());
assertEquals(response.mlTask.getError(), parsedResponse.mlTask.getError());
Expand All @@ -80,7 +81,7 @@ public void toXContentTest() throws IOException {
"\"input_type\":\"DATA_FRAME\"," +
"\"progress\":1.3," +
"\"output_index\":\"some index\"," +
"\"worker_node\":\"some node\"," +
"\"worker_node\":[\"some node\"]," +
"\"create_time\":123," +
"\"last_update_time\":123," +
"\"error\":\"error\"," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
return;
}

String workerNodes = String.join(",", nodeIds);
log.warn("Will load model on these nodes: {}", workerNodes);
log.info("Will load model on these nodes: {}", String.join(",", nodeIds));
String localNodeId = clusterService.localNode().getId();

String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
Expand All @@ -156,7 +155,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNode(workerNodes)
.workerNodes(nodeIds)
.build();
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

@Log4j2
Expand Down Expand Up @@ -127,12 +128,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Upload
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNode(clusterService.localNode().getId())
.workerNodes(ImmutableList.of(clusterService.localNode().getId()))
.build();

mlTaskDispatcher.dispatch(ActionListener.wrap(node -> {
String nodeId = node.getId();
mlTask.setWorkerNode(nodeId);
mlTask.setWorkerNodes(ImmutableList.of(nodeId));

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
Expand Down
15 changes: 15 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.model;

import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -29,16 +30,30 @@ public class MLModelCache {
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLModelState modelState;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor;
private @Getter(AccessLevel.PROTECTED) Set<String> targetWorkerNodes;
private final Set<String> workerNodes;
private final Queue<Double> modelInferenceDurationQueue;
private final Queue<Double> predictRequestDurationQueue;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
workerNodes = ConcurrentHashMap.newKeySet();
modelInferenceDurationQueue = new ConcurrentLinkedQueue<>();
predictRequestDurationQueue = new ConcurrentLinkedQueue<>();
}

public void setTargetWorkerNodes(List<String> targetWorkerNodes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here targetWorkerNodes is a private concurrent set, so I assume MLModelCache needs to be threadsafe. In this case, if setTargetWorkerNodes can be run by multi-threads, we need to protect it by synchronize, otherwise, the targetWorkerNodes could be wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will be used by MLModelCacheHelper.initModelState method, which already has synchronize

if (targetWorkerNodes == null || targetWorkerNodes.size() == 0) {
throw new IllegalArgumentException("Null or empty target worker nodes");
}
this.targetWorkerNodes.clear();
this.targetWorkerNodes.addAll(targetWorkerNodes);
}

public String[] getTargetWorkerNodes() {
return targetWorkerNodes.toArray(new String[0]);
}

public void removeWorkerNode(String nodeId) {
workerNodes.remove(nodeId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -41,14 +42,15 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
* @param state model state
* @param functionName function name
*/
public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName) {
public synchronized void initModelState(String modelId, MLModelState state, FunctionName functionName, List<String> targetWorkerNodes) {
if (isModelRunningOnNode(modelId)) {
throw new MLLimitExceededException("Duplicate load model task");
}
log.debug("init model state for model {}, state: {}", modelId, state);
MLModelCache modelCache = new MLModelCache();
modelCache.setModelState(state);
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCaches.put(modelId, modelCache);
}

Expand Down Expand Up @@ -254,6 +256,10 @@ public MLModelProfile getModelProfile(String modelId) {
if (modelCache.getPredictor() != null) {
builder.predictor(modelCache.getPredictor().toString());
}
String[] targetWorkerNodes = modelCache.getTargetWorkerNodes();
if (targetWorkerNodes.length > 0) {
builder.targetWorkerNodes(targetWorkerNodes);
}
String[] workerNodes = modelCache.getWorkerNodes();
if (workerNodes.length > 0) {
builder.workerNodes(workerNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ public void loadModel(
listener.onFailure(new IllegalArgumentException("Exceed max model per node limit"));
return;
}
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName);
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes());
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
checkAndAddRunningTask(mlTask, maxLoadTasksPerNode);
this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class MLModelProfile implements ToXContentFragment, Writeable {

private final MLModelState modelState;
private final String predictor;
private final String[] targetWorkerNodes;
private final String[] workerNodes;
private final MLPredictRequestStats modelInferenceStats;
private final MLPredictRequestStats predictRequestStats;
Expand All @@ -32,12 +33,14 @@ public class MLModelProfile implements ToXContentFragment, Writeable {
public MLModelProfile(
MLModelState modelState,
String predictor,
String[] targetWorkerNodes,
String[] workerNodes,
MLPredictRequestStats modelInferenceStats,
MLPredictRequestStats predictRequestStats
) {
this.modelState = modelState;
this.predictor = predictor;
this.targetWorkerNodes = targetWorkerNodes;
this.workerNodes = workerNodes;
this.modelInferenceStats = modelInferenceStats;
this.predictRequestStats = predictRequestStats;
Expand All @@ -52,6 +55,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (predictor != null) {
builder.field("predictor", predictor);
}
if (targetWorkerNodes != null) {
builder.field("target_worker_nodes", targetWorkerNodes);
}
if (workerNodes != null) {
builder.field("worker_nodes", workerNodes);
}
Expand All @@ -72,6 +78,7 @@ public MLModelProfile(StreamInput in) throws IOException {
this.modelState = null;
}
this.predictor = in.readOptionalString();
this.targetWorkerNodes = in.readOptionalStringArray();
this.workerNodes = in.readOptionalStringArray();
if (in.readBoolean()) {
this.modelInferenceStats = new MLPredictRequestStats(in);
Expand All @@ -94,6 +101,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(predictor);
out.writeOptionalStringArray(targetWorkerNodes);
out.writeOptionalStringArray(workerNodes);
if (modelInferenceStats != null) {
out.writeBoolean(true);
Expand Down
Loading