Skip to content

Commit

Permalink
tidying up
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 28, 2023
1 parent 98758da commit 276439e
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 93 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction;
import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction;
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
Expand Down Expand Up @@ -81,8 +80,7 @@ public InferencePlugin(Settings settings) {
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
new ActionHandler<>(DeleteInferenceModelAction.INSTANCE, TransportDeleteInferenceModelAction.class),
new ActionHandler<>(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)
new ActionHandler<>(DeleteInferenceModelAction.INSTANCE, TransportDeleteInferenceModelAction.class)
);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@
import org.elasticsearch.xpack.ilm.IndexLifecycle;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.slm.SnapshotLifecycle;
import org.elasticsearch.xpack.slm.history.SnapshotLifecycleTemplateRegistry;
import org.elasticsearch.xpack.transform.Transform;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentStateUtils;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentUtils;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -77,36 +77,17 @@ protected void doExecute(Task task, CoordinatedInferenceAction.Request request,
}
}



private boolean hasTrainedModelAssignment(String modelId, ClusterState state) {
String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getId()))
.orElse(request.getId());

responseBuilder.setId(concreteModelId);

TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
TrainedModelAssignment assignment = trainedModelAssignmentMetadata.getDeploymentAssignment(concreteModelId);
List<TrainedModelAssignment> assignments;
if (assignment == null) {
// look up by model
assignments = trainedModelAssignmentMetadata.getDeploymentsUsingModel(concreteModelId);
}
}


private void forNlp(CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
logger.info("[CoordAction] forNlp [{}]", request.getModelId());
var clusterState = clusterService.state();
var assignments = TrainedModelAssignmentStateUtils.modelAssignments(request.getModelId(), clusterState);
var assignments = TrainedModelAssignmentUtils.modelAssignments(request.getModelId(), clusterState);
if (assignments == null || assignments.isEmpty()) {
doInferenceServiceModel(request, listener);
} else {
doInClusterModel(request, listener);
}
}


private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
logger.info("[CoordAction] doInferenceServiceModel [{}]", request.getModelId());
executeAsyncWithOrigin(
Expand Down Expand Up @@ -150,23 +131,23 @@ static InferModelAction.Request translateRequest(CoordinatedInferenceAction.Requ
return inferModelRequest;
}

// private ActionListener<InferModelAction.Response> wrapModelNotFoundInBoostedTreeModelCheck(ActionListener<InferModelAction.Response> listener) {
//// return ActionListener.wrap(
//// listener::onResponse,
//// e -> {
//// executeAsyncWithOrigin(
//// client,
//// ML_ORIGIN,
//// GetTrainedModelsAction.INSTANCE,
//// new InferenceAction.Request(TaskType.ANY, request.getModelId(), request.getInputs(), request.getTaskSettings()),
//// ActionListener.wrap(r -> translateInferenceServiceResponse(r.getResults(), listener), listener::onFailure)
//// );
//// }
//// );
// }


/*
// private ActionListener<InferModelAction.Response> wrapModelNotFoundInBoostedTreeModelCheck(ActionListener<InferModelAction.Response>
// listener) {
//// return ActionListener.wrap(
//// listener::onResponse,
//// e -> {
//// executeAsyncWithOrigin(
//// client,
//// ML_ORIGIN,
//// GetTrainedModelsAction.INSTANCE,
//// new InferenceAction.Request(TaskType.ANY, request.getModelId(), request.getInputs(), request.getTaskSettings()),
//// ActionListener.wrap(r -> translateInferenceServiceResponse(r.getResults(), listener), listener::onFailure)
//// );
//// }
//// );
// }

/*
private void handleInferenceServiceModelFailure(
Exception error,
CoordinatedInferenceAction.Request request,
Expand Down Expand Up @@ -220,7 +201,7 @@ private void lookForInferenceServiceModelWithId(String modelId, ActionListener<P
listener
);
}
*/
*/
static void translateInferenceServiceResponse(
List<? extends InferenceResults> inferenceResults,
ActionListener<InferModelAction.Response> listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@

package org.elasticsearch.xpack.ml.inference.assignment;

import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;

import java.util.List;
import java.util.Optional;

public class TrainedModelAssignmentUtils {
public static final String NODES_CHANGED_REASON = "nodes changed";
Expand All @@ -24,5 +30,22 @@ public static RoutingInfo createShuttingDownRoute(RoutingInfo existingRoute) {
return routeUpdate.apply(existingRoute);
}

public static List<TrainedModelAssignment> modelAssignments(String modelId, ClusterState state) {
String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(state).getModelId(modelId)).orElse(modelId);

List<TrainedModelAssignment> assignments;

TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(state);
TrainedModelAssignment assignment = trainedModelAssignmentMetadata.getDeploymentAssignment(concreteModelId);
if (assignment != null) {
assignments = List.of(assignment);
} else {
// look up by model
assignments = trainedModelAssignmentMetadata.getDeploymentsUsingModel(concreteModelId);
}

return assignments;
}

private TrainedModelAssignmentUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
// The model is hosted either on a ml node or in an inference service
inferRequest.setModelType(CoordinatedInferenceAction.Request.ModelType.FOR_NLP_MODEL);
inferRequest.setModelType(CoordinatedInferenceAction.Request.ModelType.NLP_MODEL);

SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<>();
queryRewriteContext.registerAsyncAction((client, listener) -> {
Expand All @@ -151,10 +151,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return;
}

if (inferenceResponse.getInferenceResults().get(0)instanceof TextExpansionResults textExpansionResults) {
if (inferenceResponse.getInferenceResults().get(0) instanceof TextExpansionResults textExpansionResults) {
textExpansionResultsSupplier.set(textExpansionResults);
listener.onResponse(null);
} else if (inferenceResponse.getInferenceResults().get(0)instanceof WarningInferenceResults warning) {
} else if (inferenceResponse.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
listener.onFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
inferRequest.setHighPriority(true);
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
// The model is hosted either on a ml node or in an inference service
inferRequest.setModelType(CoordinatedInferenceAction.Request.ModelType.FOR_NLP_MODEL);
inferRequest.setModelType(CoordinatedInferenceAction.Request.ModelType.NLP_MODEL);

executeAsyncWithOrigin(client, ML_ORIGIN, CoordinatedInferenceAction.INSTANCE, inferRequest, ActionListener.wrap(response -> {
if (response.getInferenceResults().isEmpty()) {
listener.onFailure(new IllegalStateException("text embedding inference response contain no results"));
return;
}

if (response.getInferenceResults().get(0)instanceof TextEmbeddingResults textEmbeddingResults) {
if (response.getInferenceResults().get(0) instanceof TextEmbeddingResults textEmbeddingResults) {
listener.onResponse(textEmbeddingResults.getInferenceAsFloat());
} else if (response.getInferenceResults().get(0)instanceof WarningInferenceResults warning) {
} else if (response.getInferenceResults().get(0) instanceof WarningInferenceResults warning) {
listener.onFailure(new IllegalStateException(warning.getWarning()));
} else {
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ protected Object simulateMethod(Method method, Object[] args) {
CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1];
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout());
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType());
assertEquals(CoordinatedInferenceAction.Request.ModelType.FOR_NLP_MODEL, request.getModelType());
assertEquals(CoordinatedInferenceAction.Request.ModelType.NLP_MODEL, request.getModelType());

// Randomisation cannot be used here as {@code #doAssertLuceneQuery}
// asserts that 2 rewritten queries are the same
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ protected void doAssertClientRequest(ActionRequest request, TextEmbeddingQueryVe
assertEquals(builder.getModelId(), inferRequest.getModelId());
assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, inferRequest.getInferenceTimeout());
assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, inferRequest.getPrefixType());
assertEquals(CoordinatedInferenceAction.Request.ModelType.FOR_NLP_MODEL, inferRequest.getModelType());
assertEquals(CoordinatedInferenceAction.Request.ModelType.NLP_MODEL, inferRequest.getModelType());
}

public ActionResponse createResponse(float[] array, TextEmbeddingQueryVectorBuilder builder) {
Expand Down

0 comments on commit 276439e

Please sign in to comment.