diff --git a/.github/workflows/backport.yml b/.github/workflows/backport.yml
index 85323c08815a6..3b987b10b2de6 100644
--- a/.github/workflows/backport.yml
+++ b/.github/workflows/backport.yml
@@ -37,3 +37,33 @@ jobs:
commit_email: elasticsarchmachine@users.noreply.github.com
auto_merge: 'false'
manual_backport_command_template: 'backport --pr %pullNumber%'
+ backport_and_merge:
+ name: Backport and Merge PR
+ if: |
+ github.event.pull_request.merged == true
+ && contains(github.event.pull_request.labels.*.name, 'auto-backport-and-merge')
+ && (
+ (github.event.action == 'labeled' && github.event.label.name == 'auto-backport-and-merge')
+ || (github.event.action == 'closed')
+ )
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Actions
+ uses: actions/checkout@v2
+ with:
+ repository: 'elastic/kibana-github-actions'
+ ref: main
+ path: ./actions
+
+ - name: Install Actions
+ run: npm install --production --prefix ./actions
+
+ - name: Run Backport
+ uses: ./actions/backport
+ with:
+ github_token: ${{secrets.ELASTICSEARCHMACHINE_TOKEN}}
+ commit_user: elasticsearchmachine
+ commit_email: elasticsarchmachine@users.noreply.github.com
+ target_pr_labels: 'backport, auto-merge'
+ auto_merge: 'false'
+ manual_backport_command_template: 'backport --pr %pullNumber%'
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 5cf789707c58c..3f3eb5218afed 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -5,5 +5,6 @@
+
\ No newline at end of file
diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties
index c4a6b1ae1af02..a5e18172352c7 100644
--- a/build-tools-internal/version.properties
+++ b/build-tools-internal/version.properties
@@ -1,8 +1,8 @@
elasticsearch = 8.0.0
lucene = 8.9.0
-bundled_jdk_vendor = adoptopenjdk
-bundled_jdk = 16.0.1+9
+bundled_jdk_vendor = openjdk
+bundled_jdk = 16.0.2+7@d4a915d82b4c4fbb9bde534da945d746
checkstyle = 8.42
@@ -53,4 +53,4 @@ jimfs = 1.2
jimfs_guava = 30.1-jre
# test framework
-networknt_json_schema_validator = 1.0.48
\ No newline at end of file
+networknt_json_schema_validator = 1.0.48
diff --git a/docs/changelog/75981.yaml b/docs/changelog/75981.yaml
new file mode 100644
index 0000000000000..8b7d8a03136d6
--- /dev/null
+++ b/docs/changelog/75981.yaml
@@ -0,0 +1,9 @@
+pr: 75981
+summary: Bump bundled JDK to 16.0.2
+area: Packaging
+type: upgrade
+issues: []
+versions:
+ - v8.0.0
+ - v7.14.1
+ - v7.15.0
diff --git a/docs/reference/ml/anomaly-detection/apis/delete-job.asciidoc b/docs/reference/ml/anomaly-detection/apis/delete-job.asciidoc
index 82b20e58c78f4..316bbd287a9d9 100644
--- a/docs/reference/ml/anomaly-detection/apis/delete-job.asciidoc
+++ b/docs/reference/ml/anomaly-detection/apis/delete-job.asciidoc
@@ -18,8 +18,6 @@ Deletes an existing {anomaly-job}.
* Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.
-* Before you can delete a job, you must delete the {dfeeds} that are associated
-with it. See <>.
* Before you can delete a job, you must close it (unless you specify the `force`
parameter). See <>.
@@ -36,6 +34,10 @@ are granted to anyone over the `.ml-*` indices.
It is not currently possible to delete multiple jobs using wildcards or a comma
separated list.
+If you delete a job that has a {dfeed}, the request will first attempt to
+delete the {dfeed}, as though <> was called with the same
+`timeout` and `force` parameters as this delete request.
+
[[ml-delete-job-path-parms]]
== {api-path-parms-title}
diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java
index 8fc2a377afc43..d265a514a453c 100644
--- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java
+++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java
@@ -142,6 +142,7 @@
import org.elasticsearch.painless.symbol.IRDecorations.IRDSize;
import org.elasticsearch.painless.symbol.IRDecorations.IRDStoreType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDSymbol;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDThisMethod;
import org.elasticsearch.painless.symbol.IRDecorations.IRDTypeParameters;
import org.elasticsearch.painless.symbol.IRDecorations.IRDUnaryType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDValue;
@@ -1650,6 +1651,7 @@ public void visitInvokeCallMember(InvokeCallMemberNode irInvokeCallMemberNode, W
methodWriter.writeDebugInfo(irInvokeCallMemberNode.getLocation());
LocalFunction localFunction = irInvokeCallMemberNode.getDecorationValue(IRDFunction.class);
+ PainlessMethod thisMethod = irInvokeCallMemberNode.getDecorationValue(IRDThisMethod.class);
PainlessMethod importedMethod = irInvokeCallMemberNode.getDecorationValue(IRDMethod.class);
PainlessClassBinding classBinding = irInvokeCallMemberNode.getDecorationValue(IRDClassBinding.class);
PainlessInstanceBinding instanceBinding = irInvokeCallMemberNode.getDecorationValue(IRDInstanceBinding.class);
@@ -1669,6 +1671,16 @@ public void visitInvokeCallMember(InvokeCallMemberNode irInvokeCallMemberNode, W
} else {
methodWriter.invokeVirtual(CLASS_TYPE, localFunction.getAsmMethod());
}
+ } else if (thisMethod != null) {
+ methodWriter.loadThis();
+
+ for (ExpressionNode irArgumentNode : irArgumentNodes) {
+ visit(irArgumentNode, writeScope);
+ }
+
+ Method asmMethod = new Method(thisMethod.javaMethod.getName(),
+ thisMethod.methodType.dropParameterTypes(0, 1).toMethodDescriptorString());
+ methodWriter.invokeVirtual(CLASS_TYPE, asmMethod);
} else if (importedMethod != null) {
for (ExpressionNode irArgumentNode : irArgumentNodes) {
visit(irArgumentNode, writeScope);
diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java
index 769702afe94b3..e387dba4cb2ff 100644
--- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java
+++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java
@@ -119,6 +119,7 @@
import org.elasticsearch.painless.symbol.Decorations.StandardPainlessMethod;
import org.elasticsearch.painless.symbol.Decorations.StaticType;
import org.elasticsearch.painless.symbol.Decorations.TargetType;
+import org.elasticsearch.painless.symbol.Decorations.ThisPainlessMethod;
import org.elasticsearch.painless.symbol.Decorations.TypeParameters;
import org.elasticsearch.painless.symbol.Decorations.UnaryType;
import org.elasticsearch.painless.symbol.Decorations.UpcastPainlessCast;
@@ -1714,6 +1715,7 @@ public void visitCallLocal(ECallLocal userCallLocalNode, SemanticScope semanticS
ScriptScope scriptScope = semanticScope.getScriptScope();
FunctionTable.LocalFunction localFunction = null;
+ PainlessMethod thisMethod = null;
PainlessMethod importedMethod = null;
PainlessClassBinding classBinding = null;
int classBindingOffset = 0;
@@ -1728,44 +1730,47 @@ public void visitCallLocal(ECallLocal userCallLocalNode, SemanticScope semanticS
localFunction = null;
}
- if (localFunction != null) {
- semanticScope.setUsesInstanceMethod();
- } else {
- importedMethod = scriptScope.getPainlessLookup().lookupImportedPainlessMethod(methodName, userArgumentsSize);
+ if (localFunction == null) {
+ thisMethod = scriptScope.getPainlessLookup().lookupPainlessMethod(
+ scriptScope.getScriptClassInfo().getBaseClass(), false, methodName, userArgumentsSize);
- if (importedMethod == null) {
- classBinding = scriptScope.getPainlessLookup().lookupPainlessClassBinding(methodName, userArgumentsSize);
+ if (thisMethod == null) {
+ importedMethod = scriptScope.getPainlessLookup().lookupImportedPainlessMethod(methodName, userArgumentsSize);
- // check to see if this class binding requires an implicit this reference
- if (classBinding != null && classBinding.typeParameters.isEmpty() == false &&
- classBinding.typeParameters.get(0) == scriptScope.getScriptClassInfo().getBaseClass()) {
- classBinding = null;
- }
+ if (importedMethod == null) {
+ classBinding = scriptScope.getPainlessLookup().lookupPainlessClassBinding(methodName, userArgumentsSize);
- if (classBinding == null) {
- // This extra check looks for a possible match where the class binding requires an implicit this
- // reference. This is a temporary solution to allow the class binding access to data from the
- // base script class without need for a user to add additional arguments. A long term solution
- // will likely involve adding a class instance binding where any instance can have a class binding
- // as part of its API. However, the situation at run-time is difficult and will modifications that
- // are a substantial change if even possible to do.
- classBinding = scriptScope.getPainlessLookup().lookupPainlessClassBinding(methodName, userArgumentsSize + 1);
-
- if (classBinding != null) {
- if (classBinding.typeParameters.isEmpty() == false &&
- classBinding.typeParameters.get(0) == scriptScope.getScriptClassInfo().getBaseClass()) {
- classBindingOffset = 1;
- } else {
- classBinding = null;
- }
+ // check to see if this class binding requires an implicit this reference
+ if (classBinding != null && classBinding.typeParameters.isEmpty() == false &&
+ classBinding.typeParameters.get(0) == scriptScope.getScriptClassInfo().getBaseClass()) {
+ classBinding = null;
}
if (classBinding == null) {
- instanceBinding = scriptScope.getPainlessLookup().lookupPainlessInstanceBinding(methodName, userArgumentsSize);
+ // This extra check looks for a possible match where the class binding requires an implicit this
+ // reference. This is a temporary solution to allow the class binding access to data from the
+ // base script class without need for a user to add additional arguments. A long term solution
+ // will likely involve adding a class instance binding where any instance can have a class binding
+ // as part of its API. However, the situation at run-time is difficult and will modifications that
+ // are a substantial change if even possible to do.
+ classBinding = scriptScope.getPainlessLookup().lookupPainlessClassBinding(methodName, userArgumentsSize + 1);
+
+ if (classBinding != null) {
+ if (classBinding.typeParameters.isEmpty() == false &&
+ classBinding.typeParameters.get(0) == scriptScope.getScriptClassInfo().getBaseClass()) {
+ classBindingOffset = 1;
+ } else {
+ classBinding = null;
+ }
+ }
+
+ if (classBinding == null) {
+ instanceBinding = scriptScope.getPainlessLookup().lookupPainlessInstanceBinding(methodName, userArgumentsSize);
- if (instanceBinding == null) {
- throw userCallLocalNode.createError(new IllegalArgumentException(
- "Unknown call [" + methodName + "] with [" + userArgumentNodes + "] arguments."));
+ if (instanceBinding == null) {
+ throw userCallLocalNode.createError(new IllegalArgumentException(
+ "Unknown call [" + methodName + "] with [" + userArgumentNodes + "] arguments."));
+ }
}
}
}
@@ -1775,10 +1780,18 @@ public void visitCallLocal(ECallLocal userCallLocalNode, SemanticScope semanticS
List> typeParameters;
if (localFunction != null) {
+ semanticScope.setUsesInstanceMethod();
semanticScope.putDecoration(userCallLocalNode, new StandardLocalFunction(localFunction));
typeParameters = new ArrayList<>(localFunction.getTypeParameters());
valueType = localFunction.getReturnType();
+ } else if (thisMethod != null) {
+ semanticScope.setUsesInstanceMethod();
+ semanticScope.putDecoration(userCallLocalNode, new ThisPainlessMethod(thisMethod));
+
+ scriptScope.markNonDeterministic(thisMethod.annotations.containsKey(NonDeterministicAnnotation.class));
+ typeParameters = new ArrayList<>(thisMethod.typeParameters);
+ valueType = thisMethod.returnType;
} else if (importedMethod != null) {
semanticScope.putDecoration(userCallLocalNode, new StandardPainlessMethod(importedMethod));
diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java
index 1c004c985b5b2..d67cc59b95045 100644
--- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java
+++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java
@@ -185,6 +185,7 @@
import org.elasticsearch.painless.symbol.Decorations.StandardPainlessMethod;
import org.elasticsearch.painless.symbol.Decorations.StaticType;
import org.elasticsearch.painless.symbol.Decorations.TargetType;
+import org.elasticsearch.painless.symbol.Decorations.ThisPainlessMethod;
import org.elasticsearch.painless.symbol.Decorations.TypeParameters;
import org.elasticsearch.painless.symbol.Decorations.UnaryType;
import org.elasticsearch.painless.symbol.Decorations.UpcastPainlessCast;
@@ -239,6 +240,7 @@
import org.elasticsearch.painless.symbol.IRDecorations.IRDSize;
import org.elasticsearch.painless.symbol.IRDecorations.IRDStoreType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDSymbol;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDThisMethod;
import org.elasticsearch.painless.symbol.IRDecorations.IRDTypeParameters;
import org.elasticsearch.painless.symbol.IRDecorations.IRDUnaryType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDValue;
@@ -1221,6 +1223,10 @@ public void visitCallLocal(ECallLocal callLocalNode, ScriptScope scriptScope) {
if (scriptScope.hasDecoration(callLocalNode, StandardLocalFunction.class)) {
LocalFunction localFunction = scriptScope.getDecoration(callLocalNode, StandardLocalFunction.class).getLocalFunction();
irInvokeCallMemberNode.attachDecoration(new IRDFunction(localFunction));
+ } else if (scriptScope.hasDecoration(callLocalNode, ThisPainlessMethod.class)) {
+ PainlessMethod thisMethod =
+ scriptScope.getDecoration(callLocalNode, ThisPainlessMethod.class).getThisPainlessMethod();
+ irInvokeCallMemberNode.attachDecoration(new IRDThisMethod(thisMethod));
} else if (scriptScope.hasDecoration(callLocalNode, StandardPainlessMethod.class)) {
PainlessMethod importedMethod =
scriptScope.getDecoration(callLocalNode, StandardPainlessMethod.class).getStandardPainlessMethod();
diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java
index bce418be4ce12..de6f748928870 100644
--- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java
+++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java
@@ -417,6 +417,19 @@ public LocalFunction getLocalFunction() {
}
}
+ public static class ThisPainlessMethod implements Decoration {
+
+ private final PainlessMethod thisPainlessMethod;
+
+ public ThisPainlessMethod(PainlessMethod thisPainlessMethod) {
+ this.thisPainlessMethod = Objects.requireNonNull(thisPainlessMethod);
+ }
+
+ public PainlessMethod getThisPainlessMethod() {
+ return thisPainlessMethod;
+ }
+ }
+
public static class StandardPainlessClassBinding implements Decoration {
private final PainlessClassBinding painlessClassBinding;
diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java
index 303d5d3506f5e..f9e76e5317b17 100644
--- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java
+++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java
@@ -370,6 +370,19 @@ public IRDFunction(LocalFunction value) {
}
}
+ /** describes a method for a node on the script class; which method depends on node type */
+ public static class IRDThisMethod extends IRDecoration {
+
+ public IRDThisMethod(PainlessMethod value) {
+ super(value);
+ }
+
+ @Override
+ public String toString() {
+ return PainlessLookupUtility.buildPainlessMethodKey(getValue().javaMethod.getName(), getValue().typeParameters.size());
+ }
+ }
+
/** describes the call to a class binding */
public static class IRDClassBinding extends IRDecoration {
diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ThisTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ThisTests.java
new file mode 100644
index 0000000000000..1116c17c83190
--- /dev/null
+++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ThisTests.java
@@ -0,0 +1,106 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless;
+
+import org.elasticsearch.painless.spi.Whitelist;
+import org.elasticsearch.painless.spi.WhitelistLoader;
+import org.elasticsearch.script.ScriptContext;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class ThisTests extends ScriptTestCase {
+
+ public abstract static class ThisBaseScript {
+
+ protected String baseString;
+
+ public ThisBaseScript(String baseString) {
+ this.baseString = baseString;
+ }
+
+ public String getBaseString() {
+ return baseString;
+ }
+
+ public void setBaseString(String testString) {
+ this.baseString = testString;
+ }
+
+ public int getBaseLength() {
+ return baseString.length();
+ }
+ }
+
+ public abstract static class ThisScript extends ThisBaseScript {
+
+ protected String thisString;
+
+ public ThisScript(String baseString, String thisString) {
+ super(baseString);
+
+ this.thisString = thisString;
+ }
+
+ public String thisString() {
+ return thisString;
+ }
+
+ public void thisString(String testString) {
+ this.thisString = testString;
+ }
+
+ public int thisLength() {
+ return thisString.length();
+ }
+
+ public abstract Object execute();
+
+ public interface Factory {
+
+ ThisScript newInstance(String baseString, String testString);
+ }
+
+ public static final String[] PARAMETERS = {};
+ public static final ScriptContext CONTEXT =
+ new ScriptContext<>("this_test", ThisScript.Factory.class);
+ }
+
+ @Override
+ protected Map, List> scriptContexts() {
+ Map, List> contexts = new HashMap<>();
+ List whitelists = new ArrayList<>(Whitelist.BASE_WHITELISTS);
+ whitelists.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.painless.this"));
+ contexts.put(ThisScript.CONTEXT, whitelists);
+ return contexts;
+ }
+
+ public Object exec(String script, String baseString, String testString) {
+ ThisScript.Factory factory = scriptEngine.compile(null, script, ThisScript.CONTEXT, new HashMap<>());
+ ThisScript testThisScript = factory.newInstance(baseString, testString);
+ return testThisScript.execute();
+ }
+
+ public void testThisMethods() {
+ assertEquals("basethis", exec("getBaseString() + thisString()", "base", "this"));
+ assertEquals(8, exec("getBaseLength() + thisLength()", "yyy", "xxxxx"));
+
+ List result = new ArrayList<>();
+ result.add("this");
+ result.add("base");
+ assertEquals(result, exec("List result = []; " +
+ "thisString('this');" +
+ "setBaseString('base');" +
+ "result.add(thisString()); " +
+ "result.add(getBaseString());" +
+ "result;", "", ""));
+ }
+}
diff --git a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.this b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.this
new file mode 100644
index 0000000000000..fb5eedf3388c9
--- /dev/null
+++ b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.this
@@ -0,0 +1,12 @@
+class org.elasticsearch.painless.ThisTests$ThisBaseScript @no_import {
+ String getBaseString();
+ void setBaseString(String);
+ int getBaseLength();
+}
+
+
+class org.elasticsearch.painless.ThisTests$ThisScript @no_import {
+ String thisString();
+ void thisString(String);
+ int thisLength();
+}
diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle
index 93c6bcba6cf97..f19b7e7019a0b 100644
--- a/x-pack/plugin/build.gradle
+++ b/x-pack/plugin/build.gradle
@@ -145,6 +145,10 @@ def v7compatibilityNotSupportedTests = {
'rollup/start_job/Test start job twice',
'service_accounts/10_basic/Test service account tokens', // https://github.com/elastic/elasticsearch/pull/75200
+ // temporarily muted awaiting backport of https://github.com/elastic/elasticsearch/pull/76010
+ 'ml/delete_job_force/Test force delete job that is referred by a datafeed',
+ 'ml/jobs_crud/Test delete job that is referred by a datafeed',
+
// a type field was added to cat.ml_trained_models #73660, this is a backwards compatible change.
// still this is a cat api, and we don't support them with rest api compatibility. (the test would be very hard to transform too)
'ml/trained_model_cat_apis/Test cat trained models'
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java
new file mode 100644
index 0000000000000..ccf567e76a559
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationAction.java
@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class CreateTrainedModelAllocationAction extends ActionType {
+ public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create";
+
+ private CreateTrainedModelAllocationAction() {
+ super(NAME, CreateTrainedModelAllocationAction.Response::new);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+
+ public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ taskParams.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(taskParams, request.taskParams);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(taskParams);
+ }
+
+ public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
+ return taskParams;
+ }
+ }
+
+ public static class Response extends ActionResponse implements ToXContentObject {
+
+ private static final ParseField ALLOCATION = new ParseField("allocation");
+
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "create_trained_model_allocation_response",
+ a -> new Response((TrainedModelAllocation) a[0])
+ );
+ static {
+ PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION);
+ }
+ static Response fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
+
+ private final TrainedModelAllocation trainedModelAllocation;
+
+ public Response(TrainedModelAllocation trainedModelAllocation) {
+ this.trainedModelAllocation = trainedModelAllocation;
+ }
+
+ public Response(StreamInput in) throws IOException {
+ super(in);
+ this.trainedModelAllocation = new TrainedModelAllocation(in);
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ trainedModelAllocation.writeTo(out);
+ }
+
+ public TrainedModelAllocation getTrainedModelAllocation() {
+ return trainedModelAllocation;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field("allocation", trainedModelAllocation);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Response response = (Response) o;
+ return Objects.equals(trainedModelAllocation, response.trainedModelAllocation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelAllocation);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java
new file mode 100644
index 0000000000000..589ae631dece8
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationAction.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class DeleteTrainedModelAllocationAction extends ActionType {
+ public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete";
+
+ private DeleteTrainedModelAllocationAction() {
+ super(NAME, AcknowledgedResponse::readFrom);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final String modelId;
+
+ public Request(String modelId) {
+ this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.modelId = in.readString();
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeString(modelId);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(modelId, request.modelId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(modelId);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
index 6e0c4d517d1b7..fb41c1d92a5ec 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
@@ -11,11 +11,15 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -31,7 +35,7 @@
import java.util.Objects;
import java.util.concurrent.TimeUnit;
-public class StartTrainedModelDeploymentAction extends ActionType {
+public class StartTrainedModelDeploymentAction extends ActionType {
public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";
@@ -39,7 +43,7 @@ public class StartTrainedModelDeploymentAction extends ActionType implements ToXContentObject {
@@ -120,9 +124,29 @@ public String toString() {
public static class TaskParams implements PersistentTaskParams, MlTaskParams {
- public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
+ // TODO add support for other roles? If so, it may have to be an instance method...
+ // NOTE, whatever determines allocation should not be dynamically set on the node
+ // Otherwise allocation logic might fail
+ public static boolean mayAllocateToNode(DiscoveryNode node) {
+ return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
+ }
+ public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_deployment_params",
+ true,
+ a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2])
+ );
+ static {
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX);
+ PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
+ }
+
+ public static TaskParams fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
/**
* This has been found to be approximately 300MB on linux by manual testing.
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java
new file mode 100644
index 0000000000000..e1ae0e6c2258c
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateAction.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.action.support.master.MasterNodeRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class UpdateTrainedModelAllocationStateAction extends ActionType {
+ public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction();
+ public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";
+
+ private UpdateTrainedModelAllocationStateAction() {
+ super(NAME, AcknowledgedResponse::readFrom);
+ }
+
+ public static class Request extends MasterNodeRequest {
+ private final String nodeId;
+ private final String modelId;
+ private final RoutingStateAndReason routingState;
+
+ public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
+ this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
+ this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
+ this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.nodeId = in.readString();
+ this.modelId = in.readString();
+ this.routingState = new RoutingStateAndReason(in);
+ }
+
+ public String getNodeId() {
+ return nodeId;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public RoutingStateAndReason getRoutingState() {
+ return routingState;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeString(nodeId);
+ out.writeString(modelId);
+ routingState.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(nodeId, request.nodeId)
+ && Objects.equals(modelId, request.modelId)
+ && Objects.equals(routingState, request.routingState);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(nodeId, modelId, routingState);
+ }
+
+ @Override
+ public String toString() {
+ return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
new file mode 100644
index 0000000000000..c9ef574d39f8f
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import java.util.Locale;
+
+public enum AllocationState {
+ STARTED,
+ STOPPING;
+
+ public static AllocationState fromString(String value) {
+ return valueOf(value.toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
new file mode 100644
index 0000000000000..865a490cbf64a
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import java.util.Locale;
+
+public enum RoutingState {
+ STARTING,
+ STARTED,
+ STOPPING,
+ FAILED,
+ STOPPED;
+
+ public static RoutingState fromString(String value) {
+ return valueOf(value.toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java
new file mode 100644
index 0000000000000..c6f1ce7d71510
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReason.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class RoutingStateAndReason implements ToXContentObject, Writeable {
+
+ private static final ParseField REASON = new ParseField("reason");
+ private static final ParseField ROUTING_STATE = new ParseField("routing_state");
+
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_routing_state",
+ a -> new RoutingStateAndReason(RoutingState.fromString((String) a[0]), (String) a[1])
+ );
+ static {
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), ROUTING_STATE);
+ PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
+ }
+
+ public static RoutingStateAndReason fromXContent(XContentParser parser) {
+ return PARSER.apply(parser, null);
+ }
+
+ private final String reason;
+ private final RoutingState state;
+
+ public RoutingStateAndReason(RoutingState state, String reason) {
+ this.state = ExceptionsHelper.requireNonNull(state, ROUTING_STATE);
+ this.reason = reason;
+ }
+
+ public RoutingStateAndReason(StreamInput in) throws IOException {
+ this.state = in.readEnum(RoutingState.class);
+ this.reason = in.readOptionalString();
+ }
+
+ public String getReason() {
+ return reason;
+ }
+
+ public RoutingState getState() {
+ return state;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeEnum(state);
+ out.writeOptionalString(reason);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(ROUTING_STATE.getPreferredName(), state);
+ if (reason != null) {
+ builder.field(REASON.getPreferredName(), reason);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RoutingStateAndReason that = (RoutingStateAndReason) o;
+ return Objects.equals(reason, that.reason) && state == that.state;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(reason, state);
+ }
+
+ @Override
+ public String toString() {
+ return "RoutingStateAndReason{" + "reason='" + reason + '\'' + ", state=" + state + '}';
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
new file mode 100644
index 0000000000000..f15511097d8d0
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
@@ -0,0 +1,240 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.cluster.AbstractDiffable;
+import org.elasticsearch.cluster.Diffable;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+// TODO implement better diffable logic so that whole diff does not need to be serialized if only one part changes
+/**
+ * Trained model allocation object that contains allocation options and the allocation routing table
+ */
+public class TrainedModelAllocation extends AbstractDiffable
+ implements
+ Diffable,
+ ToXContentObject {
+
+ private static final ParseField ALLOCATION_STATE = new ParseField("allocation_state");
+ private static final ParseField ROUTING_TABLE = new ParseField("routing_table");
+ private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters");
+
+ @SuppressWarnings("unchecked")
+ private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
+ "trained_model_allocation",
+ true,
+ a -> new TrainedModelAllocation(
+ (StartTrainedModelDeploymentAction.TaskParams) a[0],
+ (Map) a[1],
+ AllocationState.fromString((String)a[2])
+ )
+ );
+ static {
+ PARSER.declareObject(
+ ConstructingObjectParser.constructorArg(),
+ (p, c) -> StartTrainedModelDeploymentAction.TaskParams.fromXContent(p),
+ TASK_PARAMETERS
+ );
+ PARSER.declareObject(
+ ConstructingObjectParser.constructorArg(),
+ (p, c) -> p.map(LinkedHashMap::new, RoutingStateAndReason::fromXContent),
+ ROUTING_TABLE
+ );
+ PARSER.declareString(ConstructingObjectParser.constructorArg(), ALLOCATION_STATE);
+ }
+
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+ private final Map nodeRoutingTable;
+ private final AllocationState allocationState;
+
+ public static TrainedModelAllocation fromXContent(XContentParser parser) throws IOException {
+ return PARSER.apply(parser, null);
+ }
+
+ TrainedModelAllocation(
+ StartTrainedModelDeploymentAction.TaskParams taskParams,
+ Map nodeRoutingTable,
+ AllocationState allocationState
+ ) {
+ this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS);
+ this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE);
+ this.allocationState = ExceptionsHelper.requireNonNull(allocationState, ALLOCATION_STATE);
+ }
+
+ public TrainedModelAllocation(StreamInput in) throws IOException {
+ this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
+ this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new);
+ this.allocationState = in.readEnum(AllocationState.class);
+ }
+
+ public boolean isRoutedToNode(String nodeId) {
+ return nodeRoutingTable.containsKey(nodeId);
+ }
+
+ public Map getNodeRoutingTable() {
+ return Collections.unmodifiableMap(nodeRoutingTable);
+ }
+
+ public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
+ return taskParams;
+ }
+
+ public AllocationState getAllocationState() {
+ return allocationState;
+ }
+
+ public String[] getStartedNodes() {
+ return nodeRoutingTable
+ .entrySet()
+ .stream()
+ .filter(entry -> RoutingState.STARTED.equals(entry.getValue().getState()))
+ .map(Map.Entry::getKey)
+ .toArray(String[]::new);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ TrainedModelAllocation that = (TrainedModelAllocation) o;
+ return Objects.equals(nodeRoutingTable, that.nodeRoutingTable)
+ && Objects.equals(taskParams, that.taskParams)
+ && Objects.equals(allocationState, that.allocationState);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(nodeRoutingTable, taskParams, allocationState);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ builder.field(TASK_PARAMETERS.getPreferredName(), taskParams);
+ builder.field(ROUTING_TABLE.getPreferredName(), nodeRoutingTable);
+ builder.field(ALLOCATION_STATE.getPreferredName(), allocationState);
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ taskParams.writeTo(out);
+ out.writeMap(nodeRoutingTable, StreamOutput::writeString, (o, w) -> w.writeTo(o));
+ out.writeEnum(allocationState);
+ }
+
+ public static class Builder {
+ private final Map nodeRoutingTable;
+ private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+ private AllocationState allocationState;
+ private boolean isChanged;
+
+ public static Builder fromAllocation(TrainedModelAllocation allocation) {
+ return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState);
+ }
+
+ public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ return new Builder(taskParams);
+ }
+
+ private Builder(
+ StartTrainedModelDeploymentAction.TaskParams taskParams,
+ Map nodeRoutingTable,
+ AllocationState allocationState
+ ) {
+ this.taskParams = taskParams;
+ this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable);
+ this.allocationState = allocationState;
+ }
+
+ private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+ this.nodeRoutingTable = new LinkedHashMap<>();
+ this.taskParams = taskParams;
+ this.allocationState = AllocationState.STARTED;
+ }
+
+ public Builder addNewRoutingEntry(String nodeId) {
+ if (nodeRoutingTable.containsKey(nodeId)) {
+ throw new ResourceAlreadyExistsException(
+ "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId()
+ );
+ }
+ isChanged = true;
+ nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.STARTING, ""));
+ return this;
+ }
+
+ public Builder addNewFailedRoutingEntry(String nodeId, String reason) {
+ if (nodeRoutingTable.containsKey(nodeId)) {
+ throw new ResourceAlreadyExistsException(
+ "routing entry for node [{}] for model [{}] already exists", nodeId, taskParams.getModelId()
+ );
+ }
+ isChanged = true;
+ nodeRoutingTable.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, reason));
+ return this;
+ }
+
+ public Builder updateExistingRoutingEntry(String nodeId, RoutingStateAndReason state) {
+ RoutingStateAndReason stateAndReason = nodeRoutingTable.get(nodeId);
+ if (stateAndReason == null) {
+ throw new ResourceNotFoundException(
+ "routing entry for node [{}] for model [{}] does not exist", nodeId, taskParams.getModelId()
+ );
+ }
+ if (stateAndReason.equals(state)) {
+ return this;
+ }
+ nodeRoutingTable.put(nodeId, state);
+ isChanged = true;
+ return this;
+ }
+
+ public Builder removeRoutingEntry(String nodeId) {
+ if (nodeRoutingTable.remove(nodeId) != null) {
+ isChanged = true;
+ }
+ return this;
+ }
+
+ public Builder stopAllocation() {
+ if (allocationState.equals(AllocationState.STOPPING)) {
+ return this;
+ }
+ isChanged = true;
+ allocationState = AllocationState.STOPPING;
+ return this;
+ }
+
+ public boolean isChanged() {
+ return isChanged;
+ }
+
+ public TrainedModelAllocation build() {
+ return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState);
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
index 98be0e92e85cd..85acf4df9d78a 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/IndexLocation.java
@@ -47,7 +47,7 @@ public static IndexLocation fromXContentLenient(XContentParser parser) throws IO
private final String modelId;
private final String indexName;
- IndexLocation(String modelId, String indexName) {
+ public IndexLocation(String modelId, String indexName) {
this.modelId = Objects.requireNonNull(modelId);
this.indexName = Objects.requireNonNull(indexName);
}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java
new file mode 100644
index 0000000000000..ae130f1288653
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Request;
+
+public class CreateTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(
+ new StartTrainedModelDeploymentAction.TaskParams(
+ randomAlphaOfLength(10),
+ randomAlphaOfLength(10),
+ randomNonNegativeLong()
+ )
+ );
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java
new file mode 100644
index 0000000000000..980f5ef050559
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionResponseTests.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocationTests;
+
+import java.io.IOException;
+
+public class CreateTrainedModelAllocationActionResponseTests extends AbstractSerializingTestCase {
+
+ @Override
+ protected Response createTestInstance() {
+ return new Response(TrainedModelAllocationTests.randomInstance());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Response::new;
+ }
+
+ @Override
+ protected Response doParseInstance(XContentParser parser) throws IOException {
+ return Response.fromXContent(parser);
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java
new file mode 100644
index 0000000000000..933c60dcbf419
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAllocationActionRequestTests.java
@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction.Request;
+
+public class DeleteTrainedModelAllocationActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(randomAlphaOfLength(10));
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java
new file mode 100644
index 0000000000000..5b7104934f121
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelAllocationStateActionRequestTests.java
@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction.Request;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReasonTests;
+
+public class UpdateTrainedModelAllocationStateActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ return new Request(randomAlphaOfLength(10), randomAlphaOfLength(10), RoutingStateAndReasonTests.randomInstance());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return Request::new;
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java
new file mode 100644
index 0000000000000..42a62e35fe80d
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStateTests.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class AllocationStateTests extends ESTestCase {
+
+ public void testToAndFromString() {
+ for (AllocationState state : AllocationState.values()) {
+ String value = state.toString();
+ assertThat(AllocationState.fromString(value), equalTo(state));
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java
new file mode 100644
index 0000000000000..438372248cee3
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateAndReasonTests.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+public class RoutingStateAndReasonTests extends AbstractSerializingTestCase {
+
+ public static RoutingStateAndReason randomInstance() {
+ return new RoutingStateAndReason(randomFrom(RoutingState.values()), randomBoolean() ? null : randomAlphaOfLength(10));
+ }
+
+ @Override
+ protected RoutingStateAndReason doParseInstance(XContentParser parser) throws IOException {
+ return RoutingStateAndReason.fromXContent(parser);
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return RoutingStateAndReason::new;
+ }
+
+ @Override
+ protected RoutingStateAndReason createTestInstance() {
+ return randomInstance();
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java
new file mode 100644
index 0000000000000..883339250ce24
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingStateTests.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class RoutingStateTests extends ESTestCase {
+
+ public void testToAndFromString() {
+ for (RoutingState state : RoutingState.values()) {
+ String value = state.toString();
+ assertThat(RoutingState.fromString(value), equalTo(state));
+ }
+ }
+
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java
new file mode 100644
index 0000000000000..903c897bfd24c
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
+import static org.hamcrest.Matchers.is;
+
+public class TrainedModelAllocationTests extends AbstractSerializingTestCase {
+
+ public static TrainedModelAllocation randomInstance() {
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
+ new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong())
+ );
+ List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
+ for (String node : nodes) {
+ if (randomBoolean()) {
+ builder.addNewFailedRoutingEntry(node, randomAlphaOfLength(10));
+ } else {
+ builder.addNewRoutingEntry(node);
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TrainedModelAllocation doParseInstance(XContentParser parser) throws IOException {
+ return TrainedModelAllocation.fromXContent(parser);
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return TrainedModelAllocation::new;
+ }
+
+ @Override
+ protected TrainedModelAllocation createTestInstance() {
+ return randomInstance();
+ }
+
+ public void testBuilderChanged() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ assertThat(builder.isChanged(), is(false));
+ String addingNode = "foo";
+
+ assertUnchanged(builder, b -> b.removeRoutingEntry(addingNode));
+
+ if (randomBoolean()) {
+ builder.addNewRoutingEntry(addingNode);
+ } else {
+ builder.addNewFailedRoutingEntry(addingNode, "test failed");
+ }
+ assertThat(builder.isChanged(), is(true));
+
+ TrainedModelAllocation.Builder builderWithNode = TrainedModelAllocation.Builder.fromAllocation(builder.build());
+ assertThat(builderWithNode.isChanged(), is(false));
+
+ builderWithNode.removeRoutingEntry(addingNode);
+ assertThat(builderWithNode.isChanged(), is(true));
+ }
+
+ public void testBuilderAddingExistingRoute() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ String addingNode = "new-node";
+ if (randomBoolean()) {
+ builder.addNewRoutingEntry(addingNode);
+ } else {
+ builder.addNewFailedRoutingEntry(addingNode, "test failed");
+ }
+ expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewFailedRoutingEntry("new-node", "anything"));
+ expectThrows(ResourceAlreadyExistsException.class, () -> builder.addNewRoutingEntry("new-node"));
+ }
+
+ public void testBuilderUpdatingMissingRoute() {
+ TrainedModelAllocation original = randomInstance();
+ TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.fromAllocation(original);
+ String addingNode = "new-node";
+ expectThrows(
+ ResourceNotFoundException.class,
+ () -> builder.updateExistingRoutingEntry(addingNode, RoutingStateAndReasonTests.randomInstance())
+ );
+ }
+
+ public void testGetStartedNodes() {
+ String startedNode1 = "started-node-1";
+ String startedNode2 = "started-node-2";
+ String nodeInAnotherState1 = "another-state-node-1";
+ String nodeInAnotherState2 = "another-state-node-2";
+ TrainedModelAllocation allocation = TrainedModelAllocation.Builder.empty(
+ new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomAlphaOfLength(10), randomNonNegativeLong())
+ )
+ .addNewRoutingEntry(startedNode1)
+ .addNewRoutingEntry(startedNode2)
+ .addNewRoutingEntry(nodeInAnotherState1)
+ .addNewRoutingEntry(nodeInAnotherState2)
+ .updateExistingRoutingEntry(startedNode1, new RoutingStateAndReason(RoutingState.STARTED, ""))
+ .updateExistingRoutingEntry(startedNode2, new RoutingStateAndReason(RoutingState.STARTED, ""))
+ .updateExistingRoutingEntry(
+ nodeInAnotherState1,
+ new RoutingStateAndReason(
+ randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING),
+ randomAlphaOfLength(10)
+ )
+ )
+ .updateExistingRoutingEntry(
+ nodeInAnotherState2,
+ new RoutingStateAndReason(
+ randomFrom(RoutingState.STARTING, RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STOPPING),
+ randomAlphaOfLength(10)
+ )
+ )
+ .build();
+ assertThat(allocation.getStartedNodes(), arrayContainingInAnyOrder(startedNode1, startedNode2));
+ }
+
+ private static void assertUnchanged(
+ TrainedModelAllocation.Builder builder,
+ Function function
+ ) {
+ function.apply(builder);
+ assertThat(builder.isChanged(), is(false));
+ }
+
+}
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
index a3ed3d0c51634..5b366c7e83b38 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
@@ -15,6 +15,8 @@
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
+import org.junit.After;
+import org.junit.Before;
import java.io.IOException;
import java.util.Base64;
@@ -53,6 +55,32 @@ protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
}
+ @Before
+ public void setLogging() throws IOException {
+ Request loggingSettings = new Request("PUT", "_cluster/settings");
+ loggingSettings.setJsonEntity("" +
+ "{" +
+ "\"transient\" : {\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" : \"TRACE\",\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : \"TRACE\"\n" +
+ " }" +
+ "}");
+ client().performRequest(loggingSettings);
+ }
+
+ @After
+ public void unsetLogging() throws IOException {
+ Request loggingSettings = new Request("PUT", "_cluster/settings");
+ loggingSettings.setJsonEntity("" +
+ "{" +
+ "\"transient\" : {\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.allocation\" :null,\n" +
+ " \"logger.org.elasticsearch.xpack.ml.inference.deployment\" : null\n" +
+ " }" +
+ "}");
+ client().performRequest(loggingSettings);
+ }
+
private static final String MODEL_INDEX = "model_store";
private static final String MODEL_ID ="simple_model_to_evaluate";
private static final String BASE_64_ENCODED_MODEL =
@@ -92,8 +120,11 @@ public void testEvaluate() throws IOException {
createTrainedModel();
startDeployment();
try {
- Response inference = infer("my words");
- assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}"));
+ // Adding multiple inference calls to verify different calls get routed to separate nodes
+ for (int i = 0; i < 10; i++) {
+ Response inference = infer("my words");
+ assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}"));
+ }
} finally {
stopDeployment();
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
index 06439fb04b107..1c4c3572e992c 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
@@ -89,6 +89,7 @@
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@@ -100,6 +101,7 @@
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAliasAction;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedRunningStateAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -159,6 +161,7 @@
import org.elasticsearch.xpack.core.ml.action.UpdateJobAction;
import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
@@ -177,6 +180,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.template.TemplateUtils;
import org.elasticsearch.xpack.ml.action.TransportCloseJobAction;
+import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarEventAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteDataFrameAnalyticsAction;
@@ -188,6 +192,7 @@
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAliasAction;
+import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAllocationAction;
import org.elasticsearch.xpack.ml.action.TransportGetDatafeedRunningStateAction;
import org.elasticsearch.xpack.ml.action.TransportInferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
@@ -247,6 +252,7 @@
import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction;
import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction;
+import org.elasticsearch.xpack.ml.action.TransportUpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction;
import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction;
@@ -275,6 +281,9 @@
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -284,6 +293,7 @@
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
+import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilter;
import org.elasticsearch.xpack.ml.job.categorization.FirstNonBlankLineCharFilterFactory;
@@ -855,6 +865,18 @@ public Collection