diff --git a/docs/changelog/100388.yaml b/docs/changelog/100388.yaml new file mode 100644 index 000000000000..0ff5228ef36d --- /dev/null +++ b/docs/changelog/100388.yaml @@ -0,0 +1,7 @@ +pr: 100388 +summary: Fix for inference requests being sent to every node with a model allocation. +If there are more nodes than items in the original request then empty requests were sent. +area: Machine Learning +type: bug +issues: + - 100180 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index 3664e4f62026..96ac12035628 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -213,7 +213,9 @@ public List> selectRandomStartedNodesWeighedOnAllocations var nodeCounts = new ArrayList>(); for (int i = 0; i < counts.length; i++) { - nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); + if (counts[i] > 0) { + nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); + } } return nodeCounts; } @@ -232,7 +234,10 @@ public List> selectRandomStartedNodesWeighedOnAllocations var nodeCounts = new ArrayList>(); for (int i = 0; i < counts.length; i++) { - nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); + // filter out zero counts + if (counts[i] > 0) { + nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i])); + } } return nodeCounts; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java index c3b3a8f7d88f..7785f8785a21 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; @@ -22,7 +21,6 @@ public class ErrorInferenceResults implements InferenceResults { public static final String NAME = "error"; - public static final ParseField WARNING = new ParseField("error"); private final Exception exception; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index c85729b5a631..ca777be21b3b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -182,6 +182,17 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSin assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1))); } + public void testSingleRequestWith2Nodes() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1); + assertThat(nodes, hasSize(1)); + assertEquals(nodes.get(0).v2(), Integer.valueOf(1)); + } + public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() { TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index bc25afba066b..2827967c42cd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -354,17 +354,19 @@ private void sendResponse() { } else { for (int i = 0; i < results.length(); i++) { var resultList = results.get(i); - if (resultList != null) { - for (var result : resultList) { - if (result instanceof ErrorInferenceResults errorResult) { - // Any failure fails all requests - // TODO is this the correct behaviour for batched requests? - finalListener.onFailure(errorResult.getException()); - return; - } + if (resultList == null) { + continue; + } + + for (var result : resultList) { + if (result instanceof ErrorInferenceResults errorResult) { + // Any failure fails all requests + // TODO is this the correct behaviour for batched requests? + finalListener.onFailure(errorResult.getException()); + return; } - responseBuilder.addInferenceResults(resultList); } + responseBuilder.addInferenceResults(resultList); } finalListener.onResponse(responseBuilder.build()); } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index 4912bff3518f..b9fbf0b6b1f0 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -97,7 +97,6 @@ public void removeLogging() throws IOException { client().performRequest(request); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/100180") public void testTrainedModelDeployment() throws Exception { assumeTrue("NLP model deployments added in 8.0", UPGRADE_FROM_VERSION.onOrAfter(Version.V_8_0_0));