From ef3abd96171e7386caf98a6884c48ed1cc24db0a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 28 Mar 2024 14:25:51 +0000 Subject: [PATCH] [feature/semantic-text] Simplify the integration of the field inference metadata in `IndexMetadata` (#106743) This change refactors the integration of the field inference metadata in IndexMetadata. Instead of partial diffs, the new class simply sends the entire object as diff if it has changed. This PR also rename the fields and methods related to the inference fields consistently. The inference phase (in the transport shard bulk action) is also changed so that inference is not called if: The document contains a value for the inference input. The document also contains a value for the inference results of that field (in the _inference map). If the document contains no value for the inference input but an inference result for that field, it is marked as failed. --------- Co-authored-by: carlosdelest --- .../cluster/ClusterStateDiffIT.java | 4 +- .../action/bulk/BulkOperation.java | 4 +- .../action/bulk/BulkShardRequest.java | 19 +- .../action/update/TransportUpdateAction.java | 16 +- .../metadata/FieldInferenceMetadata.java | 190 -------------- .../cluster/metadata/IndexMetadata.java | 112 ++++---- .../metadata/InferenceFieldMetadata.java | 125 +++++++++ .../metadata/MetadataCreateIndexService.java | 11 +- .../metadata/MetadataMappingService.java | 8 +- .../index/mapper/FieldTypeLookup.java | 14 - .../index/mapper/InferenceFieldMapper.java | 28 ++ .../index/mapper/InferenceModelFieldType.java | 21 -- .../index/mapper/MapperMergeContext.java | 4 +- .../index/mapper/MappingLookup.java | 35 ++- .../cluster/metadata/IndexMetadataTests.java | 40 ++- .../metadata/InferenceFieldMetadataTests.java | 66 +++++ .../index/mapper/FieldTypeLookupTests.java | 28 -- .../index/mapper/MappingLookupTests.java | 18 -- .../mapper/MockInferenceModelFieldType.java | 45 ---- .../xpack/inference/InferencePlugin.java | 10 +- .../ShardBulkInferenceActionFilter.java | 248 +++++++++++++----- .../mapper/InferenceMetadataFieldMapper.java | 9 +- .../mapper/SemanticTextFieldMapper.java | 74 ++++-- .../SemanticTextClusterMetadataTests.java | 10 +- .../ShardBulkInferenceActionFilterTests.java | 65 ++--- .../inference/10_semantic_text_inference.yml | 107 ++++++-- .../20_semantic_text_field_mapper.yml | 20 ++ 27 files changed, 758 insertions(+), 573 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java create mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java delete mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java create mode 100644 server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java delete mode 100644 test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index fbb3016b925da..e0dbc74567053 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -61,7 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; -import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); + builder.putInferenceFields(randomInferenceFields()); break; default: throw new IllegalArgumentException("Shouldn't be here"); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index a6439769b51b4..e66426562a92e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -294,8 +294,8 @@ private void executeBulkRequestsByShard( requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); + if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 39fa791a3e27d..8d1618b443ace 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,7 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -23,6 +23,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +35,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest inferenceFieldMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +52,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { - this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; + public void setInferenceFieldMap(Map fieldInferenceMap) { + this.inferenceFieldMap = fieldInferenceMap; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public FieldInferenceMetadata consumeFieldInferenceMetadata() { - FieldInferenceMetadata ret = fieldsInferenceMetadataMap; - fieldsInferenceMetadataMap = null; + public Map consumeInferenceFieldMap() { + Map ret = inferenceFieldMap; + inferenceFieldMap = null; return ret; } /** * Public for test */ - public FieldInferenceMetadata getFieldsInferenceMetadataMap() { - return fieldsInferenceMetadataMap; + public Map getInferenceFieldMap() { + return inferenceFieldMap; } public long totalSizeInBytes() { diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 63ae56bfbd047..36a47bc7e02e9 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -40,6 +40,7 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -184,7 +185,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis); switch (result.getResponseResult()) { case CREATED -> { - IndexRequest upsertRequest = result.action(); + IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference upsertSourceBytes = upsertRequest.source(); client.bulk( @@ -226,7 +227,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< ); } case UPDATED -> { - IndexRequest indexRequest = result.action(); + IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference indexSourceBytes = indexRequest.source(); client.bulk( @@ -335,4 +336,15 @@ private void handleUpdateFailureWithRetry( } listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause)); } + + private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) { + var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields(); + if (inferenceMetadata.isEmpty()) { + return request; + } + Map docMap = request.sourceAsMap(); + docMap.remove(InferenceFieldMapper.NAME); + request.source(docMap); + return request; + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java deleted file mode 100644 index 349706c139127..0000000000000 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.cluster.metadata; - -import org.elasticsearch.cluster.Diff; -import org.elasticsearch.cluster.Diffable; -import org.elasticsearch.cluster.DiffableUtils; -import org.elasticsearch.cluster.SimpleDiffable; -import org.elasticsearch.common.collect.ImmutableOpenMap; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.mapper.MappingLookup; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - -/** - * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator - * node, which not necessarily has mapping information. - */ -public class FieldInferenceMetadata implements Diffable, ToXContentFragment { - - private final ImmutableOpenMap fieldInferenceOptions; - - public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); - - public FieldInferenceMetadata(MappingLookup mappingLookup) { - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); - mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { - builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); - }); - fieldInferenceOptions = builder.build(); - } - - public FieldInferenceMetadata(StreamInput in) throws IOException { - fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); - } - - public FieldInferenceMetadata(Map fieldsToInferenceMap) { - fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); - } - - public ImmutableOpenMap getFieldInferenceOptions() { - return fieldInferenceOptions; - } - - public boolean isEmpty() { - return fieldInferenceOptions.isEmpty(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(fieldInferenceOptions); - return builder; - } - - public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); - } - - @Override - public int hashCode() { - return Objects.hash(fieldInferenceOptions); - } - - @Override - public Diff diff(FieldInferenceMetadata previousState) { - if (previousState == null) { - previousState = EMPTY; - } - return new FieldInferenceMetadataDiff(previousState, this); - } - - static class FieldInferenceMetadataDiff implements Diff { - - public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( - FieldInferenceMetadata.EMPTY, - FieldInferenceMetadata.EMPTY - ); - - private final Diff> fieldInferenceMapDiff; - - private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = - new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); - - FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { - fieldInferenceMapDiff = DiffableUtils.diff( - before.fieldInferenceOptions, - after.fieldInferenceOptions, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - FieldInferenceMetadataDiff(StreamInput in) throws IOException { - fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - public static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - fieldInferenceMapDiff.writeTo(out); - } - - @Override - public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); - } - } - - public record FieldInferenceOptions(String inferenceId, Set sourceFields) - implements - SimpleDiffable, - ToXContentFragment { - - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); - - FieldInferenceOptions(StreamInput in) throws IOException { - this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - out.writeStringCollection(sourceFields); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); - builder.endObject(); - return builder; - } - - public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "field_inference_parser", - false, - (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) - ); - - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 89c925427cf88..b66da654f8a1c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -576,6 +576,8 @@ public Iterator> settings() { @Nullable private final MappingMetadata mapping; + private final ImmutableOpenMap inferenceFields; + private final ImmutableOpenMap customData; private final Map> inSyncAllocationIds; @@ -631,7 +633,6 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -645,6 +646,7 @@ private IndexMetadata( final int numberOfReplicas, final Settings settings, final MappingMetadata mapping, + final ImmutableOpenMap inferenceFields, final ImmutableOpenMap aliases, final ImmutableOpenMap customData, final Map> inSyncAllocationIds, @@ -677,8 +679,7 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast, - @Nullable FieldInferenceMetadata fieldInferenceMetadata + @Nullable Long shardSizeInBytesForecast ) { this.index = index; this.version = version; @@ -696,6 +697,7 @@ private IndexMetadata( this.totalNumberOfShards = numberOfShards * (numberOfReplicas + 1); this.settings = settings; this.mapping = mapping; + this.inferenceFields = inferenceFields; this.customData = customData; this.aliases = aliases; this.inSyncAllocationIds = inSyncAllocationIds; @@ -734,7 +736,6 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -753,6 +754,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.numberOfReplicas, this.settings, mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -785,8 +787,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -812,6 +813,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, Maps.copyMapWithAddedOrReplacedEntry(this.inSyncAllocationIds, shardId, Set.copyOf(inSyncSet)), @@ -844,8 +846,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -869,6 +870,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -901,8 +903,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -926,6 +927,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -958,8 +960,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -979,6 +980,7 @@ public IndexMetadata withIncrementedVersion() { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -1011,8 +1013,7 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -1203,6 +1204,10 @@ public MappingMetadata mapping() { return mapping; } + public Map getInferenceFields() { + return inferenceFields; + } + @Nullable public IndexMetadataStats getStats() { return stats; @@ -1216,10 +1221,6 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public FieldInferenceMetadata getFieldInferenceMetadata() { - return fieldInferenceMetadata; - } - public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1417,7 +1418,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { + if (inferenceFields.equals(that.inferenceFields) == false) { return false; } if (isSystem != that.isSystem) { @@ -1440,7 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldInferenceMetadata.hashCode(); + result = 31 * result + inferenceFields.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1487,6 +1488,7 @@ private static class IndexMetadataDiff implements Diff { @Nullable private final Diff settingsDiff; private final Diff> mappings; + private final Diff> inferenceFields; private final Diff> aliases; private final Diff> customData; private final Diff>> inSyncAllocationIds; @@ -1496,7 +1498,6 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1519,6 +1520,7 @@ private static class IndexMetadataDiff implements Diff { : ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, after.mapping).build(), DiffableUtils.getStringKeySerializer() ); + inferenceFields = DiffableUtils.diff(before.inferenceFields, after.inferenceFields, DiffableUtils.getStringKeySerializer()); aliases = DiffableUtils.diff(before.aliases, after.aliases, DiffableUtils.getStringKeySerializer()); customData = DiffableUtils.diff(before.customData, after.customData, DiffableUtils.getStringKeySerializer()); inSyncAllocationIds = DiffableUtils.diff( @@ -1533,7 +1535,6 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1544,6 +1545,8 @@ private static class IndexMetadataDiff implements Diff { new DiffableUtils.DiffableValueReader<>(DiffableStringMap::readFrom, DiffableStringMap::readDiffFrom); private static final DiffableUtils.DiffableValueReader ROLLOVER_INFO_DIFF_VALUE_READER = new DiffableUtils.DiffableValueReader<>(RolloverInfo::new, RolloverInfo::readDiffFrom); + private static final DiffableUtils.DiffableValueReader INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(InferenceFieldMetadata::new, InferenceFieldMetadata::readDiffFrom); IndexMetadataDiff(StreamInput in) throws IOException { index = in.readString(); @@ -1566,6 +1569,15 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER + ); + } else { + inferenceFields = DiffableUtils.emptyDiff(); + } aliases = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), ALIAS_METADATA_DIFF_VALUE_READER); customData = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), CUSTOM_DIFF_VALUE_READER); inSyncAllocationIds = DiffableUtils.readJdkMapDiff( @@ -1593,11 +1605,6 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); - } else { - fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; - } } @Override @@ -1620,6 +1627,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields.writeTo(out); + } aliases.writeTo(out); customData.writeTo(out); inSyncAllocationIds.writeTo(out); @@ -1633,9 +1643,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeOptionalWriteable(fieldInferenceMetadata); - } } @Override @@ -1656,6 +1663,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.mapping = mappings.apply( ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, part.mapping).build() ).get(MapperService.SINGLE_MAPPING_NAME); + builder.inferenceFields.putAllFromMap(inferenceFields.apply(part.inferenceFields)); builder.aliases.putAllFromMap(aliases.apply(part.aliases)); builder.customMetadata.putAllFromMap(customData.apply(part.customData)); builder.inSyncAllocationIds.putAll(inSyncAllocationIds.apply(part.inSyncAllocationIds)); @@ -1665,7 +1673,6 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1702,6 +1709,10 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); + } int aliasesSize = in.readVInt(); for (int i = 0; i < aliasesSize; i++) { AliasMetadata aliasMd = new AliasMetadata(in); @@ -1733,9 +1744,6 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function inferenceFields; private final ImmutableOpenMap.Builder aliases; private final ImmutableOpenMap.Builder customMetadata; private final Map> inSyncAllocationIds; @@ -1834,10 +1843,10 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; public Builder(String index) { this.index = index; + this.inferenceFields = ImmutableOpenMap.builder(); this.aliases = ImmutableOpenMap.builder(); this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); @@ -1855,6 +1864,7 @@ public Builder(IndexMetadata indexMetadata) { this.settings = indexMetadata.getSettings(); this.primaryTerms = indexMetadata.primaryTerms.clone(); this.mapping = indexMetadata.mapping; + this.inferenceFields = ImmutableOpenMap.builder(indexMetadata.inferenceFields); this.aliases = ImmutableOpenMap.builder(indexMetadata.aliases); this.customMetadata = ImmutableOpenMap.builder(indexMetadata.customData); this.routingNumShards = indexMetadata.routingNumShards; @@ -1866,7 +1876,6 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2096,8 +2105,13 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); + public Builder putInferenceField(InferenceFieldMetadata value) { + this.inferenceFields.put(value.getName(), value); + return this; + } + + public Builder putInferenceFields(Map values) { + this.inferenceFields.putAllFromMap(values); return this; } @@ -2263,6 +2277,7 @@ IndexMetadata build(boolean repair) { numberOfReplicas, settings, mapping, + inferenceFields.build(), aliasesMap, newCustomMetadata, Map.ofEntries(denseInSyncAllocationIds), @@ -2295,8 +2310,7 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast, - fieldInferenceMetadata + shardSizeInBytesForecast ); } @@ -2422,8 +2436,12 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { - builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); + if (indexMetadata.getInferenceFields().isEmpty() == false) { + builder.startObject(KEY_FIELD_INFERENCE); + for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { + field.toXContent(builder, params); + } + builder.endObject(); } builder.endObject(); @@ -2504,7 +2522,9 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { + private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private final String name; + private final String inferenceId; + private final String[] sourceFields; + + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this.name = Objects.requireNonNull(name); + this.inferenceId = Objects.requireNonNull(inferenceId); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public InferenceFieldMetadata(StreamInput input) throws IOException { + this.name = input.readString(); + this.inferenceId = input.readString(); + this.sourceFields = input.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(inferenceId); + out.writeStringArray(sourceFields); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceFieldMetadata that = (InferenceFieldMetadata) o; + return inferenceId.equals(that.inferenceId) && Arrays.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + int result = Objects.hash(inferenceId); + result = 31 * result + Arrays.hashCode(sourceFields); + return result; + } + + public String getName() { + return name; + } + + public String getInferenceId() { + return inferenceId; + } + + public String[] getSourceFields() { + return sourceFields; + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(name); + builder.field(INFERENCE_ID_FIELD, inferenceId); + builder.array(SOURCE_FIELDS_FIELD, sourceFields); + return builder.endObject(); + } + + public static InferenceFieldMetadata fromXContent(XContentParser parser) throws IOException { + final String name = parser.currentName(); + + XContentParser.Token token = parser.nextToken(); + if (token == null) { + // no data... + return null; + } + String currentFieldName = null; + String inferenceId = null; + List inputFields = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_STRING) { + if (INFERENCE_ID_FIELD.equals(currentFieldName)) { + inferenceId = parser.text(); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_STRING) { + inputFields.add(parser.text()); + } else { + parser.skipChildren(); + } + } + } + } else { + parser.skipChildren(); + } + } + return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index 96ca7a15edc30..52642e1de8ac9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1263,12 +1263,11 @@ static IndexMetadata buildIndexMetadata( indexMetadataBuilder.system(isSystem); // now, update the mappings with the actual source Map mappingsMetadata = new HashMap<>(); - DocumentMapper mapper = documentMapperSupplier.get(); - if (mapper != null) { - MappingMetadata mappingMd = new MappingMetadata(mapper); - mappingsMetadata.put(mapper.type(), mappingMd); - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); - indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); + DocumentMapper docMapper = documentMapperSupplier.get(); + if (docMapper != null) { + MappingMetadata mappingMd = new MappingMetadata(docMapper); + mappingsMetadata.put(docMapper.type(), mappingMd); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 0e31592991369..e7c2bb9ae9b9a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -201,10 +201,10 @@ private static ClusterState applyRequest( IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(indexMetadata); // Mapping updates on a single type may have side-effects on other types so we need to // update mapping metadata on all types - DocumentMapper mapper = mapperService.documentMapper(); - if (mapper != null) { - indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); + DocumentMapper docMapper = mapperService.documentMapper(); + if (docMapper != null) { + indexMetadataBuilder.putMapping(new MappingMetadata(docMapper)); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 0741cfa682b74..5e3dbe9590b99 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,11 +36,6 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; - /** - * A map from inference model ID to all fields that use the model to generate embeddings. - */ - private final Map inferenceIdsForFields; - private final int maxParentPathDots; FieldTypeLookup( @@ -53,7 +48,6 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -71,9 +65,6 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } - if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); - } } int maxParentPathDots = 0; @@ -106,7 +97,6 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -215,10 +205,6 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map getInferenceIdsForFields() { - return inferenceIdsForFields; - } - /** * If field is a leaf multi-field return the path to the parent field. Otherwise, return null. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java new file mode 100644 index 0000000000000..078ef391f17ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.inference.InferenceService; + +import java.util.Set; + +/** + * Field mapper that requires to transform its input before indexation through the {@link InferenceService}. + */ +public interface InferenceFieldMapper { + String NAME = "_inference"; + + /** + * Retrieve the inference metadata associated with this mapper. + * + * @param sourcePaths The source path that populates the input for the field (before inference) + */ + InferenceFieldMetadata getMetadata(Set sourcePaths); +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java deleted file mode 100644 index 6e12a204ed7d0..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -/** - * Field type that uses an inference model. - */ -public interface InferenceModelFieldType { - /** - * Retrieve inference model used by the field type. - * - * @return model id used by the field type - */ - String getInferenceId(); -} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java index 8f8854ad47c7d..ddf6f339cbbb6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java @@ -46,7 +46,7 @@ public static MapperMergeContext from(MapperBuilderContext mapperBuilderContext, * @param name the name of the child context * @return a new {@link MapperMergeContext} with this context as its parent */ - MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { + public MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { return createChildContext(mapperBuilderContext.createChildContext(name, dynamic)); } @@ -60,7 +60,7 @@ MapperMergeContext createChildContext(MapperBuilderContext childContext) { return new MapperMergeContext(childContext, newFieldsBudget); } - MapperBuilderContext getMapperBuilderContext() { + public MapperBuilderContext getMapperBuilderContext() { return mapperBuilderContext; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index c2bd95115f27e..bf879f30e5a29 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -10,9 +10,11 @@ import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.inference.InferenceService; import java.util.ArrayList; import java.util.Collection; @@ -47,6 +49,7 @@ private CacheKey() {} /** Full field name to mapper */ private final Map fieldMappers; private final Map objectMappers; + private final Map inferenceFields; private final int runtimeFieldMappersCount; private final NestedLookup nestedLookup; private final FieldTypeLookup fieldTypeLookup; @@ -84,12 +87,12 @@ private static void collect( Collection fieldMappers, Collection fieldAliasMappers ) { - if (mapper instanceof ObjectMapper) { - objectMappers.add((ObjectMapper) mapper); - } else if (mapper instanceof FieldMapper) { - fieldMappers.add((FieldMapper) mapper); - } else if (mapper instanceof FieldAliasMapper) { - fieldAliasMappers.add((FieldAliasMapper) mapper); + if (mapper instanceof ObjectMapper objectMapper) { + objectMappers.add(objectMapper); + } else if (mapper instanceof FieldMapper fieldMapper) { + fieldMappers.add(fieldMapper); + } else if (mapper instanceof FieldAliasMapper fieldAliasMapper) { + fieldAliasMappers.add(fieldAliasMapper); } else { throw new IllegalStateException("Unrecognized mapper type [" + mapper.getClass().getSimpleName() + "]."); } @@ -174,6 +177,15 @@ private MappingLookup( final Collection runtimeFields = mapping.getRoot().runtimeFields(); this.fieldTypeLookup = new FieldTypeLookup(mappers, aliasMappers, runtimeFields); + + Map inferenceFields = new HashMap<>(); + for (FieldMapper mapper : mappers) { + if (mapper instanceof InferenceFieldMapper inferenceFieldMapper) { + inferenceFields.put(mapper.name(), inferenceFieldMapper.getMetadata(fieldTypeLookup.sourcePaths(mapper.name()))); + } + } + this.inferenceFields = Map.copyOf(inferenceFields); + if (runtimeFields.isEmpty()) { // without runtime fields this is the same as the field type lookup this.indexTimeLookup = fieldTypeLookup; @@ -360,6 +372,13 @@ public Map objectMappers() { return objectMappers; } + /** + * Returns a map containing all fields that require to run inference (through the {@link InferenceService} prior to indexation. + */ + public Map inferenceFields() { + return inferenceFields; + } + public NestedLookup nestedLookup() { return nestedLookup; } @@ -523,8 +542,4 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } - - public Map getInferenceIdsForFields() { - return fieldTypeLookup.getInferenceIdsForFields(); - } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b32873df71365..45ffba25eb558 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -27,7 +27,6 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -84,7 +83,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); + Map dynamicFields = randomInferenceFields(); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -110,7 +109,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldInferenceMetadata(fieldInferenceMetadata) + .putInferenceFields(dynamicFields) .build(); assertEquals(system, metadata.isSystem()); @@ -145,7 +144,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), fromXContentMeta.getInferenceFields()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -169,7 +168,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); } } @@ -553,35 +552,32 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldInferenceMetadata() { + public void testInferenceFieldMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); + assertTrue(idxMeta1.getInferenceFields().isEmpty()); - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); - assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); + Map dynamicFields = randomInferenceFields(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).putInferenceFields(dynamicFields).build(); + assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { - if (randomBoolean() && allowNull) { - return null; + public static Map randomInferenceFields() { + Map map = new HashMap<>(); + int numFields = randomIntBetween(0, 5); + for (int i = 0; i < numFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + map.put(field, randomInferenceFieldMetadata(field)); } - - Map fieldInferenceMap = randomMap( - 0, - 10, - () -> new Tuple<>(randomIdentifier(), randomFieldInference()) - ); - return new FieldInferenceMetadata(fieldInferenceMap); + return map; } - private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { - return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); + private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { + return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java new file mode 100644 index 0000000000000..958d86535ae76 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Predicate; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceFieldMetadataTests extends AbstractXContentTestCase { + + public void testSerialization() throws IOException { + final InferenceFieldMetadata before = createTestItem(); + final BytesStreamOutput out = new BytesStreamOutput(); + before.writeTo(out); + + final StreamInput in = out.bytes().streamInput(); + final InferenceFieldMetadata after = new InferenceFieldMetadata(in); + + assertThat(after, equalTo(before)); + } + + @Override + protected InferenceFieldMetadata createTestInstance() { + return createTestItem(); + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field + } + + @Override + protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + assertEquals(XContentParser.Token.FIELD_NAME, parser.currentToken()); + InferenceFieldMetadata inferenceMetadata = InferenceFieldMetadata.fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return inferenceMetadata; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + private static InferenceFieldMetadata createTestItem() { + String name = randomAlphaOfLengthBetween(3, 10); + String inferenceId = randomIdentifier(); + String[] inputFields = generateRandomStringArray(5, 10, false, false); + return new InferenceFieldMetadata(name, inferenceId, inputFields); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 932eac3e60d27..3f50b9fdf6621 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -16,7 +16,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -36,10 +35,6 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -47,10 +42,6 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -430,25 +421,6 @@ public void testRuntimeFieldNameOutsideContext() { } } - public void testInferenceModelFieldType() { - MockFieldMapper f1 = new MockFieldMapper(new MockInferenceModelFieldType("foo1", "bar1")); - MockFieldMapper f2 = new MockFieldMapper(new MockInferenceModelFieldType("foo2", "bar1")); - MockFieldMapper f3 = new MockFieldMapper(new MockInferenceModelFieldType("foo3", "bar2")); - - FieldTypeLookup lookup = new FieldTypeLookup(List.of(f1, f2, f3), emptyList(), emptyList()); - assertEquals(f1.fieldType(), lookup.get("foo1")); - assertEquals(f2.fieldType(), lookup.get("foo2")); - assertEquals(f3.fieldType(), lookup.get("foo3")); - - Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(3, inferenceIdsForFields.size()); - - assertEquals("bar1", inferenceIdsForFields.get("foo1")); - assertEquals("bar1", inferenceIdsForFields.get("foo2")); - assertEquals("bar2", inferenceIdsForFields.get("foo3")); - } - private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index bb337d0c61c93..0308dac5fa216 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -121,8 +121,6 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getInferenceIdsForFields()); - assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -190,22 +188,6 @@ public MetricType getMetricType() { ); } - public void testInferenceIdsForFields() { - MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); - MappingLookup mappingLookup = createMappingLookup( - Collections.singletonList(new MockFieldMapper(fieldType)), - emptyList(), - emptyList() - ); - assertEquals(1, size(mappingLookup.fieldMappers())); - assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - - Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(1, inferenceIdsForFields.size()); - assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); - } - private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java deleted file mode 100644 index 0d21134b5d9a9..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.util.Map; - -public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private static final String TYPE_NAME = "mock_inference_model_field_type"; - - private final String modelId; - - public MockInferenceModelFieldType(String name, String modelId) { - super(name, false, false, false, TextSearchInfo.NONE, Map.of()); - this.modelId = modelId; - } - - @Override - public String typeName() { - return TYPE_NAME; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public String getInferenceId() { - return modelId; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3fcd9049ae803..494d6918b6086 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -284,11 +284,17 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + if (SemanticTextFeature.isEnabled()) { + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + } + return Map.of(); } @Override public Collection getActionFilters() { - return singletonList(shardBulkInferenceActionFilter.get()); + if (SemanticTextFeature.isEnabled()) { + return singletonList(shardBulkInferenceActionFilter.get()); + } + return List.of(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 00dc195313a61..fef62051a6471 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -20,12 +20,11 @@ import org.elasticsearch.action.bulk.BulkShardRequest; import org.elasticsearch.action.bulk.TransportShardBulkAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -39,6 +38,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -49,19 +49,66 @@ import java.util.stream.Collectors; /** - * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} - * in the subsequent {@link TransportShardBulkAction} downstream. + * A {@link MappedActionFilter} intercepting {@link BulkShardRequest}s to apply inference on fields declared as + * {@link SemanticTextFieldMapper} in the index mapping. + * The source of each {@link BulkItemRequest} requiring inference is augmented with the results for each field + * under the {@link InferenceMetadataFieldMapper#NAME} section. + * For example, for an index with a semantic_text field named {@code my_semantic_field} the following source document: + *
+ *
+ * {
+ *      "my_semantic_text_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * is rewritten into: + *
+ *
+ * {
+ *      "_inference": {
+ *        "my_semantic_field": {
+ *          "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks": [
+ *                      {
+ *                             "inference": {
+ *                                 "lucas": 0.05212344,
+ *                                 "ty": 0.041213956,
+ *                                 "dragon": 0.50991,
+ *                                 "type": 0.23241979,
+ *                                 "dr": 1.9312073,
+ *                                 "##o": 0.2797593
+ *                             },
+ *                             "text": "these are not the droids you're looking for"
+ *                       }
+ *                  ]
+ *        }
+ *      }
+ *      "my_semantic_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * The rewriting process occurs on the bulk coordinator node, and the results are then passed downstream + * to the {@link TransportShardBulkAction} for actual indexing. + * + * TODO: batchSize should be configurable via a cluster setting */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + protected static final int DEFAULT_BATCH_SIZE = 512; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; + private final int batchSize; public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + } + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; + this.batchSize = batchSize; } @Override @@ -86,7 +133,7 @@ public void app switch (action) { case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); @@ -102,33 +149,33 @@ public void app } private void processBulkShardRequest( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); } private record InferenceProvider(InferenceService service, Model model) {} private record FieldInferenceRequest(int id, String field, String input) {} - private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + private record FieldInferenceResponse(String field, @Nullable Model model, @Nullable ChunkedInferenceServiceResults chunkedResults) {} private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final FieldInferenceMetadata fieldInferenceMetadata; + private final Map fieldInferenceMap; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - this.fieldInferenceMetadata = fieldInferenceMetadata; + this.fieldInferenceMap = fieldInferenceMap; this.bulkShardRequest = bulkShardRequest; this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); this.onCompletion = onCompletion; @@ -212,30 +259,49 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + int currentBatchSize = Math.min(requests.size(), batchSize); + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { - for (int i = 0; i < results.size(); i++) { - var request = requests.get(i); - var result = results.get(i); - var acc = inferenceResults.get(request.id); - acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + try { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + } + } finally { + onFinish(); } } @Override public void onFailure(Exception exc) { - for (int i = 0; i < requests.size(); i++) { - var request = requests.get(i); - inferenceResults.get(request.id).failures.add( - new ElasticsearchException( - "Exception when running inference id [{}] on field [{}]", - exc, - inferenceProvider.model.getInferenceEntityId(), - request.field - ) - ); + try { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } finally { + onFinish(); + } + } + + private void onFinish() { + if (nextBatch.isEmpty()) { + onFinish.close(); + } else { + executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); } } }; @@ -246,14 +312,33 @@ public void onFailure(Exception exc) { Map.of(), InputType.INGEST, new ChunkingOptions(null, null), - ActionListener.runAfter(completionListener, onFinish::close) + completionListener ); } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + FieldInferenceResponseAccumulator acc = inferenceResults.get(id); + if (acc == null) { + acc = new FieldInferenceResponseAccumulator( + id, + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ); + inferenceResults.set(id, acc); + } + return acc; + } + + private void addInferenceResponseFailure(int id, Exception failure) { + var acc = ensureResponseAccumulatorSlot(id); + acc.failures().add(failure); + } + /** - * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. - * If the response contains failures, the bulk item request is mark as failed for the downstream action. - * Otherwise, the source of the request is augmented with the field inference results. + * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is marked as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results under the + * {@link InferenceMetadataFieldMapper#NAME} field. */ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { if (response.failures().isEmpty() == false) { @@ -265,24 +350,37 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); Map newDocMap = indexRequest.sourceAsMap(); - Map inferenceMap = new LinkedHashMap<>(); - // ignore the existing inference map if any + Object inferenceObj = newDocMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()); + Map inferenceMap = XContentMapValues.nodeMapValue(inferenceObj, InferenceMetadataFieldMapper.NAME); newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { - try { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceMap, - fieldResponse.field(), - fieldResponse.model(), - fieldResponse.chunkedResults() - ); - } catch (Exception exc) { - item.abort(item.index(), exc); + if (fieldResponse.chunkedResults != null) { + try { + InferenceMetadataFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); + } + } else { + inferenceMap.remove(fieldResponse.field); } } indexRequest.source(newDocMap); } + /** + * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. + * If results are already populated for fields in the existing _inference object, + * the inference request for this specific field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link InferenceMetadataFieldMapper} + * during field indexing, where an error will be thrown if they mismatch or if the content is malformed. + * + * TODO: Should we validate the settings for pre-existing results here and apply the inference only if they differ? + */ private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { @@ -290,35 +388,57 @@ private Map> createFieldInferenceRequests(Bu // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } - final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - if (indexRequest == null) { + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + if (updateRequest.script() != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + continue; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request continue; } final Map docMap = indexRequest.sourceAsMap(); - boolean hasInput = false; - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); + final Map inferenceMap = XContentMapValues.nodeMapValue( + docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), + InferenceMetadataFieldMapper.NAME + ); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + Object inferenceResult = inferenceMap.remove(field); var value = XContentMapValues.extractValue(field, docMap); if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( - item.id(), - new FieldInferenceResponseAccumulator( + if (inferenceResult != null) { + addInferenceResponseFailure( item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); + new ElasticsearchStatusException( + "The field [{}] is referenced in the [{}] metadata field but has no value", + RestStatus.BAD_REQUEST, + field, + InferenceMetadataFieldMapper.NAME + ) + ); + } + continue; } + ensureResponseAccumulatorSlot(item.id()); if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - hasInput = true; } else { - inferenceResults.get(item.id()).failures.add( + addInferenceResponseFailure( + item.id(), new ElasticsearchStatusException( "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, @@ -328,12 +448,6 @@ private Map> createFieldInferenceRequests(Bu ); } } - if (hasInput == false) { - // remove the existing _inference field (if present) since none of the content require inference. - if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { - indexRequest.source(docMap); - } - } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 9eeb7a5407bc4..702f686605e56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; @@ -52,6 +53,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; + /** * A mapper for the {@code _inference} field. *
@@ -117,7 +120,7 @@ * */ public class InferenceMetadataFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; + public static final String NAME = InferenceFieldMapper.NAME; public static final String CONTENT_TYPE = "_inference"; public static final String INFERENCE_ID = "inference_id"; @@ -183,7 +186,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( XContentLocation xContentLocation ) { final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); - final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + final String inferenceId = semanticFieldContext.mapper.getInferenceId(); if (newInferenceId.equals(inferenceId) == false) { throw new DocumentParsingException( xContentLocation, @@ -212,7 +215,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( return newMapper.getSubMappers(); } else { SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); - SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 2445d5c8751a5..f8fde0b63e4ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -8,21 +8,23 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; -import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; @@ -40,6 +42,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; @@ -51,7 +55,7 @@ * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will * be indexed using {@link InferenceMetadataFieldMapper}. */ -public class SemanticTextFieldMapper extends FieldMapper { +public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -66,6 +70,7 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { ); private final IndexVersion indexVersionCreated; + private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -74,11 +79,13 @@ private SemanticTextFieldMapper( MappedFieldType mappedFieldType, CopyTo copyTo, IndexVersion indexVersionCreated, + String inferenceId, SemanticTextModelSettings modelSettings, NestedObjectMapper subMappers ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); this.indexVersionCreated = indexVersionCreated; + this.inferenceId = inferenceId; this.modelSettings = modelSettings; this.subMappers = subMappers; } @@ -111,6 +118,10 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public String getInferenceId() { + return inferenceId; + } + public SemanticTextModelSettings getModelSettings() { return modelSettings; } @@ -119,6 +130,11 @@ public NestedObjectMapper getSubMappers() { return subMappers; } + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + return new InferenceFieldMetadata(name(), inferenceId, sourcePaths.toArray(String[]::new)); + } + public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; @@ -142,11 +158,15 @@ public static class Builder extends FieldMapper.Builder { XContentBuilder::field, (m) -> m == null ? "null" : Strings.toString(m) ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + private final Parameter> meta = Parameter.metaParam(); + private Function subFieldsFunction; + public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; + this.subFieldsFunction = c -> createSubFields(c); } public Builder setInferenceId(String id) { @@ -164,9 +184,38 @@ protected Parameter[] getParameters() { return new Parameter[] { inferenceId, modelSettings, meta }; } + @Override + protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { + super.merge(mergeWith, conflicts, mapperMergeContext); + conflicts.check(); + SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var childMergeContext = mapperMergeContext.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + NestedObjectMapper mergedSubFields = (NestedObjectMapper) semanticMergeWith.getSubMappers() + .merge( + subFieldsFunction.apply(childMergeContext.getMapperBuilderContext()), + MapperService.MergeReason.MAPPING_UPDATE, + childMergeContext + ); + subFieldsFunction = c -> mergedSubFields; + } + @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + final NestedObjectMapper subFields = subFieldsFunction.apply(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subFields, meta.getValue()), + copyTo, + indexVersionCreated, + inferenceId.getValue(), + modelSettings.getValue(), + subFields + ); + } + + private NestedObjectMapper createSubFields(MapperBuilderContext context) { NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) @@ -176,20 +225,11 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); } nestedBuilder.add(textMapperBuilder); - var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - var subMappers = nestedBuilder.build(childContext); - return new SemanticTextFieldMapper( - name(), - new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), - copyTo, - indexVersionCreated, - modelSettings.getValue(), - subMappers - ); + return nestedBuilder.build(context); } } - public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -212,7 +252,6 @@ public String typeName() { return CONTENT_TYPE; } - @Override public String getInferenceId() { return inferenceId; } @@ -241,11 +280,6 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return super.syntheticFieldLoader(); - } - private static Mapper.Builder createInferenceMapperBuilder( String fieldName, SemanticTextModelSettings modelSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index bf3cc6334433a..4c1cc8fa38bb4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -33,10 +33,7 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); - assertEquals( - indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); } public void testAddSemanticTextField() throws Exception { @@ -53,10 +50,7 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals( - resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 8b18cf74236a0..d734e9998734d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -45,12 +45,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; @@ -75,11 +75,11 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); + assertNull(((BulkShardRequest) request).getInferenceFieldMap()); } finally { chainExecuted.countDown(); } @@ -91,8 +91,8 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata( - new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -101,12 +101,16 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = randomStaticModel(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -120,22 +124,20 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( - Map.of( - "field1", - new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), - "field2", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), - "field3", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) - ) + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + "field2", + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + "field3", + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { - items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFieldMap)[0]; } BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setFieldInferenceMetadata(inferenceFields); + request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -150,30 +152,29 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map inferenceFieldsMap = new HashMap<>(); + Map inferenceFieldMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); } - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFieldMap); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -192,13 +193,13 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(fieldInferenceMetadata); + original.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @SuppressWarnings("unchecked") - private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { String id = (String) invocationOnMock.getArguments()[0]; @@ -256,20 +257,20 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); return filter; } private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - FieldInferenceMetadata fieldInferenceMetadata + Map fieldInferenceMap ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - var model = modelMap.get(entry.getValue().inferenceId()); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + var model = modelMap.get(entry.getInferenceId()); String text = randomAlphaOfLengthBetween(10, 100); docMap.put(field, text); if (model == null) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 528003e278aeb..8847fb7f7efc1 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -171,15 +171,16 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "another non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -197,6 +198,32 @@ setup: index: test-sparse-index id: doc_1 + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "I am a test" } + - match: { _source.another_inference_field: "I am a teapot" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "I am a test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "I am a teapot" } + - do: update: index: test-sparse-index @@ -211,12 +238,31 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "updated inference test" } - - match: { _source.another_inference_field: "another updated inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "updated inference test" } + - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "bulk inference test", "another_inference_field": "bulk updated inference test"}}' - - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "bulk inference test" } + - match: { _source.another_inference_field: "bulk updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "bulk inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "bulk updated inference test" } --- "Reindex works for semantic_text fields": @@ -268,18 +314,19 @@ setup: index: destination-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- -"Fails for non-existent model": +"Fails for non-existent inference": - do: indices.create: index: incorrect-test-sparse-index @@ -310,3 +357,23 @@ setup: id: doc_1 body: non_inference_field: "non inference test" + +--- +"Updates with script are not allowed": + - do: + bulk: + index: test-sparse-index + body: + - '{"index": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"script": "ctx._source.new_field = \"hello\"", "scripted_upsert": true}' + + - match: { errors: true } + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 27f233436b925..9dc109b3fb81d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -66,3 +66,23 @@ setup: id: doc_1 body: dense_field: "you know, for testing" + +--- +"Inference section contains unreferenced fields": + - do: + catch: /Field \[unknown_field\] is not registered as a \[semantic_text\] field type/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _inference: + unknown_field: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + chunks: + - text: "inference test" + inference: [ 0.1, 0.2, 0.3, 0.4, 0.5 ] + - text: "another inference test" + inference: [ -0.1, -0.2, -0.3, -0.4, -0.5 ]