diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 3a1f6e20bb288..fbb3016b925da 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -54,7 +54,6 @@ import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -62,6 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,33 +587,13 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldsForModels(randomFieldsForModels()); + builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } - - /** - * Generates a random fieldsForModels map - */ - private Map> randomFieldsForModels() { - if (randomBoolean()) { - return null; - } - - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } - - return fieldsForModels; - } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 4b7a67e9ca0e3..e80530f75cf4b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; @@ -75,11 +76,13 @@ public static void getInstance( Set shardIds, ActionListener listener ) { - Set inferenceIds = new HashSet<>(); - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); - inferenceIds.addAll(fieldsForModels.keySet()); - }); + Set inferenceIds = shardIds.stream() + .map(ShardId::getIndex) + .collect(Collectors.toSet()) + .stream() + .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) + .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)) + .collect(Collectors.toSet()); final Map inferenceProviderMap = new ConcurrentHashMap<>(); Runnable onModelLoadingComplete = () -> listener.onResponse( new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) @@ -134,11 +137,11 @@ public void processBulkShardRequest( BiConsumer onBulkItemFailure ) { - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); + Map> fieldsForInferenceIds = getFieldsForInferenceIds( + clusterState.metadata().index(bulkShardRequest.shardId().getIndex()).getFieldInferenceMetadata().getFieldInferenceOptions() + ); // No inference fields? Terminate early - if (fieldsForModels.isEmpty()) { + if (fieldsForInferenceIds.isEmpty()) { listener.onResponse(bulkShardRequest); return; } @@ -176,7 +179,7 @@ public void processBulkShardRequest( if (bulkItemRequest != null) { performInferenceOnBulkItemRequest( bulkItemRequest, - fieldsForModels, + fieldsForInferenceIds, i, onBulkItemFailureWithIndex, bulkItemReqRef.acquire() @@ -186,6 +189,22 @@ public void processBulkShardRequest( } } + private static Map> getFieldsForInferenceIds( + Map fieldInferenceMap + ) { + Map> fieldsForInferenceIdsMap = new HashMap<>(); + for (Map.Entry entry : fieldInferenceMap.entrySet()) { + String fieldName = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + + // Get or create the set associated with the inferenceId + Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); + fields.add(fieldName); + } + + return fieldsForInferenceIdsMap; + } + @SuppressWarnings("unchecked") private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java new file mode 100644 index 0000000000000..349706c139127 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.SimpleDiffable; +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator + * node, which not necessarily has mapping information. + */ +public class FieldInferenceMetadata implements Diffable, ToXContentFragment { + + private final ImmutableOpenMap fieldInferenceOptions; + + public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); + + public FieldInferenceMetadata(MappingLookup mappingLookup) { + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { + builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); + }); + fieldInferenceOptions = builder.build(); + } + + public FieldInferenceMetadata(StreamInput in) throws IOException { + fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); + } + + public FieldInferenceMetadata(Map fieldsToInferenceMap) { + fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); + } + + public ImmutableOpenMap getFieldInferenceOptions() { + return fieldInferenceOptions; + } + + public boolean isEmpty() { + return fieldInferenceOptions.isEmpty(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.map(fieldInferenceOptions); + return builder; + } + + public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { + return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FieldInferenceMetadata that = (FieldInferenceMetadata) o; + return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); + } + + @Override + public int hashCode() { + return Objects.hash(fieldInferenceOptions); + } + + @Override + public Diff diff(FieldInferenceMetadata previousState) { + if (previousState == null) { + previousState = EMPTY; + } + return new FieldInferenceMetadataDiff(previousState, this); + } + + static class FieldInferenceMetadataDiff implements Diff { + + public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( + FieldInferenceMetadata.EMPTY, + FieldInferenceMetadata.EMPTY + ); + + private final Diff> fieldInferenceMapDiff; + + private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); + + FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { + fieldInferenceMapDiff = DiffableUtils.diff( + before.fieldInferenceOptions, + after.fieldInferenceOptions, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + FieldInferenceMetadataDiff(StreamInput in) throws IOException { + fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + fieldInferenceMapDiff.writeTo(out); + } + + @Override + public FieldInferenceMetadata apply(FieldInferenceMetadata part) { + return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); + } + } + + public record FieldInferenceOptions(String inferenceId, Set sourceFields) + implements + SimpleDiffable, + ToXContentFragment { + + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); + + FieldInferenceOptions(StreamInput in) throws IOException { + this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeStringCollection(sourceFields); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); + builder.endObject(); + return builder; + } + + public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field_inference_parser", + false, + (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 81406f0a74ce5..89c925427cf88 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -78,7 +78,6 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -541,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String KEY_FIELD_INFERENCE = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -632,8 +631,7 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - // Key: model ID, Value: Fields that use model - private final ImmutableOpenMap> fieldsForModels; + private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -680,7 +678,7 @@ private IndexMetadata( @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, @Nullable Long shardSizeInBytesForecast, - final ImmutableOpenMap> fieldsForModels + @Nullable FieldInferenceMetadata fieldInferenceMetadata ) { this.index = index; this.version = version; @@ -736,7 +734,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldsForModels = Objects.requireNonNull(fieldsForModels); + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -788,7 +786,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -847,7 +845,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -904,7 +902,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -961,7 +959,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1014,7 +1012,7 @@ public IndexMetadata withIncrementedVersion() { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1218,8 +1216,8 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public Map> getFieldsForModels() { - return fieldsForModels; + public FieldInferenceMetadata getFieldInferenceMetadata() { + return fieldInferenceMetadata; } public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; @@ -1419,7 +1417,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldsForModels.equals(that.fieldsForModels) == false) { + if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { return false; } if (isSystem != that.isSystem) { @@ -1442,7 +1440,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldsForModels.hashCode(); + result = 31 * result + fieldInferenceMetadata.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1498,7 +1496,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff>> fieldsForModels; + private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1535,12 +1533,7 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldsForModels = DiffableUtils.diff( - before.fieldsForModels, - after.fieldsForModels, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); + fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1601,13 +1594,9 @@ private static class IndexMetadataDiff implements Diff { shardSizeInBytesForecast = null; } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = DiffableUtils.readJdkMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); + fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); } else { - fieldsForModels = DiffableUtils.emptyDiff(); + fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; } } @@ -1645,7 +1634,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels.writeTo(out); + out.writeOptionalWriteable(fieldInferenceMetadata); } } @@ -1676,7 +1665,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); + builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1745,9 +1734,7 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) - ); + builder.fieldInferenceMetadata(new FieldInferenceMetadata(in)); } return builder.build(true); } @@ -1796,7 +1783,7 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + fieldInferenceMetadata.writeTo(out); } } @@ -1847,7 +1834,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private final ImmutableOpenMap.Builder> fieldsForModels; + private FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; public Builder(String index) { this.index = index; @@ -1855,7 +1842,6 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); - this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1880,7 +1866,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); + this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2110,8 +2096,8 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldsForModels(Map> fieldsForModels) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); + public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); return this; } @@ -2310,7 +2296,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels.build() + fieldInferenceMetadata ); } @@ -2436,8 +2422,8 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); + if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { + builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); } builder.endObject(); @@ -2517,18 +2503,8 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) - .entrySet() - .stream() - .collect( - Collectors.toMap( - Map.Entry::getKey, - v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) - ) - ); - builder.fieldsForModels(fieldsForModels); + case KEY_FIELD_INFERENCE: + builder.fieldInferenceMetadata(FieldInferenceMetadata.fromXContent(parser)); break; default: // assume it's custom index metadata @@ -2726,17 +2702,6 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } - - private static void processFieldsForModels( - ImmutableOpenMap.Builder> builder, - Map> fieldsForModels - ) { - builder.clear(); - if (fieldsForModels != null) { - // Ensure that all field sets contained in the processed map are immutable - fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); - } - } } /** diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index d8fe0b0c19e52..96ca7a15edc30 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1267,8 +1267,8 @@ static IndexMetadata buildIndexMetadata( if (mapper != null) { MappingMetadata mappingMd = new MappingMetadata(mapper); mappingsMetadata.put(mapper.type(), mappingMd); - - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); + indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index d913a6465482d..0e31592991369 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -204,7 +204,7 @@ private static ClusterState applyRequest( DocumentMapper mapper = mapperService.documentMapper(); if (mapper != null) { indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 372b1412df724..0741cfa682b74 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -39,7 +39,7 @@ final class FieldTypeLookup { /** * A map from inference model ID to all fields that use the model to generate embeddings. */ - private final Map> fieldsForModels; + private final Map inferenceIdsForFields; private final int maxParentPathDots; @@ -53,7 +53,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map> fieldsForModels = new HashMap<>(); + final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -72,11 +72,7 @@ final class FieldTypeLookup { fieldToCopiedFields.get(targetField).add(fieldName); } if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - String inferenceModel = inferenceModelFieldType.getInferenceModel(); - if (inferenceModel != null) { - Set fields = fieldsForModels.computeIfAbsent(inferenceModel, v -> new HashSet<>()); - fields.add(fieldName); - } + inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); } } @@ -110,8 +106,7 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - fieldsForModels.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); - this.fieldsForModels = Map.copyOf(fieldsForModels); + this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -220,8 +215,8 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map> getFieldsForModels() { - return fieldsForModels; + Map getInferenceIdsForFields() { + return inferenceIdsForFields; } /** diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java index 490d7f36219cf..6e12a204ed7d0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java @@ -17,5 +17,5 @@ public interface InferenceModelFieldType { * * @return model id used by the field type */ - String getInferenceModel(); + String getInferenceId(); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index cf2212110a210..c2bd95115f27e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -524,7 +524,7 @@ public void validateDoesNotShadow(String name) { } } - public Map> getFieldsForModels() { - return fieldTypeLookup.getFieldsForModels(); + public Map getInferenceIdsForFields() { + return fieldTypeLookup.getInferenceIdsForFields(); } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 2ce7b161d3dd1..c3887f506b891 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -92,7 +93,7 @@ public class BulkOperationTests extends ESTestCase { public void testNoInference() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry( Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); @@ -116,7 +117,7 @@ public void testNoInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -158,7 +159,7 @@ private static Model mockModel(String inferenceServiceId) { public void testFailedBulkShardRequest() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry(Map.of()); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); @@ -176,7 +177,7 @@ public void testFailedBulkShardRequest() { runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -206,11 +207,15 @@ public void testFailedBulkShardRequest() { @SuppressWarnings("unchecked") public void testInference() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry( @@ -244,7 +249,7 @@ public void testInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -279,7 +284,9 @@ public void testInference() { public void testFailedInference() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -298,7 +305,7 @@ public void testFailedInference() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -313,7 +320,9 @@ public void testFailedInference() { public void testInferenceFailsForIncorrectRootObject() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -331,7 +340,7 @@ public void testInferenceFailsForIncorrectRootObject() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -343,11 +352,15 @@ public void testInferenceFailsForIncorrectRootObject() { public void testInferenceIdNotFound() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -368,7 +381,7 @@ public void testInferenceIdNotFound() { ActionListener bulkOperationListener = mock(ActionListener.class); doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -444,7 +457,7 @@ public String toString() { private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, boolean expectTransportShardBulkActionToExecute, @@ -452,7 +465,7 @@ private static BulkShardRequest runBulkOperation( ) { return runBulkOperation( docSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -463,7 +476,7 @@ private static BulkShardRequest runBulkOperation( private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener, @@ -472,7 +485,7 @@ private static BulkShardRequest runBulkOperation( ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .settings(settings) .numberOfShards(1) .numberOfReplicas(0) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b2354a4356595..b32873df71365 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -41,7 +42,6 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -84,7 +84,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - Map> fieldsForModels = randomFieldsForModels(true); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -110,7 +110,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .build(); assertEquals(system, metadata.isSystem()); @@ -145,7 +145,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -169,7 +169,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); } } @@ -553,35 +553,35 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldsForModels() { + public void testFieldInferenceMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); + assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); - Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); - assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); + assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - private static Map> randomFieldsForModels(boolean allowNull) { - if (allowNull && randomBoolean()) { + public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { + if (randomBoolean() && allowNull) { return null; } - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } + Map fieldInferenceMap = randomMap( + 0, + 10, + () -> new Tuple<>(randomIdentifier(), randomFieldInference()) + ); + return new FieldInferenceMetadata(fieldInferenceMap); + } - return fieldsForModels; + private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { + return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 27663edde945c..932eac3e60d27 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -37,7 +37,7 @@ public void testEmpty() { assertNotNull(names); assertThat(names, hasSize(0)); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -48,7 +48,7 @@ public void testAddNewField() { assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -440,11 +440,13 @@ public void testInferenceModelFieldType() { assertEquals(f2.fieldType(), lookup.get("foo2")); assertEquals(f3.fieldType(), lookup.get("foo3")); - Map> fieldsForModels = lookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(2, fieldsForModels.size()); - assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1")); - assertEquals(Set.of("foo3"), fieldsForModels.get("bar2")); + Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(3, inferenceIdsForFields.size()); + + assertEquals("bar1", inferenceIdsForFields.get("foo1")); + assertEquals("bar1", inferenceIdsForFields.get("foo2")); + assertEquals("bar2", inferenceIdsForFields.get("foo3")); } private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index f512f5d352a43..bb337d0c61c93 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,7 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -122,8 +121,8 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getFieldsForModels()); - assertTrue(mappingLookup.getFieldsForModels().isEmpty()); + assertNotNull(mappingLookup.getInferenceIdsForFields()); + assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -191,7 +190,7 @@ public MetricType getMetricType() { ); } - public void testFieldsForModels() { + public void testInferenceIdsForFields() { MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); MappingLookup mappingLookup = createMappingLookup( Collections.singletonList(new MockFieldMapper(fieldType)), @@ -201,10 +200,10 @@ public void testFieldsForModels() { assertEquals(1, size(mappingLookup.fieldMappers())); assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - Map> fieldsForModels = mappingLookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(1, fieldsForModels.size()); - assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); + Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(1, inferenceIdsForFields.size()); + assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); } private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java index 854749d6308db..0d21134b5d9a9 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java @@ -39,7 +39,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..d9e18728615ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -108,7 +108,7 @@ public String typeName() { } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 69fa64ffa6d1c..a7d3fcce26116 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -20,8 +20,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Set; public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @@ -35,7 +33,10 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") ); - assertEquals(Map.of("test_model", Set.of("field")), indexService.getMetadata().getFieldsForModels()); + assertEquals( + indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } public void testAddSemanticTextField() throws Exception { @@ -52,7 +53,10 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals(Map.of("test_model", Set.of("field")), resultingState.metadata().index("test").getFieldsForModels()); + assertEquals( + resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } private static List singleTask(PutMappingClusterStateUpdateRequest request) {