Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

semantic_text - Field inference #103697

Merged
Show file tree
Hide file tree
Changes from 121 commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
941f960
Added fieldsForModels to IndexMetadata & MappingMetadata
Mikep86 Dec 8, 2023
46f1f2e
Updated IndexMetadata tests
Mikep86 Dec 11, 2023
02555e3
Randomize when fieldsForModels is set
Mikep86 Dec 11, 2023
be7a9be
Merge branch 'main' into store-semantic_text-model-info-in-mappings
Mikep86 Dec 12, 2023
3208f74
Added fieldsForModels to FieldTypeLookup
Mikep86 Dec 12, 2023
9a7513e
Updated MappingLookup to add getFieldsForModels
Mikep86 Dec 12, 2023
409c8d5
Updated MappingMetadata to set fieldsForModels
Mikep86 Dec 12, 2023
c5748f7
Ensure that fieldsForModels is immutable
Mikep86 Dec 12, 2023
83319f5
Fix NPE
Mikep86 Dec 13, 2023
b4a6f6e
Update docs/changelog/103319.yaml
Mikep86 Dec 13, 2023
31642b8
Update IndexMetadata equals & hashCode
Mikep86 Dec 13, 2023
206ddb9
Fix NPE
Mikep86 Dec 13, 2023
f2503ea
Fix checkstyle error
Mikep86 Dec 13, 2023
891c02f
Merge branch 'store-semantic_text-model-info-in-mappings' of github.c…
Mikep86 Dec 13, 2023
6c5d541
Fix NPE
Mikep86 Dec 13, 2023
d78af4c
Update MappingMetadata to ensure that fieldsForModels is always non-null
Mikep86 Dec 13, 2023
aa5b800
Resolved TODOs
Mikep86 Dec 13, 2023
84aac32
Merge branch 'main' into store-semantic_text-model-info-in-mappings
Mikep86 Dec 13, 2023
04112d1
Adjusted cluster state diff tests
Mikep86 Dec 13, 2023
a66f69b
IndexMetadata test updates
Mikep86 Dec 13, 2023
6feacd7
Added/updated FieldTypeLookup tests
Mikep86 Dec 13, 2023
7be2f4b
Fix spotless violations
Mikep86 Dec 13, 2023
c6c98a6
Merge branch 'main' into store-semantic_text-model-info-in-mappings
Mikep86 Dec 13, 2023
fa678a3
Added/updated MappingLookup tests
Mikep86 Dec 14, 2023
981ac8f
Delete docs/changelog/103319.yaml
Mikep86 Dec 14, 2023
f45af49
Refactored into separate methods
carlosdelest Dec 21, 2023
4a93be8
Added fieldsForModels to IndexMetadata & MappingMetadata
Mikep86 Dec 8, 2023
1b82261
Moved InferenceAction and result classes to server
carlosdelest Dec 21, 2023
39cbdff
Remove unneeded code for retrieving IndexMetadata.fieldsForModels
carlosdelest Dec 21, 2023
0e493ee
Change inference result classes to server
carlosdelest Dec 21, 2023
90457b2
First version of TransportBulkAction
carlosdelest Dec 22, 2023
14a40d7
Merge remote-tracking branch 'mikep/store-semantic_text-model-info-in…
carlosdelest Dec 22, 2023
41a5274
Working version, no threading yet
carlosdelest Dec 22, 2023
49793f1
Refactoring - used RefCountingRunnable
carlosdelest Dec 22, 2023
1d18936
More refactoring
carlosdelest Dec 22, 2023
89249ec
License headers
carlosdelest Dec 22, 2023
4802b23
License headers
carlosdelest Dec 22, 2023
9bb66e1
More refactoring around runnables
carlosdelest Dec 22, 2023
7f28aba
Spotless
carlosdelest Dec 22, 2023
a0a7b58
Add comments
carlosdelest Dec 22, 2023
607f005
Spotless
carlosdelest Dec 22, 2023
6d781bc
First working test version
carlosdelest Jan 8, 2024
b609472
Multiple inference fields test
carlosdelest Jan 9, 2024
e4a6a67
Added test cases
carlosdelest Jan 9, 2024
477b89c
Style fixes
carlosdelest Jan 9, 2024
ca845c6
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Jan 9, 2024
fa55298
Remove unused import
carlosdelest Jan 9, 2024
9e73ab5
Merge remote-tracking branch 'origin/main' into carlosdelest/semantic…
carlosdelest Jan 10, 2024
39061b4
First attempt on creating an IT test
carlosdelest Jan 10, 2024
91d7771
Revert "Change inference result classes to server"
carlosdelest Jan 10, 2024
36650d3
Revert "Moved InferenceAction and result classes to server"
carlosdelest Jan 10, 2024
360be07
Add a new InferenceProvider interface to avoid moving InferenceAction…
carlosdelest Jan 10, 2024
25e4fc4
Use InferenceProvider in TransportBulkAction instead of directly invo…
carlosdelest Jan 10, 2024
fac2913
First version for adding InferenceProviderPlugin to the InferencePlugin
carlosdelest Jan 10, 2024
fb02ae4
Adding IndexMetadata support for fieldsForModels as it stands today
carlosdelest Jan 10, 2024
3a24ea8
Merge branch 'carlosdelest/semantic-text-inference-ml-alternative' in…
carlosdelest Jan 10, 2024
6034579
Revert "First attempt on creating an IT test"
carlosdelest Jan 10, 2024
6ec089e
spotless
carlosdelest Jan 10, 2024
85eeec0
Remove changes from other branches
carlosdelest Jan 10, 2024
b90d6ad
Add javadoc
carlosdelest Jan 10, 2024
130cc82
Remove changes from other branches
carlosdelest Jan 10, 2024
d66951b
Makes InferenceProvider non null to deal with injection
carlosdelest Jan 11, 2024
d798396
Spotless
carlosdelest Jan 11, 2024
33600bf
Implement missing method
carlosdelest Jan 11, 2024
8816f0b
Fix tests and remove useless exception from interface
carlosdelest Jan 11, 2024
33b325b
Remove references to the removed exception - I'm hopefully tired and …
carlosdelest Jan 11, 2024
8aa0562
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Jan 12, 2024
e14ef02
Add back TransportVersions for semantic text
carlosdelest Jan 12, 2024
d1bc78f
Add inference service param needed
carlosdelest Jan 15, 2024
10a0eda
Performs inference even if text value is the same as previous
carlosdelest Jan 15, 2024
d39f2c9
Fix typo
carlosdelest Jan 16, 2024
f3f008f
Add warn when inference provider is not found
carlosdelest Jan 16, 2024
8389572
Merge remote-tracking branch 'origin/feature/semantic-text' into carl…
carlosdelest Jan 18, 2024
b2aab09
Removed changes from MappingMetadata
carlosdelest Jan 18, 2024
4947b1a
Refactor inference into a separate class
carlosdelest Jan 18, 2024
8ec8d05
Merge remote-tracking branch 'origin/feature/semantic-text' into carl…
carlosdelest Jan 24, 2024
daf0bfc
Move ModelRegistry as an interface to server
carlosdelest Jan 30, 2024
c4e66cf
Take back InferenceProviderPlugin changes
carlosdelest Jan 30, 2024
5397da7
First version with ModelRegistry / InferenceServiceRegistry
carlosdelest Jan 30, 2024
106e8b7
Move required string constants to server, adjust inference results in…
carlosdelest Jan 31, 2024
e19c4df
Replace consumers with listener constructs
carlosdelest Jan 31, 2024
e542655
Baby steps for replacing custom listeners with ActionListeners
carlosdelest Jan 31, 2024
3293be4
Get more similar interfaces for item processors
carlosdelest Jan 31, 2024
d12e31e
More changes to listeners
carlosdelest Jan 31, 2024
cedca07
Refactorings to create the inference provider
carlosdelest Feb 1, 2024
134dd00
Minor refactorings
carlosdelest Feb 1, 2024
3d0b537
Minor refactorings
carlosdelest Feb 1, 2024
dd0f2e9
Remove references from InferenceProvider from tests, remove current i…
carlosdelest Feb 1, 2024
2386a3b
Fix existing semantic text fields
carlosdelest Feb 1, 2024
3c3169d
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Feb 1, 2024
3c9c32f
Include changes from main in ModelRegistry
carlosdelest Feb 1, 2024
fb9b9e5
Fix merge from main
carlosdelest Feb 1, 2024
7a47fd7
Fix merge from main
carlosdelest Feb 1, 2024
246a80d
Merge remote-tracking branch 'origin/feature/semantic-text' into carl…
carlosdelest Feb 1, 2024
7100253
Fix merge from main
carlosdelest Feb 2, 2024
ecd9cf6
Add InferencePlugin changes for providing ModelRegistry and Inference…
carlosdelest Feb 2, 2024
3afd17d
Merge remote-tracking branch 'origin/feature/semantic-text' into carl…
carlosdelest Feb 2, 2024
878611e
Fix index version
carlosdelest Feb 2, 2024
e20bba5
Fix error when marking bulk items as null
carlosdelest Feb 5, 2024
df5f799
First test version
carlosdelest Feb 5, 2024
0465118
Add multiple fields to test
carlosdelest Feb 5, 2024
d97a043
Add failing inference test
carlosdelest Feb 5, 2024
9c8cd37
Add test for inference id not found
carlosdelest Feb 5, 2024
3798944
Tests improvements
carlosdelest Feb 6, 2024
7738460
Add bulk shard failure test
carlosdelest Feb 6, 2024
c4154b9
Avoid removing bulk items from request on failure, fix tests
carlosdelest Feb 6, 2024
175051b
Move semantic_text field mappers to inference plugin
carlosdelest Feb 6, 2024
3b03f7a
Merge branch 'carlosdelest/semantic-text-move-mappers-to-inference' i…
carlosdelest Feb 6, 2024
cb1f270
Remove @Nullable annotations for registries
carlosdelest Feb 6, 2024
7996244
Add YAML REST test scaffolding
carlosdelest Feb 7, 2024
fbce1d4
First test version
carlosdelest Feb 7, 2024
b12ea91
First test version
carlosdelest Feb 7, 2024
b17a4cc
Add tests
carlosdelest Feb 7, 2024
fb7f9d3
Fix bug for re-calculating inference results
carlosdelest Feb 7, 2024
0ad9496
Merge branch 'feature/semantic-text' into carlosdelest/semantic-text-…
carlosdelest Feb 7, 2024
e5ee956
Fix merge
carlosdelest Feb 7, 2024
f818bd0
Fix javadoc for SemanticTextInferenceResultFieldMapper
carlosdelest Feb 7, 2024
4fdd65e
Remove unnecessary class
carlosdelest Feb 7, 2024
75cbe3d
Fix merge with main
carlosdelest Feb 7, 2024
28e64d8
Add comments on failure to load method
carlosdelest Feb 8, 2024
792f3de
Moved yamlRestTest directory from qa to inference
carlosdelest Feb 8, 2024
2b944b1
Add cast checks and refactored a bit error handling
carlosdelest Feb 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 92 additions & 22 deletions server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
Expand All @@ -35,6 +36,8 @@
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.IndexClosedException;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -44,6 +47,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;

import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY;
Expand All @@ -69,6 +73,8 @@ final class BulkOperation extends ActionRunnable<BulkResponse> {
private final LongSupplier relativeTimeProvider;
private IndexNameExpressionResolver indexNameExpressionResolver;
private NodeClient client;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;

BulkOperation(
Task task,
Expand All @@ -82,6 +88,8 @@ final class BulkOperation extends ActionRunnable<BulkResponse> {
IndexNameExpressionResolver indexNameExpressionResolver,
LongSupplier relativeTimeProvider,
long startTimeNanos,
ModelRegistry modelRegistry,
InferenceServiceRegistry inferenceServiceRegistry,
ActionListener<BulkResponse> listener
) {
super(listener);
Expand All @@ -97,6 +105,8 @@ final class BulkOperation extends ActionRunnable<BulkResponse> {
this.relativeTimeProvider = relativeTimeProvider;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.client = client;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext());
}

Expand Down Expand Up @@ -189,37 +199,99 @@ private void executeBulkRequestsByShard(Map<ShardId, List<BulkItemRequest>> requ
return;
}

String nodeId = clusterService.localNode().getId();
BulkShardRequestInferenceProvider.getInstance(
inferenceServiceRegistry,
modelRegistry,
clusterState,
requestsByShard.keySet(),
new ActionListener<BulkShardRequestInferenceProvider>() {
@Override
public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) {
processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider);
}

@Override
public void onFailure(Exception e) {
throw new ElasticsearchException("Error loading inference models", e);
}
}
);
}

void processRequestsByShards(
Map<ShardId, List<BulkItemRequest>> requestsByShard,
ClusterState clusterState,
BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider
) {
Runnable onBulkItemsComplete = () -> {
listener.onResponse(
new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos))
);
// Allow memory for bulk shard request items to be reclaimed before all items have been completed
bulkRequest = null;
};

try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) {
for (Map.Entry<ShardId, List<BulkItemRequest>> entry : requestsByShard.entrySet()) {
final ShardId shardId = entry.getKey();
final List<BulkItemRequest> requests = entry.getValue();
BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests);

Releasable ref = bulkItemRequestCompleteRefCount.acquire();
final BiConsumer<BulkItemRequest, Exception> bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e);
bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() {
@Override
public void onResponse(BulkShardRequest inferenceBulkShardRequest) {
executeBulkShardRequest(
inferenceBulkShardRequest,
ActionListener.releaseAfter(ActionListener.noop(), ref),
bulkItemFailedListener
);
}

BulkShardRequest bulkShardRequest = new BulkShardRequest(
shardId,
bulkRequest.getRefreshPolicy(),
requests.toArray(new BulkItemRequest[0])
);
bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards());
bulkShardRequest.timeout(bulkRequest.timeout());
bulkShardRequest.routedBasedOnClusterVersion(clusterState.version());
if (task != null) {
bulkShardRequest.setParentTask(nodeId, task.getId());
}
executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire());
@Override
public void onFailure(Exception e) {
throw new ElasticsearchException("Error performing inference", e);
}
}, bulkItemFailedListener);
}
}
}

private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) {
private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List<BulkItemRequest> requests) {
BulkShardRequest bulkShardRequest = new BulkShardRequest(
shardId,
bulkRequest.getRefreshPolicy(),
requests.toArray(new BulkItemRequest[0])
);
bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards());
bulkShardRequest.timeout(bulkRequest.timeout());
bulkShardRequest.routedBasedOnClusterVersion(clusterState.version());
if (task != null) {
bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId());
}
return bulkShardRequest;
}

// When an item fails, store the failure in the responses array
private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) {
final String indexName = itemRequest.index();

DocWriteRequest<?> docWriteRequest = itemRequest.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e);
responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure));
}

private void executeBulkShardRequest(
BulkShardRequest bulkShardRequest,
ActionListener<BulkShardRequest> listener,
BiConsumer<BulkItemRequest, Exception> bulkItemErrorListener
) {
if (bulkShardRequest.items().length == 0) {
// No requests to execute due to previous errors, terminate early
listener.onResponse(bulkShardRequest);
return;
}

client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() {
@Override
public void onResponse(BulkShardResponse bulkShardResponse) {
Expand All @@ -230,19 +302,17 @@ public void onResponse(BulkShardResponse bulkShardResponse) {
}
responses.set(bulkItemResponse.getItemId(), bulkItemResponse);
}
releaseOnFinish.close();
listener.onResponse(bulkShardRequest);
}

@Override
public void onFailure(Exception e) {
// create failures for all relevant requests
for (BulkItemRequest request : bulkShardRequest.items()) {
final String indexName = request.index();
DocWriteRequest<?> docWriteRequest = request.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e);
responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure));
BulkItemRequest[] items = bulkShardRequest.items();
for (BulkItemRequest item : items) {
bulkItemErrorListener.accept(item, e);
}
releaseOnFinish.close();
listener.onFailure(e);
}
});
}
Expand Down
Loading