Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService (#113623) #113886

Merged
merged 7 commits into from
Oct 7, 2024
6 changes: 6 additions & 0 deletions docs/changelog/113623.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 113623
summary: "Adding chunking settings to `MistralService,` `GoogleAiStudioService,` and\
\ `HuggingFaceService`"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ public ModelConfigurations(String inferenceEntityId, TaskType taskType, String s
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
ChunkingSettings chunkingSettings
) {
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand All @@ -23,6 +24,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -71,11 +74,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

GoogleAiStudioModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -97,6 +108,7 @@ private static GoogleAiStudioModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -117,6 +129,7 @@ private static GoogleAiStudioModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand All @@ -135,11 +148,17 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -150,6 +169,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -158,6 +178,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand All @@ -169,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand Down Expand Up @@ -245,11 +272,22 @@ protected void doChunkedInfer(
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
googleAiStudioModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -38,6 +39,7 @@ public GoogleAiStudioEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -47,6 +49,7 @@ public GoogleAiStudioEmbeddingsModel(
service,
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -62,10 +65,11 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
String service,
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets),
serviceSettings
);
Expand Down Expand Up @@ -98,6 +102,29 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
}
}

// Should only be used directly for testing
GoogleAiStudioEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
String service,
String uri,
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingsettings,
@Nullable DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings),
new ModelSecrets(secrets),
serviceSettings
);
try {
this.uri = new URI(uri);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

@Override
public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() {
return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -26,6 +29,7 @@

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;

Expand All @@ -52,10 +56,18 @@ public void parseRequestConfig(
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

var model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
Expand All @@ -80,10 +92,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
Expand All @@ -94,10 +112,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
Expand All @@ -108,6 +132,7 @@ protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -48,6 +50,7 @@ protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -58,6 +61,7 @@ protected HuggingFaceModel createModel(
taskType,
NAME,
serviceSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -111,11 +115,22 @@ protected void doChunkedInfer(
var huggingFaceModel = (HuggingFaceModel) model;
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
huggingFaceModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = huggingFaceModel.accept(actionCreator);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -56,6 +57,7 @@ protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand Down
Loading