diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 409fbdd70333e..e0dbc74567053 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -61,6 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -571,7 +572,7 @@ public IndexMetadata randomCreate(String name) { @Override public IndexMetadata randomChange(IndexMetadata part) { IndexMetadata.Builder builder = IndexMetadata.builder(part); - switch (randomIntBetween(0, 2)) { + switch (randomIntBetween(0, 3)) { case 0: builder.settings(Settings.builder().put(part.getSettings()).put(randomSettings(Settings.EMPTY))); break; @@ -585,6 +586,9 @@ public IndexMetadata randomChange(IndexMetadata part) { case 2: builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; + case 3: + builder.putInferenceFields(randomInferenceFields()); + break; default: throw new IllegalArgumentException("Shouldn't be here"); } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 8589e183a150e..9d21e9fe5d794 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -166,6 +166,7 @@ static TransportVersion def(int id) { public static final TransportVersion INDEXING_PRESSURE_DOCUMENT_REJECTIONS_COUNT = def(8_625_00_0); public static final TransportVersion ALIAS_ACTION_RESULTS = def(8_626_00_0); public static final TransportVersion HISTOGRAM_AGGS_KEY_SORTED = def(8_627_00_0); + public static final TransportVersion INFERENCE_FIELDS_METADATA = def(8_628_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 22672756bdaf0..529814e83ba38 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,6 +540,8 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; + public static final String KEY_INFERENCE_FIELDS = "field_inference"; + public static final String INDEX_STATE_FILE_PREFIX = "state-"; static final TransportVersion SYSTEM_INDEX_FLAG_ADDED = TransportVersions.V_7_10_0; @@ -574,6 +576,8 @@ public Iterator> settings() { @Nullable private final MappingMetadata mapping; + private final ImmutableOpenMap inferenceFields; + private final ImmutableOpenMap customData; private final Map> inSyncAllocationIds; @@ -642,6 +646,7 @@ private IndexMetadata( final int numberOfReplicas, final Settings settings, final MappingMetadata mapping, + final ImmutableOpenMap inferenceFields, final ImmutableOpenMap aliases, final ImmutableOpenMap customData, final Map> inSyncAllocationIds, @@ -692,6 +697,7 @@ private IndexMetadata( this.totalNumberOfShards = numberOfShards * (numberOfReplicas + 1); this.settings = settings; this.mapping = mapping; + this.inferenceFields = inferenceFields; this.customData = customData; this.aliases = aliases; this.inSyncAllocationIds = inSyncAllocationIds; @@ -748,6 +754,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.numberOfReplicas, this.settings, mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -806,6 +813,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, Maps.copyMapWithAddedOrReplacedEntry(this.inSyncAllocationIds, shardId, Set.copyOf(inSyncSet)), @@ -862,6 +870,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -918,6 +927,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -970,6 +980,7 @@ public IndexMetadata withIncrementedVersion() { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -1193,6 +1204,10 @@ public MappingMetadata mapping() { return mapping; } + public Map getInferenceFields() { + return inferenceFields; + } + @Nullable public IndexMetadataStats getStats() { return stats; @@ -1403,6 +1418,9 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } + if (inferenceFields.equals(that.inferenceFields) == false) { + return false; + } if (isSystem != that.isSystem) { return false; } @@ -1423,6 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); + result = 31 * result + inferenceFields.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1469,6 +1488,7 @@ private static class IndexMetadataDiff implements Diff { @Nullable private final Diff settingsDiff; private final Diff> mappings; + private final Diff> inferenceFields; private final Diff> aliases; private final Diff> customData; private final Diff>> inSyncAllocationIds; @@ -1500,6 +1520,7 @@ private static class IndexMetadataDiff implements Diff { : ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, after.mapping).build(), DiffableUtils.getStringKeySerializer() ); + inferenceFields = DiffableUtils.diff(before.inferenceFields, after.inferenceFields, DiffableUtils.getStringKeySerializer()); aliases = DiffableUtils.diff(before.aliases, after.aliases, DiffableUtils.getStringKeySerializer()); customData = DiffableUtils.diff(before.customData, after.customData, DiffableUtils.getStringKeySerializer()); inSyncAllocationIds = DiffableUtils.diff( @@ -1524,6 +1545,8 @@ private static class IndexMetadataDiff implements Diff { new DiffableUtils.DiffableValueReader<>(DiffableStringMap::readFrom, DiffableStringMap::readDiffFrom); private static final DiffableUtils.DiffableValueReader ROLLOVER_INFO_DIFF_VALUE_READER = new DiffableUtils.DiffableValueReader<>(RolloverInfo::new, RolloverInfo::readDiffFrom); + private static final DiffableUtils.DiffableValueReader INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(InferenceFieldMetadata::new, InferenceFieldMetadata::readDiffFrom); IndexMetadataDiff(StreamInput in) throws IOException { index = in.readString(); @@ -1546,6 +1569,15 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { + inferenceFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER + ); + } else { + inferenceFields = DiffableUtils.emptyDiff(); + } aliases = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), ALIAS_METADATA_DIFF_VALUE_READER); customData = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), CUSTOM_DIFF_VALUE_READER); inSyncAllocationIds = DiffableUtils.readJdkMapDiff( @@ -1595,6 +1627,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { + inferenceFields.writeTo(out); + } aliases.writeTo(out); customData.writeTo(out); inSyncAllocationIds.writeTo(out); @@ -1628,6 +1663,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.mapping = mappings.apply( ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, part.mapping).build() ).get(MapperService.SINGLE_MAPPING_NAME); + builder.inferenceFields.putAllFromMap(inferenceFields.apply(part.inferenceFields)); builder.aliases.putAllFromMap(aliases.apply(part.aliases)); builder.customMetadata.putAllFromMap(customData.apply(part.customData)); builder.inSyncAllocationIds.putAll(inSyncAllocationIds.apply(part.inSyncAllocationIds)); @@ -1673,6 +1709,10 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); + } int aliasesSize = in.readVInt(); for (int i = 0; i < aliasesSize; i++) { AliasMetadata aliasMd = new AliasMetadata(in); @@ -1733,6 +1773,9 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException mapping.writeTo(out); } } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { + out.writeCollection(inferenceFields.values()); + } out.writeCollection(aliases.values()); out.writeMap(customData, StreamOutput::writeWriteable); out.writeMap( @@ -1788,6 +1831,7 @@ public static class Builder { private long[] primaryTerms = null; private Settings settings = Settings.EMPTY; private MappingMetadata mapping; + private final ImmutableOpenMap.Builder inferenceFields; private final ImmutableOpenMap.Builder aliases; private final ImmutableOpenMap.Builder customMetadata; private final Map> inSyncAllocationIds; @@ -1802,6 +1846,7 @@ public static class Builder { public Builder(String index) { this.index = index; + this.inferenceFields = ImmutableOpenMap.builder(); this.aliases = ImmutableOpenMap.builder(); this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); @@ -1819,6 +1864,7 @@ public Builder(IndexMetadata indexMetadata) { this.settings = indexMetadata.getSettings(); this.primaryTerms = indexMetadata.primaryTerms.clone(); this.mapping = indexMetadata.mapping; + this.inferenceFields = ImmutableOpenMap.builder(indexMetadata.inferenceFields); this.aliases = ImmutableOpenMap.builder(indexMetadata.aliases); this.customMetadata = ImmutableOpenMap.builder(indexMetadata.customData); this.routingNumShards = indexMetadata.routingNumShards; @@ -2059,6 +2105,16 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder putInferenceField(InferenceFieldMetadata value) { + this.inferenceFields.put(value.getName(), value); + return this; + } + + public Builder putInferenceFields(Map values) { + this.inferenceFields.putAllFromMap(values); + return this; + } + public IndexMetadata build() { return build(false); } @@ -2221,6 +2277,7 @@ IndexMetadata build(boolean repair) { numberOfReplicas, settings, mapping, + inferenceFields.build(), aliasesMap, newCustomMetadata, Map.ofEntries(denseInSyncAllocationIds), @@ -2379,6 +2436,14 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } + if (indexMetadata.getInferenceFields().isEmpty() == false) { + builder.startObject(KEY_INFERENCE_FIELDS); + for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { + field.toXContent(builder, params); + } + builder.endObject(); + } + builder.endObject(); } @@ -2456,6 +2521,11 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { + private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private final String name; + private final String inferenceId; + private final String[] sourceFields; + + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this.name = Objects.requireNonNull(name); + this.inferenceId = Objects.requireNonNull(inferenceId); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public InferenceFieldMetadata(StreamInput input) throws IOException { + this.name = input.readString(); + this.inferenceId = input.readString(); + this.sourceFields = input.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(inferenceId); + out.writeStringArray(sourceFields); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceFieldMetadata that = (InferenceFieldMetadata) o; + return Objects.equals(name, that.name) + && Objects.equals(inferenceId, that.inferenceId) + && Arrays.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + int result = Objects.hash(name, inferenceId); + result = 31 * result + Arrays.hashCode(sourceFields); + return result; + } + + public String getName() { + return name; + } + + public String getInferenceId() { + return inferenceId; + } + + public String[] getSourceFields() { + return sourceFields; + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(name); + builder.field(INFERENCE_ID_FIELD, inferenceId); + builder.array(SOURCE_FIELDS_FIELD, sourceFields); + return builder.endObject(); + } + + public static InferenceFieldMetadata fromXContent(XContentParser parser) throws IOException { + final String name = parser.currentName(); + + XContentParser.Token token = parser.nextToken(); + Objects.requireNonNull(token, "Expected InferenceFieldMetadata but got EOF"); + + String currentFieldName = null; + String inferenceId = null; + List inputFields = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_STRING) { + if (INFERENCE_ID_FIELD.equals(currentFieldName)) { + inferenceId = parser.text(); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_STRING) { + inputFields.add(parser.text()); + } else { + parser.skipChildren(); + } + } + } + } else { + parser.skipChildren(); + } + } + return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index da24f0b9d0dc5..52642e1de8ac9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1263,10 +1263,11 @@ static IndexMetadata buildIndexMetadata( indexMetadataBuilder.system(isSystem); // now, update the mappings with the actual source Map mappingsMetadata = new HashMap<>(); - DocumentMapper mapper = documentMapperSupplier.get(); - if (mapper != null) { - MappingMetadata mappingMd = new MappingMetadata(mapper); - mappingsMetadata.put(mapper.type(), mappingMd); + DocumentMapper docMapper = documentMapperSupplier.get(); + if (docMapper != null) { + MappingMetadata mappingMd = new MappingMetadata(docMapper); + mappingsMetadata.put(docMapper.type(), mappingMd); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 3ca206eaddb28..4e714b96f64c7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -193,9 +193,10 @@ private static ClusterState applyRequest( IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(indexMetadata); // Mapping updates on a single type may have side-effects on other types so we need to // update mapping metadata on all types - DocumentMapper mapper = mapperService.documentMapper(); - if (mapper != null) { - indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); + DocumentMapper docMapper = mapperService.documentMapper(); + if (docMapper != null) { + indexMetadataBuilder.putMapping(new MappingMetadata(docMapper)); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java new file mode 100644 index 0000000000000..2b0833c72021b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.inference.InferenceService; + +import java.util.Set; + +/** + * Field mapper that requires to transform its input before indexation through the {@link InferenceService}. + */ +public interface InferenceFieldMapper { + + /** + * Retrieve the inference metadata associated with this mapper. + * + * @param sourcePaths The source path that populates the input for the field (before inference) + */ + InferenceFieldMetadata getMetadata(Set sourcePaths); +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 673593cc6e240..bf879f30e5a29 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -10,9 +10,11 @@ import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.inference.InferenceService; import java.util.ArrayList; import java.util.Collection; @@ -47,6 +49,7 @@ private CacheKey() {} /** Full field name to mapper */ private final Map fieldMappers; private final Map objectMappers; + private final Map inferenceFields; private final int runtimeFieldMappersCount; private final NestedLookup nestedLookup; private final FieldTypeLookup fieldTypeLookup; @@ -84,12 +87,12 @@ private static void collect( Collection fieldMappers, Collection fieldAliasMappers ) { - if (mapper instanceof ObjectMapper) { - objectMappers.add((ObjectMapper) mapper); - } else if (mapper instanceof FieldMapper) { - fieldMappers.add((FieldMapper) mapper); - } else if (mapper instanceof FieldAliasMapper) { - fieldAliasMappers.add((FieldAliasMapper) mapper); + if (mapper instanceof ObjectMapper objectMapper) { + objectMappers.add(objectMapper); + } else if (mapper instanceof FieldMapper fieldMapper) { + fieldMappers.add(fieldMapper); + } else if (mapper instanceof FieldAliasMapper fieldAliasMapper) { + fieldAliasMappers.add(fieldAliasMapper); } else { throw new IllegalStateException("Unrecognized mapper type [" + mapper.getClass().getSimpleName() + "]."); } @@ -174,6 +177,15 @@ private MappingLookup( final Collection runtimeFields = mapping.getRoot().runtimeFields(); this.fieldTypeLookup = new FieldTypeLookup(mappers, aliasMappers, runtimeFields); + + Map inferenceFields = new HashMap<>(); + for (FieldMapper mapper : mappers) { + if (mapper instanceof InferenceFieldMapper inferenceFieldMapper) { + inferenceFields.put(mapper.name(), inferenceFieldMapper.getMetadata(fieldTypeLookup.sourcePaths(mapper.name()))); + } + } + this.inferenceFields = Map.copyOf(inferenceFields); + if (runtimeFields.isEmpty()) { // without runtime fields this is the same as the field type lookup this.indexTimeLookup = fieldTypeLookup; @@ -360,6 +372,13 @@ public Map objectMappers() { return objectMappers; } + /** + * Returns a map containing all fields that require to run inference (through the {@link InferenceService} prior to indexation. + */ + public Map inferenceFields() { + return inferenceFields; + } + public NestedLookup nestedLookup() { return nestedLookup; } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 5cc1a7206e7e4..116acf938fcbc 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -83,6 +83,8 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; + Map inferenceFields = randomInferenceFields(); + IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) .creationDate(randomLong()) @@ -107,6 +109,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) + .putInferenceFields(inferenceFields) .build(); assertEquals(system, metadata.isSystem()); @@ -141,6 +144,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getInferenceFields(), fromXContentMeta.getInferenceFields()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -162,8 +166,9 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getCustomData(), deserialized.getCustomData()); assertEquals(metadata.isSystem(), deserialized.isSystem()); assertEquals(metadata.getStats(), deserialized.getStats()); - assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); - assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); + assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); + assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); } } @@ -547,10 +552,34 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } + public void testInferenceFieldMetadata() { + Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); + assertTrue(idxMeta1.getInferenceFields().isEmpty()); + + Map dynamicFields = randomInferenceFields(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).putInferenceFields(dynamicFields).build(); + assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); + } + private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } + public static Map randomInferenceFields() { + Map map = new HashMap<>(); + int numFields = randomIntBetween(0, 5); + for (int i = 0; i < numFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + map.put(field, randomInferenceFieldMetadata(field)); + } + return map; + } + + private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { + return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); + } + private IndexMetadataStats randomIndexStats(int numberOfShards) { IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards); int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java new file mode 100644 index 0000000000000..bd4c87be51157 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Predicate; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceFieldMetadataTests extends AbstractXContentTestCase { + + public void testSerialization() throws IOException { + final InferenceFieldMetadata before = createTestItem(); + final BytesStreamOutput out = new BytesStreamOutput(); + before.writeTo(out); + + final StreamInput in = out.bytes().streamInput(); + final InferenceFieldMetadata after = new InferenceFieldMetadata(in); + + assertThat(after, equalTo(before)); + } + + @Override + protected InferenceFieldMetadata createTestInstance() { + return createTestItem(); + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field + } + + @Override + protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + assertEquals(XContentParser.Token.FIELD_NAME, parser.currentToken()); + InferenceFieldMetadata inferenceMetadata = InferenceFieldMetadata.fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return inferenceMetadata; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + private static InferenceFieldMetadata createTestItem() { + String name = randomAlphaOfLengthBetween(3, 10); + String inferenceId = randomIdentifier(); + String[] inputFields = generateRandomStringArray(5, 10, false, false); + return new InferenceFieldMetadata(name, inferenceId, inputFields); + } + + public void testNullCtorArgsThrowException() { + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null)); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index c83caa617e16e..e2b03c6b81af3 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -729,6 +729,7 @@ public static IndicesService mockIndicesServices(MappingLookup mappingLookup) th Mapping mapping = new Mapping(root, new MetadataFieldMapper[0], null); DocumentMapper documentMapper = mock(DocumentMapper.class); when(documentMapper.mapping()).thenReturn(mapping); + when(documentMapper.mappers()).thenReturn(MappingLookup.EMPTY); when(documentMapper.mappingSource()).thenReturn(mapping.toCompressedXContent()); RoutingFieldMapper routingFieldMapper = mock(RoutingFieldMapper.class); when(routingFieldMapper.required()).thenReturn(false);