diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index fbb3016b925da..e0dbc74567053 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -61,7 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; -import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); + builder.putInferenceFields(randomInferenceFields()); break; default: throw new IllegalArgumentException("Shouldn't be here"); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index a6439769b51b4..e66426562a92e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -294,8 +294,8 @@ private void executeBulkRequestsByShard( requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); + if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 39fa791a3e27d..8d1618b443ace 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,7 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -23,6 +23,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +35,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest inferenceFieldMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +52,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { - this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; + public void setInferenceFieldMap(Map fieldInferenceMap) { + this.inferenceFieldMap = fieldInferenceMap; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public FieldInferenceMetadata consumeFieldInferenceMetadata() { - FieldInferenceMetadata ret = fieldsInferenceMetadataMap; - fieldsInferenceMetadataMap = null; + public Map consumeInferenceFieldMap() { + Map ret = inferenceFieldMap; + inferenceFieldMap = null; return ret; } /** * Public for test */ - public FieldInferenceMetadata getFieldsInferenceMetadataMap() { - return fieldsInferenceMetadataMap; + public Map getInferenceFieldMap() { + return inferenceFieldMap; } public long totalSizeInBytes() { diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 63ae56bfbd047..36a47bc7e02e9 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -40,6 +40,7 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -184,7 +185,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis); switch (result.getResponseResult()) { case CREATED -> { - IndexRequest upsertRequest = result.action(); + IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference upsertSourceBytes = upsertRequest.source(); client.bulk( @@ -226,7 +227,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< ); } case UPDATED -> { - IndexRequest indexRequest = result.action(); + IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference indexSourceBytes = indexRequest.source(); client.bulk( @@ -335,4 +336,15 @@ private void handleUpdateFailureWithRetry( } listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause)); } + + private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) { + var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields(); + if (inferenceMetadata.isEmpty()) { + return request; + } + Map docMap = request.sourceAsMap(); + docMap.remove(InferenceFieldMapper.NAME); + request.source(docMap); + return request; + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java deleted file mode 100644 index 349706c139127..0000000000000 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.cluster.metadata; - -import org.elasticsearch.cluster.Diff; -import org.elasticsearch.cluster.Diffable; -import org.elasticsearch.cluster.DiffableUtils; -import org.elasticsearch.cluster.SimpleDiffable; -import org.elasticsearch.common.collect.ImmutableOpenMap; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.mapper.MappingLookup; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - -/** - * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator - * node, which not necessarily has mapping information. - */ -public class FieldInferenceMetadata implements Diffable, ToXContentFragment { - - private final ImmutableOpenMap fieldInferenceOptions; - - public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); - - public FieldInferenceMetadata(MappingLookup mappingLookup) { - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); - mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { - builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); - }); - fieldInferenceOptions = builder.build(); - } - - public FieldInferenceMetadata(StreamInput in) throws IOException { - fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); - } - - public FieldInferenceMetadata(Map fieldsToInferenceMap) { - fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); - } - - public ImmutableOpenMap getFieldInferenceOptions() { - return fieldInferenceOptions; - } - - public boolean isEmpty() { - return fieldInferenceOptions.isEmpty(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(fieldInferenceOptions); - return builder; - } - - public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); - } - - @Override - public int hashCode() { - return Objects.hash(fieldInferenceOptions); - } - - @Override - public Diff diff(FieldInferenceMetadata previousState) { - if (previousState == null) { - previousState = EMPTY; - } - return new FieldInferenceMetadataDiff(previousState, this); - } - - static class FieldInferenceMetadataDiff implements Diff { - - public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( - FieldInferenceMetadata.EMPTY, - FieldInferenceMetadata.EMPTY - ); - - private final Diff> fieldInferenceMapDiff; - - private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = - new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); - - FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { - fieldInferenceMapDiff = DiffableUtils.diff( - before.fieldInferenceOptions, - after.fieldInferenceOptions, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - FieldInferenceMetadataDiff(StreamInput in) throws IOException { - fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - public static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - fieldInferenceMapDiff.writeTo(out); - } - - @Override - public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); - } - } - - public record FieldInferenceOptions(String inferenceId, Set sourceFields) - implements - SimpleDiffable, - ToXContentFragment { - - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); - - FieldInferenceOptions(StreamInput in) throws IOException { - this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - out.writeStringCollection(sourceFields); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); - builder.endObject(); - return builder; - } - - public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "field_inference_parser", - false, - (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) - ); - - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 89c925427cf88..b66da654f8a1c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -576,6 +576,8 @@ public Iterator> settings() { @Nullable private final MappingMetadata mapping; + private final ImmutableOpenMap inferenceFields; + private final ImmutableOpenMap customData; private final Map> inSyncAllocationIds; @@ -631,7 +633,6 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -645,6 +646,7 @@ private IndexMetadata( final int numberOfReplicas, final Settings settings, final MappingMetadata mapping, + final ImmutableOpenMap inferenceFields, final ImmutableOpenMap aliases, final ImmutableOpenMap customData, final Map> inSyncAllocationIds, @@ -677,8 +679,7 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast, - @Nullable FieldInferenceMetadata fieldInferenceMetadata + @Nullable Long shardSizeInBytesForecast ) { this.index = index; this.version = version; @@ -696,6 +697,7 @@ private IndexMetadata( this.totalNumberOfShards = numberOfShards * (numberOfReplicas + 1); this.settings = settings; this.mapping = mapping; + this.inferenceFields = inferenceFields; this.customData = customData; this.aliases = aliases; this.inSyncAllocationIds = inSyncAllocationIds; @@ -734,7 +736,6 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -753,6 +754,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.numberOfReplicas, this.settings, mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -785,8 +787,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -812,6 +813,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, Maps.copyMapWithAddedOrReplacedEntry(this.inSyncAllocationIds, shardId, Set.copyOf(inSyncSet)), @@ -844,8 +846,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -869,6 +870,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -901,8 +903,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -926,6 +927,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -958,8 +960,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -979,6 +980,7 @@ public IndexMetadata withIncrementedVersion() { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -1011,8 +1013,7 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -1203,6 +1204,10 @@ public MappingMetadata mapping() { return mapping; } + public Map getInferenceFields() { + return inferenceFields; + } + @Nullable public IndexMetadataStats getStats() { return stats; @@ -1216,10 +1221,6 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public FieldInferenceMetadata getFieldInferenceMetadata() { - return fieldInferenceMetadata; - } - public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1417,7 +1418,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { + if (inferenceFields.equals(that.inferenceFields) == false) { return false; } if (isSystem != that.isSystem) { @@ -1440,7 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldInferenceMetadata.hashCode(); + result = 31 * result + inferenceFields.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1487,6 +1488,7 @@ private static class IndexMetadataDiff implements Diff { @Nullable private final Diff settingsDiff; private final Diff> mappings; + private final Diff> inferenceFields; private final Diff> aliases; private final Diff> customData; private final Diff>> inSyncAllocationIds; @@ -1496,7 +1498,6 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1519,6 +1520,7 @@ private static class IndexMetadataDiff implements Diff { : ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, after.mapping).build(), DiffableUtils.getStringKeySerializer() ); + inferenceFields = DiffableUtils.diff(before.inferenceFields, after.inferenceFields, DiffableUtils.getStringKeySerializer()); aliases = DiffableUtils.diff(before.aliases, after.aliases, DiffableUtils.getStringKeySerializer()); customData = DiffableUtils.diff(before.customData, after.customData, DiffableUtils.getStringKeySerializer()); inSyncAllocationIds = DiffableUtils.diff( @@ -1533,7 +1535,6 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1544,6 +1545,8 @@ private static class IndexMetadataDiff implements Diff { new DiffableUtils.DiffableValueReader<>(DiffableStringMap::readFrom, DiffableStringMap::readDiffFrom); private static final DiffableUtils.DiffableValueReader ROLLOVER_INFO_DIFF_VALUE_READER = new DiffableUtils.DiffableValueReader<>(RolloverInfo::new, RolloverInfo::readDiffFrom); + private static final DiffableUtils.DiffableValueReader INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(InferenceFieldMetadata::new, InferenceFieldMetadata::readDiffFrom); IndexMetadataDiff(StreamInput in) throws IOException { index = in.readString(); @@ -1566,6 +1569,15 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER + ); + } else { + inferenceFields = DiffableUtils.emptyDiff(); + } aliases = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), ALIAS_METADATA_DIFF_VALUE_READER); customData = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), CUSTOM_DIFF_VALUE_READER); inSyncAllocationIds = DiffableUtils.readJdkMapDiff( @@ -1593,11 +1605,6 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); - } else { - fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; - } } @Override @@ -1620,6 +1627,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields.writeTo(out); + } aliases.writeTo(out); customData.writeTo(out); inSyncAllocationIds.writeTo(out); @@ -1633,9 +1643,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeOptionalWriteable(fieldInferenceMetadata); - } } @Override @@ -1656,6 +1663,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.mapping = mappings.apply( ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, part.mapping).build() ).get(MapperService.SINGLE_MAPPING_NAME); + builder.inferenceFields.putAllFromMap(inferenceFields.apply(part.inferenceFields)); builder.aliases.putAllFromMap(aliases.apply(part.aliases)); builder.customMetadata.putAllFromMap(customData.apply(part.customData)); builder.inSyncAllocationIds.putAll(inSyncAllocationIds.apply(part.inSyncAllocationIds)); @@ -1665,7 +1673,6 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1702,6 +1709,10 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); + } int aliasesSize = in.readVInt(); for (int i = 0; i < aliasesSize; i++) { AliasMetadata aliasMd = new AliasMetadata(in); @@ -1733,9 +1744,6 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function inferenceFields; private final ImmutableOpenMap.Builder aliases; private final ImmutableOpenMap.Builder customMetadata; private final Map> inSyncAllocationIds; @@ -1834,10 +1843,10 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; public Builder(String index) { this.index = index; + this.inferenceFields = ImmutableOpenMap.builder(); this.aliases = ImmutableOpenMap.builder(); this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); @@ -1855,6 +1864,7 @@ public Builder(IndexMetadata indexMetadata) { this.settings = indexMetadata.getSettings(); this.primaryTerms = indexMetadata.primaryTerms.clone(); this.mapping = indexMetadata.mapping; + this.inferenceFields = ImmutableOpenMap.builder(indexMetadata.inferenceFields); this.aliases = ImmutableOpenMap.builder(indexMetadata.aliases); this.customMetadata = ImmutableOpenMap.builder(indexMetadata.customData); this.routingNumShards = indexMetadata.routingNumShards; @@ -1866,7 +1876,6 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2096,8 +2105,13 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); + public Builder putInferenceField(InferenceFieldMetadata value) { + this.inferenceFields.put(value.getName(), value); + return this; + } + + public Builder putInferenceFields(Map values) { + this.inferenceFields.putAllFromMap(values); return this; } @@ -2263,6 +2277,7 @@ IndexMetadata build(boolean repair) { numberOfReplicas, settings, mapping, + inferenceFields.build(), aliasesMap, newCustomMetadata, Map.ofEntries(denseInSyncAllocationIds), @@ -2295,8 +2310,7 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast, - fieldInferenceMetadata + shardSizeInBytesForecast ); } @@ -2422,8 +2436,12 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { - builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); + if (indexMetadata.getInferenceFields().isEmpty() == false) { + builder.startObject(KEY_FIELD_INFERENCE); + for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { + field.toXContent(builder, params); + } + builder.endObject(); } builder.endObject(); @@ -2504,7 +2522,9 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { + private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private final String name; + private final String inferenceId; + private final String[] sourceFields; + + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this.name = Objects.requireNonNull(name); + this.inferenceId = Objects.requireNonNull(inferenceId); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public InferenceFieldMetadata(StreamInput input) throws IOException { + this.name = input.readString(); + this.inferenceId = input.readString(); + this.sourceFields = input.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(inferenceId); + out.writeStringArray(sourceFields); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceFieldMetadata that = (InferenceFieldMetadata) o; + return inferenceId.equals(that.inferenceId) && Arrays.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + int result = Objects.hash(inferenceId); + result = 31 * result + Arrays.hashCode(sourceFields); + return result; + } + + public String getName() { + return name; + } + + public String getInferenceId() { + return inferenceId; + } + + public String[] getSourceFields() { + return sourceFields; + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(name); + builder.field(INFERENCE_ID_FIELD, inferenceId); + builder.array(SOURCE_FIELDS_FIELD, sourceFields); + return builder.endObject(); + } + + public static InferenceFieldMetadata fromXContent(XContentParser parser) throws IOException { + final String name = parser.currentName(); + + XContentParser.Token token = parser.nextToken(); + if (token == null) { + // no data... + return null; + } + String currentFieldName = null; + String inferenceId = null; + List inputFields = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_STRING) { + if (INFERENCE_ID_FIELD.equals(currentFieldName)) { + inferenceId = parser.text(); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_STRING) { + inputFields.add(parser.text()); + } else { + parser.skipChildren(); + } + } + } + } else { + parser.skipChildren(); + } + } + return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index 96ca7a15edc30..52642e1de8ac9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1263,12 +1263,11 @@ static IndexMetadata buildIndexMetadata( indexMetadataBuilder.system(isSystem); // now, update the mappings with the actual source Map mappingsMetadata = new HashMap<>(); - DocumentMapper mapper = documentMapperSupplier.get(); - if (mapper != null) { - MappingMetadata mappingMd = new MappingMetadata(mapper); - mappingsMetadata.put(mapper.type(), mappingMd); - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); - indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); + DocumentMapper docMapper = documentMapperSupplier.get(); + if (docMapper != null) { + MappingMetadata mappingMd = new MappingMetadata(docMapper); + mappingsMetadata.put(docMapper.type(), mappingMd); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 0e31592991369..e7c2bb9ae9b9a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -201,10 +201,10 @@ private static ClusterState applyRequest( IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(indexMetadata); // Mapping updates on a single type may have side-effects on other types so we need to // update mapping metadata on all types - DocumentMapper mapper = mapperService.documentMapper(); - if (mapper != null) { - indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); + DocumentMapper docMapper = mapperService.documentMapper(); + if (docMapper != null) { + indexMetadataBuilder.putMapping(new MappingMetadata(docMapper)); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 0741cfa682b74..5e3dbe9590b99 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,11 +36,6 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; - /** - * A map from inference model ID to all fields that use the model to generate embeddings. - */ - private final Map inferenceIdsForFields; - private final int maxParentPathDots; FieldTypeLookup( @@ -53,7 +48,6 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -71,9 +65,6 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } - if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); - } } int maxParentPathDots = 0; @@ -106,7 +97,6 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -215,10 +205,6 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map getInferenceIdsForFields() { - return inferenceIdsForFields; - } - /** * If field is a leaf multi-field return the path to the parent field. Otherwise, return null. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java new file mode 100644 index 0000000000000..078ef391f17ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.inference.InferenceService; + +import java.util.Set; + +/** + * Field mapper that requires to transform its input before indexation through the {@link InferenceService}. + */ +public interface InferenceFieldMapper { + String NAME = "_inference"; + + /** + * Retrieve the inference metadata associated with this mapper. + * + * @param sourcePaths The source path that populates the input for the field (before inference) + */ + InferenceFieldMetadata getMetadata(Set sourcePaths); +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java deleted file mode 100644 index 6e12a204ed7d0..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -/** - * Field type that uses an inference model. - */ -public interface InferenceModelFieldType { - /** - * Retrieve inference model used by the field type. - * - * @return model id used by the field type - */ - String getInferenceId(); -} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java index 8f8854ad47c7d..ddf6f339cbbb6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java @@ -46,7 +46,7 @@ public static MapperMergeContext from(MapperBuilderContext mapperBuilderContext, * @param name the name of the child context * @return a new {@link MapperMergeContext} with this context as its parent */ - MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { + public MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { return createChildContext(mapperBuilderContext.createChildContext(name, dynamic)); } @@ -60,7 +60,7 @@ MapperMergeContext createChildContext(MapperBuilderContext childContext) { return new MapperMergeContext(childContext, newFieldsBudget); } - MapperBuilderContext getMapperBuilderContext() { + public MapperBuilderContext getMapperBuilderContext() { return mapperBuilderContext; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index c2bd95115f27e..bf879f30e5a29 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -10,9 +10,11 @@ import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.inference.InferenceService; import java.util.ArrayList; import java.util.Collection; @@ -47,6 +49,7 @@ private CacheKey() {} /** Full field name to mapper */ private final Map fieldMappers; private final Map objectMappers; + private final Map inferenceFields; private final int runtimeFieldMappersCount; private final NestedLookup nestedLookup; private final FieldTypeLookup fieldTypeLookup; @@ -84,12 +87,12 @@ private static void collect( Collection fieldMappers, Collection fieldAliasMappers ) { - if (mapper instanceof ObjectMapper) { - objectMappers.add((ObjectMapper) mapper); - } else if (mapper instanceof FieldMapper) { - fieldMappers.add((FieldMapper) mapper); - } else if (mapper instanceof FieldAliasMapper) { - fieldAliasMappers.add((FieldAliasMapper) mapper); + if (mapper instanceof ObjectMapper objectMapper) { + objectMappers.add(objectMapper); + } else if (mapper instanceof FieldMapper fieldMapper) { + fieldMappers.add(fieldMapper); + } else if (mapper instanceof FieldAliasMapper fieldAliasMapper) { + fieldAliasMappers.add(fieldAliasMapper); } else { throw new IllegalStateException("Unrecognized mapper type [" + mapper.getClass().getSimpleName() + "]."); } @@ -174,6 +177,15 @@ private MappingLookup( final Collection runtimeFields = mapping.getRoot().runtimeFields(); this.fieldTypeLookup = new FieldTypeLookup(mappers, aliasMappers, runtimeFields); + + Map inferenceFields = new HashMap<>(); + for (FieldMapper mapper : mappers) { + if (mapper instanceof InferenceFieldMapper inferenceFieldMapper) { + inferenceFields.put(mapper.name(), inferenceFieldMapper.getMetadata(fieldTypeLookup.sourcePaths(mapper.name()))); + } + } + this.inferenceFields = Map.copyOf(inferenceFields); + if (runtimeFields.isEmpty()) { // without runtime fields this is the same as the field type lookup this.indexTimeLookup = fieldTypeLookup; @@ -360,6 +372,13 @@ public Map objectMappers() { return objectMappers; } + /** + * Returns a map containing all fields that require to run inference (through the {@link InferenceService} prior to indexation. + */ + public Map inferenceFields() { + return inferenceFields; + } + public NestedLookup nestedLookup() { return nestedLookup; } @@ -523,8 +542,4 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } - - public Map getInferenceIdsForFields() { - return fieldTypeLookup.getInferenceIdsForFields(); - } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b32873df71365..45ffba25eb558 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -27,7 +27,6 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -84,7 +83,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); + Map dynamicFields = randomInferenceFields(); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -110,7 +109,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldInferenceMetadata(fieldInferenceMetadata) + .putInferenceFields(dynamicFields) .build(); assertEquals(system, metadata.isSystem()); @@ -145,7 +144,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), fromXContentMeta.getInferenceFields()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -169,7 +168,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); } } @@ -553,35 +552,32 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldInferenceMetadata() { + public void testInferenceFieldMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); + assertTrue(idxMeta1.getInferenceFields().isEmpty()); - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); - assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); + Map dynamicFields = randomInferenceFields(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).putInferenceFields(dynamicFields).build(); + assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { - if (randomBoolean() && allowNull) { - return null; + public static Map randomInferenceFields() { + Map map = new HashMap<>(); + int numFields = randomIntBetween(0, 5); + for (int i = 0; i < numFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + map.put(field, randomInferenceFieldMetadata(field)); } - - Map fieldInferenceMap = randomMap( - 0, - 10, - () -> new Tuple<>(randomIdentifier(), randomFieldInference()) - ); - return new FieldInferenceMetadata(fieldInferenceMap); + return map; } - private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { - return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); + private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { + return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java new file mode 100644 index 0000000000000..958d86535ae76 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Predicate; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceFieldMetadataTests extends AbstractXContentTestCase { + + public void testSerialization() throws IOException { + final InferenceFieldMetadata before = createTestItem(); + final BytesStreamOutput out = new BytesStreamOutput(); + before.writeTo(out); + + final StreamInput in = out.bytes().streamInput(); + final InferenceFieldMetadata after = new InferenceFieldMetadata(in); + + assertThat(after, equalTo(before)); + } + + @Override + protected InferenceFieldMetadata createTestInstance() { + return createTestItem(); + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field + } + + @Override + protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + assertEquals(XContentParser.Token.FIELD_NAME, parser.currentToken()); + InferenceFieldMetadata inferenceMetadata = InferenceFieldMetadata.fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return inferenceMetadata; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + private static InferenceFieldMetadata createTestItem() { + String name = randomAlphaOfLengthBetween(3, 10); + String inferenceId = randomIdentifier(); + String[] inputFields = generateRandomStringArray(5, 10, false, false); + return new InferenceFieldMetadata(name, inferenceId, inputFields); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 932eac3e60d27..3f50b9fdf6621 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -16,7 +16,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -36,10 +35,6 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -47,10 +42,6 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -430,25 +421,6 @@ public void testRuntimeFieldNameOutsideContext() { } } - public void testInferenceModelFieldType() { - MockFieldMapper f1 = new MockFieldMapper(new MockInferenceModelFieldType("foo1", "bar1")); - MockFieldMapper f2 = new MockFieldMapper(new MockInferenceModelFieldType("foo2", "bar1")); - MockFieldMapper f3 = new MockFieldMapper(new MockInferenceModelFieldType("foo3", "bar2")); - - FieldTypeLookup lookup = new FieldTypeLookup(List.of(f1, f2, f3), emptyList(), emptyList()); - assertEquals(f1.fieldType(), lookup.get("foo1")); - assertEquals(f2.fieldType(), lookup.get("foo2")); - assertEquals(f3.fieldType(), lookup.get("foo3")); - - Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(3, inferenceIdsForFields.size()); - - assertEquals("bar1", inferenceIdsForFields.get("foo1")); - assertEquals("bar1", inferenceIdsForFields.get("foo2")); - assertEquals("bar2", inferenceIdsForFields.get("foo3")); - } - private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index bb337d0c61c93..0308dac5fa216 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -121,8 +121,6 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getInferenceIdsForFields()); - assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -190,22 +188,6 @@ public MetricType getMetricType() { ); } - public void testInferenceIdsForFields() { - MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); - MappingLookup mappingLookup = createMappingLookup( - Collections.singletonList(new MockFieldMapper(fieldType)), - emptyList(), - emptyList() - ); - assertEquals(1, size(mappingLookup.fieldMappers())); - assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - - Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(1, inferenceIdsForFields.size()); - assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); - } - private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java deleted file mode 100644 index 0d21134b5d9a9..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.util.Map; - -public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private static final String TYPE_NAME = "mock_inference_model_field_type"; - - private final String modelId; - - public MockInferenceModelFieldType(String name, String modelId) { - super(name, false, false, false, TextSearchInfo.NONE, Map.of()); - this.modelId = modelId; - } - - @Override - public String typeName() { - return TYPE_NAME; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public String getInferenceId() { - return modelId; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3fcd9049ae803..494d6918b6086 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -284,11 +284,17 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + if (SemanticTextFeature.isEnabled()) { + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + } + return Map.of(); } @Override public Collection getActionFilters() { - return singletonList(shardBulkInferenceActionFilter.get()); + if (SemanticTextFeature.isEnabled()) { + return singletonList(shardBulkInferenceActionFilter.get()); + } + return List.of(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 00dc195313a61..fef62051a6471 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -20,12 +20,11 @@ import org.elasticsearch.action.bulk.BulkShardRequest; import org.elasticsearch.action.bulk.TransportShardBulkAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -39,6 +38,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -49,19 +49,66 @@ import java.util.stream.Collectors; /** - * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} - * in the subsequent {@link TransportShardBulkAction} downstream. + * A {@link MappedActionFilter} intercepting {@link BulkShardRequest}s to apply inference on fields declared as + * {@link SemanticTextFieldMapper} in the index mapping. + * The source of each {@link BulkItemRequest} requiring inference is augmented with the results for each field + * under the {@link InferenceMetadataFieldMapper#NAME} section. + * For example, for an index with a semantic_text field named {@code my_semantic_field} the following source document: + *
+ *
+ * {
+ *      "my_semantic_text_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * is rewritten into: + *
+ *
+ * {
+ *      "_inference": {
+ *        "my_semantic_field": {
+ *          "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks": [
+ *                      {
+ *                             "inference": {
+ *                                 "lucas": 0.05212344,
+ *                                 "ty": 0.041213956,
+ *                                 "dragon": 0.50991,
+ *                                 "type": 0.23241979,
+ *                                 "dr": 1.9312073,
+ *                                 "##o": 0.2797593
+ *                             },
+ *                             "text": "these are not the droids you're looking for"
+ *                       }
+ *                  ]
+ *        }
+ *      }
+ *      "my_semantic_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * The rewriting process occurs on the bulk coordinator node, and the results are then passed downstream + * to the {@link TransportShardBulkAction} for actual indexing. + * + * TODO: batchSize should be configurable via a cluster setting */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + protected static final int DEFAULT_BATCH_SIZE = 512; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; + private final int batchSize; public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + } + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; + this.batchSize = batchSize; } @Override @@ -86,7 +133,7 @@ public void app switch (action) { case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); @@ -102,33 +149,33 @@ public void app } private void processBulkShardRequest( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); } private record InferenceProvider(InferenceService service, Model model) {} private record FieldInferenceRequest(int id, String field, String input) {} - private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + private record FieldInferenceResponse(String field, @Nullable Model model, @Nullable ChunkedInferenceServiceResults chunkedResults) {} private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final FieldInferenceMetadata fieldInferenceMetadata; + private final Map fieldInferenceMap; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - this.fieldInferenceMetadata = fieldInferenceMetadata; + this.fieldInferenceMap = fieldInferenceMap; this.bulkShardRequest = bulkShardRequest; this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); this.onCompletion = onCompletion; @@ -212,30 +259,49 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + int currentBatchSize = Math.min(requests.size(), batchSize); + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { - for (int i = 0; i < results.size(); i++) { - var request = requests.get(i); - var result = results.get(i); - var acc = inferenceResults.get(request.id); - acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + try { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + } + } finally { + onFinish(); } } @Override public void onFailure(Exception exc) { - for (int i = 0; i < requests.size(); i++) { - var request = requests.get(i); - inferenceResults.get(request.id).failures.add( - new ElasticsearchException( - "Exception when running inference id [{}] on field [{}]", - exc, - inferenceProvider.model.getInferenceEntityId(), - request.field - ) - ); + try { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } finally { + onFinish(); + } + } + + private void onFinish() { + if (nextBatch.isEmpty()) { + onFinish.close(); + } else { + executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); } } }; @@ -246,14 +312,33 @@ public void onFailure(Exception exc) { Map.of(), InputType.INGEST, new ChunkingOptions(null, null), - ActionListener.runAfter(completionListener, onFinish::close) + completionListener ); } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + FieldInferenceResponseAccumulator acc = inferenceResults.get(id); + if (acc == null) { + acc = new FieldInferenceResponseAccumulator( + id, + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ); + inferenceResults.set(id, acc); + } + return acc; + } + + private void addInferenceResponseFailure(int id, Exception failure) { + var acc = ensureResponseAccumulatorSlot(id); + acc.failures().add(failure); + } + /** - * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. - * If the response contains failures, the bulk item request is mark as failed for the downstream action. - * Otherwise, the source of the request is augmented with the field inference results. + * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is marked as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results under the + * {@link InferenceMetadataFieldMapper#NAME} field. */ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { if (response.failures().isEmpty() == false) { @@ -265,24 +350,37 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); Map newDocMap = indexRequest.sourceAsMap(); - Map inferenceMap = new LinkedHashMap<>(); - // ignore the existing inference map if any + Object inferenceObj = newDocMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()); + Map inferenceMap = XContentMapValues.nodeMapValue(inferenceObj, InferenceMetadataFieldMapper.NAME); newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { - try { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceMap, - fieldResponse.field(), - fieldResponse.model(), - fieldResponse.chunkedResults() - ); - } catch (Exception exc) { - item.abort(item.index(), exc); + if (fieldResponse.chunkedResults != null) { + try { + InferenceMetadataFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); + } + } else { + inferenceMap.remove(fieldResponse.field); } } indexRequest.source(newDocMap); } + /** + * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. + * If results are already populated for fields in the existing _inference object, + * the inference request for this specific field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link InferenceMetadataFieldMapper} + * during field indexing, where an error will be thrown if they mismatch or if the content is malformed. + * + * TODO: Should we validate the settings for pre-existing results here and apply the inference only if they differ? + */ private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { @@ -290,35 +388,57 @@ private Map> createFieldInferenceRequests(Bu // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } - final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - if (indexRequest == null) { + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + if (updateRequest.script() != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + continue; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request continue; } final Map docMap = indexRequest.sourceAsMap(); - boolean hasInput = false; - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); + final Map inferenceMap = XContentMapValues.nodeMapValue( + docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), + InferenceMetadataFieldMapper.NAME + ); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + Object inferenceResult = inferenceMap.remove(field); var value = XContentMapValues.extractValue(field, docMap); if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( - item.id(), - new FieldInferenceResponseAccumulator( + if (inferenceResult != null) { + addInferenceResponseFailure( item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); + new ElasticsearchStatusException( + "The field [{}] is referenced in the [{}] metadata field but has no value", + RestStatus.BAD_REQUEST, + field, + InferenceMetadataFieldMapper.NAME + ) + ); + } + continue; } + ensureResponseAccumulatorSlot(item.id()); if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - hasInput = true; } else { - inferenceResults.get(item.id()).failures.add( + addInferenceResponseFailure( + item.id(), new ElasticsearchStatusException( "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, @@ -328,12 +448,6 @@ private Map> createFieldInferenceRequests(Bu ); } } - if (hasInput == false) { - // remove the existing _inference field (if present) since none of the content require inference. - if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { - indexRequest.source(docMap); - } - } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 9eeb7a5407bc4..702f686605e56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; @@ -52,6 +53,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; + /** * A mapper for the {@code _inference} field. *
@@ -117,7 +120,7 @@ * */ public class InferenceMetadataFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; + public static final String NAME = InferenceFieldMapper.NAME; public static final String CONTENT_TYPE = "_inference"; public static final String INFERENCE_ID = "inference_id"; @@ -183,7 +186,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( XContentLocation xContentLocation ) { final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); - final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + final String inferenceId = semanticFieldContext.mapper.getInferenceId(); if (newInferenceId.equals(inferenceId) == false) { throw new DocumentParsingException( xContentLocation, @@ -212,7 +215,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( return newMapper.getSubMappers(); } else { SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); - SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 2445d5c8751a5..f8fde0b63e4ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -8,21 +8,23 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; -import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; @@ -40,6 +42,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; @@ -51,7 +55,7 @@ * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will * be indexed using {@link InferenceMetadataFieldMapper}. */ -public class SemanticTextFieldMapper extends FieldMapper { +public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -66,6 +70,7 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { ); private final IndexVersion indexVersionCreated; + private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -74,11 +79,13 @@ private SemanticTextFieldMapper( MappedFieldType mappedFieldType, CopyTo copyTo, IndexVersion indexVersionCreated, + String inferenceId, SemanticTextModelSettings modelSettings, NestedObjectMapper subMappers ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); this.indexVersionCreated = indexVersionCreated; + this.inferenceId = inferenceId; this.modelSettings = modelSettings; this.subMappers = subMappers; } @@ -111,6 +118,10 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public String getInferenceId() { + return inferenceId; + } + public SemanticTextModelSettings getModelSettings() { return modelSettings; } @@ -119,6 +130,11 @@ public NestedObjectMapper getSubMappers() { return subMappers; } + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + return new InferenceFieldMetadata(name(), inferenceId, sourcePaths.toArray(String[]::new)); + } + public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; @@ -142,11 +158,15 @@ public static class Builder extends FieldMapper.Builder { XContentBuilder::field, (m) -> m == null ? "null" : Strings.toString(m) ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + private final Parameter> meta = Parameter.metaParam(); + private Function subFieldsFunction; + public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; + this.subFieldsFunction = c -> createSubFields(c); } public Builder setInferenceId(String id) { @@ -164,9 +184,38 @@ protected Parameter[] getParameters() { return new Parameter[] { inferenceId, modelSettings, meta }; } + @Override + protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { + super.merge(mergeWith, conflicts, mapperMergeContext); + conflicts.check(); + SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var childMergeContext = mapperMergeContext.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + NestedObjectMapper mergedSubFields = (NestedObjectMapper) semanticMergeWith.getSubMappers() + .merge( + subFieldsFunction.apply(childMergeContext.getMapperBuilderContext()), + MapperService.MergeReason.MAPPING_UPDATE, + childMergeContext + ); + subFieldsFunction = c -> mergedSubFields; + } + @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + final NestedObjectMapper subFields = subFieldsFunction.apply(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subFields, meta.getValue()), + copyTo, + indexVersionCreated, + inferenceId.getValue(), + modelSettings.getValue(), + subFields + ); + } + + private NestedObjectMapper createSubFields(MapperBuilderContext context) { NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) @@ -176,20 +225,11 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); } nestedBuilder.add(textMapperBuilder); - var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - var subMappers = nestedBuilder.build(childContext); - return new SemanticTextFieldMapper( - name(), - new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), - copyTo, - indexVersionCreated, - modelSettings.getValue(), - subMappers - ); + return nestedBuilder.build(context); } } - public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -212,7 +252,6 @@ public String typeName() { return CONTENT_TYPE; } - @Override public String getInferenceId() { return inferenceId; } @@ -241,11 +280,6 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return super.syntheticFieldLoader(); - } - private static Mapper.Builder createInferenceMapperBuilder( String fieldName, SemanticTextModelSettings modelSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index bf3cc6334433a..4c1cc8fa38bb4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -33,10 +33,7 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); - assertEquals( - indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); } public void testAddSemanticTextField() throws Exception { @@ -53,10 +50,7 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals( - resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 8b18cf74236a0..d734e9998734d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -45,12 +45,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; @@ -75,11 +75,11 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); + assertNull(((BulkShardRequest) request).getInferenceFieldMap()); } finally { chainExecuted.countDown(); } @@ -91,8 +91,8 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata( - new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -101,12 +101,16 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = randomStaticModel(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -120,22 +124,20 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( - Map.of( - "field1", - new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), - "field2", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), - "field3", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) - ) + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + "field2", + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + "field3", + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { - items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFieldMap)[0]; } BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setFieldInferenceMetadata(inferenceFields); + request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -150,30 +152,29 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map inferenceFieldsMap = new HashMap<>(); + Map inferenceFieldMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); } - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFieldMap); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -192,13 +193,13 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(fieldInferenceMetadata); + original.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @SuppressWarnings("unchecked") - private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { String id = (String) invocationOnMock.getArguments()[0]; @@ -256,20 +257,20 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); return filter; } private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - FieldInferenceMetadata fieldInferenceMetadata + Map fieldInferenceMap ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - var model = modelMap.get(entry.getValue().inferenceId()); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + var model = modelMap.get(entry.getInferenceId()); String text = randomAlphaOfLengthBetween(10, 100); docMap.put(field, text); if (model == null) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 528003e278aeb..8847fb7f7efc1 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -171,15 +171,16 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "another non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -197,6 +198,32 @@ setup: index: test-sparse-index id: doc_1 + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "I am a test" } + - match: { _source.another_inference_field: "I am a teapot" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "I am a test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "I am a teapot" } + - do: update: index: test-sparse-index @@ -211,12 +238,31 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "updated inference test" } - - match: { _source.another_inference_field: "another updated inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "updated inference test" } + - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "bulk inference test", "another_inference_field": "bulk updated inference test"}}' - - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "bulk inference test" } + - match: { _source.another_inference_field: "bulk updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "bulk inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "bulk updated inference test" } --- "Reindex works for semantic_text fields": @@ -268,18 +314,19 @@ setup: index: destination-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- -"Fails for non-existent model": +"Fails for non-existent inference": - do: indices.create: index: incorrect-test-sparse-index @@ -310,3 +357,23 @@ setup: id: doc_1 body: non_inference_field: "non inference test" + +--- +"Updates with script are not allowed": + - do: + bulk: + index: test-sparse-index + body: + - '{"index": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"script": "ctx._source.new_field = \"hello\"", "scripted_upsert": true}' + + - match: { errors: true } + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 27f233436b925..9dc109b3fb81d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -66,3 +66,23 @@ setup: id: doc_1 body: dense_field: "you know, for testing" + +--- +"Inference section contains unreferenced fields": + - do: + catch: /Field \[unknown_field\] is not registered as a \[semantic_text\] field type/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _inference: + unknown_field: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + chunks: + - text: "inference test" + inference: [ 0.1, 0.2, 0.3, 0.4, 0.5 ] + - text: "another inference test" + inference: [ -0.1, -0.2, -0.3, -0.4, -0.5 ]