Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] indicate overall deployment failure if all node routes are failed #88378

Merged
5 changes: 5 additions & 0 deletions docs/changelog/88378.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 88378
summary: Indicate overall deployment failure if all node routes are failed
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
public enum AssignmentState {
STARTING,
STARTED,
STOPPING;
STOPPING,
FAILED; // Not persisted, calculated via route states

public static AssignmentState fromString(String value) {
return valueOf(value.toUpperCase(Locale.ROOT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,15 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVInt(queueCapacity);
out.writeInstant(startTime);
out.writeList(nodeStats);
out.writeOptionalEnum(state);
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the code could be structure in a way that makes it more explicit that the reason for these checks is that the FAILED state cannot be streamed to an < 8.4 node.

if (AssignmentState.FAILED.equals(state) && out.getVersion().before(Version.V_8_4_0)) {
   out.writeOptionalEnum(AssignmentState.STARTING);
} else {
    out.writeOptionalEnum(state);
}

out.writeOptionalEnum(state);
} else {
if (AssignmentState.FAILED.equals(state)) {
out.writeOptionalEnum(AssignmentState.STARTING);
} else {
out.writeOptionalEnum(state);
}
}
out.writeOptionalString(reason);
out.writeOptionalWriteable(allocationStatus);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
Expand Down Expand Up @@ -347,6 +348,30 @@ public void testLiveDeploymentStats() throws IOException {
}
}

@SuppressWarnings("unchecked")
public void testFailedDeploymentStats() throws Exception {
String badModel = "bad_model";
String poorlyFormattedModelBase64 = "cG9vcmx5IGZvcm1hdHRlZCBtb2RlbAo=";
int length = Base64.getDecoder().decode(poorlyFormattedModelBase64).length;
createTrainedModel(badModel);
putVocabulary(List.of("once", "twice"), badModel);
Request request = new Request("PUT", "_ml/trained_models/" + badModel + "/definition/0");
request.setJsonEntity("""
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(length, poorlyFormattedModelBase64));
client().performRequest(request);
startDeployment(badModel, AllocationStatus.State.STARTING.toString());
assertBusy(() -> {
Response noInferenceCallsStatsResponse = getTrainedModelStats(badModel);
List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(noInferenceCallsStatsResponse).get(
"trained_model_stats"
);
assertThat(stats, hasSize(1));

String assignmentState = (String) XContentMapValues.extractValue("deployment_stats.state", stats.get(0));
assertThat(assignmentState, equalTo(AssignmentState.FAILED.toString()));
});
}

private void assertAtLeastOneOfTheseIsNotNull(String name, List<Map<String, Object>> nodes) {
assertTrue("all nodes have null value for [" + name + "]", nodes.stream().anyMatch(n -> n.get(name) != null));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ protected void doExecute(
TrainedModelAssignment trainedModelAssignment = assignment.getModelAssignment(stats.getModelId());
if (trainedModelAssignment != null) {
stats.setState(trainedModelAssignment.getAssignmentState()).setReason(trainedModelAssignment.getReason().orElse(null));
if (trainedModelAssignment.getNodeRoutingTable()
.values()
.stream()
.allMatch(ri -> ri.getState().equals(RoutingState.FAILED))) {
stats.setState(AssignmentState.FAILED);
if (stats.getReason() == null) {
stats.setReason("All node routes are failed; see node route reason for details");
}
}
if (trainedModelAssignment.getAssignmentState().isAnyOf(AssignmentState.STARTED, AssignmentState.STARTING)) {
stats.setAllocationStatus(trainedModelAssignment.calculateAllocationStatus().orElse(null));
}
Expand Down