Skip to content

Commit

Permalink
[feature/semantic-text] Simplify the integration of the field inferen…
Browse files Browse the repository at this point in the history
…ce metadata in `IndexMetadata` (#106743)

This change refactors the integration of the field inference metadata in IndexMetadata. Instead of partial diffs, the new class simply sends the entire object as diff if it has changed.
This PR also rename the fields and methods related to the inference fields consistently.
The inference phase (in the transport shard bulk action) is also changed so that inference is not called if:

The document contains a value for the inference input.
The document also contains a value for the inference results of that field (in the _inference map).
If the document contains no value for the inference input but an inference result for that field, it is marked as failed.
---------

Co-authored-by: carlosdelest <[email protected]>
  • Loading branch information
jimczi and carlosdelest authored Mar 28, 2024
1 parent 122e439 commit ef3abd9
Show file tree
Hide file tree
Showing 27 changed files with 758 additions and 573 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder;
import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata;
import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields;
import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange;
import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder;
import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo;
Expand Down Expand Up @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) {
builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0));
break;
case 3:
builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true));
builder.putInferenceFields(randomInferenceFields());
break;
default:
throw new IllegalArgumentException("Shouldn't be here");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ private void executeBulkRequestsByShard(
requests.toArray(new BulkItemRequest[0])
);
var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName());
if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) {
bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata());
if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) {
bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields());
}
bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards());
bulkShardRequest.timeout(bulkRequest.timeout());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import org.elasticsearch.action.support.replication.ReplicatedWriteRequest;
import org.elasticsearch.action.support.replication.ReplicationRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.transport.RawIndexingDataTransportRequest;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequest>
Expand All @@ -34,7 +35,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequ

private final BulkItemRequest[] items;

private transient FieldInferenceMetadata fieldsInferenceMetadataMap = null;
private transient Map<String, InferenceFieldMetadata> inferenceFieldMap = null;

public BulkShardRequest(StreamInput in) throws IOException {
super(in);
Expand All @@ -51,24 +52,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe
* Public for test
* Set the transient metadata indicating that this request requires running inference before proceeding.
*/
public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) {
this.fieldsInferenceMetadataMap = fieldsInferenceMetadata;
public void setInferenceFieldMap(Map<String, InferenceFieldMetadata> fieldInferenceMap) {
this.inferenceFieldMap = fieldInferenceMap;
}

/**
* Consumes the inference metadata to execute inference on the bulk items just once.
*/
public FieldInferenceMetadata consumeFieldInferenceMetadata() {
FieldInferenceMetadata ret = fieldsInferenceMetadataMap;
fieldsInferenceMetadataMap = null;
public Map<String, InferenceFieldMetadata> consumeInferenceFieldMap() {
Map<String, InferenceFieldMetadata> ret = inferenceFieldMap;
inferenceFieldMap = null;
return ret;
}

/**
* Public for test
*/
public FieldInferenceMetadata getFieldsInferenceMetadataMap() {
return fieldsInferenceMetadataMap;
public Map<String, InferenceFieldMetadata> getInferenceFieldMap() {
return inferenceFieldMap;
}

public long totalSizeInBytes() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.mapper.InferenceFieldMapper;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.IndicesService;
Expand Down Expand Up @@ -184,7 +185,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener<
final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis);
switch (result.getResponseResult()) {
case CREATED -> {
IndexRequest upsertRequest = result.action();
IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action());
// we fetch it from the index request so we don't generate the bytes twice, its already done in the index request
final BytesReference upsertSourceBytes = upsertRequest.source();
client.bulk(
Expand Down Expand Up @@ -226,7 +227,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener<
);
}
case UPDATED -> {
IndexRequest indexRequest = result.action();
IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action());
// we fetch it from the index request so we don't generate the bytes twice, its already done in the index request
final BytesReference indexSourceBytes = indexRequest.source();
client.bulk(
Expand Down Expand Up @@ -335,4 +336,15 @@ private void handleUpdateFailureWithRetry(
}
listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause));
}

private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) {
var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields();
if (inferenceMetadata.isEmpty()) {
return request;
}
Map<String, Object> docMap = request.sourceAsMap();
docMap.remove(InferenceFieldMapper.NAME);
request.source(docMap);
return request;
}
}

This file was deleted.

Loading

0 comments on commit ef3abd9

Please sign in to comment.