Skip to content

Commit

Permalink
[ML] Move InferenceInputs up a level (#112726) (#113564)
Browse files Browse the repository at this point in the history
Refactor before streaming support is added - moving InferenceInputs up a
level so that construction happens at the top level rather than each
individual implementation.

UnsupportedOperationException will now be thrown as an
IllegalStateException later in the call chain, both would go through the
listener's onFailure method anyway.

Backport of
6c1aaa4
  • Loading branch information
prwhelan authored Sep 25, 2024
1 parent 37ebafd commit c644dbb
Show file tree
Hide file tree
Showing 21 changed files with 121 additions and 406 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;

import java.io.IOException;
Expand Down Expand Up @@ -55,9 +58,9 @@ public void infer(
) {
init();
if (query != null) {
doInfer(model, query, input, taskSettings, inputType, timeout, listener);
doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener);
} else {
doInfer(model, input, taskSettings, inputType, timeout, listener);
doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener);
}
}

Expand Down Expand Up @@ -86,22 +89,13 @@ public void chunkedInfer(
ActionListener<List<ChunkedInferenceServiceResults>> listener
) {
init();
doChunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener);
// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, chunkingOptions, timeout, listener);
}

protected abstract void doInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);

protected abstract void doInfer(
Model model,
String query,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -110,8 +104,7 @@ protected abstract void doInfer(

protected abstract void doChunkedInfer(
Model model,
@Nullable String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
Expand Down Expand Up @@ -204,8 +204,7 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
@Override
public void doInfer(
Model model,
String query,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -220,35 +219,13 @@ public void doInfer(
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new QueryAndDocsInputs(query, input), timeout, listener);
}

@Override
public void doInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof AlibabaCloudSearchModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
action.execute(inputs, timeout, listener);
}

@Override
protected void doChunkedInfer(
Model model,
@Nullable String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand All @@ -263,8 +240,11 @@ protected void doChunkedInfer(
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
.batchRequestsWithListeners(listener);
var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
Expand Down Expand Up @@ -71,7 +72,7 @@ public AmazonBedrockService(
@Override
protected void doInfer(
Model model,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -80,30 +81,16 @@ protected void doInfer(
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
action.execute(inputs, timeout, listener);
} else {
listener.onFailure(createInvalidModelException(model));
}
}

@Override
protected void doInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("Amazon Bedrock service does not support inference with query input");
}

@Override
protected void doChunkedInfer(
Model model,
String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand All @@ -113,7 +100,7 @@ protected void doChunkedInfer(
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
var batchedRequests = new EmbeddingRequestChunker(input, maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
var batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT)
.batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.inference.external.action.anthropic.AnthropicActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand Down Expand Up @@ -165,7 +166,7 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta
@Override
public void doInfer(
Model model,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -180,27 +181,13 @@ public void doInfer(
var actionCreator = new AnthropicActionCreator(getSender(), getServiceComponents());

var action = anthropicModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
}

@Override
protected void doInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("Anthropic service does not support inference with query input");
action.execute(inputs, timeout, listener);
}

@Override
protected void doChunkedInfer(
Model model,
@Nullable String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand Down Expand Up @@ -62,7 +63,7 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents
@Override
protected void doInfer(
Model model,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -72,30 +73,16 @@ protected void doInfer(

if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
action.execute(inputs, timeout, listener);
} else {
listener.onFailure(createInvalidModelException(model));
}
}

@Override
protected void doInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("Azure AI Studio service does not support inference with query input");
}

@Override
protected void doChunkedInfer(
Model model,
String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand All @@ -104,8 +91,11 @@ protected void doChunkedInfer(
) {
if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
.batchRequestsWithListeners(listener);
var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand Down Expand Up @@ -185,7 +186,7 @@ public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType
@Override
protected void doInfer(
Model model,
List<String> input,
InferenceInputs inputs,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand All @@ -200,27 +201,13 @@ protected void doInfer(
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());

var action = azureOpenAiModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
}

@Override
protected void doInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("Azure OpenAI service does not support inference with query input");
action.execute(inputs, timeout, listener);
}

@Override
protected void doChunkedInfer(
Model model,
String query,
List<String> input,
DocumentsOnlyInput inputs,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
Expand All @@ -233,8 +220,11 @@ protected void doChunkedInfer(
}
AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
.batchRequestsWithListeners(listener);
var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = azureOpenAiModel.accept(actionCreator, taskSettings);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Loading

0 comments on commit c644dbb

Please sign in to comment.