From 833469cfb0d72f8da0558d5cda01cb3ac607a664 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 17 Jan 2024 18:40:05 -0500 Subject: [PATCH 01/29] Store semantic_text model info in mappings (#103319) Store semantic_text model info in IndexMetadata: On document ingestion, we need to perform inference only once, in the coordinating node. Otherwise, we would be doing inference for each of the shards the document is stored in. The problem with the coordinating node is that it doesn't necessarily hold mapping information if it is not used for storing an index. A pure coordinating node doesn't have any mapping information at all. We need to understand when we need to generate text embeddings on the coordinating node. This means that the model information associated with index fields needs to be efficiently accessed from there. This information needs to be kept up to date with mapping changes, and not be recomputed otherwise. The model / fields information is going to be included as part of the IndexMetadata, to ensure it is communicated to all nodes in the cluster. --- .../cluster/ClusterStateDiffIT.java | 26 ++++- .../org/elasticsearch/TransportVersions.java | 1 + .../cluster/metadata/IndexMetadata.java | 99 +++++++++++++++++-- .../metadata/MetadataCreateIndexService.java | 2 + .../metadata/MetadataMappingService.java | 1 + .../index/mapper/FieldTypeLookup.java | 19 ++++ .../index/mapper/MappingLookup.java | 4 + .../cluster/metadata/IndexMetadataTests.java | 37 ++++++- .../index/mapper/FieldTypeLookupTests.java | 26 +++++ .../index/mapper/MappingLookupTests.java | 19 ++++ .../metadata/DataStreamTestHelper.java | 3 +- .../mapper/MockInferenceModelFieldType.java | 45 +++++++++ .../SemanticTextClusterMetadataTests.java | 54 ++++++++++ .../xpack/ml/LocalStateMachineLearning.java | 6 ++ 14 files changed, 331 insertions(+), 11 deletions(-) create mode 100644 test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index b869b3a90fbce..433b4bdaf5d98 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -55,6 +55,7 @@ import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -571,7 +572,7 @@ public IndexMetadata randomCreate(String name) { @Override public IndexMetadata randomChange(IndexMetadata part) { IndexMetadata.Builder builder = IndexMetadata.builder(part); - switch (randomIntBetween(0, 2)) { + switch (randomIntBetween(0, 3)) { case 0: builder.settings(Settings.builder().put(part.getSettings()).put(randomSettings(Settings.EMPTY))); break; @@ -585,11 +586,34 @@ public IndexMetadata randomChange(IndexMetadata part) { case 2: builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; + case 3: + builder.fieldsForModels(randomFieldsForModels()); + break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } + + /** + * Generates a random fieldsForModels map + */ + private Map> randomFieldsForModels() { + if (randomBoolean()) { + return null; + } + + Map> fieldsForModels = new HashMap<>(); + for (int i = 0; i < randomIntBetween(0, 5); i++) { + Set fields = new HashSet<>(); + for (int j = 0; j < randomIntBetween(1, 4); j++) { + fields.add(randomAlphaOfLengthBetween(4, 10)); + } + fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); + } + + return fieldsForModels; + } }); } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index a730587f32c20..c914eac4927a0 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -188,6 +188,7 @@ static TransportVersion def(int id) { public static final TransportVersion PEERFINDER_REPORTS_PEERS_MASTERS = def(8_575_00_0); public static final TransportVersion ESQL_MULTI_CLUSTERS_ENRICH = def(8_576_00_0); public static final TransportVersion NESTED_KNN_MORE_INNER_HITS = def(8_577_00_0); + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_578_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 83b1c48e69eb9..a95c3e905d5f4 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -78,6 +78,7 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -540,6 +541,8 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; + public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String INDEX_STATE_FILE_PREFIX = "state-"; static final TransportVersion SYSTEM_INDEX_FLAG_ADDED = TransportVersions.V_7_10_0; @@ -629,6 +632,8 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; + // Key: model ID, Value: Fields that use model + private final ImmutableOpenMap> fieldsForModels; private IndexMetadata( final Index index, @@ -674,7 +679,8 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast + @Nullable Long shardSizeInBytesForecast, + final ImmutableOpenMap> fieldsForModels ) { this.index = index; this.version = version; @@ -730,6 +736,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; + this.fieldsForModels = Objects.requireNonNull(fieldsForModels); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -780,7 +787,8 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -838,7 +846,8 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -894,7 +903,8 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -950,7 +960,8 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1002,7 +1013,8 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1206,6 +1218,10 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1404,6 +1420,9 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } + if (fieldsForModels.equals(that.fieldsForModels) == false) { + return false; + } if (isSystem != that.isSystem) { return false; } @@ -1424,6 +1443,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); + result = 31 * result + fieldsForModels.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1479,6 +1499,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; + private final Diff>> fieldsForModels; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1515,6 +1536,12 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; + fieldsForModels = DiffableUtils.diff( + before.fieldsForModels, + after.fieldsForModels, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1574,6 +1601,15 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = DiffableUtils.readJdkMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } else { + fieldsForModels = DiffableUtils.emptyDiff(); + } } @Override @@ -1609,6 +1645,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels.writeTo(out); + } } @Override @@ -1638,6 +1677,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); + builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); return builder.build(true); } } @@ -1705,6 +1745,11 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) + ); + } return builder.build(true); } @@ -1751,6 +1796,9 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalDouble(writeLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -1800,6 +1848,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; + private final ImmutableOpenMap.Builder> fieldsForModels; public Builder(String index) { this.index = index; @@ -1807,6 +1856,7 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); + this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1831,6 +1881,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; + this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); } public Builder index(String index) { @@ -2060,6 +2111,11 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder fieldsForModels(Map> fieldsForModels) { + processFieldsForModels(this.fieldsForModels, fieldsForModels); + return this; + } + public IndexMetadata build() { return build(false); } @@ -2254,7 +2310,8 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast + shardSizeInBytesForecast, + fieldsForModels.build() ); } @@ -2380,6 +2437,10 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } + if (indexMetadata.fieldsForModels.isEmpty() == false) { + builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); + } + builder.endObject(); } @@ -2457,6 +2518,19 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) + .entrySet() + .stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) + ) + ); + builder.fieldsForModels(fieldsForModels); + break; default: // assume it's custom index metadata builder.putCustom(currentFieldName, parser.mapStrings()); @@ -2653,6 +2727,17 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } + + private static void processFieldsForModels( + ImmutableOpenMap.Builder> builder, + Map> fieldsForModels + ) { + builder.clear(); + if (fieldsForModels != null) { + // Ensure that all field sets contained in the processed map are immutable + fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); + } + } } /** diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index da24f0b9d0dc5..d8fe0b0c19e52 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1267,6 +1267,8 @@ static IndexMetadata buildIndexMetadata( if (mapper != null) { MappingMetadata mappingMd = new MappingMetadata(mapper); mappingsMetadata.put(mapper.type(), mappingMd); + + indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 7a2d20d042f84..8d12ebd36c645 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -199,6 +199,7 @@ private static ClusterState applyRequest( DocumentMapper mapper = mapperService.documentMapper(); if (mapper != null) { indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); + indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 2b4eec2bdd565..564e6f903a2ae 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,6 +36,11 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; + /** + * A map from inference model ID to all fields that use the model to generate embeddings. + */ + private final Map> fieldsForModels; + private final int maxParentPathDots; FieldTypeLookup( @@ -48,6 +53,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); + final Map> fieldsForModels = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -65,6 +71,13 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } + if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { + String inferenceModel = inferenceModelFieldType.getInferenceModel(); + if (inferenceModel != null) { + Set fields = fieldsForModels.computeIfAbsent(inferenceModel, v -> new HashSet<>()); + fields.add(fieldName); + } + } } int maxParentPathDots = 0; @@ -97,6 +110,8 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); + fieldsForModels.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); + this.fieldsForModels = Map.copyOf(fieldsForModels); } public static int dotCount(String path) { @@ -205,6 +220,10 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } + Map> getFieldsForModels() { + return fieldsForModels; + } + /** * If field is a leaf multi-field return the path to the parent field. Otherwise, return null. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 4880ce5edc204..2c16a0fda9e60 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -498,4 +498,8 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } + + public Map> getFieldsForModels() { + return fieldTypeLookup.getFieldsForModels(); + } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b4c9f670f66b6..58b8adcf53538 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -82,6 +83,8 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; + Map> fieldsForModels = randomFieldsForModels(true); + IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) .creationDate(randomLong()) @@ -105,6 +108,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) + .fieldsForModels(fieldsForModels) .build(); assertEquals(system, metadata.isSystem()); @@ -138,6 +142,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -159,8 +164,9 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getCustomData(), deserialized.getCustomData()); assertEquals(metadata.isSystem(), deserialized.isSystem()); assertEquals(metadata.getStats(), deserialized.getStats()); - assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); - assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); + assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); + assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); } } @@ -544,10 +550,37 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } + public void testFieldsForModels() { + Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); + assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); + + Map> fieldsForModels = randomFieldsForModels(false); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); + assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); + } + private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } + private static Map> randomFieldsForModels(boolean allowNull) { + if (allowNull && randomBoolean()) { + return null; + } + + Map> fieldsForModels = new HashMap<>(); + for (int i = 0; i < randomIntBetween(0, 5); i++) { + Set fields = new HashSet<>(); + for (int j = 0; j < randomIntBetween(1, 4); j++) { + fields.add(randomAlphaOfLengthBetween(4, 10)); + } + fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); + } + + return fieldsForModels; + } + private IndexMetadataStats randomIndexStats(int numberOfShards) { IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards); int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 3f50b9fdf6621..27663edde945c 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -16,6 +16,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -35,6 +36,10 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -42,6 +47,10 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -421,6 +430,23 @@ public void testRuntimeFieldNameOutsideContext() { } } + public void testInferenceModelFieldType() { + MockFieldMapper f1 = new MockFieldMapper(new MockInferenceModelFieldType("foo1", "bar1")); + MockFieldMapper f2 = new MockFieldMapper(new MockInferenceModelFieldType("foo2", "bar1")); + MockFieldMapper f3 = new MockFieldMapper(new MockInferenceModelFieldType("foo3", "bar2")); + + FieldTypeLookup lookup = new FieldTypeLookup(List.of(f1, f2, f3), emptyList(), emptyList()); + assertEquals(f1.fieldType(), lookup.get("foo1")); + assertEquals(f2.fieldType(), lookup.get("foo2")); + assertEquals(f3.fieldType(), lookup.get("foo3")); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertEquals(2, fieldsForModels.size()); + assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1")); + assertEquals(Set.of("foo3"), fieldsForModels.get("bar2")); + } + private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index 0308dac5fa216..f512f5d352a43 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -121,6 +122,8 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); + assertNotNull(mappingLookup.getFieldsForModels()); + assertTrue(mappingLookup.getFieldsForModels().isEmpty()); } public void testValidateDoesNotShadow() { @@ -188,6 +191,22 @@ public MetricType getMetricType() { ); } + public void testFieldsForModels() { + MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); + MappingLookup mappingLookup = createMappingLookup( + Collections.singletonList(new MockFieldMapper(fieldType)), + emptyList(), + emptyList() + ); + assertEquals(1, size(mappingLookup.fieldMappers())); + assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); + + Map> fieldsForModels = mappingLookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertEquals(1, fieldsForModels.size()); + assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); + } + private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index d0b30bff92f3e..99fb21d652d93 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -539,7 +539,7 @@ public static MetadataRolloverService getMetadataRolloverService( AllocationService allocationService = mock(AllocationService.class); when(allocationService.reroute(any(ClusterState.class), any(String.class), any())).then(i -> i.getArguments()[0]); when(allocationService.getShardRoutingRoleStrategy()).thenReturn(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY); - MappingLookup mappingLookup = null; + MappingLookup mappingLookup = MappingLookup.EMPTY; if (dataStream != null) { RootObjectMapper.Builder root = new RootObjectMapper.Builder("_doc", ObjectMapper.Defaults.SUBOBJECTS); root.add( @@ -616,6 +616,7 @@ public static IndicesService mockIndicesServices(MappingLookup mappingLookup) th DocumentMapper documentMapper = mock(DocumentMapper.class); when(documentMapper.mapping()).thenReturn(mapping); when(documentMapper.mappingSource()).thenReturn(mapping.toCompressedXContent()); + when(documentMapper.mappers()).thenReturn(mappingLookup); RoutingFieldMapper routingFieldMapper = mock(RoutingFieldMapper.class); when(routingFieldMapper.required()).thenReturn(false); when(documentMapper.routingFieldMapper()).thenReturn(routingFieldMapper); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java new file mode 100644 index 0000000000000..854749d6308db --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.index.query.SearchExecutionContext; + +import java.util.Map; + +public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private static final String TYPE_NAME = "mock_inference_model_field_type"; + + private final String modelId; + + public MockInferenceModelFieldType(String name, String modelId) { + super(name, false, false, false, TextSearchInfo.NONE, Map.of()); + this.modelId = modelId; + } + + @Override + public String typeName() { + return TYPE_NAME; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException("termQuery not implemented"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.toString(name(), context, format); + } + + @Override + public String getInferenceModel() { + return modelId; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java new file mode 100644 index 0000000000000..47cae14003c70 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingClusterStateUpdateRequest; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class SemanticTextClusterMetadataTests extends MlSingleNodeTestCase { + public void testCreateIndexWithSemanticTextField() { + final IndexService indexService = createIndex( + "test", + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") + ); + assertEquals(Map.of("test_model", Set.of("field")), indexService.getMetadata().getFieldsForModels()); + } + + public void testAddSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { "properties": { "field": { "type": "semantic_text", "model_id": "test_model" }}}"""); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + assertEquals(Map.of("test_model", Set.of("field")), resultingState.metadata().index("test").getFieldsForModels()); + } + + private static List singleTask(PutMappingClusterStateUpdateRequest request) { + return Collections.singletonList(new MetadataMappingService.PutMappingClusterStateUpdateTask(request, ActionListener.running(() -> { + throw new AssertionError("task should not complete publication"); + }))); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 2d7832d747de4..5af3fd527e31e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; @@ -102,6 +103,11 @@ public Map> getTokeniz return mlPlugin.getTokenizers(); } + @Override + public Map getMappers() { + return mlPlugin.getMappers(); + } + /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. * The MachineLearningLicenseTests attempt to create a datafeed referencing this LocalStateMachineLearning object. From 64b4799671229fc0f059271feec518c44d8f3bb3 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 18 Jan 2024 10:47:01 -0500 Subject: [PATCH 02/29] semantic_text inference results indexing (#103978) Adds SemanticTextInferenceResultFieldMapper, which indexes inference results for semantic_text fields. --- .../test/nodes.stats/11_indices_metrics.yml | 12 +- .../xpack/ml/MachineLearning.java | 7 + .../ml/mapper/SemanticTextFieldMapper.java | 2 +- ...emanticTextInferenceResultFieldMapper.java | 306 ++++++++++ ...icTextInferenceResultFieldMapperTests.java | 561 ++++++++++++++++++ 5 files changed, 883 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml index b119a1a1d94f3..146f0e5c62bc9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml @@ -413,7 +413,7 @@ - match: { nodes.$node_id.indices.mappings.total_estimated_overhead_in_bytes: 0 } --- -"indices mappings exact count test for indices level": +"indices mappings count test for indices level": - skip: features: [arbitrary_key] @@ -468,7 +468,7 @@ - do: nodes.stats: { metric: _all, level: "indices", human: true } - # In the below assertions, we expect a field count of 26 because the above mapping expects the following: + # In the below assertions, we expect a field count of at least 26 because the above mapping expects the following: # Field mappers (incl. alias fields and object mappers' flattened leaves): # 1. _data_stream_timestamp # 2. _doc_count @@ -498,13 +498,17 @@ # 25. authors.name # Runtime field mappers: # 26. a_source_field + # + # Plugins (which may or may not be loaded depending on the context in which this test is executed) may add additional + # field mappers: + # 27. _semantic_text_inference (from ML plugin) - gte: { nodes.$node_id.indices.mappings.total_count: 26 } - is_true: nodes.$node_id.indices.mappings.total_estimated_overhead - gte: { nodes.$node_id.indices.mappings.total_estimated_overhead_in_bytes: 26624 } - - match: { nodes.$node_id.indices.indices.index1.mappings.total_count: 26 } + - gte: { nodes.$node_id.indices.indices.index1.mappings.total_count: 26 } - is_true: nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead - - match: { nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead_in_bytes: 26624 } + - gte: { nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead_in_bytes: 26624 } --- "indices mappings does not exist in shards level": diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 152d8fde8c86c..2452ecb3b02b9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,6 +49,7 @@ import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -365,6 +366,7 @@ import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -2301,4 +2303,9 @@ public Map getMappers() { } return Map.of(); } + + @Override + public Map getMetadataMappers() { + return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java index cf713546a071a..9546bc4ba9add 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using a different field mapper. + * be indexed using {@link SemanticTextInferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java new file mode 100644 index 0000000000000..ff224522034bf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java @@ -0,0 +1,306 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.BooleanFieldMapper; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.script.ScriptCompiler; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _semantic_text_inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_semantic_text_inference": {
+ *             "my_semantic_text_field": [
+ *                 {
+ *                     "sparse_embedding": {
+ *                         "is_truncated": false,
+ *                         "embedding": {
+ *                             "lucas": 0.05212344,
+ *                             "ty": 0.041213956,
+ *                             "dragon": 0.50991,
+ *                             "type": 0.23241979,
+ *                             "dr": 1.9312073,
+ *                             "##o": 0.2797593
+ *                         }
+ *                     },
+ *                     "text": "these are not the droids you're looking for"
+ *                 }
+ *             ]
+ *         }
+ *     }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_text_field": {
+ *                 "type": "nested",
+ *                 "properties": {
+ *                     "sparse_embedding": {
+ *                         "properties": {
+ *                             "embedding": {
+ *                                 "type": "sparse_vector"
+ *                             }
+ *                         }
+ *                     },
+ *                     "text": {
+ *                         "type": "text",
+ *                         "index": false
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { + public static final String CONTENT_TYPE = "_semantic_text_inference"; + public static final String NAME = "_semantic_text_inference"; + public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; + public static final String TEXT_SUBFIELD_NAME = "text"; + public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); + + private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( + List.of(), + Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME), + List.of(SPARSE_VECTOR_SUBFIELD_NAME), + Set.of(SparseEmbeddingResults.Embedding.EMBEDDING) + ); + + private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + private SemanticTextInferenceResultFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); + } + + parseInferenceResults(context); + } + + private static void parseInferenceResults(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + token); + } + + parseFieldInferenceResults(context, mapperBuilderContext); + } + } + + private static void parseFieldInferenceResults(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) + throws IOException { + + String fieldName = context.parser().currentName(); + Mapper mapper = context.getMapper(fieldName); + if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + throw new DocumentParsingException( + context.parser().getTokenLocation(), + Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + + parseFieldInferenceResultsArray(context, mapperBuilderContext, fieldName); + } + + private static void parseFieldInferenceResultsArray( + DocumentParserContext context, + MapperBuilderContext mapperBuilderContext, + String fieldName + ) throws IOException { + XContentParser parser = context.parser(); + NestedObjectMapper nestedObjectMapper = createNestedObjectMapper(context, mapperBuilderContext, fieldName); + + if (parser.nextToken() != XContentParser.Token.START_ARRAY) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_ARRAY, got " + parser.currentToken()); + } + + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); + parseFieldInferenceResultElement(nestedContext, nestedObjectMapper, new LinkedList<>()); + } + } + + private static void parseFieldInferenceResultElement( + DocumentParserContext context, + ObjectMapper objectMapper, + LinkedList subfieldPath + ) throws IOException { + XContentParser parser = context.parser(); + DocumentParserContext childContext = context.createChildContext(objectMapper); + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); + } + + Set visitedSubfields = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + parser.currentToken()); + } + + String currentName = parser.currentName(); + visitedSubfields.add(currentName); + + Mapper childMapper = objectMapper.getMapper(currentName); + if (childMapper == null) { + logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); + advancePastCurrentFieldName(parser); + continue; + } + + if (childMapper instanceof FieldMapper) { + parser.nextToken(); + ((FieldMapper) childMapper).parse(childContext); + } else if (childMapper instanceof ObjectMapper) { + parser.nextToken(); + subfieldPath.push(currentName); + parseFieldInferenceResultElement(childContext, (ObjectMapper) childMapper, subfieldPath); + subfieldPath.pop(); + } else { + // This should never happen, but fail parsing if it does so that it's not a silent failure + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) + ); + } + } + + Set requiredSubfields = REQUIRED_SUBFIELDS_MAP.get(subfieldPath); + if (requiredSubfields != null && visitedSubfields.containsAll(requiredSubfields) == false) { + Set missingSubfields = requiredSubfields.stream() + .filter(s -> visitedSubfields.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); + } + } + + private static NestedObjectMapper createNestedObjectMapper( + DocumentParserContext context, + MapperBuilderContext mapperBuilderContext, + String fieldName + ) { + IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); + ObjectMapper.Builder sparseVectorMapperBuilder = new ObjectMapper.Builder( + SPARSE_VECTOR_SUBFIELD_NAME, + ObjectMapper.Defaults.SUBOBJECTS + ).add( + new BooleanFieldMapper.Builder(SparseEmbeddingResults.Embedding.IS_TRUNCATED, ScriptCompiler.NONE, false, indexVersionCreated) + ).add(new SparseVectorFieldMapper.Builder(SparseEmbeddingResults.Embedding.EMBEDDING)); + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( + TEXT_SUBFIELD_NAME, + indexVersionCreated, + context.indexAnalyzers() + ).index(false).store(false); + + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( + fieldName, + context.indexSettings().getIndexVersionCreated() + ); + nestedBuilder.add(sparseVectorMapperBuilder).add(textMapperBuilder); + + return nestedBuilder.build(mapperBuilderContext); + } + + private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java new file mode 100644 index 0000000000000..bde6da7fe8277 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -0,0 +1,561 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.LuceneDocument; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MetadataMapperTestCase; +import org.elasticsearch.index.mapper.NestedLookup; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.LeafNestedDocuments; +import org.elasticsearch.search.NestedDocuments; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.containsString; + +public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { + private SemanticTextInferenceResults { + if (sparseEmbeddingResults.embeddings().size() != text.size()) { + throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); + } + } + } + + private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + + private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} + + @Override + protected String fieldName() { + return SemanticTextInferenceResultFieldMapper.NAME; + } + + @Override + protected boolean isConfigurable() { + return false; + } + + @Override + protected boolean isSupportedOn(IndexVersion version) { + return version.onOrAfter(IndexVersions.ES_VERSION_8_13); // TODO: Switch to ES_VERSION_8_14 when available + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException { + + } + + @Override + protected Collection getPlugins() { + return List.of(new MachineLearning(Settings.EMPTY)); + } + + public void testSuccessfulParse() throws IOException { + final String fieldName1 = randomAlphaOfLengthBetween(5, 15); + final String fieldName2 = randomAlphaOfLengthBetween(5, 15); + + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { + addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); + addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); + })); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), + generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + ) + ) + ) + ); + + Set visitedChildDocs = new HashSet<>(); + Set expectedVisitedChildDocs = Set.of( + new VisitedChildDocInfo(fieldName1, 2), + new VisitedChildDocInfo(fieldName1, 1), + new VisitedChildDocInfo(fieldName2, 3) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); + assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); + assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + assertEquals(expectedVisitedChildDocs, visitedChildDocs); + + MapperService nestedMapperService = createMapperService(mapping(b -> { + addInferenceResultsNestedMapping(b, fieldName1); + addInferenceResultsNestedMapping(b, fieldName2); + })); + withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + nestedMapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1, 0, null), + new SearchHit.NestedIdentity(fieldName1, 1, null), + new SearchHit.NestedIdentity(fieldName2, 0, null) + ); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } + + public void testMissingSubfields() throws IOException { + final String fieldName = randomAlphaOfLengthBetween(5, 15); + + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + + { + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + new SparseVectorSubfieldOptions(false, true, true), + true, + null + ) + ) + ) + ); + assertThat( + ex.getMessage(), + containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + "]") + ); + } + { + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + new SparseVectorSubfieldOptions(true, true, true), + false, + null + ) + ) + ) + ); + assertThat( + ex.getMessage(), + containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME + "]") + ); + } + { + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + new SparseVectorSubfieldOptions(false, true, true), + false, + null + ) + ) + ) + ); + assertThat( + ex.getMessage(), + containsString( + "Missing required subfields: [" + + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + ", " + + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME + + "]" + ) + ); + } + { + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + new SparseVectorSubfieldOptions(true, false, false), + false, + null + ) + ) + ) + ); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SparseEmbeddingResults.Embedding.EMBEDDING + "]")); + } + } + + public void testExtraSubfields() throws IOException { + final String fieldName = randomAlphaOfLengthBetween(5, 15); + final List semanticTextInferenceResultsList = List.of( + generateSemanticTextinferenceResults(fieldName, List.of("a b")) + ); + + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + + Consumer checkParsedDocument = d -> { + Set visitedChildDocs = new HashSet<>(); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + + List luceneDocs = d.docs(); + assertEquals(2, luceneDocs.size()); + assertValidChildDoc(luceneDocs.get(0), d.rootDoc(), visitedChildDocs); + assertEquals(d.rootDoc(), luceneDocs.get(1)); + assertNull(luceneDocs.get(1).getParent()); + assertEquals(expectedVisitedChildDocs, visitedChildDocs); + }; + + { + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + semanticTextInferenceResultsList, + new SparseVectorSubfieldOptions(true, true, true), + true, + Map.of("extra_key", "extra_value") + ) + ) + ); + + checkParsedDocument.accept(doc); + LuceneDocument childDoc = doc.docs().get(0); + assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); + } + { + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + semanticTextInferenceResultsList, + new SparseVectorSubfieldOptions(true, true, true), + true, + Map.of("extra_key", Map.of("k1", "v1")) + ) + ) + ); + + checkParsedDocument.accept(doc); + LuceneDocument childDoc = doc.docs().get(0); + assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); + } + { + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + semanticTextInferenceResultsList, + new SparseVectorSubfieldOptions(true, true, true), + true, + Map.of("extra_key", List.of("v1")) + ) + ) + ); + + checkParsedDocument.accept(doc); + LuceneDocument childDoc = doc.docs().get(0); + assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); + } + { + Map extraSubfields = new HashMap<>(); + extraSubfields.put("extra_key", null); + + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + semanticTextInferenceResultsList, + new SparseVectorSubfieldOptions(true, true, true), + true, + extraSubfields + ) + ) + ); + + checkParsedDocument.accept(doc); + LuceneDocument childDoc = doc.docs().get(0); + assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); + } + } + + public void testMissingSemanticTextMapping() throws IOException { + final String fieldName = randomAlphaOfLengthBetween(5, 15); + + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> {})); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + ) + ); + assertThat( + ex.getMessage(), + containsString( + Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ) + ); + } + + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + mappingBuilder.startObject(fieldName); + mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mappingBuilder.field("model_id", modelId); + mappingBuilder.endObject(); + } + + private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { + List embeddings = new ArrayList<>(chunks.size()); + for (String chunk : chunks) { + String[] tokens = chunk.split("\\s+"); + List weightedTokens = Arrays.stream(tokens) + .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) + .toList(); + + embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + } + + return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + } + + private static void addSemanticTextInferenceResults( + XContentBuilder sourceBuilder, + List semanticTextInferenceResults + ) throws IOException { + addSemanticTextInferenceResults( + sourceBuilder, + semanticTextInferenceResults, + new SparseVectorSubfieldOptions(true, true, true), + true, + null + ); + } + + private static void addSemanticTextInferenceResults( + XContentBuilder sourceBuilder, + List semanticTextInferenceResults, + SparseVectorSubfieldOptions sparseVectorSubfieldOptions, + boolean includeTextSubfield, + Map extraSubfields + ) throws IOException { + + Map>> inferenceResultsMap = new HashMap<>(); + for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { + List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); + + Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() + .embeddings() + .iterator(); + Iterator textIterator = semanticTextInferenceResult.text().iterator(); + while (embeddingsIterator.hasNext() && textIterator.hasNext()) { + SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); + String text = textIterator.next(); + + Map subfieldMap = new HashMap<>(); + if (sparseVectorSubfieldOptions.include()) { + Map embeddingMap = embedding.asMap(); + if (sparseVectorSubfieldOptions.includeIsTruncated() == false) { + embeddingMap.remove(SparseEmbeddingResults.Embedding.IS_TRUNCATED); + } + if (sparseVectorSubfieldOptions.includeEmbedding() == false) { + embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING); + } + subfieldMap.put(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); + } + if (includeTextSubfield) { + subfieldMap.put(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME, text); + } + if (extraSubfields != null) { + subfieldMap.putAll(extraSubfields); + } + + parsedInferenceResults.add(subfieldMap); + } + + inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), parsedInferenceResults); + } + + sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + } + + private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { + mappingBuilder.startObject(semanticTextFieldName); + mappingBuilder.field("type", "nested"); + mappingBuilder.startObject("properties"); + mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME); + mappingBuilder.startObject("properties"); + mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING); + mappingBuilder.field("type", "sparse_vector"); + mappingBuilder.endObject(); + mappingBuilder.endObject(); + mappingBuilder.endObject(); + mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME); + mappingBuilder.field("type", "text"); + mappingBuilder.field("index", false); + mappingBuilder.endObject(); + mappingBuilder.endObject(); + mappingBuilder.endObject(); + } + + private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { + NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); + assertNotNull(mapper); + + BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (String token : tokens) { + queryBuilder.add( + new BooleanClause( + new TermQuery( + new Term( + path + + "." + + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + "." + + SparseEmbeddingResults.Embedding.EMBEDDING, + token + ) + ), + BooleanClause.Occur.MUST + ) + ); + } + queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); + + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + } + + private static void assertValidChildDoc( + LuceneDocument childDoc, + LuceneDocument expectedParent, + Set visitedChildDocs + ) { + assertEquals(expectedParent, childDoc.getParent()); + visitedChildDocs.add( + new VisitedChildDocInfo( + childDoc.getPath(), + childDoc.getFields( + childDoc.getPath() + + "." + + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + "." + + SparseEmbeddingResults.Embedding.EMBEDDING + ).size() + ) + ); + } + + private static void assertChildLeafNestedDocument( + LeafNestedDocuments leaf, + int advanceToDoc, + int expectedRootDoc, + Set visitedNestedIdentities + ) throws IOException { + + assertNotNull(leaf.advance(advanceToDoc)); + assertEquals(advanceToDoc, leaf.doc()); + assertEquals(expectedRootDoc, leaf.rootDoc()); + assertNotNull(leaf.nestedIdentity()); + visitedNestedIdentities.add(leaf.nestedIdentity()); + } +} From e3b6a657f43e8823b607ad299eae50e46c7a14f3 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Tue, 6 Feb 2024 16:56:46 +0100 Subject: [PATCH 03/29] Move semantic_text field mappers to inference plugin (#105187) --- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferencePlugin.java | 27 ++++++++++++++++++- .../xpack/inference}/SemanticTextFeature.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- ...emanticTextInferenceResultFieldMapper.java | 2 +- .../SemanticTextClusterMetadataTests.java | 13 +++++++-- .../mapper/SemanticTextFieldMapperTests.java | 6 ++--- ...icTextInferenceResultFieldMapperTests.java | 8 +++--- .../xpack/ml/MachineLearning.java | 21 +-------------- .../xpack/ml/LocalStateMachineLearning.java | 6 ----- 10 files changed, 49 insertions(+), 39 deletions(-) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/SemanticTextFeature.java (93%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextFieldMapper.java (98%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextInferenceResultFieldMapper.java (99%) rename x-pack/plugin/{ml => inference}/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java (87%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml => inference/src/test/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextFieldMapperTests.java (96%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml => inference/src/test/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextInferenceResultFieldMapperTests.java (98%) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 2d25a48117778..ddd56c758d67c 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -17,6 +17,7 @@ requires org.apache.httpcomponents.httpasyncclient; requires org.apache.httpcomponents.httpcore.nio; requires org.apache.lucene.core; + requires org.elasticsearch.logging; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 4e44929e7ba9b..a83f8bb5f9b5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,6 +21,8 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -29,6 +31,7 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.InferenceRegistryPlugin; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -52,6 +55,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -67,12 +72,19 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceRegistryPlugin { +public class InferencePlugin extends Plugin + implements + ActionPlugin, + ExtensiblePlugin, + SystemIndexPlugin, + InferenceRegistryPlugin, + MapperPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -254,4 +266,17 @@ public InferenceServiceRegistry getInferenceServiceRegistry() { public ModelRegistry getModelRegistry() { return modelRegistry.get(); } + + @Override + public Map getMappers() { + if (SemanticTextFeature.isEnabled()) { + return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + } + return Map.of(); + } + + @Override + public Map getMetadataMappers() { + return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java similarity index 93% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java index f861760803e56..4f2c5c564bcb8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml; +package org.elasticsearch.xpack.inference; import org.elasticsearch.common.util.FeatureFlag; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 9546bc4ba9add..027b85a9a9f45 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java similarity index 99% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index ff224522034bf..5dda6ae3781ab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java similarity index 87% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 47cae14003c70..69fa64ffa6d1c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -13,14 +13,23 @@ import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexService; -import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; -public class SemanticTextClusterMetadataTests extends MlSingleNodeTestCase { +public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return List.of(InferencePlugin.class); + } + public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index ccb8f106e4945..a3a705c9cc902 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.index.IndexableField; import org.elasticsearch.common.Strings; @@ -18,7 +18,7 @@ import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -74,7 +74,7 @@ public void testUpdatesToModelIdNotSupported() throws IOException { @Override protected Collection getPlugins() { - return singletonList(new MachineLearning(Settings.EMPTY)); + return singletonList(new InferencePlugin(Settings.EMPTY)); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java similarity index 98% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index bde6da7fe8277..7f13d34986482 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; @@ -37,7 +37,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; import java.util.ArrayList; @@ -78,7 +78,7 @@ protected boolean isConfigurable() { @Override protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_13); // TODO: Switch to ES_VERSION_8_14 when available + return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available } @Override @@ -88,7 +88,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { @Override protected Collection getPlugins() { - return List.of(new MachineLearning(Settings.EMPTY)); + return List.of(new InferencePlugin(Settings.EMPTY)); } public void testSuccessfulParse() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 10b2ed089d632..70a3b9bab49f1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,8 +49,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -70,7 +68,6 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.IngestPlugin; -import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Platforms; import org.elasticsearch.plugins.Plugin; @@ -366,8 +363,6 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; -import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -487,8 +482,7 @@ public class MachineLearning extends Plugin PersistentTaskPlugin, SearchPlugin, ShutdownAwarePlugin, - ExtensiblePlugin, - MapperPlugin { + ExtensiblePlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; // Endpoints that were deprecated in 7.x can still be called in 8.x using the REST compatibility layer @@ -2298,17 +2292,4 @@ public void signalShutdown(Collection shutdownNodeIds) { mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds); } } - - @Override - public Map getMappers() { - if (SemanticTextFeature.isEnabled()) { - return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); - } - return Map.of(); - } - - @Override - public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); - } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 5af3fd527e31e..2d7832d747de4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; @@ -103,11 +102,6 @@ public Map> getTokeniz return mlPlugin.getTokenizers(); } - @Override - public Map getMappers() { - return mlPlugin.getMappers(); - } - /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. * The MachineLearningLicenseTests attempt to create a datafeed referencing this LocalStateMachineLearning object. From ca65a702344566bb223c82e1c21bc81daed9254d Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:38:08 +0100 Subject: [PATCH 04/29] semantic_text - Field inference (#103697) --- .../action/bulk/BulkOperation.java | 114 +++- .../BulkShardRequestInferenceProvider.java | 324 +++++++++ .../action/bulk/TransportBulkAction.java | 36 +- .../bulk/TransportSimulateBulkAction.java | 4 +- .../action/bulk/BulkOperationTests.java | 642 ++++++++++++++++++ ...ActionIndicesThatCannotBeCreatedTests.java | 8 +- .../bulk/TransportBulkActionIngestTests.java | 8 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 16 +- .../snapshots/SnapshotResiliencyTests.java | 5 +- x-pack/plugin/inference/build.gradle | 12 + ...emanticTextInferenceResultFieldMapper.java | 30 +- ...icTextInferenceResultFieldMapperTests.java | 61 +- .../xpack/inference/InferenceRestIT.java | 41 ++ .../inference/10_semantic_text_inference.yml | 233 +++++++ 15 files changed, 1422 insertions(+), 116 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java create mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 1d95f430d5c7e..2b84ec8746cd2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -35,6 +36,8 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -44,6 +47,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -69,6 +73,8 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -82,6 +88,8 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -97,6 +105,8 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -189,7 +199,30 @@ private void executeBulkRequestsByShard(Map> requ return; } - String nodeId = clusterService.localNode().getId(); + BulkShardRequestInferenceProvider.getInstance( + inferenceServiceRegistry, + modelRegistry, + clusterState, + requestsByShard.keySet(), + new ActionListener() { + @Override + public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { + processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); + } + + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error loading inference models", e); + } + } + ); + } + + void processRequestsByShards( + Map> requestsByShard, + ClusterState clusterState, + BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider + ) { Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -197,29 +230,68 @@ private void executeBulkRequestsByShard(Map> requ // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; - try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); + BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); + + Releasable ref = bulkItemRequestCompleteRefCount.acquire(); + final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); + bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { + @Override + public void onResponse(BulkShardRequest inferenceBulkShardRequest) { + executeBulkShardRequest( + inferenceBulkShardRequest, + ActionListener.releaseAfter(ActionListener.noop(), ref), + bulkItemFailedListener + ); + } - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); - } - executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error performing inference", e); + } + }, bulkItemFailedListener); } } } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { + private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } + return bulkShardRequest; + } + + // When an item fails, store the failure in the responses array + private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { + final String indexName = itemRequest.index(); + + DocWriteRequest docWriteRequest = itemRequest.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); + } + + private void executeBulkShardRequest( + BulkShardRequest bulkShardRequest, + ActionListener listener, + BiConsumer bulkItemErrorListener + ) { + if (bulkShardRequest.items().length == 0) { + // No requests to execute due to previous errors, terminate early + listener.onResponse(bulkShardRequest); + return; + } + client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -230,19 +302,17 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - releaseOnFinish.close(); + listener.onResponse(bulkShardRequest); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - for (BulkItemRequest request : bulkShardRequest.items()) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + BulkItemRequest[] items = bulkShardRequest.items(); + for (BulkItemRequest item : items) { + bulkItemErrorListener.accept(item, e); } - releaseOnFinish.close(); + listener.onFailure(e); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java new file mode 100644 index 0000000000000..02f905f7cd87a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -0,0 +1,324 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.TriConsumer; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +/** + * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. + */ +public class BulkShardRequestInferenceProvider { + + // Root field name for storing inference results + public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; + + // Contains the original text for the field + public static final String TEXT_SUBFIELD_NAME = "text"; + + // Contains the inference result when it's a sparse vector + public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; + + private final ClusterState clusterState; + private final Map inferenceProvidersMap; + + private record InferenceProvider(Model model, InferenceService service) { + private InferenceProvider { + Objects.requireNonNull(model); + Objects.requireNonNull(service); + } + } + + BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { + this.clusterState = clusterState; + this.inferenceProvidersMap = inferenceProvidersMap; + } + + public static void getInstance( + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry, + ClusterState clusterState, + Set shardIds, + ActionListener listener + ) { + Set inferenceIds = new HashSet<>(); + shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { + var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); + inferenceIds.addAll(fieldsForModels.keySet()); + }); + final Map inferenceProviderMap = new ConcurrentHashMap<>(); + Runnable onModelLoadingComplete = () -> listener.onResponse( + new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) + ); + try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { + for (var inferenceId : inferenceIds) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + InferenceProvider inferenceProvider = new InferenceProvider( + service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()), + service.get() + ); + inferenceProviderMap.put(inferenceId, inferenceProvider); + } + } + + @Override + public void onFailure(Exception e) { + // Failure on loading a model should not prevent the rest from being loaded and used. + // When the model is actually retrieved via the inference ID in the inference process, it will fail + // and the user will get the details on the inference failure. + } + }; + + modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); + } + } + } + + /** + * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from + * the original request will be modified with the inference results, to avoid copying the entire requests from + * the original bulk request. + * + * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. + * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, + * which may have fewer items than the original because of inference failures + * @param onBulkItemFailure invoked when a bulk item fails inference + */ + public void processBulkShardRequest( + BulkShardRequest bulkShardRequest, + ActionListener listener, + BiConsumer onBulkItemFailure + ) { + + Map> fieldsForModels = clusterState.metadata() + .index(bulkShardRequest.shardId().getIndex()) + .getFieldsForModels(); + // No inference fields? Terminate early + if (fieldsForModels.isEmpty()) { + listener.onResponse(bulkShardRequest); + return; + } + + Set failedItems = Collections.synchronizedSet(new HashSet<>()); + Runnable onInferenceComplete = () -> { + if (failedItems.isEmpty()) { + listener.onResponse(bulkShardRequest); + return; + } + // Remove failed items from the original bulk shard request + BulkItemRequest[] originalItems = bulkShardRequest.items(); + BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; + for (int i = 0, j = 0; i < originalItems.length; i++) { + if (failedItems.contains(i) == false) { + newItems[j++] = originalItems[i]; + } + } + BulkShardRequest newBulkShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + newItems + ); + listener.onResponse(newBulkShardRequest); + }; + TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { + failedItems.add(i); + onBulkItemFailure.accept(bulkItemRequest, e); + }; + try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { + BulkItemRequest[] items = bulkShardRequest.items(); + for (int i = 0; i < items.length; i++) { + BulkItemRequest bulkItemRequest = items[i]; + // Bulk item might be null because of previous errors, skip in that case + if (bulkItemRequest != null) { + performInferenceOnBulkItemRequest( + bulkItemRequest, + fieldsForModels, + i, + onBulkItemFailureWithIndex, + bulkItemReqRef.acquire() + ); + } + } + } + } + + @SuppressWarnings("unchecked") + private void performInferenceOnBulkItemRequest( + BulkItemRequest bulkItemRequest, + Map> fieldsForModels, + Integer itemIndex, + TriConsumer onBulkItemFailure, + Releasable releaseOnFinish + ) { + + DocWriteRequest docWriteRequest = bulkItemRequest.request(); + Map sourceMap = null; + if (docWriteRequest instanceof IndexRequest indexRequest) { + sourceMap = indexRequest.sourceAsMap(); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); + } + if (sourceMap == null || sourceMap.isEmpty()) { + releaseOnFinish.close(); + return; + } + final Map docMap = new ConcurrentHashMap<>(sourceMap); + + // When a document completes processing, update the source with the inference + try (var docRef = new RefCountingRunnable(() -> { + if (docWriteRequest instanceof IndexRequest indexRequest) { + indexRequest.source(docMap); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + if (updateRequest.docAsUpsert()) { + updateRequest.upsertRequest().source(docMap); + } else { + updateRequest.doc().source(docMap); + } + } + releaseOnFinish.close(); + })) { + + Map rootInferenceFieldMap; + try { + rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + ROOT_INFERENCE_FIELD, + k -> new HashMap() + ); + } catch (ClassCastException e) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") + ); + return; + } + + for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); + List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); + if (inferenceFieldNames.isEmpty()) { + continue; + } + + InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); + if (inferenceProvider == null) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException("No inference provider found for model ID " + modelId) + ); + return; + } + ActionListener inferenceResultsListener = new ActionListener<>() { + @Override + public void onResponse(InferenceServiceResults results) { + if (results == null) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException( + "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ) + ); + } + + int i = 0; + for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { + String fieldName = inferenceFieldNames.get(i++); + List> inferenceFieldResultList; + try { + inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new ArrayList<>() + ); + } catch (ClassCastException e) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException( + "Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object" + ) + ); + return; + } + // Remove previous inference results if any + inferenceFieldResultList.clear(); + + // TODO Check inference result type to change subfield name + var inferenceFieldMap = Map.of( + SPARSE_VECTOR_SUBFIELD_NAME, + inferenceResults.asMap("output").get("output"), + TEXT_SUBFIELD_NAME, + docMap.get(fieldName) + ); + inferenceFieldResultList.add(inferenceFieldMap); + } + } + + @Override + public void onFailure(Exception e) { + onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); + } + }; + inferenceProvider.service() + .infer( + inferenceProvider.model, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + // TODO check for additional settings needed + Map.of(), + InputType.INGEST, + ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) + ); + } + } + } + + private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { + List inferenceFieldNames = new ArrayList<>(); + for (String inferenceField : fieldModelsEntrySet.getValue()) { + Object fieldValue = docMap.get(inferenceField); + + // Perform inference on string, non-null values + if (fieldValue instanceof String) { + inferenceFieldNames.add(inferenceField); + } + } + return inferenceFieldNames; + } +} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 32566b559410d..86fc511251553 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,6 +57,8 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -98,6 +100,8 @@ public class TransportBulkAction extends HandledTransportAction responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -614,10 +630,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { new BulkOperation( task, @@ -631,6 +647,8 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, + modelRegistry, + inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index f65d0f462fde6..c8dc3e7b7ffd5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -58,7 +58,9 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime + System::nanoTime, + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java new file mode 100644 index 0000000000000..f8ed331d358b2 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -0,0 +1,642 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexAbstraction; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.service.ClusterApplierService; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class BulkOperationTests extends ESTestCase { + + private static final String INDEX_NAME = "test-index"; + private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; + private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; + private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; + private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; + private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; + private static final String SERVICE_1_ID = "elser_v2"; + private static final String SERVICE_2_ID = "e5"; + private static final String INFERENCE_FAILED_MSG = "Inference failed"; + private static TestThreadPool threadPool; + + public void testNoInference() { + + Map> fieldsForModels = Map.of(); + ModelRegistry modelRegistry = createModelRegistry( + Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) + ); + + Model model1 = mock(Model.class); + InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); + Model model2 = mock(Model.class); + InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( + Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) + ); + + Map originalSource = Map.of( + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + true, + bulkOperationListener + ); + verify(bulkOperationListener).onResponse(any()); + + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + + Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); + // Original doc source is preserved + originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); + + // Check inference not invoked + verifyNoMoreInteractions(modelRegistry); + verifyNoMoreInteractions(inferenceServiceRegistry); + } + + public void testFailedBulkShardRequest() { + + Map> fieldsForModels = Map.of(); + ModelRegistry modelRegistry = createModelRegistry(Map.of()); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); + + Map originalSource = Map.of( + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + + runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener, + true, + request -> new BulkShardResponse( + request.shardId(), + new BulkItemResponse[] { + BulkItemResponse.failure( + 0, + DocWriteRequest.OpType.INDEX, + new BulkItemResponse.Failure( + INDEX_NAME, + randomIdentifier(), + new IllegalArgumentException("Error on bulk shard request") + ) + ) } + ) + ); + verify(bulkOperationListener).onResponse(any()); + + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse[] items = bulkResponse.getItems(); + assertTrue(items[0].isFailed()); + } + + @SuppressWarnings("unchecked") + public void testInference() { + + Map> fieldsForModels = Map.of( + INFERENCE_SERVICE_1_ID, + Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), + INFERENCE_SERVICE_2_ID, + Set.of(INFERENCE_FIELD_SERVICE_2) + ); + + ModelRegistry modelRegistry = createModelRegistry( + Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) + ); + + Model model1 = mock(Model.class); + InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); + Model model2 = mock(Model.class); + InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( + Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) + ); + + String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); + String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); + String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + firstInferenceTextService1, + SECOND_INFERENCE_FIELD_SERVICE_1, + secondInferenceTextService1, + INFERENCE_FIELD_SERVICE_2, + inferenceTextService2, + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + true, + bulkOperationListener + ); + verify(bulkOperationListener).onResponse(any()); + + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + + Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); + // Original doc source is preserved + originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); + + // Check inference results + verifyInferenceServiceInvoked( + modelRegistry, + INFERENCE_SERVICE_1_ID, + inferenceService1, + model1, + List.of(firstInferenceTextService1, secondInferenceTextService1) + ); + verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); + checkInferenceResults( + originalSource, + writtenDocSource, + FIRST_INFERENCE_FIELD_SERVICE_1, + SECOND_INFERENCE_FIELD_SERVICE_1, + INFERENCE_FIELD_SERVICE_2 + ); + } + + public void testFailedInference() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + firstInferenceTextService1, + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); + + verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); + + } + + public void testInferenceFailsForIncorrectRootObject() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + randomAlphaOfLengthBetween(1, 100), + ROOT_INFERENCE_FIELD, + "incorrect_root_object" + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); + } + + public void testInferenceFailsForIncorrectInferenceFieldObject() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + randomAlphaOfLengthBetween(1, 100), + ROOT_INFERENCE_FIELD, + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value") + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat( + item.getFailure().getCause().getMessage(), + containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object") + ); + } + + public void testInferenceIdNotFound() { + + Map> fieldsForModels = Map.of( + INFERENCE_SERVICE_1_ID, + Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), + INFERENCE_SERVICE_2_ID, + Set.of(INFERENCE_FIELD_SERVICE_2) + ); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + Map originalSource = Map.of( + INFERENCE_FIELD_SERVICE_2, + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat( + item.getFailure().getCause().getMessage(), + equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) + ); + } + + @SuppressWarnings("unchecked") + private static void checkInferenceResults( + Map docSource, + Map writtenDocSource, + String... inferenceFieldNames + ) { + + Map inferenceRootResultField = (Map) writtenDocSource.get( + BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD + ); + + for (String inferenceFieldName : inferenceFieldNames) { + List> inferenceService1FieldResults = (List>) inferenceRootResultField.get( + inferenceFieldName + ); + assertNotNull(inferenceService1FieldResults); + assertThat(inferenceService1FieldResults.size(), equalTo(1)); + Map inferenceResultElement = inferenceService1FieldResults.get(0); + assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); + assertThat( + inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), + equalTo(docSource.get(inferenceFieldName)) + ); + } + } + + private static void verifyInferenceServiceInvoked( + ModelRegistry modelRegistry, + String inferenceService1Id, + InferenceService inferenceService, + Model model, + Collection inferenceTexts + ) { + verify(modelRegistry).getModel(eq(inferenceService1Id), any()); + verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); + verifyNoMoreInteractions(inferenceService); + } + + private static ArgumentMatcher> containsInAnyOrder(Collection expected) { + return new ArgumentMatcher<>() { + @Override + public boolean matches(List argument) { + return argument.containsAll(expected) && argument.size() == expected.size(); + } + + @Override + public String toString() { + return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; + } + }; + } + + private static BulkShardRequest runBulkOperation( + Map docSource, + Map> fieldsForModels, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry, + boolean expectTransportShardBulkActionToExecute, + ActionListener bulkOperationListener + ) { + return runBulkOperation( + docSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener, + expectTransportShardBulkActionToExecute, + successfulBulkShardResponse + ); + } + + private static BulkShardRequest runBulkOperation( + Map docSource, + Map> fieldsForModels, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry, + ActionListener bulkOperationListener, + boolean expectTransportShardBulkActionToExecute, + Function bulkShardResponseSupplier + ) { + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); + IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) + .fieldsForModels(fieldsForModels) + .settings(settings) + .numberOfShards(1) + .numberOfReplicas(0) + .build(); + ClusterService clusterService = createClusterService(indexMetadata); + + IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); + when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); + + BulkRequest bulkRequest = new BulkRequest(); + bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); + + NodeClient client = mock(NodeClient.class); + + ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); + doAnswer(invocation -> { + BulkShardRequest request = invocation.getArgument(1); + ActionListener bulkShardResponseListener = invocation.getArgument(2); + bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); + return null; + }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); + + Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); + BulkOperation bulkOperation = new BulkOperation( + task, + threadPool, + ThreadPool.Names.WRITE, + clusterService, + bulkRequest, + client, + new AtomicArray<>(bulkRequest.requests.size()), + new HashMap<>(), + indexResolver, + () -> System.nanoTime(), + System.nanoTime(), + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + + bulkOperation.doRun(); + if (expectTransportShardBulkActionToExecute) { + verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + return bulkShardRequestCaptor.getValue(); + } + + return null; + } + + private static final Function successfulBulkShardResponse = (request) -> { + return new BulkShardResponse( + request.shardId(), + Arrays.stream(request.items()) + .filter(Objects::nonNull) + .map( + item -> BulkItemResponse.success( + item.id(), + DocWriteRequest.OpType.INDEX, + new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) + ) + ) + .toArray(BulkItemResponse[]::new) + ); + }; + + private static InferenceService createInferenceService(Model model, String inferenceServiceId) { + InferenceService inferenceService = mock(InferenceService.class); + when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); + List texts = invocation.getArgument(1); + List inferenceResults = new ArrayList<>(); + for (int i = 0; i < texts.size(); i++) { + inferenceResults.add(createInferenceResults()); + } + doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); + + listener.onResponse(inferenceServiceResults); + return null; + }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); + return inferenceService; + } + + private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) { + InferenceService inferenceService = mock(InferenceService.class); + when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); + return null; + }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); + return inferenceService; + } + + private static InferenceResults createInferenceResults() { + InferenceResults inferenceResults = mock(InferenceResults.class); + when(inferenceResults.asMap(any())).then( + invocation -> Map.of( + (String) invocation.getArguments()[0], + Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) + ) + ); + return inferenceResults; + } + + private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); + return inferenceServiceRegistry; + } + + private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + // Fails for unknown inference ids + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IllegalArgumentException("Model not found")); + return null; + }).when(modelRegistry).getModel(any(), any()); + inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { + ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceId, + emptyMap(), + emptyMap() + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModel(eq(inferenceId), any()); + }); + + return modelRegistry; + } + + private static ClusterService createClusterService(IndexMetadata indexMetadata) { + Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); + + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); + + ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); + when(clusterService.state()).thenReturn(clusterState); + + ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); + when(clusterApplierService.state()).thenReturn(clusterState); + when(clusterApplierService.threadPool()).thenReturn(threadPool); + when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); + return clusterService; + } + + @BeforeClass + public static void createThreadPool() { + threadPool = new TestThreadPool(getTestClass().getName()); + } + + @AfterClass + public static void stopThreadPool() { + if (threadPool != null) { + threadPool.shutdownNow(); + threadPool = null; + } + } + +} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 3057b00553a22..988a92352649a 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -129,17 +129,19 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null, + null ) { @Override void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 6815d634292a4..2d6492e4e73a4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -148,7 +148,9 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null, + null ); } @@ -157,10 +159,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 1a16d9083df55..ad522e36f9bd9 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -98,7 +98,9 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index cb9bdd1f3a827..a2e54a1c7c3b8 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); + super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); + super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); } }; } @@ -253,7 +253,9 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider + relativeTimeProvider, + null, + null ); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 0d38cdfcafd2b..2b05914a62879 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1941,13 +1941,16 @@ protected void assertSnapshotOrGenericThread() { client, null, () -> DocumentParsingObserver.EMPTY_INSTANCE + ), mockFeatureService, client, actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null, + null ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index e4f4de0027073..781261c330e78 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -6,6 +6,13 @@ */ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' + +restResources { + restApi { + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' + } +} esplugin { name 'x-pack-inference' @@ -24,4 +31,9 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(':modules:reindex') + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') +} + +tasks.named('yamlRestTest') { + usesDefaultDistribution() } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 5dda6ae3781ab..9e6c1eb0a6586 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -42,6 +42,9 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; + /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -58,15 +61,12 @@ * "my_semantic_text_field": [ * { * "sparse_embedding": { - * "is_truncated": false, - * "embedding": { - * "lucas": 0.05212344, - * "ty": 0.041213956, - * "dragon": 0.50991, - * "type": 0.23241979, - * "dr": 1.9312073, - * "##o": 0.2797593 - * } + * "lucas": 0.05212344, + * "ty": 0.041213956, + * "dragon": 0.50991, + * "type": 0.23241979, + * "dr": 1.9312073, + * "##o": 0.2797593 * }, * "text": "these are not the droids you're looking for" * } @@ -87,11 +87,7 @@ * "type": "nested", * "properties": { * "sparse_embedding": { - * "properties": { - * "embedding": { - * "type": "sparse_vector" - * } - * } + * "type": "sparse_vector" * }, * "text": { * "type": "text", @@ -107,15 +103,11 @@ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { public static final String CONTENT_TYPE = "_semantic_text_inference"; public static final String NAME = "_semantic_text_inference"; - public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; - public static final String TEXT_SUBFIELD_NAME = "text"; public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( List.of(), - Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME), - List.of(SPARSE_VECTOR_SUBFIELD_NAME), - Set.of(SparseEmbeddingResults.Embedding.EMBEDDING) + Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME) ); private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 7f13d34986482..aa2ad72941e0e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -51,6 +51,8 @@ import java.util.Set; import java.util.function.Consumer; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { @@ -212,10 +214,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + "]") - ); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]")); } { DocumentParsingException ex = expectThrows( @@ -233,10 +232,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME + "]") - ); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]")); } { DocumentParsingException ex = expectThrows( @@ -256,32 +252,8 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString( - "Missing required subfields: [" - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME - + ", " - + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME - + "]" - ) - ); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), - new SparseVectorSubfieldOptions(true, false, false), - false, - null - ) - ) - ) + containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + ", " + TEXT_SUBFIELD_NAME + "]") ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SparseEmbeddingResults.Embedding.EMBEDDING + "]")); } } @@ -460,10 +432,10 @@ private static void addSemanticTextInferenceResults( if (sparseVectorSubfieldOptions.includeEmbedding() == false) { embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING); } - subfieldMap.put(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); + subfieldMap.put(SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); } if (includeTextSubfield) { - subfieldMap.put(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME, text); + subfieldMap.put(TEXT_SUBFIELD_NAME, text); } if (extraSubfields != null) { subfieldMap.putAll(extraSubfields); @@ -482,14 +454,14 @@ private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuil mappingBuilder.startObject(semanticTextFieldName); mappingBuilder.field("type", "nested"); mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME); + mappingBuilder.startObject(SPARSE_VECTOR_SUBFIELD_NAME); mappingBuilder.startObject("properties"); mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING); mappingBuilder.field("type", "sparse_vector"); mappingBuilder.endObject(); mappingBuilder.endObject(); mappingBuilder.endObject(); - mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME); + mappingBuilder.startObject(TEXT_SUBFIELD_NAME); mappingBuilder.field("type", "text"); mappingBuilder.field("index", false); mappingBuilder.endObject(); @@ -507,14 +479,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook queryBuilder.add( new BooleanClause( new TermQuery( - new Term( - path - + "." - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME - + "." - + SparseEmbeddingResults.Embedding.EMBEDDING, - token - ) + new Term(path + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token) ), BooleanClause.Occur.MUST ) @@ -535,11 +500,7 @@ private static void assertValidChildDoc( new VisitedChildDocInfo( childDoc.getPath(), childDoc.getFields( - childDoc.getPath() - + "." - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME - + "." - + SparseEmbeddingResults.Embedding.EMBEDDING + childDoc.getPath() + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING ).size() ) ); diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java new file mode 100644 index 0000000000000..933e696d29d83 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +public class InferenceRestIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") + .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .distribution(DistributionType.DEFAULT) + .build(); + + public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(); + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml new file mode 100644 index 0000000000000..0e1b33252153b --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -0,0 +1,233 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: test-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: test-inference-id + another_inference_field: + type: semantic_text + model_id: test-inference-id + non_inference_field: + type: text + +--- +"Calculates embeddings for new documents": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding + - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + +--- +"Updating non semantic_text fields does not recalculate embeddings": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + + - do: + update: + index: test-index + id: doc_1 + body: + doc: + non_inference_field: "another non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "another non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + +--- +"Updating semantic_text fields recalculates embeddings": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - do: + update: + index: test-index + id: doc_1 + body: + doc: + inference_field: "updated inference test" + another_inference_field: "another updated inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "updated inference test" } + - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" } + + +--- +"Reindex works for semantic_text fields": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + + - do: + indices.refresh: { } + + - do: + indices.create: + index: destination-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: test-inference-id + another_inference_field: + type: semantic_text + model_id: test-inference-id + non_inference_field: + type: text + + - do: + reindex: + wait_for_completion: true + body: + source: + index: test-index + dest: + index: destination-index + - do: + get: + index: destination-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + +--- +"Fails for non-existent model": + - do: + indices.create: + index: incorrect-test-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: non-existing-inference-id + non_inference_field: + type: text + + - do: + catch: bad_request + index: + index: incorrect-test-index + id: doc_1 + body: + inference_field: "inference test" + non_inference_field: "non inference test" + + - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + + # Succeeds when semantic_text field is not used + - do: + index: + index: incorrect-test-index + id: doc_1 + body: + non_inference_field: "non inference test" From b1a3ee864d7b653448d6764d3ca41cdb8e06414c Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:28:38 +0100 Subject: [PATCH 05/29] Semantic text dense vector support (#105515) --- .../BulkShardRequestInferenceProvider.java | 59 +++---- .../vectors/DenseVectorFieldMapper.java | 12 +- .../inference/InferenceServiceResults.java | 2 + ...tModelSettings.java => ModelSettings.java} | 57 +++--- .../action/bulk/BulkOperationTests.java | 145 +++++++++------- .../mock/AbstractTestInferenceService.java | 5 - .../TestSparseInferenceServiceExtension.java | 4 +- ...emanticTextInferenceResultFieldMapper.java | 160 ++++++++++------- ...icTextInferenceResultFieldMapperTests.java | 86 +++++----- .../xpack/inference/InferenceRestIT.java | 2 +- .../inference/10_semantic_text_inference.yml | 162 +++++++++++++----- .../20_semantic_text_field_mapper.yml | 153 +++++++++++++++++ .../CoordinatedInferenceIngestIT.java | 4 +- 13 files changed, 572 insertions(+), 279 deletions(-) rename server/src/main/java/org/elasticsearch/inference/{SemanticTextModelSettings.java => ModelSettings.java} (61%) create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 02f905f7cd87a..fdf3af80b8526 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -24,11 +24,13 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.ModelSettings; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -46,10 +48,10 @@ public class BulkShardRequestInferenceProvider { public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; // Contains the original text for the field - public static final String TEXT_SUBFIELD_NAME = "text"; - // Contains the inference result when it's a sparse vector - public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; + public static final String INFERENCE_RESULTS = "inference_results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; private final ClusterState clusterState; private final Map inferenceProvidersMap; @@ -90,7 +92,13 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { InferenceProvider inferenceProvider = new InferenceProvider( - service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ), service.get() ); inferenceProviderMap.put(inferenceId, inferenceProvider); @@ -105,7 +113,7 @@ public void onFailure(Exception e) { } }; - modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); + modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); } } } @@ -259,35 +267,22 @@ public void onResponse(InferenceServiceResults results) { } int i = 0; - for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { - String fieldName = inferenceFieldNames.get(i++); - List> inferenceFieldResultList; - try { - inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new ArrayList<>() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object" + for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { + String inferenceFieldName = inferenceFieldNames.get(i++); + Map inferenceFieldResult = new LinkedHashMap<>(); + inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap()); + inferenceFieldResult.put( + INFERENCE_RESULTS, + List.of( + Map.of( + INFERENCE_CHUNKS_RESULTS, + inferenceResults.asMap("output").get("output"), + INFERENCE_CHUNKS_TEXT, + docMap.get(inferenceFieldName) ) - ); - return; - } - // Remove previous inference results if any - inferenceFieldResultList.clear(); - - // TODO Check inference result type to change subfield name - var inferenceFieldMap = Map.of( - SPARSE_VECTOR_SUBFIELD_NAME, - inferenceResults.asMap("output").get("output"), - TEXT_SUBFIELD_NAME, - docMap.get(fieldName) + ) ); - inferenceFieldResultList.add(inferenceFieldMap); + rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 47efa0ca49771..c6e4d4af926a2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -210,6 +210,16 @@ protected Parameter[] getParameters() { return new Parameter[] { elementType, dims, indexed, similarity, indexOptions, meta }; } + public Builder similarity(VectorSimilarity vectorSimilarity) { + similarity.setValue(vectorSimilarity); + return this; + } + + public Builder dimensions(int dimensions) { + this.dims.setValue(dimensions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { return new DenseVectorFieldMapper( @@ -708,7 +718,7 @@ static Function errorByteElementsAppender(byte[] v ElementType.FLOAT ); - enum VectorSimilarity { + public enum VectorSimilarity { L2_NORM { @Override float score(float similarity, ElementType elementType, int dim) { diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 62166115820f5..14cfeacf76139 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -35,6 +35,8 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragm /** * Convert the result to a map to aid with test assertions + * + * @return a map */ Map asMap(); } diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java similarity index 61% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to server/src/main/java/org/elasticsearch/inference/ModelSettings.java index 78773bfb72a95..957e2f44d5813 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelSettings.java @@ -8,7 +8,6 @@ package org.elasticsearch.inference; -import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,28 +18,22 @@ import java.util.Objects; /** - * Model settings that are interesting for semantic_text inference fields. This class is used to serialize common - * ServiceSettings methods when building inference for semantic_text fields. - * - * @param taskType task type - * @param inferenceId inference id - * @param dimensions number of dimensions. May be null if not applicable - * @param similarity similarity used by the service. May be null if not applicable + * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. + * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ -public record SemanticTextModelSettings( - TaskType taskType, - String inferenceId, - @Nullable Integer dimensions, - @Nullable SimilarityMeasure similarity -) { +public class ModelSettings { public static final String NAME = "model_settings"; - private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + private final TaskType taskType; + private final String inferenceId; + private final Integer dimensions; + private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; @@ -49,7 +42,7 @@ public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer this.similarity = similarity; } - public SemanticTextModelSettings(Model model) { + public ModelSettings(Model model) { this( model.getTaskType(), model.getInferenceEntityId(), @@ -58,16 +51,16 @@ public SemanticTextModelSettings(Model model) { ); } - public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { + public static ModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { TaskType taskType = TaskType.fromString((String) args[0]); String inferenceId = (String) args[1]; Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[2]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); + SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); + return new ModelSettings(taskType, inferenceId, dimensions, similarity); }); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); @@ -88,4 +81,20 @@ public Map asMap() { } return Map.of(NAME, attrsMap); } + + public TaskType taskType() { + return taskType; + } + + public String inferenceId() { + return inferenceId; + } + + public Integer dimensions() { + return dimensions; + } + + public SimilarityMeasure similarity() { + return similarity; + } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index f8ed331d358b2..4b81e089ed2b2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -33,6 +33,9 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -56,6 +59,9 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -91,10 +97,10 @@ public void testNoInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -130,6 +136,26 @@ public void testNoInference() { verifyNoMoreInteractions(inferenceServiceRegistry); } + private static Model mockModel(String inferenceServiceId) { + Model model = mock(Model.class); + + when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); + TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; + when(model.getTaskType()).thenReturn(taskType); + + ServiceSettings serviceSettings = mock(ServiceSettings.class); + when(model.getServiceSettings()).thenReturn(serviceSettings); + SimilarityMeasure similarity = switch (randomInt(2)) { + case 0 -> SimilarityMeasure.COSINE; + case 1 -> SimilarityMeasure.DOT_PRODUCT; + default -> null; + }; + when(serviceSettings.similarity()).thenReturn(similarity); + when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); + + return model; + } + public void testFailedBulkShardRequest() { Map> fieldsForModels = Map.of(); @@ -191,10 +217,10 @@ public void testInference() { Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); - Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); - Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + Model model1 = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService1 = createInferenceService(model1); + Model model2 = mockModel(INFERENCE_SERVICE_2_ID); + InferenceService inferenceService2 = createInferenceService(model2); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -257,8 +283,8 @@ public void testFailedInference() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); @@ -291,8 +317,8 @@ public void testInferenceFailsForIncorrectRootObject() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceServiceThatFails(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -315,39 +341,6 @@ public void testInferenceFailsForIncorrectRootObject() { assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); } - public void testInferenceFailsForIncorrectInferenceFieldObject() { - - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value") - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object") - ); - } - public void testInferenceIdNotFound() { Map> fieldsForModels = Map.of( @@ -359,8 +352,8 @@ public void testInferenceIdNotFound() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + Model model = mockModel(INFERENCE_SERVICE_1_ID); + InferenceService inferenceService = createInferenceService(model); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); Map originalSource = Map.of( @@ -400,17 +393,20 @@ private static void checkInferenceResults( ); for (String inferenceFieldName : inferenceFieldNames) { - List> inferenceService1FieldResults = (List>) inferenceRootResultField.get( - inferenceFieldName - ); + Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(1)); - Map inferenceResultElement = inferenceService1FieldResults.get(0); - assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); - assertThat( - inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), - equalTo(docSource.get(inferenceFieldName)) + assertThat(inferenceService1FieldResults.size(), equalTo(2)); + Map modelSettings = (Map) inferenceService1FieldResults.get(ModelSettings.NAME); + assertNotNull(modelSettings); + assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName())); + assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName())); + + List> inferenceResultElement = (List>) inferenceService1FieldResults.get( + INFERENCE_RESULTS ); + assertFalse(inferenceResultElement.isEmpty()); + assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); + assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); } } @@ -421,8 +417,13 @@ private static void verifyInferenceServiceInvoked( Model model, Collection inferenceTexts ) { - verify(modelRegistry).getModel(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); + verify(inferenceService).parsePersistedConfigWithSecrets( + eq(inferenceService1Id), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ); verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); verifyNoMoreInteractions(inferenceService); } @@ -537,9 +538,16 @@ private static BulkShardRequest runBulkOperation( ); }; - private static InferenceService createInferenceService(Model model, String inferenceServiceId) { + private static InferenceService createInferenceService(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); @@ -548,7 +556,7 @@ private static InferenceService createInferenceService(Model model, String infer for (int i = 0; i < texts.size(); i++) { inferenceResults.add(createInferenceResults()); } - doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); + doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); listener.onResponse(inferenceServiceResults); return null; @@ -556,9 +564,16 @@ private static InferenceService createInferenceService(Model model, String infer return inferenceService; } - private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) { + private static InferenceService createInferenceServiceThatFails(Model model) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when( + inferenceService.parsePersistedConfigWithSecrets( + eq(model.getInferenceEntityId()), + eq(TaskType.SPARSE_EMBEDDING), + anyMap(), + anyMap() + ) + ).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); @@ -591,7 +606,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onFailure(new IllegalArgumentException("Model not found")); return null; - }).when(modelRegistry).getModel(any(), any()); + }).when(modelRegistry).getModelWithSecrets(any(), any()); inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( inferenceId, @@ -604,7 +619,7 @@ private static ModelRegistry createModelRegistry(Map inferenceId ActionListener listener = invocation.getArgument(1); listener.onResponse(unparsedModel); return null; - }).when(modelRegistry).getModel(eq(inferenceId), any()); + }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); }); return modelRegistry; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 99dfc9582eb05..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -101,11 +101,6 @@ public TestServiceModel( super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); } - @Override - public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() { - return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings(); - } - @Override public TestTaskSettings getTaskSettings() { return (TestTaskSettings) super.getTaskSettings(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index e5020774a70f3..33bbc94901e9d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -115,7 +115,7 @@ private SparseEmbeddingResults makeResults(List input) { for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F)); } embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); } @@ -127,7 +127,7 @@ private List makeChunkedResults(List inp for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 9e6c1eb0a6586..dbde641d8f757 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; @@ -25,25 +25,25 @@ import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; /** * A mapper for the {@code _semantic_text_inference} field. @@ -102,16 +102,13 @@ */ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = "_semantic_text_inference"; + public static final String NAME = ROOT_INFERENCE_FIELD; public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); - private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( - List.of(), - Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME) - ); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + static class SemanticTextInferenceFieldType extends MappedFieldType { private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); @@ -142,75 +139,86 @@ private SemanticTextInferenceResultFieldMapper() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parseInferenceResults(context); + parseAllFields(context); } - private static void parseInferenceResults(DocumentParserContext context) throws IOException { + private static void parseAllFields(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + token); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResults(context, mapperBuilderContext); + parseSingleField(context, mapperBuilderContext); } } - private static void parseFieldInferenceResults(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) - throws IOException { + private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - String fieldName = context.parser().currentName(); + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); Mapper mapper = context.getMapper(fieldName); if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new DocumentParsingException( - context.parser().getTokenLocation(), + parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ); } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); + parser.nextToken(); + ModelSettings modelSettings = ModelSettings.parse(parser); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResultsArray(context, mapperBuilderContext, fieldName); + String currentName = parser.currentName(); + if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( + context, + mapperBuilderContext, + fieldName, + modelSettings + ); + parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); + } else { + logger.debug("Skipping unrecognized field name [" + currentName + "]"); + advancePastCurrentFieldName(parser); + } + } } - private static void parseFieldInferenceResultsArray( + private static void parseFieldInferenceChunks( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings, + NestedObjectMapper nestedObjectMapper ) throws IOException { XContentParser parser = context.parser(); - NestedObjectMapper nestedObjectMapper = createNestedObjectMapper(context, mapperBuilderContext, fieldName); - if (parser.nextToken() != XContentParser.Token.START_ARRAY) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_ARRAY, got " + parser.currentToken()); - } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceResultElement(nestedContext, nestedObjectMapper, new LinkedList<>()); + parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); } } - private static void parseFieldInferenceResultElement( + private static void parseFieldInferenceChunkElement( DocumentParserContext context, ObjectMapper objectMapper, - LinkedList subfieldPath + ModelSettings modelSettings ) throws IOException { XContentParser parser = context.parser(); DocumentParserContext childContext = context.createChildContext(objectMapper); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); Set visitedSubfields = new HashSet<>(); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); visitedSubfields.add(currentName); @@ -222,14 +230,9 @@ private static void parseFieldInferenceResultElement( continue; } - if (childMapper instanceof FieldMapper) { + if (childMapper instanceof FieldMapper fieldMapper) { parser.nextToken(); - ((FieldMapper) childMapper).parse(childContext); - } else if (childMapper instanceof ObjectMapper) { - parser.nextToken(); - subfieldPath.push(currentName); - parseFieldInferenceResultElement(childContext, (ObjectMapper) childMapper, subfieldPath); - subfieldPath.pop(); + fieldMapper.parse(childContext); } else { // This should never happen, but fail parsing if it does so that it's not a silent failure throw new DocumentParsingException( @@ -239,29 +242,51 @@ private static void parseFieldInferenceResultElement( } } - Set requiredSubfields = REQUIRED_SUBFIELDS_MAP.get(subfieldPath); - if (requiredSubfields != null && visitedSubfields.containsAll(requiredSubfields) == false) { - Set missingSubfields = requiredSubfields.stream() + if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() .filter(s -> visitedSubfields.contains(s) == false) .collect(Collectors.toSet()); throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); } } - private static NestedObjectMapper createNestedObjectMapper( + private static NestedObjectMapper createInferenceResultsObjectMapper( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings ) { IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - ObjectMapper.Builder sparseVectorMapperBuilder = new ObjectMapper.Builder( - SPARSE_VECTOR_SUBFIELD_NAME, - ObjectMapper.Defaults.SUBOBJECTS - ).add( - new BooleanFieldMapper.Builder(SparseEmbeddingResults.Embedding.IS_TRUNCATED, ScriptCompiler.NONE, false, indexVersionCreated) - ).add(new SparseVectorFieldMapper.Builder(SparseEmbeddingResults.Embedding.EMBEDDING)); + FieldMapper.Builder resultsBuilder; + if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { + resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + Integer dimensions = modelSettings.dimensions(); + if (dimensions == null) { + throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); + } + denseVectorMapperBuilder.dimensions(dimensions); + resultsBuilder = denseVectorMapperBuilder; + } else { + throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); + } + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - TEXT_SUBFIELD_NAME, + INFERENCE_CHUNKS_TEXT, indexVersionCreated, context.indexAnalyzers() ).index(false).store(false); @@ -270,7 +295,7 @@ private static NestedObjectMapper createNestedObjectMapper( fieldName, context.indexSettings().getIndexVersionCreated() ); - nestedBuilder.add(sparseVectorMapperBuilder).add(textMapperBuilder); + nestedBuilder.add(resultsBuilder).add(textMapperBuilder); return nestedBuilder.build(mapperBuilderContext); } @@ -286,6 +311,15 @@ private static void advancePastCurrentFieldName(XContentParser parser) throws IO } } + private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException( + parser.getTokenLocation(), + "Expected a " + expected.toString() + ", got " + parser.currentToken() + ); + } + } + @Override protected String contentType() { return CONTENT_TYPE; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index aa2ad72941e0e..06a665ade3ab4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -31,6 +31,8 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; @@ -51,8 +53,9 @@ import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { @@ -214,7 +217,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]")); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]")); } { DocumentParsingException ex = expectThrows( @@ -232,7 +235,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]")); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]")); } { DocumentParsingException ex = expectThrows( @@ -252,7 +255,7 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + ", " + TEXT_SUBFIELD_NAME + "]") + containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]") ); } } @@ -411,8 +414,10 @@ private static void addSemanticTextInferenceResults( Map extraSubfields ) throws IOException { - Map>> inferenceResultsMap = new HashMap<>(); + Map> inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { + Map fieldMap = new HashMap<>(); + fieldMap.put(ModelSettings.NAME, modelSettingsMap()); List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() @@ -425,17 +430,10 @@ private static void addSemanticTextInferenceResults( Map subfieldMap = new HashMap<>(); if (sparseVectorSubfieldOptions.include()) { - Map embeddingMap = embedding.asMap(); - if (sparseVectorSubfieldOptions.includeIsTruncated() == false) { - embeddingMap.remove(SparseEmbeddingResults.Embedding.IS_TRUNCATED); - } - if (sparseVectorSubfieldOptions.includeEmbedding() == false) { - embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING); - } - subfieldMap.put(SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); + subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); } if (includeTextSubfield) { - subfieldMap.put(TEXT_SUBFIELD_NAME, text); + subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); } if (extraSubfields != null) { subfieldMap.putAll(extraSubfields); @@ -444,28 +442,42 @@ private static void addSemanticTextInferenceResults( parsedInferenceResults.add(subfieldMap); } - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), parsedInferenceResults); + fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); + inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); } + private static Map modelSettingsMap() { + return Map.of( + ModelSettings.TASK_TYPE_FIELD.getPreferredName(), + TaskType.SPARSE_EMBEDDING.toString(), + ModelSettings.INFERENCE_ID_FIELD.getPreferredName(), + randomAlphaOfLength(8) + ); + } + private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { mappingBuilder.startObject(semanticTextFieldName); - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SPARSE_VECTOR_SUBFIELD_NAME); - mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING); - mappingBuilder.field("type", "sparse_vector"); - mappingBuilder.endObject(); - mappingBuilder.endObject(); - mappingBuilder.endObject(); - mappingBuilder.startObject(TEXT_SUBFIELD_NAME); - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - mappingBuilder.endObject(); - mappingBuilder.endObject(); + { + mappingBuilder.field("type", "nested"); + mappingBuilder.startObject("properties"); + { + mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); + { + mappingBuilder.field("type", "sparse_vector"); + } + mappingBuilder.endObject(); + mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); + { + mappingBuilder.field("type", "text"); + mappingBuilder.field("index", false); + } + mappingBuilder.endObject(); + } + mappingBuilder.endObject(); + } mappingBuilder.endObject(); } @@ -477,12 +489,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); for (String token : tokens) { queryBuilder.add( - new BooleanClause( - new TermQuery( - new Term(path + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token) - ), - BooleanClause.Occur.MUST - ) + new BooleanClause(new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)), BooleanClause.Occur.MUST) ); } queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); @@ -497,12 +504,7 @@ private static void assertValidChildDoc( ) { assertEquals(expectedParent, childDoc.getParent()); visitedChildDocs.add( - new VisitedChildDocInfo( - childDoc.getPath(), - childDoc.getFields( - childDoc.getPath() + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING - ).size() - ) + new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) ); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 933e696d29d83..a397d9864d23d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -21,7 +21,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") - .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .plugin("inference-service-test") .distribution(DistributionType.DEFAULT) .build(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 0e1b33252153b..ead7f904ad57b 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -6,7 +6,7 @@ setup: - do: inference.put_model: task_type: sparse_embedding - inference_id: test-inference-id + inference_id: sparse-inference-id body: > { "service": "test_service", @@ -17,27 +17,57 @@ setup: "task_settings": { } } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: sparse-inference-id + another_inference_field: + type: semantic_text + model_id: sparse-inference-id + non_inference_field: + type: text - do: indices.create: - index: test-index + index: test-dense-index body: mappings: properties: inference_field: type: semantic_text - model_id: test-inference-id + model_id: dense-inference-id another_inference_field: type: semantic_text - model_id: test-inference-id + model_id: dense-inference-id non_inference_field: type: text --- -"Calculates embeddings for new documents": +"Calculates text expansion results for new documents": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -46,24 +76,73 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "inference test" } - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + + - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference + - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + +--- +"text expansion documents do not create new mappings": + - do: + indices.get_mapping: + index: test-sparse-index + + - match: {test-sparse-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.non_inference_field.type: text} + - length: {test-sparse-index.mappings.properties: 3} + +--- +"Calculates text embeddings results for new documents": + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-dense-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding - - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference + - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + + +--- +"text embeddings documents do not create new mappings": + - do: + indices.get_mapping: + index: test-dense-index + + - match: {test-dense-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.non_inference_field.type: text} + - length: {test-dense-index.mappings.properties: 3} --- "Updating non semantic_text fields does not recalculate embeddings": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -72,15 +151,15 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } - do: update: - index: test-index + index: test-sparse-index id: doc_1 body: doc: @@ -88,24 +167,24 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "inference test" } - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -114,12 +193,12 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - do: update: - index: test-index + index: test-sparse-index id: doc_1 body: doc: @@ -128,22 +207,21 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - match: { _source.inference_field: "updated inference test" } - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" } - + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": - do: index: - index: test-index + index: test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -152,11 +230,11 @@ setup: - do: get: - index: test-index + index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -169,10 +247,10 @@ setup: properties: inference_field: type: semantic_text - model_id: test-inference-id + model_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: test-inference-id + model_id: sparse-inference-id non_inference_field: type: text @@ -181,7 +259,7 @@ setup: wait_for_completion: true body: source: - index: test-index + index: test-sparse-index dest: index: destination-index - do: @@ -193,17 +271,17 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": - do: indices.create: - index: incorrect-test-index + index: incorrect-test-sparse-index body: mappings: properties: @@ -216,7 +294,7 @@ setup: - do: catch: bad_request index: - index: incorrect-test-index + index: incorrect-test-sparse-index id: doc_1 body: inference_field: "inference test" @@ -227,7 +305,7 @@ setup: # Succeeds when semantic_text field is not used - do: index: - index: incorrect-test-index + index: incorrect-test-sparse-index id: doc_1 body: non_inference_field: "non inference test" diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml new file mode 100644 index 0000000000000..da61e6e403ed8 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -0,0 +1,153 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + model_id: sparse-inference-id + dense_field: + type: semantic_text + model_id: dense-inference-id + non_inference_field: + type: text + +--- +"Sparse vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + inference_id: sparse-inference-id + task_type: sparse_embedding + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + feature_2: 0.2 + feature_3: 0.3 + feature_4: 0.4 + - text: "another inference test" + inference: + feature_1: 0.1 + feature_2: 0.2 + feature_3: 0.3 + feature_4: 0.4 + +--- +"Dense vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + dense_field: + model_settings: + inference_id: sparse-inference-id + task_type: text_embedding + dimensions: 5 + similarity: cosine + inference_results: + - text: "inference test" + inference: [0.1, 0.2, 0.3, 0.4, 0.5] + - text: "another inference test" + inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + +--- +"Model settings inference id not included": + - do: + catch: /Required \[inference_id\]/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + task_type: sparse_embedding + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + +--- +"Model settings task type not included": + - do: + catch: /Required \[task_type\]/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + sparse_field: + model_settings: + inference_id: sparse-inference-id + inference_results: + - text: "inference test" + inference: + feature_1: 0.1 + +--- +"Model settings dense vector dimensions not included": + - do: + catch: /Model settings for field \[dense_field\] must contain dimensions/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _semantic_text_inference: + dense_field: + model_settings: + inference_id: sparse-inference-id + task_type: text_embedding + inference_results: + - text: "inference test" + inference: [0.1, 0.2, 0.3, 0.4, 0.5] + - text: "another inference test" + inference: [-0.1, -0.2, -0.3, -0.4, -0.5] diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java index 4d90d2a186858..d8c9dc2efd927 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java @@ -59,10 +59,10 @@ public void testIngestWithMultipleModelTypes() throws IOException { assertThat(simulatedDocs, hasSize(2)); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); var sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); } { From 2039fb357d7b78b5cee66763cbbb44a4bbc0f71f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 6 Mar 2024 16:10:54 +0100 Subject: [PATCH 06/29] This was supposed to be merged into #105515 but didn't make it --- .../bulk/BulkShardRequestInferenceProvider.java | 4 ++-- ...lSettings.java => SemanticTextModelSettings.java} | 12 ++++++------ .../action/bulk/BulkOperationTests.java | 8 ++++---- .../SemanticTextInferenceResultFieldMapper.java | 10 +++++----- .../SemanticTextInferenceResultFieldMapperTests.java | 8 ++++---- 5 files changed, 21 insertions(+), 21 deletions(-) rename server/src/main/java/org/elasticsearch/inference/{ModelSettings.java => SemanticTextModelSettings.java} (86%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index fdf3af80b8526..4b7a67e9ca0e3 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -24,7 +24,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SemanticTextModelSettings; import java.util.ArrayList; import java.util.Collections; @@ -270,7 +270,7 @@ public void onResponse(InferenceServiceResults results) { for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { String inferenceFieldName = inferenceFieldNames.get(i++); Map inferenceFieldResult = new LinkedHashMap<>(); - inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap()); + inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap()); inferenceFieldResult.put( INFERENCE_RESULTS, List.of( diff --git a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java b/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java similarity index 86% rename from server/src/main/java/org/elasticsearch/inference/ModelSettings.java rename to server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java index 957e2f44d5813..3561c2351427c 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java @@ -21,7 +21,7 @@ * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ -public class ModelSettings { +public class SemanticTextModelSettings { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); @@ -33,7 +33,7 @@ public class ModelSettings { private final Integer dimensions; private final SimilarityMeasure similarity; - public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; @@ -42,7 +42,7 @@ public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, this.similarity = similarity; } - public ModelSettings(Model model) { + public SemanticTextModelSettings(Model model) { this( model.getTaskType(), model.getInferenceEntityId(), @@ -51,16 +51,16 @@ public ModelSettings(Model model) { ); } - public static ModelSettings parse(XContentParser parser) throws IOException { + public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { TaskType taskType = TaskType.fromString((String) args[0]); String inferenceId = (String) args[1]; Integer dimensions = (Integer) args[2]; SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new ModelSettings(taskType, inferenceId, dimensions, similarity); + return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); }); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 4b81e089ed2b2..2ce7b161d3dd1 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -33,7 +33,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -396,10 +396,10 @@ private static void checkInferenceResults( Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); assertNotNull(inferenceService1FieldResults); assertThat(inferenceService1FieldResults.size(), equalTo(2)); - Map modelSettings = (Map) inferenceService1FieldResults.get(ModelSettings.NAME); + Map modelSettings = (Map) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME); assertNotNull(modelSettings); - assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName())); - assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName())); + assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName())); + assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName())); List> inferenceResultElement = (List>) inferenceService1FieldResults.get( INFERENCE_RESULTS diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index dbde641d8f757..ad1e0f8c8cb81 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -28,7 +28,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; @@ -168,7 +168,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde parser.nextToken(); failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); parser.nextToken(); - ModelSettings modelSettings = ModelSettings.parse(parser); + SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); @@ -192,7 +192,7 @@ private static void parseFieldInferenceChunks( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, String fieldName, - ModelSettings modelSettings, + SemanticTextModelSettings modelSettings, NestedObjectMapper nestedObjectMapper ) throws IOException { XContentParser parser = context.parser(); @@ -209,7 +209,7 @@ private static void parseFieldInferenceChunks( private static void parseFieldInferenceChunkElement( DocumentParserContext context, ObjectMapper objectMapper, - ModelSettings modelSettings + SemanticTextModelSettings modelSettings ) throws IOException { XContentParser parser = context.parser(); DocumentParserContext childContext = context.createChildContext(objectMapper); @@ -254,7 +254,7 @@ private static NestedObjectMapper createInferenceResultsObjectMapper( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, String fieldName, - ModelSettings modelSettings + SemanticTextModelSettings modelSettings ) { IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); FieldMapper.Builder resultsBuilder; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index 06a665ade3ab4..319f6ef73fa56 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -31,7 +31,7 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SemanticTextModelSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; @@ -417,7 +417,7 @@ private static void addSemanticTextInferenceResults( Map> inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { Map fieldMap = new HashMap<>(); - fieldMap.put(ModelSettings.NAME, modelSettingsMap()); + fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() @@ -451,9 +451,9 @@ private static void addSemanticTextInferenceResults( private static Map modelSettingsMap() { return Map.of( - ModelSettings.TASK_TYPE_FIELD.getPreferredName(), + SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING.toString(), - ModelSettings.INFERENCE_ID_FIELD.getPreferredName(), + SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), randomAlphaOfLength(8) ); } From 3ca808b3cf94e7f0526cea6aebd556b7c900c084 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Tue, 19 Mar 2024 11:55:04 +0100 Subject: [PATCH 07/29] semantic_text - extract Index Metadata inference information to separate class (#106328) --- .../cluster/ClusterStateDiffIT.java | 24 +-- .../BulkShardRequestInferenceProvider.java | 39 +++- .../metadata/FieldInferenceMetadata.java | 190 ++++++++++++++++++ .../cluster/metadata/IndexMetadata.java | 95 +++------ .../metadata/MetadataCreateIndexService.java | 4 +- .../metadata/MetadataMappingService.java | 2 +- .../index/mapper/FieldTypeLookup.java | 17 +- .../index/mapper/InferenceModelFieldType.java | 2 +- .../index/mapper/MappingLookup.java | 4 +- .../action/bulk/BulkOperationTests.java | 61 +++--- .../cluster/metadata/IndexMetadataTests.java | 42 ++-- .../index/mapper/FieldTypeLookupTests.java | 16 +- .../index/mapper/MappingLookupTests.java | 15 +- .../mapper/MockInferenceModelFieldType.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../SemanticTextClusterMetadataTests.java | 12 +- 16 files changed, 347 insertions(+), 180 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 3a1f6e20bb288..fbb3016b925da 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -54,7 +54,6 @@ import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -62,6 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,33 +587,13 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldsForModels(randomFieldsForModels()); + builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } - - /** - * Generates a random fieldsForModels map - */ - private Map> randomFieldsForModels() { - if (randomBoolean()) { - return null; - } - - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } - - return fieldsForModels; - } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 4b7a67e9ca0e3..e80530f75cf4b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; @@ -75,11 +76,13 @@ public static void getInstance( Set shardIds, ActionListener listener ) { - Set inferenceIds = new HashSet<>(); - shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); - inferenceIds.addAll(fieldsForModels.keySet()); - }); + Set inferenceIds = shardIds.stream() + .map(ShardId::getIndex) + .collect(Collectors.toSet()) + .stream() + .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) + .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)) + .collect(Collectors.toSet()); final Map inferenceProviderMap = new ConcurrentHashMap<>(); Runnable onModelLoadingComplete = () -> listener.onResponse( new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) @@ -134,11 +137,11 @@ public void processBulkShardRequest( BiConsumer onBulkItemFailure ) { - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); + Map> fieldsForInferenceIds = getFieldsForInferenceIds( + clusterState.metadata().index(bulkShardRequest.shardId().getIndex()).getFieldInferenceMetadata().getFieldInferenceOptions() + ); // No inference fields? Terminate early - if (fieldsForModels.isEmpty()) { + if (fieldsForInferenceIds.isEmpty()) { listener.onResponse(bulkShardRequest); return; } @@ -176,7 +179,7 @@ public void processBulkShardRequest( if (bulkItemRequest != null) { performInferenceOnBulkItemRequest( bulkItemRequest, - fieldsForModels, + fieldsForInferenceIds, i, onBulkItemFailureWithIndex, bulkItemReqRef.acquire() @@ -186,6 +189,22 @@ public void processBulkShardRequest( } } + private static Map> getFieldsForInferenceIds( + Map fieldInferenceMap + ) { + Map> fieldsForInferenceIdsMap = new HashMap<>(); + for (Map.Entry entry : fieldInferenceMap.entrySet()) { + String fieldName = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + + // Get or create the set associated with the inferenceId + Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); + fields.add(fieldName); + } + + return fieldsForInferenceIdsMap; + } + @SuppressWarnings("unchecked") private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java new file mode 100644 index 0000000000000..349706c139127 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.Diffable; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.SimpleDiffable; +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator + * node, which not necessarily has mapping information. + */ +public class FieldInferenceMetadata implements Diffable, ToXContentFragment { + + private final ImmutableOpenMap fieldInferenceOptions; + + public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); + + public FieldInferenceMetadata(MappingLookup mappingLookup) { + ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); + mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { + builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); + }); + fieldInferenceOptions = builder.build(); + } + + public FieldInferenceMetadata(StreamInput in) throws IOException { + fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); + } + + public FieldInferenceMetadata(Map fieldsToInferenceMap) { + fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); + } + + public ImmutableOpenMap getFieldInferenceOptions() { + return fieldInferenceOptions; + } + + public boolean isEmpty() { + return fieldInferenceOptions.isEmpty(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.map(fieldInferenceOptions); + return builder; + } + + public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { + return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FieldInferenceMetadata that = (FieldInferenceMetadata) o; + return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); + } + + @Override + public int hashCode() { + return Objects.hash(fieldInferenceOptions); + } + + @Override + public Diff diff(FieldInferenceMetadata previousState) { + if (previousState == null) { + previousState = EMPTY; + } + return new FieldInferenceMetadataDiff(previousState, this); + } + + static class FieldInferenceMetadataDiff implements Diff { + + public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( + FieldInferenceMetadata.EMPTY, + FieldInferenceMetadata.EMPTY + ); + + private final Diff> fieldInferenceMapDiff; + + private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); + + FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { + fieldInferenceMapDiff = DiffableUtils.diff( + before.fieldInferenceOptions, + after.fieldInferenceOptions, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + FieldInferenceMetadataDiff(StreamInput in) throws IOException { + fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + FIELD_INFERENCE_DIFF_VALUE_READER + ); + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + fieldInferenceMapDiff.writeTo(out); + } + + @Override + public FieldInferenceMetadata apply(FieldInferenceMetadata part) { + return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); + } + } + + public record FieldInferenceOptions(String inferenceId, Set sourceFields) + implements + SimpleDiffable, + ToXContentFragment { + + public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); + + FieldInferenceOptions(StreamInput in) throws IOException { + this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeStringCollection(sourceFields); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); + builder.endObject(); + return builder; + } + + public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "field_inference_parser", + false, + (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 81406f0a74ce5..89c925427cf88 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -78,7 +78,6 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -541,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String KEY_FIELD_INFERENCE = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -632,8 +631,7 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - // Key: model ID, Value: Fields that use model - private final ImmutableOpenMap> fieldsForModels; + private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -680,7 +678,7 @@ private IndexMetadata( @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, @Nullable Long shardSizeInBytesForecast, - final ImmutableOpenMap> fieldsForModels + @Nullable FieldInferenceMetadata fieldInferenceMetadata ) { this.index = index; this.version = version; @@ -736,7 +734,7 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldsForModels = Objects.requireNonNull(fieldsForModels); + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -788,7 +786,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -847,7 +845,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -904,7 +902,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -961,7 +959,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1014,7 +1012,7 @@ public IndexMetadata withIncrementedVersion() { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldInferenceMetadata ); } @@ -1218,8 +1216,8 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public Map> getFieldsForModels() { - return fieldsForModels; + public FieldInferenceMetadata getFieldInferenceMetadata() { + return fieldInferenceMetadata; } public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; @@ -1419,7 +1417,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldsForModels.equals(that.fieldsForModels) == false) { + if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { return false; } if (isSystem != that.isSystem) { @@ -1442,7 +1440,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldsForModels.hashCode(); + result = 31 * result + fieldInferenceMetadata.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1498,7 +1496,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff>> fieldsForModels; + private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1535,12 +1533,7 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldsForModels = DiffableUtils.diff( - before.fieldsForModels, - after.fieldsForModels, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); + fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1601,13 +1594,9 @@ private static class IndexMetadataDiff implements Diff { shardSizeInBytesForecast = null; } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = DiffableUtils.readJdkMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - DiffableUtils.StringSetValueSerializer.getInstance() - ); + fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); } else { - fieldsForModels = DiffableUtils.emptyDiff(); + fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; } } @@ -1645,7 +1634,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels.writeTo(out); + out.writeOptionalWriteable(fieldInferenceMetadata); } } @@ -1676,7 +1665,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); + builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1745,9 +1734,7 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) - ); + builder.fieldInferenceMetadata(new FieldInferenceMetadata(in)); } return builder.build(true); } @@ -1796,7 +1783,7 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalLong(shardSizeInBytesForecast); } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + fieldInferenceMetadata.writeTo(out); } } @@ -1847,7 +1834,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private final ImmutableOpenMap.Builder> fieldsForModels; + private FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; public Builder(String index) { this.index = index; @@ -1855,7 +1842,6 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); - this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1880,7 +1866,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); + this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2110,8 +2096,8 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldsForModels(Map> fieldsForModels) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); + public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { + this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); return this; } @@ -2310,7 +2296,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels.build() + fieldInferenceMetadata ); } @@ -2436,8 +2422,8 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); + if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { + builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); } builder.endObject(); @@ -2517,18 +2503,8 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) - .entrySet() - .stream() - .collect( - Collectors.toMap( - Map.Entry::getKey, - v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) - ) - ); - builder.fieldsForModels(fieldsForModels); + case KEY_FIELD_INFERENCE: + builder.fieldInferenceMetadata(FieldInferenceMetadata.fromXContent(parser)); break; default: // assume it's custom index metadata @@ -2726,17 +2702,6 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } - - private static void processFieldsForModels( - ImmutableOpenMap.Builder> builder, - Map> fieldsForModels - ) { - builder.clear(); - if (fieldsForModels != null) { - // Ensure that all field sets contained in the processed map are immutable - fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); - } - } } /** diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index d8fe0b0c19e52..96ca7a15edc30 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1267,8 +1267,8 @@ static IndexMetadata buildIndexMetadata( if (mapper != null) { MappingMetadata mappingMd = new MappingMetadata(mapper); mappingsMetadata.put(mapper.type(), mappingMd); - - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); + indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index d913a6465482d..0e31592991369 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -204,7 +204,7 @@ private static ClusterState applyRequest( DocumentMapper mapper = mapperService.documentMapper(); if (mapper != null) { indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldsForModels(mapper.mappers().getFieldsForModels()); + indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 372b1412df724..0741cfa682b74 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -39,7 +39,7 @@ final class FieldTypeLookup { /** * A map from inference model ID to all fields that use the model to generate embeddings. */ - private final Map> fieldsForModels; + private final Map inferenceIdsForFields; private final int maxParentPathDots; @@ -53,7 +53,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map> fieldsForModels = new HashMap<>(); + final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -72,11 +72,7 @@ final class FieldTypeLookup { fieldToCopiedFields.get(targetField).add(fieldName); } if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - String inferenceModel = inferenceModelFieldType.getInferenceModel(); - if (inferenceModel != null) { - Set fields = fieldsForModels.computeIfAbsent(inferenceModel, v -> new HashSet<>()); - fields.add(fieldName); - } + inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); } } @@ -110,8 +106,7 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - fieldsForModels.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); - this.fieldsForModels = Map.copyOf(fieldsForModels); + this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -220,8 +215,8 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map> getFieldsForModels() { - return fieldsForModels; + Map getInferenceIdsForFields() { + return inferenceIdsForFields; } /** diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java index 490d7f36219cf..6e12a204ed7d0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java @@ -17,5 +17,5 @@ public interface InferenceModelFieldType { * * @return model id used by the field type */ - String getInferenceModel(); + String getInferenceId(); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index cf2212110a210..c2bd95115f27e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -524,7 +524,7 @@ public void validateDoesNotShadow(String name) { } } - public Map> getFieldsForModels() { - return fieldTypeLookup.getFieldsForModels(); + public Map getInferenceIdsForFields() { + return fieldTypeLookup.getInferenceIdsForFields(); } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 2ce7b161d3dd1..c3887f506b891 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -92,7 +93,7 @@ public class BulkOperationTests extends ESTestCase { public void testNoInference() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry( Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); @@ -116,7 +117,7 @@ public void testNoInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -158,7 +159,7 @@ private static Model mockModel(String inferenceServiceId) { public void testFailedBulkShardRequest() { - Map> fieldsForModels = Map.of(); + FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; ModelRegistry modelRegistry = createModelRegistry(Map.of()); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); @@ -176,7 +177,7 @@ public void testFailedBulkShardRequest() { runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -206,11 +207,15 @@ public void testFailedBulkShardRequest() { @SuppressWarnings("unchecked") public void testInference() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry( @@ -244,7 +249,7 @@ public void testInference() { ActionListener bulkOperationListener = mock(ActionListener.class); BulkShardRequest bulkShardRequest = runBulkOperation( originalSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, true, @@ -279,7 +284,9 @@ public void testInference() { public void testFailedInference() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -298,7 +305,7 @@ public void testFailedInference() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -313,7 +320,9 @@ public void testFailedInference() { public void testInferenceFailsForIncorrectRootObject() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) + ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -331,7 +340,7 @@ public void testInferenceFailsForIncorrectRootObject() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -343,11 +352,15 @@ public void testInferenceFailsForIncorrectRootObject() { public void testInferenceIdNotFound() { - Map> fieldsForModels = Map.of( - INFERENCE_SERVICE_1_ID, - Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), - INFERENCE_SERVICE_2_ID, - Set.of(INFERENCE_FIELD_SERVICE_2) + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( + Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + SECOND_INFERENCE_FIELD_SERVICE_1, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), + INFERENCE_FIELD_SERVICE_2, + new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) + ) ); ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); @@ -368,7 +381,7 @@ public void testInferenceIdNotFound() { ActionListener bulkOperationListener = mock(ActionListener.class); doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); @@ -444,7 +457,7 @@ public String toString() { private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, boolean expectTransportShardBulkActionToExecute, @@ -452,7 +465,7 @@ private static BulkShardRequest runBulkOperation( ) { return runBulkOperation( docSource, - fieldsForModels, + fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, bulkOperationListener, @@ -463,7 +476,7 @@ private static BulkShardRequest runBulkOperation( private static BulkShardRequest runBulkOperation( Map docSource, - Map> fieldsForModels, + FieldInferenceMetadata fieldInferenceMetadata, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener, @@ -472,7 +485,7 @@ private static BulkShardRequest runBulkOperation( ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .settings(settings) .numberOfShards(1) .numberOfReplicas(0) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b2354a4356595..b32873df71365 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -41,7 +42,6 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -84,7 +84,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - Map> fieldsForModels = randomFieldsForModels(true); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -110,7 +110,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(fieldsForModels) + .fieldInferenceMetadata(fieldInferenceMetadata) .build(); assertEquals(system, metadata.isSystem()); @@ -145,7 +145,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -169,7 +169,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); + assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); } } @@ -553,35 +553,35 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldsForModels() { + public void testFieldInferenceMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); + assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); - Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); - assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); + FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); + assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - private static Map> randomFieldsForModels(boolean allowNull) { - if (allowNull && randomBoolean()) { + public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { + if (randomBoolean() && allowNull) { return null; } - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } + Map fieldInferenceMap = randomMap( + 0, + 10, + () -> new Tuple<>(randomIdentifier(), randomFieldInference()) + ); + return new FieldInferenceMetadata(fieldInferenceMap); + } - return fieldsForModels; + private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { + return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 27663edde945c..932eac3e60d27 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -37,7 +37,7 @@ public void testEmpty() { assertNotNull(names); assertThat(names, hasSize(0)); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -48,7 +48,7 @@ public void testAddNewField() { assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - Map> fieldsForModels = lookup.getFieldsForModels(); + Map fieldsForModels = lookup.getInferenceIdsForFields(); assertNotNull(fieldsForModels); assertTrue(fieldsForModels.isEmpty()); } @@ -440,11 +440,13 @@ public void testInferenceModelFieldType() { assertEquals(f2.fieldType(), lookup.get("foo2")); assertEquals(f3.fieldType(), lookup.get("foo3")); - Map> fieldsForModels = lookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(2, fieldsForModels.size()); - assertEquals(Set.of("foo1", "foo2"), fieldsForModels.get("bar1")); - assertEquals(Set.of("foo3"), fieldsForModels.get("bar2")); + Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(3, inferenceIdsForFields.size()); + + assertEquals("bar1", inferenceIdsForFields.get("foo1")); + assertEquals("bar1", inferenceIdsForFields.get("foo2")); + assertEquals("bar2", inferenceIdsForFields.get("foo3")); } private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index f512f5d352a43..bb337d0c61c93 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,7 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -122,8 +121,8 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getFieldsForModels()); - assertTrue(mappingLookup.getFieldsForModels().isEmpty()); + assertNotNull(mappingLookup.getInferenceIdsForFields()); + assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -191,7 +190,7 @@ public MetricType getMetricType() { ); } - public void testFieldsForModels() { + public void testInferenceIdsForFields() { MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); MappingLookup mappingLookup = createMappingLookup( Collections.singletonList(new MockFieldMapper(fieldType)), @@ -201,10 +200,10 @@ public void testFieldsForModels() { assertEquals(1, size(mappingLookup.fieldMappers())); assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - Map> fieldsForModels = mappingLookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(1, fieldsForModels.size()); - assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); + Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); + assertNotNull(inferenceIdsForFields); + assertEquals(1, inferenceIdsForFields.size()); + assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); } private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java index 854749d6308db..0d21134b5d9a9 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java @@ -39,7 +39,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..d9e18728615ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -108,7 +108,7 @@ public String typeName() { } @Override - public String getInferenceModel() { + public String getInferenceId() { return modelId; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 69fa64ffa6d1c..a7d3fcce26116 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -20,8 +20,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Set; public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @@ -35,7 +33,10 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") ); - assertEquals(Map.of("test_model", Set.of("field")), indexService.getMetadata().getFieldsForModels()); + assertEquals( + indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } public void testAddSemanticTextField() throws Exception { @@ -52,7 +53,10 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals(Map.of("test_model", Set.of("field")), resultingState.metadata().index("test").getFieldsForModels()); + assertEquals( + resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), + "test_model" + ); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { From 823fb58a3e8c68d3b1fede3bff5bf31650a58a3a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 13:14:09 +0000 Subject: [PATCH 08/29] [feature/semantic_text] Refactor inference to run as an action filter (#106357) --------- Co-authored-by: carlosdelest --- .../action/bulk/BulkOperation.java | 118 +-- .../action/bulk/BulkShardRequest.java | 27 + .../BulkShardRequestInferenceProvider.java | 338 --------- .../action/bulk/TransportBulkAction.java | 36 +- .../bulk/TransportSimulateBulkAction.java | 4 +- .../vectors/DenseVectorFieldMapper.java | 4 + .../inference/InferenceServiceRegistry.java | 62 +- .../InferenceServiceRegistryImpl.java | 64 -- .../inference/ModelRegistry.java | 99 --- .../elasticsearch/node/NodeConstruction.java | 15 - .../plugins/InferenceRegistryPlugin.java | 22 - .../action/bulk/BulkOperationTests.java | 670 ------------------ ...ActionIndicesThatCannotBeCreatedTests.java | 8 +- .../bulk/TransportBulkActionIngestTests.java | 8 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 16 +- .../snapshots/SnapshotResiliencyTests.java | 4 +- .../TestSparseInferenceServiceExtension.java | 8 +- ...gistryImplIT.java => ModelRegistryIT.java} | 52 +- .../xpack/inference/InferencePlugin.java | 54 +- .../TransportDeleteInferenceModelAction.java | 2 +- .../TransportGetInferenceModelAction.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- .../ShardBulkInferenceActionFilter.java | 343 +++++++++ ...r.java => InferenceResultFieldMapper.java} | 68 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../mapper}/SemanticTextModelSettings.java | 11 +- ...elRegistryImpl.java => ModelRegistry.java} | 82 ++- .../ShardBulkInferenceActionFilterTests.java | 344 +++++++++ ...a => InferenceResultFieldMapperTests.java} | 147 ++-- ...ImplTests.java => ModelRegistryTests.java} | 34 +- .../inference/10_semantic_text_inference.yml | 48 +- .../20_semantic_text_field_mapper.yml | 20 +- 34 files changed, 1102 insertions(+), 1618 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/ModelRegistry.java delete mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java delete mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryImplIT.java => ModelRegistryIT.java} (86%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapper.java => InferenceResultFieldMapper.java} (84%) rename {server/src/main/java/org/elasticsearch/inference => x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper}/SemanticTextModelSettings.java (92%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImpl.java => ModelRegistry.java} (86%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapperTests.java => InferenceResultFieldMapperTests.java} (79%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImplTests.java => ModelRegistryTests.java} (92%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 2b84ec8746cd2..452a9ec90443a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -36,8 +35,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -47,7 +44,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -73,8 +69,6 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; - private final InferenceServiceRegistry inferenceServiceRegistry; - private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -88,8 +82,6 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -105,8 +97,6 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; - this.inferenceServiceRegistry = inferenceServiceRegistry; - this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -199,30 +189,7 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInstance( - inferenceServiceRegistry, - modelRegistry, - clusterState, - requestsByShard.keySet(), - new ActionListener() { - @Override - public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { - processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); - } - - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error loading inference models", e); - } - } - ); - } - - void processRequestsByShards( - Map> requestsByShard, - ClusterState clusterState, - BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider - ) { + String nodeId = clusterService.localNode().getId(); Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -230,68 +197,33 @@ void processRequestsByShards( // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); - BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); - - Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); - bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { - @Override - public void onResponse(BulkShardRequest inferenceBulkShardRequest) { - executeBulkShardRequest( - inferenceBulkShardRequest, - ActionListener.releaseAfter(ActionListener.noop(), ref), - bulkItemFailedListener - ); - } - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error performing inference", e); - } - }, bulkItemFailedListener); + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); + } + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); } } } - private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } - return bulkShardRequest; - } - - // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { - final String indexName = itemRequest.index(); - - DocWriteRequest docWriteRequest = itemRequest.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - } - - private void executeBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer bulkItemErrorListener - ) { - if (bulkShardRequest.items().length == 0) { - // No requests to execute due to previous errors, terminate early - listener.onResponse(bulkShardRequest); - return; - } - + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -302,17 +234,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - listener.onResponse(bulkShardRequest); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - BulkItemRequest[] items = bulkShardRequest.items(); - for (BulkItemRequest item : items) { - bulkItemErrorListener.accept(item, e); + for (BulkItemRequest request : bulkShardRequest.items()) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - listener.onFailure(e); + releaseOnFinish.close(); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..39fa791a3e27d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -33,6 +34,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +47,30 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. + */ + public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { + this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public FieldInferenceMetadata consumeFieldInferenceMetadata() { + FieldInferenceMetadata ret = fieldsInferenceMetadataMap; + fieldsInferenceMetadataMap = null; + return ret; + } + + /** + * Public for test + */ + public FieldInferenceMetadata getFieldsInferenceMetadataMap() { + return fieldsInferenceMetadataMap; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java deleted file mode 100644 index e80530f75cf4b..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; -import org.elasticsearch.common.TriConsumer; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; -import java.util.stream.Collectors; - -/** - * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. - */ -public class BulkShardRequestInferenceProvider { - - // Root field name for storing inference results - public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; - - // Contains the original text for the field - - public static final String INFERENCE_RESULTS = "inference_results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - private final ClusterState clusterState; - private final Map inferenceProvidersMap; - - private record InferenceProvider(Model model, InferenceService service) { - private InferenceProvider { - Objects.requireNonNull(model); - Objects.requireNonNull(service); - } - } - - BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { - this.clusterState = clusterState; - this.inferenceProvidersMap = inferenceProvidersMap; - } - - public static void getInstance( - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - ClusterState clusterState, - Set shardIds, - ActionListener listener - ) { - Set inferenceIds = shardIds.stream() - .map(ShardId::getIndex) - .collect(Collectors.toSet()) - .stream() - .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) - .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)) - .collect(Collectors.toSet()); - final Map inferenceProviderMap = new ConcurrentHashMap<>(); - Runnable onModelLoadingComplete = () -> listener.onResponse( - new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) - ); - try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { - for (var inferenceId : inferenceIds) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - InferenceProvider inferenceProvider = new InferenceProvider( - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ), - service.get() - ); - inferenceProviderMap.put(inferenceId, inferenceProvider); - } - } - - @Override - public void onFailure(Exception e) { - // Failure on loading a model should not prevent the rest from being loaded and used. - // When the model is actually retrieved via the inference ID in the inference process, it will fail - // and the user will get the details on the inference failure. - } - }; - - modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); - } - } - } - - /** - * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from - * the original request will be modified with the inference results, to avoid copying the entire requests from - * the original bulk request. - * - * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. - * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, - * which may have fewer items than the original because of inference failures - * @param onBulkItemFailure invoked when a bulk item fails inference - */ - public void processBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer onBulkItemFailure - ) { - - Map> fieldsForInferenceIds = getFieldsForInferenceIds( - clusterState.metadata().index(bulkShardRequest.shardId().getIndex()).getFieldInferenceMetadata().getFieldInferenceOptions() - ); - // No inference fields? Terminate early - if (fieldsForInferenceIds.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - - Set failedItems = Collections.synchronizedSet(new HashSet<>()); - Runnable onInferenceComplete = () -> { - if (failedItems.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - // Remove failed items from the original bulk shard request - BulkItemRequest[] originalItems = bulkShardRequest.items(); - BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; - for (int i = 0, j = 0; i < originalItems.length; i++) { - if (failedItems.contains(i) == false) { - newItems[j++] = originalItems[i]; - } - } - BulkShardRequest newBulkShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - newItems - ); - listener.onResponse(newBulkShardRequest); - }; - TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { - failedItems.add(i); - onBulkItemFailure.accept(bulkItemRequest, e); - }; - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest bulkItemRequest = items[i]; - // Bulk item might be null because of previous errors, skip in that case - if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest( - bulkItemRequest, - fieldsForInferenceIds, - i, - onBulkItemFailureWithIndex, - bulkItemReqRef.acquire() - ); - } - } - } - } - - private static Map> getFieldsForInferenceIds( - Map fieldInferenceMap - ) { - Map> fieldsForInferenceIdsMap = new HashMap<>(); - for (Map.Entry entry : fieldInferenceMap.entrySet()) { - String fieldName = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); - - // Get or create the set associated with the inferenceId - Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); - fields.add(fieldName); - } - - return fieldsForInferenceIdsMap; - } - - @SuppressWarnings("unchecked") - private void performInferenceOnBulkItemRequest( - BulkItemRequest bulkItemRequest, - Map> fieldsForModels, - Integer itemIndex, - TriConsumer onBulkItemFailure, - Releasable releaseOnFinish - ) { - - DocWriteRequest docWriteRequest = bulkItemRequest.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; - } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - Map rootInferenceFieldMap; - try { - rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_INFERENCE_FIELD, - k -> new HashMap() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") - ); - return; - } - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { - continue; - } - - InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); - if (inferenceProvider == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("No inference provider found for model ID " + modelId) - ); - return; - } - ActionListener inferenceResultsListener = new ActionListener<>() { - @Override - public void onResponse(InferenceServiceResults results) { - if (results == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ) - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { - String inferenceFieldName = inferenceFieldNames.get(i++); - Map inferenceFieldResult = new LinkedHashMap<>(); - inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap()); - inferenceFieldResult.put( - INFERENCE_RESULTS, - List.of( - Map.of( - INFERENCE_CHUNKS_RESULTS, - inferenceResults.asMap("output").get("output"), - INFERENCE_CHUNKS_TEXT, - docMap.get(inferenceFieldName) - ) - ) - ); - rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); - } - } - - @Override - public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); - } - }; - inferenceProvider.service() - .infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - InputType.INGEST, - ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) - ); - } - } - } - - private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String) { - inferenceFieldNames.add(inferenceField); - } - } - return inferenceFieldNames; - } -} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index b05464b3a10c2..a2445e95a572f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,8 +57,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -100,8 +98,6 @@ public class TransportBulkAction extends HandledTransportAction responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -649,10 +633,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { new BulkOperation( task, @@ -666,8 +650,6 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, - modelRegistry, - inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index c8dc3e7b7ffd5..f65d0f462fde6 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -58,9 +58,7 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime, - null, - null + System::nanoTime ); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 85221896f35fd..f4a9e1727abd6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1112,6 +1112,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index ce6f1b21b734c..d5973807d9d78 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -13,41 +13,49 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry implements Closeable { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } -public interface InferenceServiceRegistry extends Closeable { - void init(Client client); - - Map getServices(); - - Optional getService(String serviceName); - - List getNamedWriteables(); - - class NoopInferenceServiceRegistry implements InferenceServiceRegistry { - public NoopInferenceServiceRegistry() {} + public void init(Client client) { + services.values().forEach(s -> s.init(client)); + } - @Override - public void init(Client client) {} + public Map getServices() { + return services; + } - @Override - public Map getServices() { - return Map.of(); - } + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } - @Override - public Optional getService(String serviceName) { - return Optional.empty(); - } + public List getNamedWriteables() { + return namedWriteables; + } - @Override - public List getNamedWriteables() { - return List.of(); + @Override + public void close() throws IOException { + for (var service : services.values()) { + service.close(); } - - @Override - public void close() throws IOException {} } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java deleted file mode 100644 index f0a990ded98ce..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class InferenceServiceRegistryImpl implements InferenceServiceRegistry { - - private final Map services; - private final List namedWriteables = new ArrayList<>(); - - public InferenceServiceRegistryImpl( - List inferenceServicePlugins, - InferenceServiceExtension.InferenceServiceFactoryContext factoryContext - ) { - // TODO check names are unique - services = inferenceServicePlugins.stream() - .flatMap(r -> r.getInferenceServiceFactories().stream()) - .map(factory -> factory.create(factoryContext)) - .collect(Collectors.toMap(InferenceService::name, Function.identity())); - } - - @Override - public void init(Client client) { - services.values().forEach(s -> s.init(client)); - } - - @Override - public Map getServices() { - return services; - } - - @Override - public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); - } - - @Override - public List getNamedWriteables() { - return namedWriteables; - } - - @Override - public void close() throws IOException { - for (var service : services.values()) { - service.close(); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java deleted file mode 100644 index fa90d5ba6f756..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; -import java.util.Map; - -public interface ModelRegistry { - - /** - * Get a model. - * Secret settings are not included - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModel(String inferenceEntityId, ActionListener listener); - - /** - * Get a model with its secret settings - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModelWithSecrets(String inferenceEntityId, ActionListener listener); - - /** - * Get all models of a particular task type. - * Secret settings are not included - * @param taskType The task type - * @param listener Models listener - */ - void getModelsByTaskType(TaskType taskType, ActionListener> listener); - - /** - * Get all models. - * Secret settings are not included - * @param listener Models listener - */ - void getAllModels(ActionListener> listener); - - void storeModel(Model model, ActionListener listener); - - void deleteModel(String modelId, ActionListener listener); - - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) {} - - class NoopModelRegistry implements ModelRegistry { - @Override - public void getModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void getAllModels(ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void storeModel(Model model, ActionListener listener) { - fail(listener); - } - - @Override - public void deleteModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - fail(listener); - } - - private static void fail(ActionListener listener) { - listener.onFailure(new IllegalArgumentException("No model registry configured")); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 15ebe2752451d..5bf19c4b87157 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -127,8 +127,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -147,7 +145,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1114,18 +1111,6 @@ record PluginServiceInstances( ); } - // Register noop versions of inference services if Inference plugin is not available - Optional inferenceRegistryPlugin = getSinglePlugin(InferenceRegistryPlugin.class); - modules.bindToInstance( - InferenceServiceRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getInferenceServiceRegistry) - .orElse(new InferenceServiceRegistry.NoopInferenceServiceRegistry()) - ); - modules.bindToInstance( - ModelRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getModelRegistry).orElse(new ModelRegistry.NoopModelRegistry()) - ); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java deleted file mode 100644 index 696c3a067dad1..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; - -/** - * Plugins that provide inference services should implement this interface. - * There should be a single one in the classpath, as we currently support a single instance for ModelRegistry / InfereceServiceRegistry. - */ -public interface InferenceRegistryPlugin { - InferenceServiceRegistry getInferenceServiceRegistry(); - - ModelRegistry getModelRegistry(); -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java deleted file mode 100644 index c3887f506b891..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ /dev/null @@ -1,670 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; -import org.elasticsearch.cluster.metadata.IndexAbstraction; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.service.ClusterApplierService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatcher; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyMap; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -public class BulkOperationTests extends ESTestCase { - - private static final String INDEX_NAME = "test-index"; - private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; - private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; - private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; - private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; - private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; - private static final String SERVICE_1_ID = "elser_v2"; - private static final String SERVICE_2_ID = "e5"; - private static final String INFERENCE_FAILED_MSG = "Inference failed"; - private static TestThreadPool threadPool; - - public void testNoInference() { - - FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference not invoked - verifyNoMoreInteractions(modelRegistry); - verifyNoMoreInteractions(inferenceServiceRegistry); - } - - private static Model mockModel(String inferenceServiceId) { - Model model = mock(Model.class); - - when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); - TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; - when(model.getTaskType()).thenReturn(taskType); - - ServiceSettings serviceSettings = mock(ServiceSettings.class); - when(model.getServiceSettings()).thenReturn(serviceSettings); - SimilarityMeasure similarity = switch (randomInt(2)) { - case 0 -> SimilarityMeasure.COSINE; - case 1 -> SimilarityMeasure.DOT_PRODUCT; - default -> null; - }; - when(serviceSettings.similarity()).thenReturn(similarity); - when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); - - return model; - } - - public void testFailedBulkShardRequest() { - - FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; - ModelRegistry modelRegistry = createModelRegistry(Map.of()); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - true, - request -> new BulkShardResponse( - request.shardId(), - new BulkItemResponse[] { - BulkItemResponse.failure( - 0, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure( - INDEX_NAME, - randomIdentifier(), - new IllegalArgumentException("Error on bulk shard request") - ) - ) } - ) - ); - verify(bulkOperationListener).onResponse(any()); - - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse[] items = bulkResponse.getItems(); - assertTrue(items[0].isFailed()); - } - - @SuppressWarnings("unchecked") - public void testInference() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) - ) - ); - - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - SECOND_INFERENCE_FIELD_SERVICE_1, - secondInferenceTextService1, - INFERENCE_FIELD_SERVICE_2, - inferenceTextService2, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference results - verifyInferenceServiceInvoked( - modelRegistry, - INFERENCE_SERVICE_1_ID, - inferenceService1, - model1, - List.of(firstInferenceTextService1, secondInferenceTextService1) - ); - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - checkInferenceResults( - originalSource, - writtenDocSource, - FIRST_INFERENCE_FIELD_SERVICE_1, - SECOND_INFERENCE_FIELD_SERVICE_1, - INFERENCE_FIELD_SERVICE_2 - ); - } - - public void testFailedInference() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); - - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); - - } - - public void testInferenceFailsForIncorrectRootObject() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - "incorrect_root_object" - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); - } - - public void testInferenceIdNotFound() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) - ) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceService(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - INFERENCE_FIELD_SERVICE_2, - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) - ); - } - - @SuppressWarnings("unchecked") - private static void checkInferenceResults( - Map docSource, - Map writtenDocSource, - String... inferenceFieldNames - ) { - - Map inferenceRootResultField = (Map) writtenDocSource.get( - BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD - ); - - for (String inferenceFieldName : inferenceFieldNames) { - Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); - assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(2)); - Map modelSettings = (Map) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME); - assertNotNull(modelSettings); - assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName())); - assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName())); - - List> inferenceResultElement = (List>) inferenceService1FieldResults.get( - INFERENCE_RESULTS - ); - assertFalse(inferenceResultElement.isEmpty()); - assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); - assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); - } - } - - private static void verifyInferenceServiceInvoked( - ModelRegistry modelRegistry, - String inferenceService1Id, - InferenceService inferenceService, - Model model, - Collection inferenceTexts - ) { - verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfigWithSecrets( - eq(inferenceService1Id), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ); - verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); - verifyNoMoreInteractions(inferenceService); - } - - private static ArgumentMatcher> containsInAnyOrder(Collection expected) { - return new ArgumentMatcher<>() { - @Override - public boolean matches(List argument) { - return argument.containsAll(expected) && argument.size() == expected.size(); - } - - @Override - public String toString() { - return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; - } - }; - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - FieldInferenceMetadata fieldInferenceMetadata, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - boolean expectTransportShardBulkActionToExecute, - ActionListener bulkOperationListener - ) { - return runBulkOperation( - docSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - expectTransportShardBulkActionToExecute, - successfulBulkShardResponse - ); - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - FieldInferenceMetadata fieldInferenceMetadata, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - ActionListener bulkOperationListener, - boolean expectTransportShardBulkActionToExecute, - Function bulkShardResponseSupplier - ) { - Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); - IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldInferenceMetadata(fieldInferenceMetadata) - .settings(settings) - .numberOfShards(1) - .numberOfReplicas(0) - .build(); - ClusterService clusterService = createClusterService(indexMetadata); - - IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); - when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); - - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); - - NodeClient client = mock(NodeClient.class); - - ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); - doAnswer(invocation -> { - BulkShardRequest request = invocation.getArgument(1); - ActionListener bulkShardResponseListener = invocation.getArgument(2); - bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); - return null; - }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); - - Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); - BulkOperation bulkOperation = new BulkOperation( - task, - threadPool, - ThreadPool.Names.WRITE, - clusterService, - bulkRequest, - client, - new AtomicArray<>(bulkRequest.requests.size()), - new HashMap<>(), - indexResolver, - () -> System.nanoTime(), - System.nanoTime(), - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - - bulkOperation.doRun(); - if (expectTransportShardBulkActionToExecute) { - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); - return bulkShardRequestCaptor.getValue(); - } - - return null; - } - - private static final Function successfulBulkShardResponse = (request) -> { - return new BulkShardResponse( - request.shardId(), - Arrays.stream(request.items()) - .filter(Objects::nonNull) - .map( - item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) - ) - ) - .toArray(BulkItemResponse[]::new) - ); - }; - - private static InferenceService createInferenceService(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); - List texts = invocation.getArgument(1); - List inferenceResults = new ArrayList<>(); - for (int i = 0; i < texts.size(); i++) { - inferenceResults.add(createInferenceResults()); - } - doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); - - listener.onResponse(inferenceServiceResults); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceService createInferenceServiceThatFails(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceResults createInferenceResults() { - InferenceResults inferenceResults = mock(InferenceResults.class); - when(inferenceResults.asMap(any())).then( - invocation -> Map.of( - (String) invocation.getArguments()[0], - Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) - ) - ); - return inferenceResults; - } - - private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { - InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); - return inferenceServiceRegistry; - } - - private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { - ModelRegistry modelRegistry = mock(ModelRegistry.class); - // Fails for unknown inference ids - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IllegalArgumentException("Model not found")); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { - ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - serviceId, - emptyMap(), - emptyMap() - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); - }); - - return modelRegistry; - } - - private static ClusterService createClusterService(IndexMetadata indexMetadata) { - Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); - - ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); - when(clusterService.state()).thenReturn(clusterState); - - ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); - when(clusterApplierService.state()).thenReturn(clusterState); - when(clusterApplierService.threadPool()).thenReturn(threadPool); - when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); - return clusterService; - } - - @BeforeClass - public static void createThreadPool() { - threadPool = new TestThreadPool(getTestClass().getName()); - } - - @AfterClass - public static void stopThreadPool() { - if (threadPool != null) { - threadPool.shutdownNow(); - threadPool = null; - } - } - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 988a92352649a..3057b00553a22 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -129,19 +129,17 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) { @Override void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 2d6492e4e73a4..6815d634292a4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -148,9 +148,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } @@ -159,10 +157,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index ad522e36f9bd9..1a16d9083df55 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -98,9 +98,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index a2e54a1c7c3b8..cb9bdd1f3a827 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } @@ -253,9 +253,7 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider, - null, - null + relativeTimeProvider ); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 7f1b5cdaee598..0a53db94b9aaf 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -2360,9 +2360,7 @@ protected void assertSnapshotOrGenericThread() { actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 33bbc94901e9d..b6e48d3b1c29a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -123,15 +123,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java similarity index 86% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ccda986a8d280..0f23e0b33d774 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -26,7 +26,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; @@ -55,13 +55,13 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class ModelRegistryImplIT extends ESSingleNodeTestCase { +public class ModelRegistryIT extends ESSingleNodeTestCase { - private ModelRegistryImpl ModelRegistryImpl; + private ModelRegistry modelRegistry; @Before public void createComponents() { - ModelRegistryImpl = new ModelRegistryImpl(client()); + modelRegistry = new ModelRegistry(client()); } @Override @@ -75,7 +75,7 @@ public void testStoreModel() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertThat(storeModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); @@ -87,7 +87,7 @@ public void testStoreModelWithUnknownFields() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertNull(storeModelHolder.get()); assertNotNull(exceptionHolder.get()); @@ -106,12 +106,12 @@ public void testGetModel() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -133,13 +133,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); putModelHolder.set(false); // an model with the same id exists - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(false)); assertThat(exceptionHolder.get(), not(nullValue())); assertThat( @@ -154,20 +154,20 @@ public void testDeleteModel() throws Exception { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference deleteResponseHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertTrue(deleteResponseHolder.get()); // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); @@ -187,13 +187,13 @@ public void testGetModelsByTaskType() throws InterruptedException { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) @@ -204,7 +204,7 @@ public void testGetModelsByTaskType() throws InterruptedException { assertThat(m.secrets().keySet(), empty()); }); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(2)); var denseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) @@ -228,13 +228,13 @@ public void testGetAllModels() throws InterruptedException { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -258,18 +258,18 @@ public void testGetModelWithSecrets() throws InterruptedException { AtomicReference exceptionHolder = new AtomicReference<>(); var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); // get model without secrets - blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 31aae67770c98..2a9c300e12c13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -26,11 +27,8 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceRegistryImpl; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -49,6 +47,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -56,9 +55,9 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -80,13 +79,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin - implements - ActionPlugin, - ExtensiblePlugin, - SystemIndexPlugin, - InferenceRegistryPlugin, - MapperPlugin { +import static java.util.Collections.singletonList; + +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -111,8 +106,7 @@ public class InferencePlugin extends Plugin private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce modelRegistry = new SetOnce<>(); - + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -163,7 +157,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelReg = new ModelRegistryImpl(services.client()); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -174,13 +168,14 @@ public Collection createComponents(PluginServices services) { var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly - var inferenceRegistry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext); - inferenceRegistry.init(services.client()); - inferenceServiceRegistry.set(inferenceRegistry); - modelRegistry.set(modelReg); + var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + registry.init(services.client()); + inferenceServiceRegistry.set(registry); + + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); - // Don't return components as they will be registered using InferenceRegistryPlugin methods to retrieve them - return List.of(); + return List.of(modelRegistry, registry); } @Override @@ -279,16 +274,6 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } - @Override - public InferenceServiceRegistry getInferenceServiceRegistry() { - return inferenceServiceRegistry.get(); - } - - @Override - public ModelRegistry getModelRegistry() { - return modelRegistry.get(); - } - @Override public Map getMappers() { if (SemanticTextFeature.isEnabled()) { @@ -299,6 +284,11 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + } + + @Override + public Collection getActionFilters() { + return singletonList(shardBulkInferenceActionFilter.get()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index ad6042581f264..b55e2e6f8ebed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 0f7e48c4f8140..2de1aecea118c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -25,6 +24,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ece4fee1c935f..fb3974fc12e8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,11 +16,11 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6667e314a62b8..07d28f8e5b0a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -29,7 +29,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..fbf84762eb314 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,343 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * in the subsequent {@link TransportShardBulkAction} downstream. + */ +public class ShardBulkInferenceActionFilter implements MappedActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public String actionName() { + return TransportShardBulkAction.ACTION_NAME; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + FieldInferenceMetadata fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + private record FieldInferenceRequest(int id, String field, String input) {} + + private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} + + private class AsyncBulkShardInferenceAction implements Runnable { + private final FieldInferenceMetadata fieldInferenceMetadata; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + FieldInferenceMetadata fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMetadata = fieldInferenceMetadata; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException( + "Inference id [{}] not found for field [{}]", + inferenceId, + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + } + } + + @Override + public void onFailure(Exception exc) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + ActionListener.runAfter(completionListener, onFinish::close) + ); + } + + /** + * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is mark as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results. + */ + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } + return; + } + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + Map newDocMap = indexRequest.sourceAsMap(); + Map inferenceMap = new LinkedHashMap<>(); + // ignore the existing inference map if any + newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + for (FieldInferenceResponse fieldResponse : response.responses()) { + try { + InferenceResultFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); + } + } + indexRequest.source(newDocMap); + } + + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security) + continue; + } + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + if (indexRequest == null) { + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); + } + } + } + return fieldRequestsMap; + } + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java similarity index 84% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index ad1e0f8c8cb81..2ede5419ab74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -8,7 +8,8 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; -import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -28,23 +29,27 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.SemanticTextModelSettings; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; - /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -57,7 +62,7 @@ * { * "_source": { * "my_semantic_text_field": "these are not the droids you're looking for", - * "_semantic_text_inference": { + * "_inference": { * "my_semantic_text_field": [ * { * "sparse_embedding": { @@ -100,12 +105,17 @@ * } * */ -public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { - public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = ROOT_INFERENCE_FIELD; - public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); +public class InferenceResultFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); @@ -132,7 +142,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { } } - private SemanticTextInferenceResultFieldMapper() { + public InferenceResultFieldMapper() { super(SemanticTextInferenceFieldType.INSTANCE); } @@ -173,7 +183,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + if (RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, @@ -329,4 +339,34 @@ protected String contentType() { public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index d9e18728615ba..83272a10f98d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link SemanticTextInferenceResultFieldMapper}. + * be indexed using {@link InferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java similarity index 92% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 3561c2351427c..1b6bb22c0d6b5 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -1,13 +1,15 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.inference; +package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,7 +21,6 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ public class SemanticTextModelSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 40921cd38f181..0f3aa5b82b189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -24,7 +24,6 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -32,7 +31,6 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -57,21 +55,49 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistryImpl implements ModelRegistry { +public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} + /** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) { + + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } + } + private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; - private static final Logger logger = LogManager.getLogger(ModelRegistryImpl.class); + private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; - @Inject - public ModelRegistryImpl(Client client) { + public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } - @Override + /** + * Get a model with its secret settings + * @param inferenceEntityId Model to get + * @param listener Model listener + */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -80,7 +106,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -101,7 +132,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -116,7 +147,12 @@ public void getModel(String inferenceEntityId, ActionListener lis client.search(modelSearch, searchListener); } - @Override + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -125,7 +161,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -150,7 +190,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -217,7 +257,6 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEnt ); } - @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -314,7 +353,6 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } - @Override public void deleteModel(String inferenceEntityId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -339,16 +377,4 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } - - private static UnparsedModel unparsedModelFromMap(ModelRegistryImpl.ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..4a1825303b5a7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,344 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setFieldInferenceMetadata( + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + ); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( + Map.of( + "field1", + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + "field2", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + "field3", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + ) + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map inferenceFieldsMap = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + } + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setFieldInferenceMetadata(fieldInferenceMetadata); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[1]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[5]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + FieldInferenceMetadata fieldInferenceMetadata + ) { + Map docMap = new LinkedHashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + var model = modelMap.get(entry.getValue().inferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; + + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; + + default: + throw new AssertionError("Unknown task type " + taskType.name()); + } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + } + Map expectedDocMap = new LinkedHashMap<>(docMap); + expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults results) { + resultMap.put(text, results); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java similarity index 79% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java index 319f6ef73fa56..b5d75b528c6ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java @@ -31,49 +31,46 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.SemanticTextModelSettings; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; -public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { - private SemanticTextInferenceResults { - if (sparseEmbeddingResults.embeddings().size() != text.size()) { - throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); - } - } - } +public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + private record VisitedChildDocInfo(String path, int numChunks) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return SemanticTextInferenceResultFieldMapper.NAME; + return InferenceResultFieldMapper.NAME; } @Override @@ -109,8 +106,8 @@ public void testSuccessfulParse() throws IOException { b -> addSemanticTextInferenceResults( b, List.of( - generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), - generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) ) ) ) @@ -209,10 +206,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, - null + Map.of() ) ) ) @@ -227,10 +224,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, - null + Map.of() ) ) ) @@ -245,10 +242,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, - null + Map.of() ) ) ) @@ -263,7 +260,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); final List semanticTextInferenceResultsList = List.of( - generateSemanticTextinferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, List.of("a b")) ); DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); @@ -361,7 +358,7 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) ) ); assertThat( @@ -379,18 +376,32 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { - List embeddings = new ArrayList<>(chunks.size()); - for (String chunk : chunks) { - String[] tokens = chunk.split("\\s+"); - List weightedTokens = Arrays.stream(tokens) - .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) - .toList(); + public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[5]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } - embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); } + return new ChunkedSparseEmbeddingResults(chunks); + } - return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { + return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -402,10 +413,11 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResults, new SparseVectorSubfieldOptions(true, true, true), true, - null + Map.of() ); } + @SuppressWarnings("unchecked") private static void addSemanticTextInferenceResults( XContentBuilder sourceBuilder, List semanticTextInferenceResults, @@ -413,48 +425,39 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - - Map> inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - Map fieldMap = new HashMap<>(); - fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); - List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); - - Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() - .embeddings() - .iterator(); - Iterator textIterator = semanticTextInferenceResult.text().iterator(); - while (embeddingsIterator.hasNext() && textIterator.hasNext()) { - SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); - String text = textIterator.next(); - - Map subfieldMap = new HashMap<>(); - if (sparseVectorSubfieldOptions.include()) { - subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); - } - if (includeTextSubfield) { - subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); + InferenceResultFieldMapper.applyFieldInference( + inferenceResultsMap, + semanticTextInferenceResult.fieldName, + randomModel(), + semanticTextInferenceResult.results + ); + Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); + List> fieldResultList = (List>) optionsMap.get(RESULTS); + for (var entry : fieldResultList) { + if (includeTextSubfield == false) { + entry.remove(INFERENCE_CHUNKS_TEXT); } - if (extraSubfields != null) { - subfieldMap.putAll(extraSubfields); + if (sparseVectorSubfieldOptions.include == false) { + entry.remove(INFERENCE_CHUNKS_RESULTS); } - - parsedInferenceResults.add(subfieldMap); + entry.putAll(extraSubfields); } - - fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } - - sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); } - private static Map modelSettingsMap() { - return Map.of( - SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), - TaskType.SPARSE_EMBEDDING.toString(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - randomAlphaOfLength(8) + private static Model randomModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java similarity index 92% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fd6a203450c12..2417148c84ac2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ModelRegistryImplTests extends ESTestCase { +public class ModelRegistryTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -65,9 +65,9 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -79,9 +79,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -96,9 +96,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -113,9 +113,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -147,9 +147,9 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -176,9 +176,9 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -201,7 +201,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +218,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +249,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -275,7 +275,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index ead7f904ad57b..6008ebbcbedf8 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -83,11 +83,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +120,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- @@ -154,8 +154,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +174,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +214,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.results.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +233,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -271,11 +271,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -292,7 +292,7 @@ setup: type: text - do: - catch: bad_request + catch: missing index: index: incorrect-test-sparse-index id: doc_1 @@ -300,7 +300,7 @@ setup: inference_field: "inference test" non_inference_field: "non inference test" - - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } # Succeeds when semantic_text field is not used - do: diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index da61e6e403ed8..2c69f49218091 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -56,12 +56,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -83,14 +83,14 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding dimensions: 5 similarity: cosine - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" @@ -105,11 +105,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -123,11 +123,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -141,12 +141,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" From d4e283dde6a7b4f93c1489ac7ff733100f864376 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 22 Mar 2024 17:31:47 +0000 Subject: [PATCH 09/29] [feature/semantic_text] Register semantic text sub fields in the mapping (#106560) This PR refactors the semantic text field mapper to register its sub fields in the mapping instead of re-creating them each time when parsing documents. It also fixes the generation of these fields in case the semantic text field is defined in an object field. Lastly this change adds a new section called model_settings in the field parameter that is updated by the field mapper when inference results are received from a bulk action. The model settings are available in the fields as soon as the first document with the inference field is ingested and they are used to validate that updates. They are used to ensure consistency between what's used in the bulk action and what's defined in the field mapping. --- .../xcontent/support/XContentMapValues.java | 2 +- .../index/mapper/FieldMapper.java | 8 +- .../elasticsearch/index/mapper/Mapping.java | 2 +- .../vectors/SparseVectorFieldMapper.java | 7 +- .../TestDenseInferenceServiceExtension.java | 2 +- .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 16 +- .../mapper/InferenceMetadataFieldMapper.java | 449 ++++++++++++++++++ .../mapper/InferenceResultFieldMapper.java | 372 --------------- .../mapper/SemanticTextFieldMapper.java | 210 +++++++- .../mapper/SemanticTextModelSettings.java | 136 ++++-- .../SemanticTextClusterMetadataTests.java | 4 +- .../ShardBulkInferenceActionFilterTests.java | 12 +- ...=> InferenceMetadataFieldMapperTests.java} | 392 +++++++++------ .../mapper/SemanticTextFieldMapperTests.java | 235 +++++++-- .../xpack/inference/model/TestModel.java | 11 + .../inference/10_semantic_text_inference.yml | 59 +-- .../20_semantic_text_field_mapper.yml | 97 +--- 18 files changed, 1267 insertions(+), 751 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{InferenceResultFieldMapperTests.java => InferenceMetadataFieldMapperTests.java} (57%) diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java index 805931550ad62..f527b4cd8d684 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java @@ -555,7 +555,7 @@ public static Map nodeMapValue(Object node, String desc) { if (node instanceof Map) { return (Map) node; } else { - throw new ElasticsearchParseException(desc + " should be a hash but was of type: " + node.getClass()); + throw new ElasticsearchParseException(desc + " should be a map but was of type: " + node.getClass()); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..f9354025cab49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1176,7 +1176,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1188,7 +1188,11 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public boolean hasConflicts() { + return conflicts.isEmpty() == false; + } + + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java index 903e4e5da5b29..da184d6f7a45e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java @@ -76,7 +76,7 @@ public CompressedXContent toCompressedXContent() { /** * Returns the root object for the current mapping */ - RootObjectMapper getRoot() { + public RootObjectMapper getRoot() { return root; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 54fe6e01946b4..586850eb948d3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -166,7 +166,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2a9c300e12c13..3fcd9049ae803 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,7 +55,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -284,7 +284,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fbf84762eb314..00dc195313a61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -38,7 +38,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +50,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { @@ -267,10 +267,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons Map newDocMap = indexRequest.sourceAsMap(); Map inferenceMap = new LinkedHashMap<>(); // ignore the existing inference map if any - newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { try { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceMap, fieldResponse.field(), fieldResponse.model(), @@ -295,6 +295,7 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); + boolean hasInput = false; for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { String field = entry.getKey(); String inferenceId = entry.getValue().inferenceId(); @@ -315,6 +316,7 @@ private Map> createFieldInferenceRequests(Bu if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + hasInput = true; } else { inferenceResults.get(item.id()).failures.add( new ElasticsearchStatusException( @@ -326,6 +328,12 @@ private Map> createFieldInferenceRequests(Bu ); } } + if (hasInput == false) { + // remove the existing _inference field (if present) since none of the content require inference. + if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { + indexRequest.source(docMap); + } + } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java new file mode 100644 index 0000000000000..9eeb7a5407bc4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -0,0 +1,449 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_inference": {
+ *             "my_semantic_text_field": {
+ *                  "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks" [
+ *                      {
+ *                          "inference": {
+ *                              "lucas": 0.05212344,
+ *                              "ty": 0.041213956,
+ *                              "dragon": 0.50991,
+ *                              "type": 0.23241979,
+ *                              "dr": 1.9312073,
+ *                              "##o": 0.2797593
+ *                          },
+ *                          "text": "these are not the droids you're looking for"
+ *                      }
+ *                  ]
+ *              }
+ *          }
+ *      }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_field": {
+ *                 "chunks": {
+ *                      "type": "nested",
+ *                      "properties": {
+ *                          "embedding": {
+ *                              "type": "sparse_vector|dense_vector"
+ *                          },
+ *                          "text": {
+ *                              "type": "keyword",
+ *                              "index": false,
+ *                              "doc_values": false
+ *                          }
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class InferenceMetadataFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String INFERENCE_ID = "inference_id"; + public static final String CHUNKS = "chunks"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); + + private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); + + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + public InferenceMetadataFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + boolean origWithLeafObject = context.path().isWithinLeafObject(); + try { + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); + parseSingleField(context); + } + } finally { + context.path().setWithinLeafObject(origWithLeafObject); + } + } + + private NestedObjectMapper updateSemanticTextFieldMapper( + DocumentParserContext docContext, + SemanticTextMapperContext semanticFieldContext, + String newInferenceId, + SemanticTextModelSettings newModelSettings, + XContentLocation xContentLocation + ) { + final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); + final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + if (newInferenceId.equals(inferenceId) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID, + inferenceId, + fullFieldName, + INFERENCE_ID, + newInferenceId + ) + ); + } + if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + fullFieldName + "] must contain dimensions" + ); + } + if (semanticFieldContext.mapper.getModelSettings() == null) { + SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( + semanticFieldContext.mapper.simpleName(), + docContext.indexSettings().getIndexVersionCreated() + ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); + docContext.addDynamicMapper(newMapper); + return newMapper.getSubMappers(); + } else { + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); + } + } + return semanticFieldContext.mapper.getSubMappers(); + } + + private void parseSingleField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); + SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); + if (builderContext == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + parser.nextToken(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + + // record the location of the inference field in the original source + XContentLocation xContentLocation = parser.getTokenLocation(); + // parse eagerly to extract the inference id and the model settings first + Map map = parser.mapOrdered(); + + // inference_id + Object inferenceIdObj = map.remove(INFERENCE_ID); + final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); + if (inferenceId == null) { + throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); + } + + // model_settings + Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); + if (modelSettingsObj == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format( + "Missing required [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + } + final SemanticTextModelSettings modelSettings; + try { + modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); + } catch (Exception exc) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "Error parsing [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ), + exc + ); + } + + var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); + + // we know the model settings, so we can (re) parse the results array now + XContentParser subParser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + DocumentParserContext mapContext = context.switchParser(subParser); + parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper); + } + + private void parseFieldInference( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + switch (parser.currentName()) { + case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper); + default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); + } + } + } + + private void parseChunks( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext subContext = context.createNestedContext(nestedMapper); + parseResultsObject(xContentLocation, parser, subContext, nestedMapper); + } + } + + private void parseResultsObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + Set visited = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); + visited.add(parser.currentName()); + FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); + if (fieldMapper == null) { + if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { + throw new DocumentParsingException( + xContentLocation, + "Missing sub-fields definition for [" + parser.currentName() + "]" + ); + } else { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } + } + parser.nextToken(); + fieldMapper.parse(context); + } + if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() + .filter(s -> visited.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); + } + } + + private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); + } + } + + private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(CHUNKS, chunks); + inferenceMap.put(field, fieldMap); + } + + record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} + + /** + * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} + * and the {@link MapperBuilderContext} that was used to build it. + * If the field is not found or is of the wrong type, this method returns {@code null}. + */ + static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { + ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); + return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\.")); + } + + static SemanticTextMapperContext createSemanticFieldContext( + MapperBuilderContext mapperContext, + ObjectMapper objectMapper, + String[] paths + ) { + Mapper mapper = objectMapper.getMapper(paths[0]); + if (mapper instanceof ObjectMapper newObjectMapper) { + mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); + return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length)); + } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } else { + if (mapper == null || paths.length == 1) { + return null; + } + // check if the semantic field is defined within a multi-field + Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); + if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } + } + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index 2ede5419ab74e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 83272a10f98d4..2445d5c8751a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,30 +9,50 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; /** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * A {@link FieldMapper} for semantic text fields. + * These fields have a reference id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceResultFieldMapper}. + * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -40,15 +60,39 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); + + private final IndexVersion indexVersionCreated; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + CopyTo copyTo, + IndexVersion indexVersionCreated, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers + ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.indexVersionCreated = indexVersionCreated; + this.modelSettings = modelSettings; + this.subMappers = subMappers; + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(subMappers); + return subIterators.iterator(); } @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); + return new Builder(simpleName(), indexVersionCreated).init(this); } @Override @@ -67,39 +111,100 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; + } + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; - private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) - .addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + m -> toType(m).fieldType().inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (n, c, o) -> SemanticTextModelSettings.fromMap(o), + mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, + XContentBuilder::field, + (m) -> m == null ? "null" : Strings.toString(m) + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - public Builder(String name) { + public Builder(String name, IndexVersion indexVersionCreated) { super(name); + this.indexVersionCreated = indexVersionCreated; + } + + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextModelSettings value) { + this.modelSettings.setValue(value); + return this; } @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; + return new Parameter[] { inferenceId, modelSettings, meta }; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + final String fullName = context.buildFullName(name()); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); + nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) + .indexed(false) + .docValues(false); + if (modelSettings.get() != null) { + nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + } + nestedBuilder.add(textMapperBuilder); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + var subMappers = nestedBuilder.build(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), + copyTo, + indexVersionCreated, + modelSettings.getValue(), + subMappers + ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private final String inferenceId; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private final String modelId; - - public SemanticTextFieldType(String name, String modelId, Map meta) { + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers, + Map meta + ) { super(name, false, false, false, TextSearchInfo.NONE, meta); - this.modelId = modelId; + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.subMappers = subMappers; } @Override @@ -109,7 +214,15 @@ public String typeName() { @Override public String getInferenceId() { - return modelId; + return inferenceId; + } + + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; } @Override @@ -127,4 +240,59 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return super.syntheticFieldLoader(); + } + + private static Mapper.Builder createInferenceMapperBuilder( + String fieldName, + SemanticTextModelSettings modelSettings, + IndexVersion indexVersionCreated + ) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException( + "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() + ); + }; + } + + static boolean canMergeModelSettings( + SemanticTextModelSettings previous, + SemanticTextModelSettings current, + FieldMapper.Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + if (current == null) { + conflicts.addConflict("model_settings", ""); + return false; + } + conflicts.addConflict("model_settings", ""); + return false; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 1b6bb22c0d6b5..b1d0511008db8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,73 +7,100 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; + /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ -public class SemanticTextModelSettings { +public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); private final TaskType taskType; - private final String inferenceId; private final Integer dimensions; private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public SemanticTextModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; - this.inferenceId = inferenceId; this.dimensions = dimensions; this.similarity = similarity; - } - - public SemanticTextModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); + validate(); } public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); - }); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new SemanticTextModelSettings(taskType, dimensions, similarity); + } + ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } + public static SemanticTextModelSettings fromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, NAME); + if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { + throw new IllegalArgumentException( + "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" + ); + } + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + public Map asMap() { Map attrsMap = new HashMap<>(); attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); if (dimensions != null) { attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); } @@ -87,10 +114,6 @@ public TaskType taskType() { return taskType; } - public String inferenceId() { - return inferenceId; - } - public Integer dimensions() { return dimensions; } @@ -98,4 +121,61 @@ public Integer dimensions() { public SimilarityMeasure similarity() { return similarity; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + public void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticTextModelSettings that = (SemanticTextModelSettings) o; + return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, dimensions, similarity); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index a7d3fcce26116..bf3cc6334433a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -31,7 +31,7 @@ protected Collection> getPlugins() { public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", - client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); assertEquals( indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), @@ -46,7 +46,7 @@ public void testAddSemanticTextField() throws Exception { final ClusterService clusterService = getInstanceFromNode(ClusterService.class); final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { "properties": { "field": { "type": "semantic_text", "model_id": "test_model" }}}"""); + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); request.indices(new Index[] { indexService.index() }); final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( clusterService.state(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4a1825303b5a7..8b18cf74236a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -32,7 +32,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +51,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -285,7 +285,7 @@ private static BulkItemRequest[] randomBulkItemRequest( final ChunkedInferenceServiceResults results; switch (taskType) { case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); + results = randomTextEmbeddings(model, chunks); break; case SPARSE_EMBEDDING: @@ -296,10 +296,10 @@ private static BulkItemRequest[] randomBulkItemRequest( throw new AssertionError("Unknown task type " + taskType.name()); } model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java similarity index 57% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b5d75b528c6ab..37e4e5e774bec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -51,26 +53,28 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} +public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int numChunks) {} + private record VisitedChildDocInfo(String path) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return InferenceResultFieldMapper.NAME; + return InferenceMetadataFieldMapper.NAME; } @Override @@ -94,109 +98,129 @@ protected Collection getPlugins() { } public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) + ) ) ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) ); - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) ); - assertEquals(0, topDocs.totalHits.value); - } - }); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a", "b") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("d") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("z") + ), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } } public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); { DocumentParsingException ex = expectThrows( @@ -206,7 +230,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, Map.of() @@ -224,7 +248,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, Map.of() @@ -242,7 +266,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, Map.of() @@ -259,15 +283,18 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -358,28 +385,97 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults( + fieldName, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + List.of("a b") + ) + ) + ) + ) ) ); assertThat( ex.getMessage(), containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ) ); } + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .startObject(SemanticTextModelSettings.NAME) + .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .startObject(SemanticTextModelSettings.NAME) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); + } + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("model_id", modelId); + mappingBuilder.field("inference_id", modelId); mappingBuilder.endObject(); } - public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { List chunks = new ArrayList<>(); for (String input : inputs) { - double[] values = new double[5]; + double[] values = new double[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = randomDouble(); } @@ -400,8 +496,17 @@ public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List return new ChunkedSparseEmbeddingResults(chunks); } - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults( + String semanticTextFieldName, + Model model, + List chunks + ) { + ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); + default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); + }; + return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); } private static void addSemanticTextInferenceResults( @@ -425,16 +530,16 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceResultsMap, semanticTextInferenceResult.fieldName, - randomModel(), + semanticTextInferenceResult.model, semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(RESULTS); + List> fieldResultList = (List>) optionsMap.get(CHUNKS); for (var entry : fieldResultList) { if (includeTextSubfield == false) { entry.remove(INFERENCE_CHUNKS_TEXT); @@ -445,15 +550,26 @@ private static void addSemanticTextInferenceResults( entry.putAll(extraSubfields); } } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); } - private static Model randomModel() { + private static Model randomModel(TaskType taskType) { String serviceName = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomAlphaOfLengthBetween(5, 10); return new TestModel( inferenceId, - TaskType.SPARSE_EMBEDDING, + taskType, serviceName, new TestModel.TestServiceSettings("my-model"), new TestModel.TestTaskSettings(randomIntBetween(1, 100)), @@ -461,29 +577,6 @@ private static Model randomModel() { ); } - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); assertNotNull(mapper); @@ -503,12 +596,10 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook private static void assertValidChildDoc( LuceneDocument childDoc, LuceneDocument expectedParent, - Set visitedChildDocs + Collection visitedChildDocs ) { assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); + visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); } private static void assertChildLeafNestedDocument( @@ -524,4 +615,15 @@ private static void assertChildLeafNestedDocument( assertNotNull(leaf.nestedIdentity()); visitedNestedIdentities.add(leaf.nestedIdentity()); } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..1b5311ac9effb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -11,11 +11,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -26,52 +32,12 @@ import java.util.List; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { - - public void testDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - public void testModelIdNotPresent() throws IOException { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) - ); - assertThat(e.getMessage(), containsString("field [model_id] must be specified")); - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToModelIdNotSupported() throws IOException { - MapperService mapperService = createMapperService( - fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) - ); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); - } - @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); @@ -79,7 +45,12 @@ protected Collection getPlugins() { @Override protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("model_id", "test_model"); + b.field("type", "semantic_text").field("inference_id", "test_model"); + } + + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; } @Override @@ -115,4 +86,180 @@ protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .field("similarity", "cosine") + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\"}] " + + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( + MapperBuilderContext.root(false, false), + mapperService.mappingLookup().getMapping().getRoot(), + fieldName.split("\\.") + ); + Mapper mapper = res.mapper(); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); + assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); + + NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); + assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); + Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.getModelSettings()); + Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(inferenceMapper); + switch (semanticFieldMapper.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.getModelSettings()); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..b64485a3d3fb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -121,6 +122,16 @@ public void writeTo(StreamOutput out) throws IOException { public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.COSINE; + } + + @Override + public Integer dimensions() { + return 100; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 6008ebbcbedf8..528003e278aeb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -27,6 +27,7 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, + "similarity": "cosine", "api_key": "abc64" }, "task_settings": { @@ -41,10 +42,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -56,10 +57,10 @@ setup: properties: inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id another_inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -83,11 +84,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +121,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- @@ -154,8 +155,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +175,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +215,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +234,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -247,10 +248,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -271,11 +272,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -287,7 +288,7 @@ setup: properties: inference_field: type: semantic_text - model_id: non-existing-inference-id + inference_id: non-existing-inference-id non_inference_field: type: text diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 2c69f49218091..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -27,7 +27,8 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, - "api_key": "abc64" + "api_key": "abc64", + "similarity": "cosine" }, "task_settings": { } @@ -41,10 +42,10 @@ setup: properties: sparse_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id dense_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -55,25 +56,7 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 - - text: "another inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 + sparse_field: "you know, for testing" --- "Dense vector results format": @@ -82,72 +65,4 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - dimensions: 5 - similarity: cosine - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] - ---- -"Model settings inference id not included": - - do: - catch: /Required \[inference_id\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings task type not included": - - do: - catch: /Required \[task_type\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings dense vector dimensions not included": - - do: - catch: /Model settings for field \[dense_field\] must contain dimensions/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + dense_field: "you know, for testing" From 122e4395d7329547042998f03dbea083ff334c38 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 27 Mar 2024 17:22:49 -0400 Subject: [PATCH 10/29] Fix build error --- .../java/org/elasticsearch/index/mapper/MapperTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index fa0f0e1b95f54..34ccc4599811b 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -1030,7 +1030,7 @@ public final void testMinimalIsInvalidInRoutingPath() throws IOException { } } - private String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { if (mapper instanceof FieldMapper fieldMapper && fieldMapper.fieldType().isDimension() == false) { return "All fields that match routing_path must be configured with [time_series_dimension: true] " + "or flattened fields with a list of dimensions in [time_series_dimensions] and " From ef3abd96171e7386caf98a6884c48ed1cc24db0a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 28 Mar 2024 14:25:51 +0000 Subject: [PATCH 11/29] [feature/semantic-text] Simplify the integration of the field inference metadata in `IndexMetadata` (#106743) This change refactors the integration of the field inference metadata in IndexMetadata. Instead of partial diffs, the new class simply sends the entire object as diff if it has changed. This PR also rename the fields and methods related to the inference fields consistently. The inference phase (in the transport shard bulk action) is also changed so that inference is not called if: The document contains a value for the inference input. The document also contains a value for the inference results of that field (in the _inference map). If the document contains no value for the inference input but an inference result for that field, it is marked as failed. --------- Co-authored-by: carlosdelest --- .../cluster/ClusterStateDiffIT.java | 4 +- .../action/bulk/BulkOperation.java | 4 +- .../action/bulk/BulkShardRequest.java | 19 +- .../action/update/TransportUpdateAction.java | 16 +- .../metadata/FieldInferenceMetadata.java | 190 -------------- .../cluster/metadata/IndexMetadata.java | 112 ++++---- .../metadata/InferenceFieldMetadata.java | 125 +++++++++ .../metadata/MetadataCreateIndexService.java | 11 +- .../metadata/MetadataMappingService.java | 8 +- .../index/mapper/FieldTypeLookup.java | 14 - .../index/mapper/InferenceFieldMapper.java | 28 ++ .../index/mapper/InferenceModelFieldType.java | 21 -- .../index/mapper/MapperMergeContext.java | 4 +- .../index/mapper/MappingLookup.java | 35 ++- .../cluster/metadata/IndexMetadataTests.java | 40 ++- .../metadata/InferenceFieldMetadataTests.java | 66 +++++ .../index/mapper/FieldTypeLookupTests.java | 28 -- .../index/mapper/MappingLookupTests.java | 18 -- .../mapper/MockInferenceModelFieldType.java | 45 ---- .../xpack/inference/InferencePlugin.java | 10 +- .../ShardBulkInferenceActionFilter.java | 248 +++++++++++++----- .../mapper/InferenceMetadataFieldMapper.java | 9 +- .../mapper/SemanticTextFieldMapper.java | 74 ++++-- .../SemanticTextClusterMetadataTests.java | 10 +- .../ShardBulkInferenceActionFilterTests.java | 65 ++--- .../inference/10_semantic_text_inference.yml | 107 ++++++-- .../20_semantic_text_field_mapper.yml | 20 ++ 27 files changed, 758 insertions(+), 573 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java create mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java delete mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java create mode 100644 server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java delete mode 100644 test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index fbb3016b925da..e0dbc74567053 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -61,7 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; -import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomFieldInferenceMetadata; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -587,7 +587,7 @@ public IndexMetadata randomChange(IndexMetadata part) { builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; case 3: - builder.fieldInferenceMetadata(randomFieldInferenceMetadata(true)); + builder.putInferenceFields(randomInferenceFields()); break; default: throw new IllegalArgumentException("Shouldn't be here"); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index a6439769b51b4..e66426562a92e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -294,8 +294,8 @@ private void executeBulkRequestsByShard( requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); + if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 39fa791a3e27d..8d1618b443ace 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,7 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -23,6 +23,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +35,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest inferenceFieldMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +52,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { - this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; + public void setInferenceFieldMap(Map fieldInferenceMap) { + this.inferenceFieldMap = fieldInferenceMap; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public FieldInferenceMetadata consumeFieldInferenceMetadata() { - FieldInferenceMetadata ret = fieldsInferenceMetadataMap; - fieldsInferenceMetadataMap = null; + public Map consumeInferenceFieldMap() { + Map ret = inferenceFieldMap; + inferenceFieldMap = null; return ret; } /** * Public for test */ - public FieldInferenceMetadata getFieldsInferenceMetadataMap() { - return fieldsInferenceMetadataMap; + public Map getInferenceFieldMap() { + return inferenceFieldMap; } public long totalSizeInBytes() { diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 63ae56bfbd047..36a47bc7e02e9 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -40,6 +40,7 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -184,7 +185,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis); switch (result.getResponseResult()) { case CREATED -> { - IndexRequest upsertRequest = result.action(); + IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference upsertSourceBytes = upsertRequest.source(); client.bulk( @@ -226,7 +227,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< ); } case UPDATED -> { - IndexRequest indexRequest = result.action(); + IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action()); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference indexSourceBytes = indexRequest.source(); client.bulk( @@ -335,4 +336,15 @@ private void handleUpdateFailureWithRetry( } listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause)); } + + private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) { + var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields(); + if (inferenceMetadata.isEmpty()) { + return request; + } + Map docMap = request.sourceAsMap(); + docMap.remove(InferenceFieldMapper.NAME); + request.source(docMap); + return request; + } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java deleted file mode 100644 index 349706c139127..0000000000000 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/FieldInferenceMetadata.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.cluster.metadata; - -import org.elasticsearch.cluster.Diff; -import org.elasticsearch.cluster.Diffable; -import org.elasticsearch.cluster.DiffableUtils; -import org.elasticsearch.cluster.SimpleDiffable; -import org.elasticsearch.common.collect.ImmutableOpenMap; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.mapper.MappingLookup; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; - -/** - * Contains field inference information. This is necessary to add to cluster state as inference can be calculated in the coordinator - * node, which not necessarily has mapping information. - */ -public class FieldInferenceMetadata implements Diffable, ToXContentFragment { - - private final ImmutableOpenMap fieldInferenceOptions; - - public static final FieldInferenceMetadata EMPTY = new FieldInferenceMetadata(ImmutableOpenMap.of()); - - public FieldInferenceMetadata(MappingLookup mappingLookup) { - ImmutableOpenMap.Builder builder = ImmutableOpenMap.builder(); - mappingLookup.getInferenceIdsForFields().entrySet().forEach(entry -> { - builder.put(entry.getKey(), new FieldInferenceOptions(entry.getValue(), mappingLookup.sourcePaths(entry.getKey()))); - }); - fieldInferenceOptions = builder.build(); - } - - public FieldInferenceMetadata(StreamInput in) throws IOException { - fieldInferenceOptions = in.readImmutableOpenMap(StreamInput::readString, FieldInferenceOptions::new); - } - - public FieldInferenceMetadata(Map fieldsToInferenceMap) { - fieldInferenceOptions = ImmutableOpenMap.builder(fieldsToInferenceMap).build(); - } - - public ImmutableOpenMap getFieldInferenceOptions() { - return fieldInferenceOptions; - } - - public boolean isEmpty() { - return fieldInferenceOptions.isEmpty(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeMap(fieldInferenceOptions, (o, v) -> v.writeTo(o)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(fieldInferenceOptions); - return builder; - } - - public static FieldInferenceMetadata fromXContent(XContentParser parser) throws IOException { - return new FieldInferenceMetadata(parser.map(HashMap::new, FieldInferenceOptions::fromXContent)); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FieldInferenceMetadata that = (FieldInferenceMetadata) o; - return Objects.equals(fieldInferenceOptions, that.fieldInferenceOptions); - } - - @Override - public int hashCode() { - return Objects.hash(fieldInferenceOptions); - } - - @Override - public Diff diff(FieldInferenceMetadata previousState) { - if (previousState == null) { - previousState = EMPTY; - } - return new FieldInferenceMetadataDiff(previousState, this); - } - - static class FieldInferenceMetadataDiff implements Diff { - - public static final FieldInferenceMetadataDiff EMPTY = new FieldInferenceMetadataDiff( - FieldInferenceMetadata.EMPTY, - FieldInferenceMetadata.EMPTY - ); - - private final Diff> fieldInferenceMapDiff; - - private static final DiffableUtils.DiffableValueReader FIELD_INFERENCE_DIFF_VALUE_READER = - new DiffableUtils.DiffableValueReader<>(FieldInferenceOptions::new, FieldInferenceMetadataDiff::readDiffFrom); - - FieldInferenceMetadataDiff(FieldInferenceMetadata before, FieldInferenceMetadata after) { - fieldInferenceMapDiff = DiffableUtils.diff( - before.fieldInferenceOptions, - after.fieldInferenceOptions, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - FieldInferenceMetadataDiff(StreamInput in) throws IOException { - fieldInferenceMapDiff = DiffableUtils.readImmutableOpenMapDiff( - in, - DiffableUtils.getStringKeySerializer(), - FIELD_INFERENCE_DIFF_VALUE_READER - ); - } - - public static Diff readDiffFrom(StreamInput in) throws IOException { - return SimpleDiffable.readDiffFrom(FieldInferenceOptions::new, in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - fieldInferenceMapDiff.writeTo(out); - } - - @Override - public FieldInferenceMetadata apply(FieldInferenceMetadata part) { - return new FieldInferenceMetadata(fieldInferenceMapDiff.apply(part.fieldInferenceOptions)); - } - } - - public record FieldInferenceOptions(String inferenceId, Set sourceFields) - implements - SimpleDiffable, - ToXContentFragment { - - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - public static final ParseField SOURCE_FIELDS_FIELD = new ParseField("source_fields"); - - FieldInferenceOptions(StreamInput in) throws IOException { - this(in.readString(), in.readCollectionAsImmutableSet(StreamInput::readString)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(inferenceId); - out.writeStringCollection(sourceFields); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - builder.field(SOURCE_FIELDS_FIELD.getPreferredName(), sourceFields); - builder.endObject(); - return builder; - } - - public static FieldInferenceOptions fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "field_inference_parser", - false, - (args, unused) -> new FieldInferenceOptions((String) args[0], new HashSet<>((List) args[1])) - ); - - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), SOURCE_FIELDS_FIELD); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 89c925427cf88..b66da654f8a1c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -576,6 +576,8 @@ public Iterator> settings() { @Nullable private final MappingMetadata mapping; + private final ImmutableOpenMap inferenceFields; + private final ImmutableOpenMap customData; private final Map> inSyncAllocationIds; @@ -631,7 +633,6 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - private final FieldInferenceMetadata fieldInferenceMetadata; private IndexMetadata( final Index index, @@ -645,6 +646,7 @@ private IndexMetadata( final int numberOfReplicas, final Settings settings, final MappingMetadata mapping, + final ImmutableOpenMap inferenceFields, final ImmutableOpenMap aliases, final ImmutableOpenMap customData, final Map> inSyncAllocationIds, @@ -677,8 +679,7 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast, - @Nullable FieldInferenceMetadata fieldInferenceMetadata + @Nullable Long shardSizeInBytesForecast ) { this.index = index; this.version = version; @@ -696,6 +697,7 @@ private IndexMetadata( this.totalNumberOfShards = numberOfShards * (numberOfReplicas + 1); this.settings = settings; this.mapping = mapping; + this.inferenceFields = inferenceFields; this.customData = customData; this.aliases = aliases; this.inSyncAllocationIds = inSyncAllocationIds; @@ -734,7 +736,6 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -753,6 +754,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.numberOfReplicas, this.settings, mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -785,8 +787,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -812,6 +813,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, Maps.copyMapWithAddedOrReplacedEntry(this.inSyncAllocationIds, shardId, Set.copyOf(inSyncSet)), @@ -844,8 +846,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -869,6 +870,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -901,8 +903,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -926,6 +927,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -958,8 +960,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -979,6 +980,7 @@ public IndexMetadata withIncrementedVersion() { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -1011,8 +1013,7 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast, - this.fieldInferenceMetadata + this.shardSizeInBytesForecast ); } @@ -1203,6 +1204,10 @@ public MappingMetadata mapping() { return mapping; } + public Map getInferenceFields() { + return inferenceFields; + } + @Nullable public IndexMetadataStats getStats() { return stats; @@ -1216,10 +1221,6 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public FieldInferenceMetadata getFieldInferenceMetadata() { - return fieldInferenceMetadata; - } - public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1417,7 +1418,7 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldInferenceMetadata.equals(that.fieldInferenceMetadata) == false) { + if (inferenceFields.equals(that.inferenceFields) == false) { return false; } if (isSystem != that.isSystem) { @@ -1440,7 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldInferenceMetadata.hashCode(); + result = 31 * result + inferenceFields.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1487,6 +1488,7 @@ private static class IndexMetadataDiff implements Diff { @Nullable private final Diff settingsDiff; private final Diff> mappings; + private final Diff> inferenceFields; private final Diff> aliases; private final Diff> customData; private final Diff>> inSyncAllocationIds; @@ -1496,7 +1498,6 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; - private final Diff fieldInferenceMetadata; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1519,6 +1520,7 @@ private static class IndexMetadataDiff implements Diff { : ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, after.mapping).build(), DiffableUtils.getStringKeySerializer() ); + inferenceFields = DiffableUtils.diff(before.inferenceFields, after.inferenceFields, DiffableUtils.getStringKeySerializer()); aliases = DiffableUtils.diff(before.aliases, after.aliases, DiffableUtils.getStringKeySerializer()); customData = DiffableUtils.diff(before.customData, after.customData, DiffableUtils.getStringKeySerializer()); inSyncAllocationIds = DiffableUtils.diff( @@ -1533,7 +1535,6 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; - fieldInferenceMetadata = after.fieldInferenceMetadata.diff(before.fieldInferenceMetadata); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1544,6 +1545,8 @@ private static class IndexMetadataDiff implements Diff { new DiffableUtils.DiffableValueReader<>(DiffableStringMap::readFrom, DiffableStringMap::readDiffFrom); private static final DiffableUtils.DiffableValueReader ROLLOVER_INFO_DIFF_VALUE_READER = new DiffableUtils.DiffableValueReader<>(RolloverInfo::new, RolloverInfo::readDiffFrom); + private static final DiffableUtils.DiffableValueReader INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(InferenceFieldMetadata::new, InferenceFieldMetadata::readDiffFrom); IndexMetadataDiff(StreamInput in) throws IOException { index = in.readString(); @@ -1566,6 +1569,15 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER + ); + } else { + inferenceFields = DiffableUtils.emptyDiff(); + } aliases = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), ALIAS_METADATA_DIFF_VALUE_READER); customData = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), CUSTOM_DIFF_VALUE_READER); inSyncAllocationIds = DiffableUtils.readJdkMapDiff( @@ -1593,11 +1605,6 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldInferenceMetadata = in.readOptionalWriteable(FieldInferenceMetadata.FieldInferenceMetadataDiff::new); - } else { - fieldInferenceMetadata = FieldInferenceMetadata.FieldInferenceMetadataDiff.EMPTY; - } } @Override @@ -1620,6 +1627,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields.writeTo(out); + } aliases.writeTo(out); customData.writeTo(out); inSyncAllocationIds.writeTo(out); @@ -1633,9 +1643,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeOptionalWriteable(fieldInferenceMetadata); - } } @Override @@ -1656,6 +1663,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.mapping = mappings.apply( ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, part.mapping).build() ).get(MapperService.SINGLE_MAPPING_NAME); + builder.inferenceFields.putAllFromMap(inferenceFields.apply(part.inferenceFields)); builder.aliases.putAllFromMap(aliases.apply(part.aliases)); builder.customMetadata.putAllFromMap(customData.apply(part.customData)); builder.inSyncAllocationIds.putAll(inSyncAllocationIds.apply(part.inSyncAllocationIds)); @@ -1665,7 +1673,6 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldInferenceMetadata(fieldInferenceMetadata.apply(part.fieldInferenceMetadata)); return builder.build(true); } } @@ -1702,6 +1709,10 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); + } int aliasesSize = in.readVInt(); for (int i = 0; i < aliasesSize; i++) { AliasMetadata aliasMd = new AliasMetadata(in); @@ -1733,9 +1744,6 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function inferenceFields; private final ImmutableOpenMap.Builder aliases; private final ImmutableOpenMap.Builder customMetadata; private final Map> inSyncAllocationIds; @@ -1834,10 +1843,10 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; public Builder(String index) { this.index = index; + this.inferenceFields = ImmutableOpenMap.builder(); this.aliases = ImmutableOpenMap.builder(); this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); @@ -1855,6 +1864,7 @@ public Builder(IndexMetadata indexMetadata) { this.settings = indexMetadata.getSettings(); this.primaryTerms = indexMetadata.primaryTerms.clone(); this.mapping = indexMetadata.mapping; + this.inferenceFields = ImmutableOpenMap.builder(indexMetadata.inferenceFields); this.aliases = ImmutableOpenMap.builder(indexMetadata.aliases); this.customMetadata = ImmutableOpenMap.builder(indexMetadata.customData); this.routingNumShards = indexMetadata.routingNumShards; @@ -1866,7 +1876,6 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldInferenceMetadata = indexMetadata.fieldInferenceMetadata; } public Builder index(String index) { @@ -2096,8 +2105,13 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder fieldInferenceMetadata(FieldInferenceMetadata fieldInferenceMetadata) { - this.fieldInferenceMetadata = Objects.requireNonNullElse(fieldInferenceMetadata, FieldInferenceMetadata.EMPTY); + public Builder putInferenceField(InferenceFieldMetadata value) { + this.inferenceFields.put(value.getName(), value); + return this; + } + + public Builder putInferenceFields(Map values) { + this.inferenceFields.putAllFromMap(values); return this; } @@ -2263,6 +2277,7 @@ IndexMetadata build(boolean repair) { numberOfReplicas, settings, mapping, + inferenceFields.build(), aliasesMap, newCustomMetadata, Map.ofEntries(denseInSyncAllocationIds), @@ -2295,8 +2310,7 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast, - fieldInferenceMetadata + shardSizeInBytesForecast ); } @@ -2422,8 +2436,12 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldInferenceMetadata.isEmpty() == false) { - builder.field(KEY_FIELD_INFERENCE, indexMetadata.fieldInferenceMetadata); + if (indexMetadata.getInferenceFields().isEmpty() == false) { + builder.startObject(KEY_FIELD_INFERENCE); + for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { + field.toXContent(builder, params); + } + builder.endObject(); } builder.endObject(); @@ -2504,7 +2522,9 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { + private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private final String name; + private final String inferenceId; + private final String[] sourceFields; + + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this.name = Objects.requireNonNull(name); + this.inferenceId = Objects.requireNonNull(inferenceId); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public InferenceFieldMetadata(StreamInput input) throws IOException { + this.name = input.readString(); + this.inferenceId = input.readString(); + this.sourceFields = input.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(inferenceId); + out.writeStringArray(sourceFields); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceFieldMetadata that = (InferenceFieldMetadata) o; + return inferenceId.equals(that.inferenceId) && Arrays.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + int result = Objects.hash(inferenceId); + result = 31 * result + Arrays.hashCode(sourceFields); + return result; + } + + public String getName() { + return name; + } + + public String getInferenceId() { + return inferenceId; + } + + public String[] getSourceFields() { + return sourceFields; + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(name); + builder.field(INFERENCE_ID_FIELD, inferenceId); + builder.array(SOURCE_FIELDS_FIELD, sourceFields); + return builder.endObject(); + } + + public static InferenceFieldMetadata fromXContent(XContentParser parser) throws IOException { + final String name = parser.currentName(); + + XContentParser.Token token = parser.nextToken(); + if (token == null) { + // no data... + return null; + } + String currentFieldName = null; + String inferenceId = null; + List inputFields = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_STRING) { + if (INFERENCE_ID_FIELD.equals(currentFieldName)) { + inferenceId = parser.text(); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_STRING) { + inputFields.add(parser.text()); + } else { + parser.skipChildren(); + } + } + } + } else { + parser.skipChildren(); + } + } + return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index 96ca7a15edc30..52642e1de8ac9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1263,12 +1263,11 @@ static IndexMetadata buildIndexMetadata( indexMetadataBuilder.system(isSystem); // now, update the mappings with the actual source Map mappingsMetadata = new HashMap<>(); - DocumentMapper mapper = documentMapperSupplier.get(); - if (mapper != null) { - MappingMetadata mappingMd = new MappingMetadata(mapper); - mappingsMetadata.put(mapper.type(), mappingMd); - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(mapper.mappers()); - indexMetadataBuilder.fieldInferenceMetadata(fieldInferenceMetadata); + DocumentMapper docMapper = documentMapperSupplier.get(); + if (docMapper != null) { + MappingMetadata mappingMd = new MappingMetadata(docMapper); + mappingsMetadata.put(docMapper.type(), mappingMd); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 0e31592991369..e7c2bb9ae9b9a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -201,10 +201,10 @@ private static ClusterState applyRequest( IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(indexMetadata); // Mapping updates on a single type may have side-effects on other types so we need to // update mapping metadata on all types - DocumentMapper mapper = mapperService.documentMapper(); - if (mapper != null) { - indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); - indexMetadataBuilder.fieldInferenceMetadata(new FieldInferenceMetadata(mapper.mappers())); + DocumentMapper docMapper = mapperService.documentMapper(); + if (docMapper != null) { + indexMetadataBuilder.putMapping(new MappingMetadata(docMapper)); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 0741cfa682b74..5e3dbe9590b99 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,11 +36,6 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; - /** - * A map from inference model ID to all fields that use the model to generate embeddings. - */ - private final Map inferenceIdsForFields; - private final int maxParentPathDots; FieldTypeLookup( @@ -53,7 +48,6 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map inferenceIdsForFields = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -71,9 +65,6 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } - if (fieldType instanceof InferenceModelFieldType inferenceModelFieldType) { - inferenceIdsForFields.put(fieldName, inferenceModelFieldType.getInferenceId()); - } } int maxParentPathDots = 0; @@ -106,7 +97,6 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); - this.inferenceIdsForFields = Map.copyOf(inferenceIdsForFields); } public static int dotCount(String path) { @@ -215,10 +205,6 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } - Map getInferenceIdsForFields() { - return inferenceIdsForFields; - } - /** * If field is a leaf multi-field return the path to the parent field. Otherwise, return null. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java new file mode 100644 index 0000000000000..078ef391f17ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.inference.InferenceService; + +import java.util.Set; + +/** + * Field mapper that requires to transform its input before indexation through the {@link InferenceService}. + */ +public interface InferenceFieldMapper { + String NAME = "_inference"; + + /** + * Retrieve the inference metadata associated with this mapper. + * + * @param sourcePaths The source path that populates the input for the field (before inference) + */ + InferenceFieldMetadata getMetadata(Set sourcePaths); +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java deleted file mode 100644 index 6e12a204ed7d0..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -/** - * Field type that uses an inference model. - */ -public interface InferenceModelFieldType { - /** - * Retrieve inference model used by the field type. - * - * @return model id used by the field type - */ - String getInferenceId(); -} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java index 8f8854ad47c7d..ddf6f339cbbb6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java @@ -46,7 +46,7 @@ public static MapperMergeContext from(MapperBuilderContext mapperBuilderContext, * @param name the name of the child context * @return a new {@link MapperMergeContext} with this context as its parent */ - MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { + public MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { return createChildContext(mapperBuilderContext.createChildContext(name, dynamic)); } @@ -60,7 +60,7 @@ MapperMergeContext createChildContext(MapperBuilderContext childContext) { return new MapperMergeContext(childContext, newFieldsBudget); } - MapperBuilderContext getMapperBuilderContext() { + public MapperBuilderContext getMapperBuilderContext() { return mapperBuilderContext; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index c2bd95115f27e..bf879f30e5a29 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -10,9 +10,11 @@ import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.inference.InferenceService; import java.util.ArrayList; import java.util.Collection; @@ -47,6 +49,7 @@ private CacheKey() {} /** Full field name to mapper */ private final Map fieldMappers; private final Map objectMappers; + private final Map inferenceFields; private final int runtimeFieldMappersCount; private final NestedLookup nestedLookup; private final FieldTypeLookup fieldTypeLookup; @@ -84,12 +87,12 @@ private static void collect( Collection fieldMappers, Collection fieldAliasMappers ) { - if (mapper instanceof ObjectMapper) { - objectMappers.add((ObjectMapper) mapper); - } else if (mapper instanceof FieldMapper) { - fieldMappers.add((FieldMapper) mapper); - } else if (mapper instanceof FieldAliasMapper) { - fieldAliasMappers.add((FieldAliasMapper) mapper); + if (mapper instanceof ObjectMapper objectMapper) { + objectMappers.add(objectMapper); + } else if (mapper instanceof FieldMapper fieldMapper) { + fieldMappers.add(fieldMapper); + } else if (mapper instanceof FieldAliasMapper fieldAliasMapper) { + fieldAliasMappers.add(fieldAliasMapper); } else { throw new IllegalStateException("Unrecognized mapper type [" + mapper.getClass().getSimpleName() + "]."); } @@ -174,6 +177,15 @@ private MappingLookup( final Collection runtimeFields = mapping.getRoot().runtimeFields(); this.fieldTypeLookup = new FieldTypeLookup(mappers, aliasMappers, runtimeFields); + + Map inferenceFields = new HashMap<>(); + for (FieldMapper mapper : mappers) { + if (mapper instanceof InferenceFieldMapper inferenceFieldMapper) { + inferenceFields.put(mapper.name(), inferenceFieldMapper.getMetadata(fieldTypeLookup.sourcePaths(mapper.name()))); + } + } + this.inferenceFields = Map.copyOf(inferenceFields); + if (runtimeFields.isEmpty()) { // without runtime fields this is the same as the field type lookup this.indexTimeLookup = fieldTypeLookup; @@ -360,6 +372,13 @@ public Map objectMappers() { return objectMappers; } + /** + * Returns a map containing all fields that require to run inference (through the {@link InferenceService} prior to indexation. + */ + public Map inferenceFields() { + return inferenceFields; + } + public NestedLookup nestedLookup() { return nestedLookup; } @@ -523,8 +542,4 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } - - public Map getInferenceIdsForFields() { - return fieldTypeLookup.getInferenceIdsForFields(); - } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index b32873df71365..45ffba25eb558 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -27,7 +27,6 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.shard.ShardId; @@ -84,7 +83,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(true); + Map dynamicFields = randomInferenceFields(); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -110,7 +109,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldInferenceMetadata(fieldInferenceMetadata) + .putInferenceFields(dynamicFields) .build(); assertEquals(system, metadata.isSystem()); @@ -145,7 +144,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), fromXContentMeta.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), fromXContentMeta.getInferenceFields()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -169,7 +168,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), deserialized.getStats()); assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldInferenceMetadata(), deserialized.getFieldInferenceMetadata()); + assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); } } @@ -553,35 +552,32 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldInferenceMetadata() { + public void testInferenceFieldMetadata() { Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertSame(idxMeta1.getFieldInferenceMetadata(), FieldInferenceMetadata.EMPTY); + assertTrue(idxMeta1.getInferenceFields().isEmpty()); - FieldInferenceMetadata fieldInferenceMetadata = randomFieldInferenceMetadata(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldInferenceMetadata(fieldInferenceMetadata).build(); - assertThat(idxMeta2.getFieldInferenceMetadata(), equalTo(fieldInferenceMetadata)); + Map dynamicFields = randomInferenceFields(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).putInferenceFields(dynamicFields).build(); + assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); } private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - public static FieldInferenceMetadata randomFieldInferenceMetadata(boolean allowNull) { - if (randomBoolean() && allowNull) { - return null; + public static Map randomInferenceFields() { + Map map = new HashMap<>(); + int numFields = randomIntBetween(0, 5); + for (int i = 0; i < numFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + map.put(field, randomInferenceFieldMetadata(field)); } - - Map fieldInferenceMap = randomMap( - 0, - 10, - () -> new Tuple<>(randomIdentifier(), randomFieldInference()) - ); - return new FieldInferenceMetadata(fieldInferenceMap); + return map; } - private static FieldInferenceMetadata.FieldInferenceOptions randomFieldInference() { - return new FieldInferenceMetadata.FieldInferenceOptions(randomIdentifier(), randomSet(0, 5, ESTestCase::randomIdentifier)); + private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { + return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); } private IndexMetadataStats randomIndexStats(int numberOfShards) { diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java new file mode 100644 index 0000000000000..958d86535ae76 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Predicate; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceFieldMetadataTests extends AbstractXContentTestCase { + + public void testSerialization() throws IOException { + final InferenceFieldMetadata before = createTestItem(); + final BytesStreamOutput out = new BytesStreamOutput(); + before.writeTo(out); + + final StreamInput in = out.bytes().streamInput(); + final InferenceFieldMetadata after = new InferenceFieldMetadata(in); + + assertThat(after, equalTo(before)); + } + + @Override + protected InferenceFieldMetadata createTestInstance() { + return createTestItem(); + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field + } + + @Override + protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + assertEquals(XContentParser.Token.FIELD_NAME, parser.currentToken()); + InferenceFieldMetadata inferenceMetadata = InferenceFieldMetadata.fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return inferenceMetadata; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + private static InferenceFieldMetadata createTestItem() { + String name = randomAlphaOfLengthBetween(3, 10); + String inferenceId = randomIdentifier(); + String[] inputFields = generateRandomStringArray(5, 10, false, false); + return new InferenceFieldMetadata(name, inferenceId, inputFields); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 932eac3e60d27..3f50b9fdf6621 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -16,7 +16,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -36,10 +35,6 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -47,10 +42,6 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - - Map fieldsForModels = lookup.getInferenceIdsForFields(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -430,25 +421,6 @@ public void testRuntimeFieldNameOutsideContext() { } } - public void testInferenceModelFieldType() { - MockFieldMapper f1 = new MockFieldMapper(new MockInferenceModelFieldType("foo1", "bar1")); - MockFieldMapper f2 = new MockFieldMapper(new MockInferenceModelFieldType("foo2", "bar1")); - MockFieldMapper f3 = new MockFieldMapper(new MockInferenceModelFieldType("foo3", "bar2")); - - FieldTypeLookup lookup = new FieldTypeLookup(List.of(f1, f2, f3), emptyList(), emptyList()); - assertEquals(f1.fieldType(), lookup.get("foo1")); - assertEquals(f2.fieldType(), lookup.get("foo2")); - assertEquals(f3.fieldType(), lookup.get("foo3")); - - Map inferenceIdsForFields = lookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(3, inferenceIdsForFields.size()); - - assertEquals("bar1", inferenceIdsForFields.get("foo1")); - assertEquals("bar1", inferenceIdsForFields.get("foo2")); - assertEquals("bar2", inferenceIdsForFields.get("foo3")); - } - private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index bb337d0c61c93..0308dac5fa216 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -121,8 +121,6 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getInferenceIdsForFields()); - assertTrue(mappingLookup.getInferenceIdsForFields().isEmpty()); } public void testValidateDoesNotShadow() { @@ -190,22 +188,6 @@ public MetricType getMetricType() { ); } - public void testInferenceIdsForFields() { - MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); - MappingLookup mappingLookup = createMappingLookup( - Collections.singletonList(new MockFieldMapper(fieldType)), - emptyList(), - emptyList() - ); - assertEquals(1, size(mappingLookup.fieldMappers())); - assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - - Map inferenceIdsForFields = mappingLookup.getInferenceIdsForFields(); - assertNotNull(inferenceIdsForFields); - assertEquals(1, inferenceIdsForFields.size()); - assertEquals("test_model_id", inferenceIdsForFields.get("test_field_name")); - } - private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java deleted file mode 100644 index 0d21134b5d9a9..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.util.Map; - -public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private static final String TYPE_NAME = "mock_inference_model_field_type"; - - private final String modelId; - - public MockInferenceModelFieldType(String name, String modelId) { - super(name, false, false, false, TextSearchInfo.NONE, Map.of()); - this.modelId = modelId; - } - - @Override - public String typeName() { - return TYPE_NAME; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public String getInferenceId() { - return modelId; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3fcd9049ae803..494d6918b6086 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -284,11 +284,17 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + if (SemanticTextFeature.isEnabled()) { + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); + } + return Map.of(); } @Override public Collection getActionFilters() { - return singletonList(shardBulkInferenceActionFilter.get()); + if (SemanticTextFeature.isEnabled()) { + return singletonList(shardBulkInferenceActionFilter.get()); + } + return List.of(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 00dc195313a61..fef62051a6471 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -20,12 +20,11 @@ import org.elasticsearch.action.bulk.BulkShardRequest; import org.elasticsearch.action.bulk.TransportShardBulkAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.MappedActionFilter; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -39,6 +38,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -49,19 +49,66 @@ import java.util.stream.Collectors; /** - * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} - * in the subsequent {@link TransportShardBulkAction} downstream. + * A {@link MappedActionFilter} intercepting {@link BulkShardRequest}s to apply inference on fields declared as + * {@link SemanticTextFieldMapper} in the index mapping. + * The source of each {@link BulkItemRequest} requiring inference is augmented with the results for each field + * under the {@link InferenceMetadataFieldMapper#NAME} section. + * For example, for an index with a semantic_text field named {@code my_semantic_field} the following source document: + *
+ *
+ * {
+ *      "my_semantic_text_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * is rewritten into: + *
+ *
+ * {
+ *      "_inference": {
+ *        "my_semantic_field": {
+ *          "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks": [
+ *                      {
+ *                             "inference": {
+ *                                 "lucas": 0.05212344,
+ *                                 "ty": 0.041213956,
+ *                                 "dragon": 0.50991,
+ *                                 "type": 0.23241979,
+ *                                 "dr": 1.9312073,
+ *                                 "##o": 0.2797593
+ *                             },
+ *                             "text": "these are not the droids you're looking for"
+ *                       }
+ *                  ]
+ *        }
+ *      }
+ *      "my_semantic_field": "these are not the droids you're looking for"
+ * }
+ * 
+ * The rewriting process occurs on the bulk coordinator node, and the results are then passed downstream + * to the {@link TransportShardBulkAction} for actual indexing. + * + * TODO: batchSize should be configurable via a cluster setting */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + protected static final int DEFAULT_BATCH_SIZE = 512; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; + private final int batchSize; public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + } + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; + this.batchSize = batchSize; } @Override @@ -86,7 +133,7 @@ public void app switch (action) { case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); @@ -102,33 +149,33 @@ public void app } private void processBulkShardRequest( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); } private record InferenceProvider(InferenceService service, Model model) {} private record FieldInferenceRequest(int id, String field, String input) {} - private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + private record FieldInferenceResponse(String field, @Nullable Model model, @Nullable ChunkedInferenceServiceResults chunkedResults) {} private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final FieldInferenceMetadata fieldInferenceMetadata; + private final Map fieldInferenceMap; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - FieldInferenceMetadata fieldInferenceMetadata, + Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - this.fieldInferenceMetadata = fieldInferenceMetadata; + this.fieldInferenceMap = fieldInferenceMap; this.bulkShardRequest = bulkShardRequest; this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); this.onCompletion = onCompletion; @@ -212,30 +259,49 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + int currentBatchSize = Math.min(requests.size(), batchSize); + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { - for (int i = 0; i < results.size(); i++) { - var request = requests.get(i); - var result = results.get(i); - var acc = inferenceResults.get(request.id); - acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + try { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + } + } finally { + onFinish(); } } @Override public void onFailure(Exception exc) { - for (int i = 0; i < requests.size(); i++) { - var request = requests.get(i); - inferenceResults.get(request.id).failures.add( - new ElasticsearchException( - "Exception when running inference id [{}] on field [{}]", - exc, - inferenceProvider.model.getInferenceEntityId(), - request.field - ) - ); + try { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } finally { + onFinish(); + } + } + + private void onFinish() { + if (nextBatch.isEmpty()) { + onFinish.close(); + } else { + executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); } } }; @@ -246,14 +312,33 @@ public void onFailure(Exception exc) { Map.of(), InputType.INGEST, new ChunkingOptions(null, null), - ActionListener.runAfter(completionListener, onFinish::close) + completionListener ); } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + FieldInferenceResponseAccumulator acc = inferenceResults.get(id); + if (acc == null) { + acc = new FieldInferenceResponseAccumulator( + id, + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ); + inferenceResults.set(id, acc); + } + return acc; + } + + private void addInferenceResponseFailure(int id, Exception failure) { + var acc = ensureResponseAccumulatorSlot(id); + acc.failures().add(failure); + } + /** - * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. - * If the response contains failures, the bulk item request is mark as failed for the downstream action. - * Otherwise, the source of the request is augmented with the field inference results. + * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is marked as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results under the + * {@link InferenceMetadataFieldMapper#NAME} field. */ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { if (response.failures().isEmpty() == false) { @@ -265,24 +350,37 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); Map newDocMap = indexRequest.sourceAsMap(); - Map inferenceMap = new LinkedHashMap<>(); - // ignore the existing inference map if any + Object inferenceObj = newDocMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()); + Map inferenceMap = XContentMapValues.nodeMapValue(inferenceObj, InferenceMetadataFieldMapper.NAME); newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { - try { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceMap, - fieldResponse.field(), - fieldResponse.model(), - fieldResponse.chunkedResults() - ); - } catch (Exception exc) { - item.abort(item.index(), exc); + if (fieldResponse.chunkedResults != null) { + try { + InferenceMetadataFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); + } + } else { + inferenceMap.remove(fieldResponse.field); } } indexRequest.source(newDocMap); } + /** + * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. + * If results are already populated for fields in the existing _inference object, + * the inference request for this specific field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link InferenceMetadataFieldMapper} + * during field indexing, where an error will be thrown if they mismatch or if the content is malformed. + * + * TODO: Should we validate the settings for pre-existing results here and apply the inference only if they differ? + */ private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { @@ -290,35 +388,57 @@ private Map> createFieldInferenceRequests(Bu // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } - final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - if (indexRequest == null) { + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + if (updateRequest.script() != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + continue; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request continue; } final Map docMap = indexRequest.sourceAsMap(); - boolean hasInput = false; - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); + final Map inferenceMap = XContentMapValues.nodeMapValue( + docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), + InferenceMetadataFieldMapper.NAME + ); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + Object inferenceResult = inferenceMap.remove(field); var value = XContentMapValues.extractValue(field, docMap); if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( - item.id(), - new FieldInferenceResponseAccumulator( + if (inferenceResult != null) { + addInferenceResponseFailure( item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); + new ElasticsearchStatusException( + "The field [{}] is referenced in the [{}] metadata field but has no value", + RestStatus.BAD_REQUEST, + field, + InferenceMetadataFieldMapper.NAME + ) + ); + } + continue; } + ensureResponseAccumulatorSlot(item.id()); if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - hasInput = true; } else { - inferenceResults.get(item.id()).failures.add( + addInferenceResponseFailure( + item.id(), new ElasticsearchStatusException( "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, @@ -328,12 +448,6 @@ private Map> createFieldInferenceRequests(Bu ); } } - if (hasInput == false) { - // remove the existing _inference field (if present) since none of the content require inference. - if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { - indexRequest.source(docMap); - } - } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 9eeb7a5407bc4..702f686605e56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; @@ -52,6 +53,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; + /** * A mapper for the {@code _inference} field. *
@@ -117,7 +120,7 @@ * */ public class InferenceMetadataFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; + public static final String NAME = InferenceFieldMapper.NAME; public static final String CONTENT_TYPE = "_inference"; public static final String INFERENCE_ID = "inference_id"; @@ -183,7 +186,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( XContentLocation xContentLocation ) { final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); - final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + final String inferenceId = semanticFieldContext.mapper.getInferenceId(); if (newInferenceId.equals(inferenceId) == false) { throw new DocumentParsingException( xContentLocation, @@ -212,7 +215,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper( return newMapper.getSubMappers(); } else { SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); - SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 2445d5c8751a5..f8fde0b63e4ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -8,21 +8,23 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; -import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; @@ -40,6 +42,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; @@ -51,7 +55,7 @@ * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will * be indexed using {@link InferenceMetadataFieldMapper}. */ -public class SemanticTextFieldMapper extends FieldMapper { +public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -66,6 +70,7 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { ); private final IndexVersion indexVersionCreated; + private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -74,11 +79,13 @@ private SemanticTextFieldMapper( MappedFieldType mappedFieldType, CopyTo copyTo, IndexVersion indexVersionCreated, + String inferenceId, SemanticTextModelSettings modelSettings, NestedObjectMapper subMappers ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); this.indexVersionCreated = indexVersionCreated; + this.inferenceId = inferenceId; this.modelSettings = modelSettings; this.subMappers = subMappers; } @@ -111,6 +118,10 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public String getInferenceId() { + return inferenceId; + } + public SemanticTextModelSettings getModelSettings() { return modelSettings; } @@ -119,6 +130,11 @@ public NestedObjectMapper getSubMappers() { return subMappers; } + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + return new InferenceFieldMetadata(name(), inferenceId, sourcePaths.toArray(String[]::new)); + } + public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; @@ -142,11 +158,15 @@ public static class Builder extends FieldMapper.Builder { XContentBuilder::field, (m) -> m == null ? "null" : Strings.toString(m) ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + private final Parameter> meta = Parameter.metaParam(); + private Function subFieldsFunction; + public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; + this.subFieldsFunction = c -> createSubFields(c); } public Builder setInferenceId(String id) { @@ -164,9 +184,38 @@ protected Parameter[] getParameters() { return new Parameter[] { inferenceId, modelSettings, meta }; } + @Override + protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { + super.merge(mergeWith, conflicts, mapperMergeContext); + conflicts.check(); + SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var childMergeContext = mapperMergeContext.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + NestedObjectMapper mergedSubFields = (NestedObjectMapper) semanticMergeWith.getSubMappers() + .merge( + subFieldsFunction.apply(childMergeContext.getMapperBuilderContext()), + MapperService.MergeReason.MAPPING_UPDATE, + childMergeContext + ); + subFieldsFunction = c -> mergedSubFields; + } + @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + final NestedObjectMapper subFields = subFieldsFunction.apply(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subFields, meta.getValue()), + copyTo, + indexVersionCreated, + inferenceId.getValue(), + modelSettings.getValue(), + subFields + ); + } + + private NestedObjectMapper createSubFields(MapperBuilderContext context) { NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) @@ -176,20 +225,11 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); } nestedBuilder.add(textMapperBuilder); - var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - var subMappers = nestedBuilder.build(childContext); - return new SemanticTextFieldMapper( - name(), - new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), - copyTo, - indexVersionCreated, - modelSettings.getValue(), - subMappers - ); + return nestedBuilder.build(context); } } - public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final SemanticTextModelSettings modelSettings; private final NestedObjectMapper subMappers; @@ -212,7 +252,6 @@ public String typeName() { return CONTENT_TYPE; } - @Override public String getInferenceId() { return inferenceId; } @@ -241,11 +280,6 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return super.syntheticFieldLoader(); - } - private static Mapper.Builder createInferenceMapperBuilder( String fieldName, SemanticTextModelSettings modelSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index bf3cc6334433a..4c1cc8fa38bb4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -33,10 +33,7 @@ public void testCreateIndexWithSemanticTextField() { "test", client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); - assertEquals( - indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); } public void testAddSemanticTextField() throws Exception { @@ -53,10 +50,7 @@ public void testAddSemanticTextField() throws Exception { putMappingExecutor, singleTask(request) ); - assertEquals( - resultingState.metadata().index("test").getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), - "test_model" - ); + assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); } private static List singleTask(PutMappingClusterStateUpdateRequest request) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 8b18cf74236a0..d734e9998734d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -45,12 +45,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; @@ -75,11 +75,11 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); + assertNull(((BulkShardRequest) request).getInferenceFieldMap()); } finally { chainExecuted.countDown(); } @@ -91,8 +91,8 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata( - new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -101,12 +101,16 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = randomStaticModel(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -120,22 +124,20 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( - Map.of( - "field1", - new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), - "field2", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), - "field3", - new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) - ) + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + "field2", + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + "field3", + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { - items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFieldMap)[0]; } BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setFieldInferenceMetadata(inferenceFields); + request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -150,30 +152,29 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map inferenceFieldsMap = new HashMap<>(); + Map inferenceFieldMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); } - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFieldMap); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + assertNull(bulkShardRequest.getInferenceFieldMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -192,13 +193,13 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(fieldInferenceMetadata); + original.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @SuppressWarnings("unchecked") - private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { String id = (String) invocationOnMock.getArguments()[0]; @@ -256,20 +257,20 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); return filter; } private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - FieldInferenceMetadata fieldInferenceMetadata + Map fieldInferenceMap ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { - String field = entry.getKey(); - var model = modelMap.get(entry.getValue().inferenceId()); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + var model = modelMap.get(entry.getInferenceId()); String text = randomAlphaOfLengthBetween(10, 100); docMap.put(field, text); if (model == null) { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 528003e278aeb..8847fb7f7efc1 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -171,15 +171,16 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "another non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -197,6 +198,32 @@ setup: index: test-sparse-index id: doc_1 + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "I am a test" } + - match: { _source.another_inference_field: "I am a teapot" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "I am a test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "I am a teapot" } + - do: update: index: test-sparse-index @@ -211,12 +238,31 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "updated inference test" } - - match: { _source.another_inference_field: "another updated inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "updated inference test" } + - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "bulk inference test", "another_inference_field": "bulk updated inference test"}}' - - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field: "bulk inference test" } + - match: { _source.another_inference_field: "bulk updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "bulk inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "bulk updated inference test" } --- "Reindex works for semantic_text fields": @@ -268,18 +314,19 @@ setup: index: destination-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } - - match: { _source.non_inference_field: "non inference test" } + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } + - length: { _source._inference: 2 } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- -"Fails for non-existent model": +"Fails for non-existent inference": - do: indices.create: index: incorrect-test-sparse-index @@ -310,3 +357,23 @@ setup: id: doc_1 body: non_inference_field: "non inference test" + +--- +"Updates with script are not allowed": + - do: + bulk: + index: test-sparse-index + body: + - '{"index": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"script": "ctx._source.new_field = \"hello\"", "scripted_upsert": true}' + + - match: { errors: true } + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 27f233436b925..9dc109b3fb81d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -66,3 +66,23 @@ setup: id: doc_1 body: dense_field: "you know, for testing" + +--- +"Inference section contains unreferenced fields": + - do: + catch: /Field \[unknown_field\] is not registered as a \[semantic_text\] field type/ + index: + index: test-index + id: doc_1 + body: + non_inference_field: "you know, for testing" + _inference: + unknown_field: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + chunks: + - text: "inference test" + inference: [ 0.1, 0.2, 0.3, 0.4, 0.5 ] + - text: "another inference test" + inference: [ -0.1, -0.2, -0.3, -0.4, -0.5 ] From b6ca8d2b1afa33a55e5077d11f3701aefd5eabfe Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:38:18 +0200 Subject: [PATCH 12/29] [feature/semantic-text] semantic text copy to support (#106689) --- .../index/mapper/CopyToMapperTests.java | 7 ++ .../index/mapper/MultiFieldTests.java | 3 + .../ShardBulkInferenceActionFilter.java | 59 ++++++++---- .../mapper/InferenceMetadataFieldMapper.java | 12 ++- .../SemanticTextClusterMetadataTests.java | 45 ++++++++- .../inference/10_semantic_text_inference.yml | 95 +++++++++++++++++++ 6 files changed, 195 insertions(+), 26 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java index 5eacfe6f2e3ab..33341e6b36987 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; @@ -106,6 +107,12 @@ public void testCopyToFieldsParsing() throws Exception { fieldMapper = mapperService.documentMapper().mappers().getMapper("new_field"); assertThat(fieldMapper.typeName(), equalTo("long")); + + MappingLookup mappingLookup = mapperService.mappingLookup(); + assertThat(mappingLookup.sourcePaths("another_field"), equalTo(Set.of("copy_test", "int_to_str_test", "another_field"))); + assertThat(mappingLookup.sourcePaths("new_field"), equalTo(Set.of("new_field", "int_to_str_test"))); + assertThat(mappingLookup.sourcePaths("copy_test"), equalTo(Set.of("copy_test", "cyclic_test"))); + assertThat(mappingLookup.sourcePaths("cyclic_test"), equalTo(Set.of("cyclic_test", "copy_test"))); } public void testCopyToFieldsInnerObjectParsing() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java index d7df41131414e..6446033c07c5b 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java @@ -224,6 +224,9 @@ public void testSourcePathFields() throws IOException { final Set fieldsUsingSourcePath = new HashSet<>(); ((FieldMapper) mapper).sourcePathUsedBy().forEachRemaining(mapper1 -> fieldsUsingSourcePath.add(mapper1.name())); assertThat(fieldsUsingSourcePath, equalTo(Set.of("field.subfield1", "field.subfield2"))); + + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield1"), equalTo(Set.of("field"))); + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield2"), equalTo(Set.of("field"))); } public void testUnknownLegacyFieldsUnderKnownRootField() throws Exception { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fef62051a6471..2e6f66c64fa95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -388,10 +388,12 @@ private Map> createFieldInferenceRequests(Bu // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } + boolean isUpdateRequest = false; final IndexRequest indexRequest; if (item.request() instanceof IndexRequest ir) { indexRequest = ir; } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; if (updateRequest.script() != null) { addInferenceResponseFailure( item.id(), @@ -417,35 +419,50 @@ private Map> createFieldInferenceRequests(Bu String field = entry.getName(); String inferenceId = entry.getInferenceId(); Object inferenceResult = inferenceMap.remove(field); - var value = XContentMapValues.extractValue(field, docMap); - if (value == null) { - if (inferenceResult != null) { + for (var sourceField : entry.getSourceFields()) { + var value = XContentMapValues.extractValue(sourceField, docMap); + if (value == null) { + if (isUpdateRequest) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Field [{}] must be specified on an update request to calculate inference for field [{}]", + RestStatus.BAD_REQUEST, + sourceField, + field + ) + ); + } else if (inferenceResult != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "The field [{}] is referenced in the [{}] metadata field but has no value", + RestStatus.BAD_REQUEST, + field, + InferenceMetadataFieldMapper.NAME + ) + ); + } + continue; + } + ensureResponseAccumulatorSlot(item.id()); + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { addInferenceResponseFailure( item.id(), new ElasticsearchStatusException( - "The field [{}] is referenced in the [{}] metadata field but has no value", + "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, field, - InferenceMetadataFieldMapper.NAME + value.getClass().getSimpleName() ) ); } - continue; - } - ensureResponseAccumulatorSlot(item.id()); - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index 702f686605e56..89d1037243aac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -348,6 +348,8 @@ private void parseResultsObject( } parser.nextToken(); fieldMapper.parse(context); + // Reset leaf object after parsing the field + context.path().setWithinLeafObject(true); } if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { Set missingSubfields = REQUIRED_SUBFIELDS.stream() @@ -383,6 +385,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + @SuppressWarnings("unchecked") public static void applyFieldInference( Map inferenceMap, String field, @@ -407,11 +410,12 @@ public static void applyFieldInference( results.getWriteableName() ); } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); + + Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(CHUNKS, chunks); - inferenceMap.put(field, fieldMap); + List> fieldChunks = (List>) fieldMap.computeIfAbsent(CHUNKS, k -> new ArrayList<>()); + fieldChunks.addAll(chunks); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); } record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 4c1cc8fa38bb4..1c4a2f561ad4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -16,11 +16,15 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.hamcrest.Matchers; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; +import static org.hamcrest.CoreMatchers.equalTo; + public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { @Override @@ -36,7 +40,7 @@ public void testCreateIndexWithSemanticTextField() { assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); } - public void testAddSemanticTextField() throws Exception { + public void testSingleSourceSemanticTextField() throws Exception { final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); @@ -53,6 +57,45 @@ public void testAddSemanticTextField() throws Exception { assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); } + public void testCopyToSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "semantic": { + "type": "semantic_text", + "inference_id": "test_model" + }, + "copy_origin_1": { + "type": "text", + "copy_to": "semantic" + }, + "copy_origin_2": { + "type": "text", + "copy_to": "semantic" + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get("semantic"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo("test_model")); + assertThat( + Arrays.asList(inferenceFieldMetadata.getSourceFields()), + Matchers.containsInAnyOrder("semantic", "copy_origin_1", "copy_origin_2") + ); + } + private static List singleTask(PutMappingClusterStateUpdateRequest request) { return Collections.singletonList(new MetadataMappingService.PutMappingClusterStateUpdateTask(request, ActionListener.running(() -> { throw new AssertionError("task should not complete publication"); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 8847fb7f7efc1..0a07a88d230ef 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -377,3 +377,98 @@ setup: - match: { errors: true } - match: { items.0.update.status: 400 } - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } + +--- +"Fails when providing inference results and there is no value for field": + - do: + catch: /The field \[inference_field\] is referenced in the \[_inference\] metadata field but has no value/ + index: + index: test-sparse-index + id: doc_1 + body: + _inference: + inference_field: + chunks: + - text: "inference test" + inference: + "hello": 0.123 + + +--- +"semantic_text copy_to calculate inference for source fields": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "copy_to inference test" + inference_field: "inference test" + another_source_field: "another copy_to inference test" + + - do: + get: + index: test-copy-to-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - length: { _source._inference.inference_field.chunks: 3 } + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.inference_field.chunks.0.text + - exists: _source._inference.inference_field.chunks.1.inference + - exists: _source._inference.inference_field.chunks.1.text + - exists: _source._inference.inference_field.chunks.2.inference + - exists: _source._inference.inference_field.chunks.2.text + + +--- +"semantic_text copy_to needs values for every source field for updates": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + # Not every source field needed on creation + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "a single source field provided" + inference_field: "inference test" + + # Every source field needed on bulk updates + - do: + bulk: + body: + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' + - '{"doc": {"source_field": "a single source field is kept as provided via bulk", "inference_field": "updated inference test" }}' + + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Field [another_source_field] must be specified on an update request to calculate inference for field [inference_field]" } From 555676328ac3b2d97290f85349a65a429d0da885 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 5 Apr 2024 09:55:03 +0200 Subject: [PATCH 13/29] [feature/semantic-text] Move the inference results back to the original field (#107065) This change moves the logic for the parsing of the inference results to the semantic text field mapper. The inference metadata field mapper is no longer needed since the results of the inference action is now part of the original field. This change also moved the entire parsing into its own class so that the action filter and the field mapper shares the logic to read and write the inference format. --- .../action/update/TransportUpdateAction.java | 16 +- .../metadata/InferenceFieldMetadata.java | 6 +- .../index/mapper/DocumentParser.java | 4 + .../index/mapper/FieldMapper.java | 4 - .../xpack/inference/InferencePlugin.java | 10 - .../ShardBulkInferenceActionFilter.java | 266 ++++---- .../mapper/InferenceMetadataFieldMapper.java | 456 ------------- .../inference/mapper/SemanticTextField.java | 328 +++++++++ .../mapper/SemanticTextFieldMapper.java | 327 +++++---- .../mapper/SemanticTextModelSettings.java | 181 ----- .../ShardBulkInferenceActionFilterTests.java | 66 +- .../InferenceMetadataFieldMapperTests.java | 629 ------------------ .../mapper/SemanticTextFieldMapperTests.java | 299 ++++++++- .../mapper/SemanticTextFieldTests.java | 219 ++++++ .../inference/10_semantic_text_inference.yml | 134 ++-- .../20_semantic_text_field_mapper.yml | 20 - 16 files changed, 1256 insertions(+), 1709 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 4cd00534d68ad..b899d68107975 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -40,7 +40,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.VersionConflictEngineException; -import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -185,7 +184,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis); switch (result.getResponseResult()) { case CREATED -> { - IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action()); + IndexRequest upsertRequest = result.action(); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference upsertSourceBytes = upsertRequest.source(); client.bulk( @@ -227,7 +226,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< ); } case UPDATED -> { - IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action()); + IndexRequest indexRequest = result.action(); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference indexSourceBytes = indexRequest.source(); client.bulk( @@ -336,15 +335,4 @@ private void handleUpdateFailureWithRetry( } listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause)); } - - private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) { - var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields(); - if (inferenceMetadata.isEmpty()) { - return request; - } - Map docMap = request.sourceAsMap(); - docMap.remove(InferenceFieldMapper.NAME); - request.source(docMap); - return request; - } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 947aa2c82640c..0cd3f05f250a3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -54,12 +54,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InferenceFieldMetadata that = (InferenceFieldMetadata) o; - return inferenceId.equals(that.inferenceId) && Arrays.equals(sourceFields, that.sourceFields); + return Objects.equals(name, that.name) + && Objects.equals(inferenceId, that.inferenceId) + && Arrays.equals(sourceFields, that.sourceFields); } @Override public int hashCode() { - int result = Objects.hash(inferenceId); + int result = Objects.hash(name, inferenceId); result = 31 * result + Arrays.hashCode(sourceFields); return result; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index 1fda9ababfabd..7357f6f4bdfc6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -696,6 +696,10 @@ private static void failIfMatchesRoutingPath(DocumentParserContext context, Stri */ private static void parseCopyFields(DocumentParserContext context, List copyToFields) throws IOException { for (String field : copyToFields) { + if (context.mappingLookup().getMapper(field) instanceof InferenceFieldMapper) { + // ignore copy_to that targets inference fields, values are already extracted in the coordinating node to perform inference. + continue; + } // In case of a hierarchy of nested documents, we need to figure out // which document the field should go to LuceneDocument targetDoc = null; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 63d23462e4847..5eddfb7d91df2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1211,10 +1211,6 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - public boolean hasConflicts() { - return conflicts.isEmpty() == false; - } - public void check() { if (conflicts.isEmpty()) { return; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 494d6918b6086..666e7a3bd2043 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -23,7 +23,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -55,7 +54,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -282,14 +280,6 @@ public Map getMappers() { return Map.of(); } - @Override - public Map getMetadataMappers() { - if (SemanticTextFeature.isEnabled()) { - return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); - } - return Map.of(); - } - @Override public Collection getActionFilters() { if (SemanticTextFeature.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 2e6f66c64fa95..eaa62a3aa743a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -37,59 +37,29 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; + /** - * A {@link MappedActionFilter} intercepting {@link BulkShardRequest}s to apply inference on fields declared as - * {@link SemanticTextFieldMapper} in the index mapping. - * The source of each {@link BulkItemRequest} requiring inference is augmented with the results for each field - * under the {@link InferenceMetadataFieldMapper#NAME} section. - * For example, for an index with a semantic_text field named {@code my_semantic_field} the following source document: - *
- *
- * {
- *      "my_semantic_text_field": "these are not the droids you're looking for"
- * }
- * 
- * is rewritten into: - *
- *
- * {
- *      "_inference": {
- *        "my_semantic_field": {
- *          "inference_id": "my_inference_id",
- *                  "model_settings": {
- *                      "task_type": "SPARSE_EMBEDDING"
- *                  },
- *                  "chunks": [
- *                      {
- *                             "inference": {
- *                                 "lucas": 0.05212344,
- *                                 "ty": 0.041213956,
- *                                 "dragon": 0.50991,
- *                                 "type": 0.23241979,
- *                                 "dr": 1.9312073,
- *                                 "##o": 0.2797593
- *                             },
- *                             "text": "these are not the droids you're looking for"
- *                       }
- *                  ]
- *        }
- *      }
- *      "my_semantic_field": "these are not the droids you're looking for"
- * }
- * 
- * The rewriting process occurs on the bulk coordinator node, and the results are then passed downstream - * to the {@link TransportShardBulkAction} for actual indexing. + * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified + * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in + * the request source, we generate embeddings and include the results in the source under the semantic text field + * name as a {@link SemanticTextField}. + * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the + * results during indexing on the shard. * * TODO: batchSize should be configurable via a cluster setting */ @@ -158,11 +128,52 @@ private void processBulkShardRequest( private record InferenceProvider(InferenceService service, Model model) {} - private record FieldInferenceRequest(int id, String field, String input) {} + /** + * A field inference request on a single input. + * @param id The id of the request in the original bulk request. + * @param field The target field. + * @param input The input to run inference on. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + */ + private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isOriginalFieldInput) {} - private record FieldInferenceResponse(String field, @Nullable Model model, @Nullable ChunkedInferenceServiceResults chunkedResults) {} + /** + * The field inference response. + * @param field The target field. + * @param input The input that was used to run inference. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + * @param model The model used to run inference. + * @param chunkedResults The actual results. + */ + private record FieldInferenceResponse( + String field, + String input, + int inputOrder, + boolean isOriginalFieldInput, + Model model, + ChunkedInferenceServiceResults chunkedResults + ) {} - private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} + private record FieldInferenceResponseAccumulator( + int id, + Map> responses, + List failures + ) { + void addOrUpdateResponse(FieldInferenceResponse response) { + synchronized (this) { + var list = responses.computeIfAbsent(response.field, k -> new ArrayList<>()); + list.add(response); + } + } + + void addFailure(Exception exc) { + synchronized (this) { + failures.add(exc); + } + } + } private class AsyncBulkShardInferenceAction implements Runnable { private final Map fieldInferenceMap; @@ -234,8 +245,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var request = requests.get(i); inferenceResults.get(request.id).failures.add( new ResourceNotFoundException( - "Inference id [{}] not found for field [{}]", - inferenceId, + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), request.field ) ); @@ -271,7 +282,16 @@ public void onResponse(List results) { var request = requests.get(i); var result = results.get(i); var acc = inferenceResults.get(request.id); - acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.input(), + request.inputOrder(), + request.isOriginalFieldInput(), + inferenceProvider.model, + result + ) + ); } } finally { onFinish(); @@ -283,7 +303,8 @@ public void onFailure(Exception exc) { try { for (int i = 0; i < requests.size(); i++) { var request = requests.get(i); - inferenceResults.get(request.id).failures.add( + addInferenceResponseFailure( + request.id, new ElasticsearchException( "Exception when running inference id [{}] on field [{}]", exc, @@ -308,6 +329,7 @@ private void onFinish() { inferenceProvider.service() .chunkedInfer( inferenceProvider.model(), + null, inputs, Map.of(), InputType.INGEST, @@ -319,11 +341,7 @@ private void onFinish() { private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { - acc = new FieldInferenceResponseAccumulator( - id, - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ); + acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); inferenceResults.set(id, acc); } return acc; @@ -331,14 +349,14 @@ private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) private void addInferenceResponseFailure(int id, Exception failure) { var acc = ensureResponseAccumulatorSlot(id); - acc.failures().add(failure); + acc.addFailure(failure); } /** * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. * If the response contains failures, the bulk item request is marked as failed for the downstream action. * Otherwise, the source of the request is augmented with the field inference results under the - * {@link InferenceMetadataFieldMapper#NAME} field. + * {@link SemanticTextFieldMapper#NAME} field. */ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { if (response.failures().isEmpty() == false) { @@ -349,37 +367,38 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - Map newDocMap = indexRequest.sourceAsMap(); - Object inferenceObj = newDocMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()); - Map inferenceMap = XContentMapValues.nodeMapValue(inferenceObj, InferenceMetadataFieldMapper.NAME); - newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); - for (FieldInferenceResponse fieldResponse : response.responses()) { - if (fieldResponse.chunkedResults != null) { - try { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceMap, - fieldResponse.field(), - fieldResponse.model(), - fieldResponse.chunkedResults() - ); - } catch (Exception exc) { - item.abort(item.index(), exc); - } - } else { - inferenceMap.remove(fieldResponse.field); - } + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : response.responses.entrySet()) { + var fieldName = entry.getKey(); + var responses = entry.getValue(); + var model = responses.get(0).model(); + // ensure that the order in the original field is consistent in case of multiple inputs + Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); + List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + var result = new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), results, indexRequest.getContentType()) + ), + indexRequest.getContentType() + ); + newDocMap.put(fieldName, result); } - indexRequest.source(newDocMap); + indexRequest.source(newDocMap, indexRequest.getContentType()); } /** * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. - * If results are already populated for fields in the existing _inference object, - * the inference request for this specific field is skipped, and the existing results remain unchanged. - * Validation of inference ID and model settings occurs in the {@link InferenceMetadataFieldMapper} - * during field indexing, where an error will be thrown if they mismatch or if the content is malformed. - * - * TODO: Should we validate the settings for pre-existing results here and apply the inference only if they differ? + * If results are already populated for fields in the original index request, the inference request for this specific + * field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, + * where an error will be thrown if they mismatch or if the content is malformed. + *

+ * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? */ private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { Map> fieldRequestsMap = new LinkedHashMap<>(); @@ -411,17 +430,18 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - final Map inferenceMap = XContentMapValues.nodeMapValue( - docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), - InferenceMetadataFieldMapper.NAME - ); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); - Object inferenceResult = inferenceMap.remove(field); + var originalFieldValue = XContentMapValues.extractValue(field, docMap); + if (originalFieldValue instanceof Map) { + continue; + } + int order = 0; for (var sourceField : entry.getSourceFields()) { - var value = XContentMapValues.extractValue(sourceField, docMap); - if (value == null) { + boolean isOriginalFieldInput = sourceField.equals(field); + var valueObj = XContentMapValues.extractValue(sourceField, docMap); + if (valueObj == null) { if (isUpdateRequest) { addInferenceResponseFailure( item.id(), @@ -432,36 +452,21 @@ private Map> createFieldInferenceRequests(Bu field ) ); - } else if (inferenceResult != null) { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "The field [{}] is referenced in the [{}] metadata field but has no value", - RestStatus.BAD_REQUEST, - field, - InferenceMetadataFieldMapper.NAME - ) - ); + break; } continue; } ensureResponseAccumulatorSlot(item.id()); - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent( - inferenceId, - k -> new ArrayList<>() - ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); + final List values; + try { + values = nodeStringValues(field, valueObj); + } catch (Exception exc) { + addInferenceResponseFailure(item.id(), exc); + break; + } + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + for (var v : values) { + fieldRequests.add(new FieldInferenceRequest(item.id(), field, v, order++, isOriginalFieldInput)); } } } @@ -470,6 +475,37 @@ private Map> createFieldInferenceRequests(Bu } } + /** + * This method converts the given {@code valueObj} into a list of strings. + * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. + */ + private static List nodeStringValues(String field, Object valueObj) { + if (valueObj instanceof String value) { + return List.of(value); + } else if (valueObj instanceof Collection values) { + List valuesString = new ArrayList<>(); + for (var v : values) { + if (v instanceof String value) { + valuesString.add(value); + } else { + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + } + return valuesString; + } + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { if (docWriteRequest instanceof IndexRequest indexRequest) { return indexRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java deleted file mode 100644 index 89d1037243aac..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ /dev/null @@ -1,456 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceFieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.XContentLocation; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; - -/** - * A mapper for the {@code _inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *

- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": {
- *                  "inference_id": "my_inference_id",
- *                  "model_settings": {
- *                      "task_type": "SPARSE_EMBEDDING"
- *                  },
- *                  "chunks" [
- *                      {
- *                          "inference": {
- *                              "lucas": 0.05212344,
- *                              "ty": 0.041213956,
- *                              "dragon": 0.50991,
- *                              "type": 0.23241979,
- *                              "dr": 1.9312073,
- *                              "##o": 0.2797593
- *                          },
- *                          "text": "these are not the droids you're looking for"
- *                      }
- *                  ]
- *              }
- *          }
- *      }
- * }
- * 
- * - * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_field": {
- *                 "chunks": {
- *                      "type": "nested",
- *                      "properties": {
- *                          "embedding": {
- *                              "type": "sparse_vector|dense_vector"
- *                          },
- *                          "text": {
- *                              "type": "keyword",
- *                              "index": false,
- *                              "doc_values": false
- *                          }
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceMetadataFieldMapper extends MetadataFieldMapper { - public static final String NAME = InferenceFieldMapper.NAME; - public static final String CONTENT_TYPE = "_inference"; - - public static final String INFERENCE_ID = "inference_id"; - public static final String CHUNKS = "chunks"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceMetadataFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - boolean origWithLeafObject = context.path().isWithinLeafObject(); - try { - // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); - parseSingleField(context); - } - } finally { - context.path().setWithinLeafObject(origWithLeafObject); - } - } - - private NestedObjectMapper updateSemanticTextFieldMapper( - DocumentParserContext docContext, - SemanticTextMapperContext semanticFieldContext, - String newInferenceId, - SemanticTextModelSettings newModelSettings, - XContentLocation xContentLocation - ) { - final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); - final String inferenceId = semanticFieldContext.mapper.getInferenceId(); - if (newInferenceId.equals(inferenceId) == false) { - throw new DocumentParsingException( - xContentLocation, - Strings.format( - "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", - INFERENCE_ID, - inferenceId, - fullFieldName, - INFERENCE_ID, - newInferenceId - ) - ); - } - if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { - throw new DocumentParsingException( - xContentLocation, - "Model settings for field [" + fullFieldName + "] must contain dimensions" - ); - } - if (semanticFieldContext.mapper.getModelSettings() == null) { - SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( - semanticFieldContext.mapper.simpleName(), - docContext.indexSettings().getIndexVersionCreated() - ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); - docContext.addDynamicMapper(newMapper); - return newMapper.getSubMappers(); - } else { - SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); - canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); - try { - conflicts.check(); - } catch (Exception exc) { - throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); - } - } - return semanticFieldContext.mapper.getSubMappers(); - } - - private void parseSingleField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); - if (builderContext == null) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - - // record the location of the inference field in the original source - XContentLocation xContentLocation = parser.getTokenLocation(); - // parse eagerly to extract the inference id and the model settings first - Map map = parser.mapOrdered(); - - // inference_id - Object inferenceIdObj = map.remove(INFERENCE_ID); - final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); - if (inferenceId == null) { - throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); - } - - // model_settings - Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); - if (modelSettingsObj == null) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format( - "Missing required [%s] for field [%s] of type [%s]", - SemanticTextModelSettings.NAME, - fieldName, - SemanticTextFieldMapper.CONTENT_TYPE - ) - ); - } - final SemanticTextModelSettings modelSettings; - try { - modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); - } catch (Exception exc) { - throw new DocumentParsingException( - xContentLocation, - Strings.format( - "Error parsing [%s] for field [%s] of type [%s]", - SemanticTextModelSettings.NAME, - fieldName, - SemanticTextFieldMapper.CONTENT_TYPE - ), - exc - ); - } - - var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); - - // we know the model settings, so we can (re) parse the results array now - XContentParser subParser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - DocumentParserContext mapContext = context.switchParser(subParser); - parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper); - } - - private void parseFieldInference( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - parser.nextToken(); - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - switch (parser.currentName()) { - case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper); - default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); - } - } - } - - private void parseChunks( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - parser.nextToken(); - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext subContext = context.createNestedContext(nestedMapper); - parseResultsObject(xContentLocation, parser, subContext, nestedMapper); - } - } - - private void parseResultsObject( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); - Set visited = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); - visited.add(parser.currentName()); - FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); - if (fieldMapper == null) { - if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { - throw new DocumentParsingException( - xContentLocation, - "Missing sub-fields definition for [" + parser.currentName() + "]" - ); - } else { - logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); - advancePastCurrentFieldName(xContentLocation, parser); - continue; - } - } - parser.nextToken(); - fieldMapper.parse(context); - // Reset leaf object after parsing the field - context.path().setWithinLeafObject(true); - } - if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visited.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); - } - } - - private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); - } - } - - private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - @SuppressWarnings("unchecked") - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - - Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - List> fieldChunks = (List>) fieldMap.computeIfAbsent(CHUNKS, k -> new ArrayList<>()); - fieldChunks.addAll(chunks); - fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); - } - - record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} - - /** - * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} - * and the {@link MapperBuilderContext} that was used to build it. - * If the field is not found or is of the wrong type, this method returns {@code null}. - */ - static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { - ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); - return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\.")); - } - - static SemanticTextMapperContext createSemanticFieldContext( - MapperBuilderContext mapperContext, - ObjectMapper objectMapper, - String[] paths - ) { - Mapper mapper = objectMapper.getMapper(paths[0]); - if (mapper instanceof ObjectMapper newObjectMapper) { - mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); - return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length)); - } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { - return new SemanticTextMapperContext(mapperContext, semanticMapper); - } else { - if (mapper == null || paths.length == 1) { - return null; - } - // check if the semantic field is defined within a multi-field - Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); - if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { - return new SemanticTextMapperContext(mapperContext, semanticMapper); - } - } - return null; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java new file mode 100644 index 0000000000000..f0267d144b7b8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. + * The resulting object preserves the original input under the {@link SemanticTextField#TEXT_FIELD} and exposes + * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. + * + * @param fieldName The original field name. + * @param originalValues The original values associated with the field name. + * @param inference The inference result. + * @param contentType The {@link XContentType} used to store the embeddings chunks. + */ +public record SemanticTextField(String fieldName, List originalValues, InferenceResult inference, XContentType contentType) + implements + ToXContentObject { + + static final ParseField TEXT_FIELD = new ParseField("text"); + static final ParseField INFERENCE_FIELD = new ParseField("inference"); + static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + static final ParseField CHUNKS_FIELD = new ParseField("chunks"); + static final ParseField CHUNKED_EMBEDDINGS_FIELD = new ParseField("embeddings"); + static final ParseField CHUNKED_TEXT_FIELD = new ParseField("text"); + static final ParseField MODEL_SETTINGS_FIELD = new ParseField("model_settings"); + static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + + public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} + + public record Chunk(String text, BytesReference rawEmbeddings) {} + + public record ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) implements ToXContentObject { + public ModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { + this.taskType = Objects.requireNonNull(taskType, "task type must not be null"); + this.dimensions = dimensions; + this.similarity = similarity; + validate(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + private void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + } + + public static String getOriginalTextFieldName(String fieldName) { + return fieldName + "." + TEXT_FIELD.getPreferredName(); + } + + public static String getInferenceFieldName(String fieldName) { + return fieldName + "." + INFERENCE_FIELD.getPreferredName(); + } + + public static String getChunksFieldName(String fieldName) { + return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD.getPreferredName(); + } + + public static String getEmbeddingsFieldName(String fieldName) { + return getChunksFieldName(fieldName) + "." + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(); + } + + static SemanticTextField parse(XContentParser parser, Tuple context) throws IOException { + return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); + } + + static ModelSettings parseModelSettings(XContentParser parser) throws IOException { + return MODEL_SETTINGS_PARSER.parse(parser, null); + } + + static ModelSettings parseModelSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, MODEL_SETTINGS_FIELD.getPreferredName()); + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return parseModelSettings(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (originalValues.isEmpty() == false) { + builder.field(TEXT_FIELD.getPreferredName(), originalValues.size() == 1 ? originalValues.get(0) : originalValues); + } + builder.startObject(INFERENCE_FIELD.getPreferredName()); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inference.inferenceId); + builder.field(MODEL_SETTINGS_FIELD.getPreferredName(), inference.modelSettings); + builder.startArray(CHUNKS_FIELD.getPreferredName()); + for (var chunk : inference.chunks) { + builder.startObject(); + builder.field(CHUNKED_TEXT_FIELD.getPreferredName(), chunk.text); + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings, + contentType + ); + builder.field(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()).copyCurrentStructure(parser); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + builder.endObject(); + return builder; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser> SEMANTIC_TEXT_FIELD_PARSER = + new ConstructingObjectParser<>( + SemanticTextFieldMapper.CONTENT_TYPE, + true, + (args, context) -> new SemanticTextField( + context.v1(), + (List) (args[0] == null ? List.of() : args[0]), + (InferenceResult) args[1], + context.v2() + ) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( + INFERENCE_FIELD.getPreferredName(), + true, + args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( + CHUNKS_FIELD.getPreferredName(), + true, + args -> new Chunk((String) args[0], (BytesReference) args[1]) + ); + + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( + MODEL_SETTINGS_FIELD.getPreferredName(), + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new ModelSettings(taskType, dimensions, similarity); + } + ); + + static { + SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), TEXT_FIELD); + SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), INFERENCE_FIELD); + + INFERENCE_RESULT_PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD); + INFERENCE_RESULT_PARSER.declareObject(constructorArg(), (p, c) -> MODEL_SETTINGS_PARSER.parse(p, c), MODEL_SETTINGS_FIELD); + INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), (p, c) -> CHUNKS_PARSER.parse(p, c), CHUNKS_FIELD); + + CHUNKS_PARSER.declareString(constructorArg(), CHUNKED_TEXT_FIELD); + CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { + XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); + b.copyCurrentStructure(p); + return BytesReference.bytes(b); + }, CHUNKED_EMBEDDINGS_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); + + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); + MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); + } + + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + */ + public static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index f8fde0b63e4ea..2536825a9e0b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,11 +9,16 @@ import org.apache.lucene.search.Query; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Explicit; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -21,7 +26,6 @@ import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperMergeContext; -import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; @@ -35,9 +39,13 @@ import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -45,103 +53,33 @@ import java.util.Set; import java.util.function.Function; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; /** * A {@link FieldMapper} for semantic text fields. - * These fields have a reference id reference, that is used for performing inference at ingestion and query time. - * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { - private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); - public static final String CONTENT_TYPE = "semantic_text"; - private static SemanticTextFieldMapper toType(FieldMapper in) { - return (SemanticTextFieldMapper) in; - } + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final TypeParser PARSER = new TypeParser( (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) ); - private final IndexVersion indexVersionCreated; - private final String inferenceId; - private final SemanticTextModelSettings modelSettings; - private final NestedObjectMapper subMappers; - - private SemanticTextFieldMapper( - String simpleName, - MappedFieldType mappedFieldType, - CopyTo copyTo, - IndexVersion indexVersionCreated, - String inferenceId, - SemanticTextModelSettings modelSettings, - NestedObjectMapper subMappers - ) { - super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); - this.indexVersionCreated = indexVersionCreated; - this.inferenceId = inferenceId; - this.modelSettings = modelSettings; - this.subMappers = subMappers; - } - - @Override - public Iterator iterator() { - List subIterators = new ArrayList<>(); - subIterators.add(subMappers); - return subIterators.iterator(); - } - - @Override - public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName(), indexVersionCreated).init(this); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - // Just parses text - no indexing is performed - context.parser().textOrNull(); - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SemanticTextFieldType fieldType() { - return (SemanticTextFieldType) super.fieldType(); - } - - public String getInferenceId() { - return inferenceId; - } - - public SemanticTextModelSettings getModelSettings() { - return modelSettings; - } - - public NestedObjectMapper getSubMappers() { - return subMappers; - } - - @Override - public InferenceFieldMetadata getMetadata(Set sourcePaths) { - return new InferenceFieldMetadata(name(), inferenceId, sourcePaths.toArray(String[]::new)); - } - public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; private final Parameter inferenceId = Parameter.stringParam( "inference_id", false, - m -> toType(m).fieldType().inferenceId, + mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId, null ).addValidator(v -> { if (Strings.isEmpty(v)) { @@ -149,24 +87,24 @@ public static class Builder extends FieldMapper.Builder { } }); - private final Parameter modelSettings = new Parameter<>( + private final Parameter modelSettings = new Parameter<>( "model_settings", true, () -> null, - (n, c, o) -> SemanticTextModelSettings.fromMap(o), - mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, + (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings, XContentBuilder::field, (m) -> m == null ? "null" : Strings.toString(m) ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - private Function subFieldsFunction; + private Function inferenceFieldBuilder; public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; - this.subFieldsFunction = c -> createSubFields(c); + this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get()); } public Builder setInferenceId(String id) { @@ -174,7 +112,7 @@ public Builder setInferenceId(String id) { return this; } - public Builder setModelSettings(SemanticTextModelSettings value) { + public Builder setModelSettings(SemanticTextField.ModelSettings value) { this.modelSettings.setValue(value); return this; } @@ -188,63 +126,152 @@ protected Parameter[] getParameters() { protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { super.merge(mergeWith, conflicts, mapperMergeContext); conflicts.check(); - SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; - var childMergeContext = mapperMergeContext.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - NestedObjectMapper mergedSubFields = (NestedObjectMapper) semanticMergeWith.getSubMappers() - .merge( - subFieldsFunction.apply(childMergeContext.getMapperBuilderContext()), - MapperService.MergeReason.MAPPING_UPDATE, - childMergeContext - ); - subFieldsFunction = c -> mergedSubFields; + var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE); + var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); + var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE); + var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), childContext); + inferenceFieldBuilder = c -> mergedInferenceField; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - final NestedObjectMapper subFields = subFieldsFunction.apply(childContext); + final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); return new SemanticTextFieldMapper( name(), - new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subFields, meta.getValue()), - copyTo, - indexVersionCreated, - inferenceId.getValue(), - modelSettings.getValue(), - subFields + new SemanticTextFieldType( + fullName, + inferenceId.getValue(), + modelSettings.getValue(), + inferenceField, + indexVersionCreated, + meta.getValue() + ), + copyTo ); } + } + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } - private NestedObjectMapper createSubFields(MapperBuilderContext context) { - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); - nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); - KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) - .indexed(false) - .docValues(false); - if (modelSettings.get() != null) { - nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(fieldType().getInferenceField()); + return subIterators.iterator(); + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return; + } + XContentLocation xContentLocation = parser.getTokenLocation(); + final SemanticTextField field; + boolean isWithinLeaf = context.path().isWithinLeafObject(); + try { + context.path().setWithinLeafObject(true); + field = SemanticTextField.parse(parser, new Tuple<>(name(), context.parser().contentType())); + } finally { + context.path().setWithinLeafObject(isWithinLeaf); + } + final String fullFieldName = fieldType().name(); + if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID_FIELD.getPreferredName(), + field.inference().inferenceId(), + fullFieldName, + INFERENCE_ID_FIELD.getPreferredName(), + fieldType().getInferenceId() + ) + ); + } + final SemanticTextFieldMapper mapper; + if (fieldType().getModelSettings() == null) { + context.path().remove(); + Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + try { + mapper = builder.setModelSettings(field.inference().modelSettings()) + .setInferenceId(field.inference().inferenceId()) + .build(context.createDynamicMapperBuilderContext()); + context.addDynamicMapper(mapper); + } finally { + context.path().add(simpleName()); + } + } else { + Conflicts conflicts = new Conflicts(fullFieldName); + canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); } - nestedBuilder.add(textMapperBuilder); - return nestedBuilder.build(context); + mapper = this; + } + var chunksField = mapper.fieldType().getChunksField(); + var embeddingsField = mapper.fieldType().getEmbeddingsField(); + for (var chunk : field.inference().chunks()) { + XContentParser subParser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings(), + context.parser().contentType() + ); + DocumentParserContext subContext = context.createNestedContext(chunksField).switchParser(subParser); + subParser.nextToken(); + embeddingsField.parse(subContext); } } + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } + + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + String[] copyFields = sourcePaths.toArray(String[]::new); + // ensure consistent order + Arrays.sort(copyFields); + return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields); + } + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; - private final SemanticTextModelSettings modelSettings; - private final NestedObjectMapper subMappers; + private final SemanticTextField.ModelSettings modelSettings; + private final ObjectMapper inferenceField; + private final IndexVersion indexVersionCreated; public SemanticTextFieldType( String name, String modelId, - SemanticTextModelSettings modelSettings, - NestedObjectMapper subMappers, + SemanticTextField.ModelSettings modelSettings, + ObjectMapper inferenceField, + IndexVersion indexVersionCreated, Map meta ) { super(name, false, false, false, TextSearchInfo.NONE, meta); this.inferenceId = modelId; this.modelSettings = modelSettings; - this.subMappers = subMappers; + this.inferenceField = inferenceField; + this.indexVersionCreated = indexVersionCreated; } @Override @@ -256,22 +283,31 @@ public String getInferenceId() { return inferenceId; } - public SemanticTextModelSettings getModelSettings() { + public SemanticTextField.ModelSettings getModelSettings() { return modelSettings; } - public NestedObjectMapper getSubMappers() { - return subMappers; + public ObjectMapper getInferenceField() { + return inferenceField; + } + + public NestedObjectMapper getChunksField() { + return (NestedObjectMapper) inferenceField.getMapper(CHUNKS_FIELD.getPreferredName()); + } + + public FieldMapper getEmbeddingsField() { + return (FieldMapper) getChunksField().getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); } @Override public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented yet"); + throw new IllegalArgumentException(CONTENT_TYPE + " fields do not support term query"); } @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); + // Redirect the fetcher to load the original values of the field + return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format); } @Override @@ -280,16 +316,39 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } - private static Mapper.Builder createInferenceMapperBuilder( - String fieldName, - SemanticTextModelSettings modelSettings, - IndexVersion indexVersionCreated + private static ObjectMapper createInferenceField( + MapperBuilderContext context, + IndexVersion indexVersionCreated, + @Nullable SemanticTextField.ModelSettings modelSettings + ) { + return new ObjectMapper.Builder(INFERENCE_FIELD.getPreferredName(), Explicit.EXPLICIT_TRUE).dynamic(ObjectMapper.Dynamic.FALSE) + .add(createChunksField(indexVersionCreated, modelSettings)) + .build(context); + } + + private static NestedObjectMapper.Builder createChunksField( + IndexVersion indexVersionCreated, + SemanticTextField.ModelSettings modelSettings ) { + NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD.getPreferredName(), indexVersionCreated); + chunksField.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder( + CHUNKED_TEXT_FIELD.getPreferredName(), + indexVersionCreated + ).indexed(false).docValues(false); + if (modelSettings != null) { + chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings)); + } + chunksField.add(chunkTextField); + return chunksField; + } + + private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) { return switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); case TEXT_EMBEDDING -> { DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(), indexVersionCreated ); SimilarityMeasure similarity = modelSettings.similarity(); @@ -298,23 +357,21 @@ private static Mapper.Builder createInferenceMapperBuilder( case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + "Unknown similarity measure in model_settings [" + similarity.name() + "]" ); } } denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); yield denseVectorMapperBuilder; } - default -> throw new IllegalArgumentException( - "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() - ); + default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); }; } - static boolean canMergeModelSettings( - SemanticTextModelSettings previous, - SemanticTextModelSettings current, - FieldMapper.Conflicts conflicts + private static boolean canMergeModelSettings( + SemanticTextField.ModelSettings previous, + SemanticTextField.ModelSettings current, + Conflicts conflicts ) { if (Objects.equals(previous, current)) { return true; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java deleted file mode 100644 index b1d0511008db8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; - -/** - * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - */ -public class SemanticTextModelSettings implements ToXContentObject { - - public static final String NAME = "model_settings"; - public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); - private final TaskType taskType; - private final Integer dimensions; - private final SimilarityMeasure similarity; - - public SemanticTextModelSettings(Model model) { - this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); - } - - public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { - Objects.requireNonNull(taskType, "task type must not be null"); - this.taskType = taskType; - this.dimensions = dimensions; - this.similarity = similarity; - validate(); - } - - public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { - return PARSER.apply(parser, null); - } - - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, - true, - args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - Integer dimensions = (Integer) args[1]; - SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); - return new SemanticTextModelSettings(taskType, dimensions, similarity); - } - ); - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); - } - - public static SemanticTextModelSettings fromMap(Object node) { - if (node == null) { - return null; - } - try { - Map map = XContentMapValues.nodeMapValue(node, NAME); - if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { - throw new IllegalArgumentException( - "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" - ); - } - XContentParser parser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - return SemanticTextModelSettings.parse(parser); - } catch (Exception exc) { - throw new ElasticsearchException(exc); - } - } - - public Map asMap() { - Map attrsMap = new HashMap<>(); - attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - if (dimensions != null) { - attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return Map.of(NAME, attrsMap); - } - - public TaskType taskType() { - return taskType; - } - - public Integer dimensions() { - return dimensions; - } - - public SimilarityMeasure similarity() { - return similarity; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - if (dimensions != null) { - builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return builder.endObject(); - } - - public void validate() { - switch (taskType) { - case TEXT_EMBEDDING: - if (dimensions == null) { - throw new IllegalArgumentException( - "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" - ); - } - if (similarity == null) { - throw new IllegalArgumentException( - "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" - ); - } - break; - case SPARSE_EMBEDDING: - break; - - default: - throw new IllegalArgumentException( - "Wrong [" - + TASK_TYPE_FIELD.getPreferredName() - + "], expected " - + TEXT_EMBEDDING - + " or " - + SPARSE_EMBEDDING - + ", got " - + taskType.name() - ); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SemanticTextModelSettings that = (SemanticTextModelSettings) o; - return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; - } - - @Override - public int hashCode() { - return Objects.hash(taskType, dimensions, similarity); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index d734e9998734d..7cfaeaae4c3a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -30,9 +30,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +52,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -178,10 +180,10 @@ public void testManyRandomDocs() throws Exception { BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { - IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); - IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + IndexRequest actualRequest = getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request()); try { - assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType()); } catch (Exception exc) { throw new IllegalStateException(exc); } @@ -226,9 +228,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[1]; + List inputs = (List) invocationOnMock.getArguments()[2]; ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[5]; + List>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { List results = new ArrayList<>(); for (String input : inputs) { @@ -247,7 +249,7 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool } return null; }; - doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); Answer modelAnswer = invocationOnMock -> { String inferenceId = (String) invocationOnMock.getArguments()[0]; @@ -267,43 +269,25 @@ private static BulkItemRequest[] randomBulkItemRequest( Map fieldInferenceMap ) { Map docMap = new LinkedHashMap<>(); - Map inferenceResultsMap = new LinkedHashMap<>(); + Map expectedDocMap = new LinkedHashMap<>(); + XContentType requestContentType = randomFrom(XContentType.values()); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); var model = modelMap.get(entry.getInferenceId()); String text = randomAlphaOfLengthBetween(10, 100); docMap.put(field, text); + expectedDocMap.put(field, text); if (model == null) { // ignore results, the doc should fail with a resource not found exception continue; } - int numChunks = randomIntBetween(1, 5); - List chunks = new ArrayList<>(); - for (int i = 0; i < numChunks; i++) { - chunks.add(randomAlphaOfLengthBetween(5, 10)); - } - TaskType taskType = model.getTaskType(); - final ChunkedInferenceServiceResults results; - switch (taskType) { - case TEXT_EMBEDDING: - results = randomTextEmbeddings(model, chunks); - break; - - case SPARSE_EMBEDDING: - results = randomSparseEmbeddings(chunks); - break; - - default: - throw new AssertionError("Unknown task type " + taskType.name()); - } - model.putResult(text, results); - InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + var result = randomSemanticText(field, model, List.of(text), requestContentType); + model.putResult(text, result); + expectedDocMap.put(field, result); } - Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { - new BulkItemRequest(id, new IndexRequest("index").source(docMap)), - new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; } private static StaticModel randomStaticModel() { @@ -320,7 +304,7 @@ private static StaticModel randomStaticModel() { } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -335,11 +319,15 @@ private static class StaticModel extends TestModel { } ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + SemanticTextField result = resultMap.get(text); + if (result == null) { + return new ChunkedSparseEmbeddingResults(List.of()); + } + return toChunkedResult(result); } - void putResult(String text, ChunkedInferenceServiceResults results) { - resultMap.put(text, results); + void putResult(String text, SemanticTextField result) { + resultMap.put(text, result); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java deleted file mode 100644 index 37e4e5e774bec..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java +++ /dev/null @@ -1,629 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.document.FeatureField; -import org.apache.lucene.index.IndexableField; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.LuceneDocument; -import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MetadataMapperTestCase; -import org.elasticsearch.index.mapper.NestedLookup; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.LeafNestedDocuments; -import org.elasticsearch.search.NestedDocuments; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.model.TestModel; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; - -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; - -public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - - private record VisitedChildDocInfo(String path) {} - - private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} - - @Override - protected String fieldName() { - return InferenceMetadataFieldMapper.NAME; - } - - @Override - protected boolean isConfigurable() { - return false; - } - - @Override - protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available - } - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException { - - } - - @Override - protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY)); - } - - public void testSuccessfulParse() throws IOException { - for (int depth = 1; depth < 4; depth++) { - final String fieldName1 = randomFieldName(depth); - final String fieldName2 = randomFieldName(depth + 1); - - Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); - Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); - XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); - }); - - MapperService mapperService = createMapperService(mapping); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); - DocumentMapper documentMapper = mapperService.documentMapper(); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) - ) - ) - ) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - for (int i = 0; i < 3; i++) { - assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); - } - // nested docs are in reversed order - assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); - assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); - assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - - withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - mapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), - new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), - new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) - ); - - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + CHUNKS, - List.of("a") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + CHUNKS, - List.of("a", "b") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + CHUNKS, - List.of("d") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + CHUNKS, - List.of("z") - ), - 10 - ); - assertEquals(0, topDocs.totalHits.value); - } - }); - } - } - - public void testMissingSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - - DocumentMapper documentMapper = createDocumentMapper( - mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) - ); - - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - true, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(true, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]") - ); - } - } - - public void testExtraSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) - ); - - DocumentMapper documentMapper = createDocumentMapper( - mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) - ); - - Consumer checkParsedDocument = d -> { - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); - - List luceneDocs = d.docs(); - assertEquals(2, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), d.rootDoc(), visitedChildDocs); - assertEquals(d.rootDoc(), luceneDocs.get(1)); - assertNull(luceneDocs.get(1).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - }; - - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", "extra_value") - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", Map.of("k1", "v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", List.of("v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - Map extraSubfields = new HashMap<>(); - extraSubfields.put("extra_key", null); - - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - extraSubfields - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - } - - public void testMissingSemanticTextMapping() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> {})); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults( - fieldName, - randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), - List.of("a b") - ) - ) - ) - ) - ) - ); - assertThat( - ex.getMessage(), - containsString( - Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ) - ); - } - - public void testMissingInferenceId() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - IllegalArgumentException ex = expectThrows( - DocumentParsingException.class, - IllegalArgumentException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .startObject(SemanticTextModelSettings.NAME) - .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) - .endObject() - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); - } - - public void testMissingModelSettings() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); - } - - public void testMissingTaskType() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") - .startObject(SemanticTextModelSettings.NAME) - .endObject() - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); - } - - private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { - mappingBuilder.startObject(fieldName); - mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("inference_id", modelId); - mappingBuilder.endObject(); - } - - public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - double[] values = new double[model.getServiceSettings().dimensions()]; - for (int j = 0; j < values.length; j++) { - values[j] = randomDouble(); - } - chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); - } - return new ChunkedTextEmbeddingResults(chunks); - } - - public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - var tokens = new ArrayList(); - for (var token : input.split("\\s+")) { - tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); - } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); - } - return new ChunkedSparseEmbeddingResults(chunks); - } - - private static SemanticTextInferenceResults randomSemanticTextInferenceResults( - String semanticTextFieldName, - Model model, - List chunks - ) { - ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); - default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); - }; - return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); - } - - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults - ) throws IOException { - addSemanticTextInferenceResults( - sourceBuilder, - semanticTextInferenceResults, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of() - ); - } - - @SuppressWarnings("unchecked") - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults, - SparseVectorSubfieldOptions sparseVectorSubfieldOptions, - boolean includeTextSubfield, - Map extraSubfields - ) throws IOException { - Map inferenceResultsMap = new LinkedHashMap<>(); - for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceResultsMap, - semanticTextInferenceResult.fieldName, - semanticTextInferenceResult.model, - semanticTextInferenceResult.results - ); - Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(CHUNKS); - for (var entry : fieldResultList) { - if (includeTextSubfield == false) { - entry.remove(INFERENCE_CHUNKS_TEXT); - } - if (sparseVectorSubfieldOptions.include == false) { - entry.remove(INFERENCE_CHUNKS_RESULTS); - } - entry.putAll(extraSubfields); - } - } - sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); - } - - static String randomFieldName(int numLevel) { - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < numLevel; i++) { - if (i > 0) { - builder.append('.'); - } - builder.append(randomAlphaOfLengthBetween(5, 15)); - } - return builder.toString(); - } - - private static Model randomModel(TaskType taskType) { - String serviceName = randomAlphaOfLengthBetween(5, 10); - String inferenceId = randomAlphaOfLengthBetween(5, 10); - return new TestModel( - inferenceId, - taskType, - serviceName, - new TestModel.TestServiceSettings("my-model"), - new TestModel.TestTaskSettings(randomIntBetween(1, 100)), - new TestModel.TestSecretSettings(randomAlphaOfLength(10)) - ); - } - - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { - NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); - assertNotNull(mapper); - - BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); - BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); - for (String token : tokens) { - queryBuilder.add( - new BooleanClause(new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)), BooleanClause.Occur.MUST) - ); - } - queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); - } - - private static void assertValidChildDoc( - LuceneDocument childDoc, - LuceneDocument expectedParent, - Collection visitedChildDocs - ) { - assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); - } - - private static void assertChildLeafNestedDocument( - LeafNestedDocuments leaf, - int advanceToDoc, - int expectedRootDoc, - Set visitedNestedIdentities - ) throws IOException { - - assertNotNull(leaf.advance(advanceToDoc)); - assertEquals(advanceToDoc, leaf.doc()); - assertEquals(expectedRootDoc, leaf.rootDoc()); - assertNotNull(leaf.nestedIdentity()); - visitedNestedIdentities.add(leaf.nestedIdentity()); - } - - private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { - int count = 0; - for (IndexableField field : doc.getFields()) { - if (field instanceof FeatureField featureField) { - assertThat(featureField.name(), equalTo(fieldName)); - ++count; - } - } - assertThat(count, equalTo(expectedCount)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 1b5311ac9effb..a6f0fa83eab37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -7,32 +7,65 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.LuceneDocument; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedLookup; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.LeafNestedDocuments; +import org.elasticsearch.search.NestedDocuments; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.AssumptionViolatedException; import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomModel; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -55,7 +88,7 @@ protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { @Override protected Object getSampleValueForDocument() { - return "value"; + return null; } @Override @@ -98,7 +131,7 @@ public void testDefaults() throws Exception { assertTrue(fields.isEmpty()); } - public void testInferenceIdNotPresent() throws IOException { + public void testInferenceIdNotPresent() { Exception e = expectThrows( MapperParsingException.class, () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) @@ -112,6 +145,7 @@ public void testCannotBeUsedInMultiFields() { b.startObject("fields"); b.startObject("semantic"); b.field("type", "semantic_text"); + b.field("inference_id", "my_inference_id"); b.endObject(); b.endObject(); }))); @@ -136,7 +170,7 @@ public void testUpdatesToInferenceIdNotSupported() throws IOException { public void testUpdateModelSettings() throws IOException { for (int depth = 1; depth < 5; depth++) { - String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + String fieldName = randomFieldName(depth); MapperService mapperService = createMapperService( mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) ); @@ -157,7 +191,7 @@ public void testUpdateModelSettings() throws IOException { ) ) ); - assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + assertThat(exc.getMessage(), containsString("Required [task_type]")); } { merge( @@ -220,12 +254,7 @@ public void testUpdateModelSettings() throws IOException { } static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( - MapperBuilderContext.root(false, false), - mapperService.mappingLookup().getMapping().getRoot(), - fieldName.split("\\.") - ); - Mapper mapper = res.mapper(); + Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; @@ -235,31 +264,257 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); - assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); - assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); - NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + NestedObjectMapper chunksMapper = mapperService.mappingLookup() .nestedLookup() .getNestedMappers() - .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); - assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); - Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + .get(getChunksFieldName(fieldName)); + assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); + Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD.getPreferredName()); assertNotNull(textMapper); assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; assertFalse(textFieldMapper.fieldType().isIndexed()); assertFalse(textFieldMapper.fieldType().hasDocValues()); if (expectedModelSettings) { - assertNotNull(semanticFieldMapper.getModelSettings()); - Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); + Mapper inferenceMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); assertNotNull(inferenceMapper); - switch (semanticFieldMapper.getModelSettings().taskType()) { + switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); default -> throw new AssertionError("Invalid task type"); } } else { - assertNull(semanticFieldMapper.getModelSettings()); + assertNull(semanticFieldMapper.fieldType().getModelSettings()); + } + } + + public void testSuccessfulParse() throws IOException { + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticText(fieldName1, model1, List.of("a b", "c"), XContentType.JSON), + randomSemanticText(fieldName2, model2, List.of("d e f"), XContentType.JSON) + ) + ) + ) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + } + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); + assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); + assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) + ); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } + } + + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field( + MODEL_SETTINGS_FIELD.getPreferredName(), + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null) + ) + .field(CHUNKS_FIELD.getPreferredName(), List.of()) + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .startObject(MODEL_SETTINGS_FIELD.getPreferredName()) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("failed to parse field [model_settings]")); + } + + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + mappingBuilder.startObject(fieldName); + mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mappingBuilder.field("inference_id", modelId); + mappingBuilder.endObject(); + } + + private static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List semanticTextInferenceResults) + throws IOException { + for (var field : semanticTextInferenceResults) { + sourceBuilder.field(field.fieldName()); + sourceBuilder.value(field); + } + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); + } + + private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String fieldName, List tokens) { + NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(getChunksFieldName(fieldName)); + assertNotNull(mapper); + + BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (String token : tokens) { + queryBuilder.add( + new BooleanClause(new TermQuery(new Term(getEmbeddingsFieldName(fieldName), token)), BooleanClause.Occur.MUST) + ); + } + queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); + + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + } + + private static void assertChildLeafNestedDocument( + LeafNestedDocuments leaf, + int advanceToDoc, + int expectedRootDoc, + Set visitedNestedIdentities + ) throws IOException { + + assertNotNull(leaf.advance(advanceToDoc)); + assertEquals(advanceToDoc, leaf.doc()); + assertEquals(expectedRootDoc, leaf.rootDoc()); + assertNotNull(leaf.nestedIdentity()); + visitedNestedIdentities.add(leaf.nestedIdentity()); + } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } } + assertThat(count, equalTo(expectedCount)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java new file mode 100644 index 0000000000000..c5ab6eb1f9e15 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.model.TestModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextFieldTests extends AbstractXContentTestCase { + private static final String NAME = "field"; + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + } + + @Override + protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { + assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); + assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); + assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); + assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); + for (int i = 0; i < newInstance.inference().chunks().size(); i++) { + assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); + switch (modelSettings.taskType()) { + case TEXT_EMBEDDING -> { + double[] expectedVector = parseDenseVector( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + expectedInstance.contentType() + ); + double[] newVector = parseDenseVector( + newInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + newInstance.contentType() + ); + assertArrayEquals(expectedVector, newVector, 0f); + } + case SPARSE_EMBEDDING -> { + List expectedTokens = parseWeightedTokens( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + expectedInstance.contentType() + ); + List newTokens = parseWeightedTokens( + newInstance.inference().chunks().get(i).rawEmbeddings(), + newInstance.contentType() + ); + assertThat(newTokens, equalTo(expectedTokens)); + } + default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); + } + } + } + + @Override + protected SemanticTextField createTestInstance() { + List rawValues = randomList(1, 5, () -> randomAlphaOfLengthBetween(10, 20)); + return randomSemanticText( + NAME, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + rawValues, + randomFrom(XContentType.values()) + ); + } + + @Override + protected SemanticTextField doParseInstance(XContentParser parser) throws IOException { + return SemanticTextField.parse(parser, new Tuple<>(NAME, parser.contentType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[model.getServiceSettings().dimensions()]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } + + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + + public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) { + ChunkedInferenceServiceResults results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, inputs); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); + }; + return new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), List.of(results), contentType) + ), + contentType + ); + } + + public static Model randomModel(TaskType taskType) { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + taskType, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) { + switch (field.inference().modelSettings().taskType()) { + case SPARSE_EMBEDDING -> { + List chunks = new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + case TEXT_EMBEDDING -> { + List chunks = + new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + double[] values = parseDenseVector( + chunk.rawEmbeddings(), + field.inference().modelSettings().dimensions(), + field.contentType() + ); + chunks.add( + new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk( + chunk.text(), + values + ) + ); + } + return new ChunkedTextEmbeddingResults(chunks); + } + default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); + } + } + + private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + parser.nextToken(); + assertThat(parser.currentToken(), equalTo(XContentParser.Token.START_ARRAY)); + double[] values = new double[numDims]; + for (int i = 0; i < numDims; i++) { + assertThat(parser.nextToken(), equalTo(XContentParser.Token.VALUE_NUMBER)); + values[i] = parser.doubleValue(); + } + assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_ARRAY)); + return values; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static List parseWeightedTokens(BytesReference value, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + Map map = parser.map(); + List weightedTokens = new ArrayList<>(); + for (var entry : map.entrySet()) { + weightedTokens.add(new TextExpansionResults.WeightedToken(entry.getKey(), ((Number) entry.getValue()).floatValue())); + } + return weightedTokens; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 0a07a88d230ef..e567b2103e527 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -80,16 +80,14 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.another_inference_field.chunks.0.inference - --- "text expansion documents do not create new mappings": - do: @@ -117,16 +115,14 @@ setup: index: test-dense-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.another_inference_field.chunks.0.inference - --- "text embeddings documents do not create new mappings": @@ -155,8 +151,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } - do: update: @@ -171,17 +167,14 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } - match: { _source.non_inference_field: "another non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } - --- "Updating semantic_text fields recalculates embeddings": - do: @@ -198,12 +191,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - do: bulk: @@ -217,12 +209,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "I am a test" } - - match: { _source.another_inference_field: "I am a teapot" } + - match: { _source.inference_field.text: "I am a test" } + - match: { _source.inference_field.inference.chunks.0.text: "I am a test" } + - match: { _source.another_inference_field.text: "I am a teapot" } + - match: { _source.another_inference_field.inference.chunks.0.text: "I am a teapot" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "I am a test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "I am a teapot" } - do: update: @@ -238,12 +229,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "updated inference test" } - - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.inference_field.text: "updated inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "updated inference test" } + - match: { _source.another_inference_field.text: "another updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } - do: bulk: @@ -257,12 +247,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "bulk inference test" } - - match: { _source.another_inference_field: "bulk updated inference test" } + - match: { _source.inference_field.text: "bulk inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "bulk inference test" } + - match: { _source.another_inference_field.text: "bulk updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "bulk updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "bulk inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "bulk updated inference test" } --- "Reindex works for semantic_text fields": @@ -280,8 +269,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } - do: indices.refresh: { } @@ -314,17 +303,14 @@ setup: index: destination-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } - --- "Fails for non-existent inference": - do: @@ -378,22 +364,6 @@ setup: - match: { items.0.update.status: 400 } - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } ---- -"Fails when providing inference results and there is no value for field": - - do: - catch: /The field \[inference_field\] is referenced in the \[_inference\] metadata field but has no value/ - index: - index: test-sparse-index - id: doc_1 - body: - _inference: - inference_field: - chunks: - - text: "inference test" - inference: - "hello": 0.123 - - --- "semantic_text copy_to calculate inference for source fields": - do: @@ -426,14 +396,14 @@ setup: index: test-copy-to-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - length: { _source._inference.inference_field.chunks: 3 } - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.inference_field.chunks.0.text - - exists: _source._inference.inference_field.chunks.1.inference - - exists: _source._inference.inference_field.chunks.1.text - - exists: _source._inference.inference_field.chunks.2.inference - - exists: _source._inference.inference_field.chunks.2.text + - match: { _source.inference_field.text: "inference test" } + - length: { _source.inference_field.inference.chunks: 3 } + - match: { _source.inference_field.inference.chunks.0.text: "another copy_to inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.1.text: "inference test" } + - exists: _source.inference_field.inference.chunks.1.embeddings + - match: { _source.inference_field.inference.chunks.2.text: "copy_to inference test" } + - exists: _source.inference_field.inference.chunks.2.embeddings --- diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 9dc109b3fb81d..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -66,23 +66,3 @@ setup: id: doc_1 body: dense_field: "you know, for testing" - ---- -"Inference section contains unreferenced fields": - - do: - catch: /Field \[unknown_field\] is not registered as a \[semantic_text\] field type/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - unknown_field: - inference_id: dense-inference-id - model_settings: - task_type: text_embedding - chunks: - - text: "inference test" - inference: [ 0.1, 0.2, 0.3, 0.4, 0.5 ] - - text: "another inference test" - inference: [ -0.1, -0.2, -0.3, -0.4, -0.5 ] From 17f1fde937949b6649d84a16fae8eedfee9f189e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 5 Apr 2024 12:15:02 +0200 Subject: [PATCH 14/29] semantic_text: Add cluster metadata information for inference field mappers --- .../cluster/ClusterStateDiffIT.java | 6 +- .../org/elasticsearch/TransportVersions.java | 1 + .../cluster/metadata/IndexMetadata.java | 70 ++++++++++ .../metadata/InferenceFieldMetadata.java | 127 ++++++++++++++++++ .../metadata/MetadataCreateIndexService.java | 9 +- .../metadata/MetadataMappingService.java | 7 +- .../index/mapper/InferenceFieldMapper.java | 28 ++++ .../index/mapper/MappingLookup.java | 31 ++++- .../cluster/metadata/IndexMetadataTests.java | 33 ++++- .../metadata/InferenceFieldMetadataTests.java | 66 +++++++++ 10 files changed, 362 insertions(+), 16 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java create mode 100644 server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 409fbdd70333e..e0dbc74567053 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -61,6 +61,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; import static org.elasticsearch.cluster.metadata.AliasMetadata.newAliasMetadataBuilder; +import static org.elasticsearch.cluster.metadata.IndexMetadataTests.randomInferenceFields; import static org.elasticsearch.cluster.routing.RandomShardRoutingMutator.randomChange; import static org.elasticsearch.cluster.routing.TestShardRouting.shardRoutingBuilder; import static org.elasticsearch.cluster.routing.UnassignedInfoTests.randomUnassignedInfo; @@ -571,7 +572,7 @@ public IndexMetadata randomCreate(String name) { @Override public IndexMetadata randomChange(IndexMetadata part) { IndexMetadata.Builder builder = IndexMetadata.builder(part); - switch (randomIntBetween(0, 2)) { + switch (randomIntBetween(0, 3)) { case 0: builder.settings(Settings.builder().put(part.getSettings()).put(randomSettings(Settings.EMPTY))); break; @@ -585,6 +586,9 @@ public IndexMetadata randomChange(IndexMetadata part) { case 2: builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; + case 3: + builder.putInferenceFields(randomInferenceFields()); + break; default: throw new IllegalArgumentException("Shouldn't be here"); } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 57a3afe083707..95e3353de4120 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -163,6 +163,7 @@ static TransportVersion def(int id) { public static final TransportVersion CCR_STATS_API_TIMEOUT_PARAM = def(8_622_00_0); public static final TransportVersion ESQL_ORDINAL_BLOCK = def(8_623_00_0); public static final TransportVersion ML_INFERENCE_COHERE_RERANK = def(8_624_00_0); + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_625_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 22672756bdaf0..b66da654f8a1c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,6 +540,8 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; + public static final String KEY_FIELD_INFERENCE = "field_inference"; + public static final String INDEX_STATE_FILE_PREFIX = "state-"; static final TransportVersion SYSTEM_INDEX_FLAG_ADDED = TransportVersions.V_7_10_0; @@ -574,6 +576,8 @@ public Iterator> settings() { @Nullable private final MappingMetadata mapping; + private final ImmutableOpenMap inferenceFields; + private final ImmutableOpenMap customData; private final Map> inSyncAllocationIds; @@ -642,6 +646,7 @@ private IndexMetadata( final int numberOfReplicas, final Settings settings, final MappingMetadata mapping, + final ImmutableOpenMap inferenceFields, final ImmutableOpenMap aliases, final ImmutableOpenMap customData, final Map> inSyncAllocationIds, @@ -692,6 +697,7 @@ private IndexMetadata( this.totalNumberOfShards = numberOfShards * (numberOfReplicas + 1); this.settings = settings; this.mapping = mapping; + this.inferenceFields = inferenceFields; this.customData = customData; this.aliases = aliases; this.inSyncAllocationIds = inSyncAllocationIds; @@ -748,6 +754,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.numberOfReplicas, this.settings, mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -806,6 +813,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, Maps.copyMapWithAddedOrReplacedEntry(this.inSyncAllocationIds, shardId, Set.copyOf(inSyncSet)), @@ -862,6 +870,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -918,6 +927,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -970,6 +980,7 @@ public IndexMetadata withIncrementedVersion() { this.numberOfReplicas, this.settings, this.mapping, + this.inferenceFields, this.aliases, this.customData, this.inSyncAllocationIds, @@ -1193,6 +1204,10 @@ public MappingMetadata mapping() { return mapping; } + public Map getInferenceFields() { + return inferenceFields; + } + @Nullable public IndexMetadataStats getStats() { return stats; @@ -1403,6 +1418,9 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } + if (inferenceFields.equals(that.inferenceFields) == false) { + return false; + } if (isSystem != that.isSystem) { return false; } @@ -1423,6 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); + result = 31 * result + inferenceFields.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1469,6 +1488,7 @@ private static class IndexMetadataDiff implements Diff { @Nullable private final Diff settingsDiff; private final Diff> mappings; + private final Diff> inferenceFields; private final Diff> aliases; private final Diff> customData; private final Diff>> inSyncAllocationIds; @@ -1500,6 +1520,7 @@ private static class IndexMetadataDiff implements Diff { : ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, after.mapping).build(), DiffableUtils.getStringKeySerializer() ); + inferenceFields = DiffableUtils.diff(before.inferenceFields, after.inferenceFields, DiffableUtils.getStringKeySerializer()); aliases = DiffableUtils.diff(before.aliases, after.aliases, DiffableUtils.getStringKeySerializer()); customData = DiffableUtils.diff(before.customData, after.customData, DiffableUtils.getStringKeySerializer()); inSyncAllocationIds = DiffableUtils.diff( @@ -1524,6 +1545,8 @@ private static class IndexMetadataDiff implements Diff { new DiffableUtils.DiffableValueReader<>(DiffableStringMap::readFrom, DiffableStringMap::readDiffFrom); private static final DiffableUtils.DiffableValueReader ROLLOVER_INFO_DIFF_VALUE_READER = new DiffableUtils.DiffableValueReader<>(RolloverInfo::new, RolloverInfo::readDiffFrom); + private static final DiffableUtils.DiffableValueReader INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER = + new DiffableUtils.DiffableValueReader<>(InferenceFieldMetadata::new, InferenceFieldMetadata::readDiffFrom); IndexMetadataDiff(StreamInput in) throws IOException { index = in.readString(); @@ -1546,6 +1569,15 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields = DiffableUtils.readImmutableOpenMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + INFERENCE_FIELDS_METADATA_DIFF_VALUE_READER + ); + } else { + inferenceFields = DiffableUtils.emptyDiff(); + } aliases = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), ALIAS_METADATA_DIFF_VALUE_READER); customData = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), CUSTOM_DIFF_VALUE_READER); inSyncAllocationIds = DiffableUtils.readJdkMapDiff( @@ -1595,6 +1627,9 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + inferenceFields.writeTo(out); + } aliases.writeTo(out); customData.writeTo(out); inSyncAllocationIds.writeTo(out); @@ -1628,6 +1663,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.mapping = mappings.apply( ImmutableOpenMap.builder(1).fPut(MapperService.SINGLE_MAPPING_NAME, part.mapping).build() ).get(MapperService.SINGLE_MAPPING_NAME); + builder.inferenceFields.putAllFromMap(inferenceFields.apply(part.inferenceFields)); builder.aliases.putAllFromMap(aliases.apply(part.aliases)); builder.customMetadata.putAllFromMap(customData.apply(part.customData)); builder.inSyncAllocationIds.putAll(inSyncAllocationIds.apply(part.inSyncAllocationIds)); @@ -1673,6 +1709,10 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); + } int aliasesSize = in.readVInt(); for (int i = 0; i < aliasesSize; i++) { AliasMetadata aliasMd = new AliasMetadata(in); @@ -1733,6 +1773,9 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException mapping.writeTo(out); } } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeCollection(inferenceFields.values()); + } out.writeCollection(aliases.values()); out.writeMap(customData, StreamOutput::writeWriteable); out.writeMap( @@ -1788,6 +1831,7 @@ public static class Builder { private long[] primaryTerms = null; private Settings settings = Settings.EMPTY; private MappingMetadata mapping; + private final ImmutableOpenMap.Builder inferenceFields; private final ImmutableOpenMap.Builder aliases; private final ImmutableOpenMap.Builder customMetadata; private final Map> inSyncAllocationIds; @@ -1802,6 +1846,7 @@ public static class Builder { public Builder(String index) { this.index = index; + this.inferenceFields = ImmutableOpenMap.builder(); this.aliases = ImmutableOpenMap.builder(); this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); @@ -1819,6 +1864,7 @@ public Builder(IndexMetadata indexMetadata) { this.settings = indexMetadata.getSettings(); this.primaryTerms = indexMetadata.primaryTerms.clone(); this.mapping = indexMetadata.mapping; + this.inferenceFields = ImmutableOpenMap.builder(indexMetadata.inferenceFields); this.aliases = ImmutableOpenMap.builder(indexMetadata.aliases); this.customMetadata = ImmutableOpenMap.builder(indexMetadata.customData); this.routingNumShards = indexMetadata.routingNumShards; @@ -2059,6 +2105,16 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder putInferenceField(InferenceFieldMetadata value) { + this.inferenceFields.put(value.getName(), value); + return this; + } + + public Builder putInferenceFields(Map values) { + this.inferenceFields.putAllFromMap(values); + return this; + } + public IndexMetadata build() { return build(false); } @@ -2221,6 +2277,7 @@ IndexMetadata build(boolean repair) { numberOfReplicas, settings, mapping, + inferenceFields.build(), aliasesMap, newCustomMetadata, Map.ofEntries(denseInSyncAllocationIds), @@ -2379,6 +2436,14 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } + if (indexMetadata.getInferenceFields().isEmpty() == false) { + builder.startObject(KEY_FIELD_INFERENCE); + for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { + field.toXContent(builder, params); + } + builder.endObject(); + } + builder.endObject(); } @@ -2456,6 +2521,11 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { + private static final String INFERENCE_ID_FIELD = "inference_id"; + private static final String SOURCE_FIELDS_FIELD = "source_fields"; + + private final String name; + private final String inferenceId; + private final String[] sourceFields; + + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { + this.name = Objects.requireNonNull(name); + this.inferenceId = Objects.requireNonNull(inferenceId); + this.sourceFields = Objects.requireNonNull(sourceFields); + } + + public InferenceFieldMetadata(StreamInput input) throws IOException { + this.name = input.readString(); + this.inferenceId = input.readString(); + this.sourceFields = input.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeString(inferenceId); + out.writeStringArray(sourceFields); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceFieldMetadata that = (InferenceFieldMetadata) o; + return Objects.equals(name, that.name) + && Objects.equals(inferenceId, that.inferenceId) + && Arrays.equals(sourceFields, that.sourceFields); + } + + @Override + public int hashCode() { + int result = Objects.hash(name, inferenceId); + result = 31 * result + Arrays.hashCode(sourceFields); + return result; + } + + public String getName() { + return name; + } + + public String getInferenceId() { + return inferenceId; + } + + public String[] getSourceFields() { + return sourceFields; + } + + public static Diff readDiffFrom(StreamInput in) throws IOException { + return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(name); + builder.field(INFERENCE_ID_FIELD, inferenceId); + builder.array(SOURCE_FIELDS_FIELD, sourceFields); + return builder.endObject(); + } + + public static InferenceFieldMetadata fromXContent(XContentParser parser) throws IOException { + final String name = parser.currentName(); + + XContentParser.Token token = parser.nextToken(); + if (token == null) { + // no data... + return null; + } + String currentFieldName = null; + String inferenceId = null; + List inputFields = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.VALUE_STRING) { + if (INFERENCE_ID_FIELD.equals(currentFieldName)) { + inferenceId = parser.text(); + } + } else if (token == XContentParser.Token.START_ARRAY) { + if (SOURCE_FIELDS_FIELD.equals(currentFieldName)) { + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_STRING) { + inputFields.add(parser.text()); + } else { + parser.skipChildren(); + } + } + } + } else { + parser.skipChildren(); + } + } + return new InferenceFieldMetadata(name, inferenceId, inputFields.toArray(String[]::new)); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index da24f0b9d0dc5..52642e1de8ac9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1263,10 +1263,11 @@ static IndexMetadata buildIndexMetadata( indexMetadataBuilder.system(isSystem); // now, update the mappings with the actual source Map mappingsMetadata = new HashMap<>(); - DocumentMapper mapper = documentMapperSupplier.get(); - if (mapper != null) { - MappingMetadata mappingMd = new MappingMetadata(mapper); - mappingsMetadata.put(mapper.type(), mappingMd); + DocumentMapper docMapper = documentMapperSupplier.get(); + if (docMapper != null) { + MappingMetadata mappingMd = new MappingMetadata(docMapper); + mappingsMetadata.put(docMapper.type(), mappingMd); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 7e2c0849a6fad..e7c2bb9ae9b9a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -201,9 +201,10 @@ private static ClusterState applyRequest( IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(indexMetadata); // Mapping updates on a single type may have side-effects on other types so we need to // update mapping metadata on all types - DocumentMapper mapper = mapperService.documentMapper(); - if (mapper != null) { - indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); + DocumentMapper docMapper = mapperService.documentMapper(); + if (docMapper != null) { + indexMetadataBuilder.putMapping(new MappingMetadata(docMapper)); + indexMetadataBuilder.putInferenceFields(docMapper.mappers().inferenceFields()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java new file mode 100644 index 0000000000000..078ef391f17ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.inference.InferenceService; + +import java.util.Set; + +/** + * Field mapper that requires to transform its input before indexation through the {@link InferenceService}. + */ +public interface InferenceFieldMapper { + String NAME = "_inference"; + + /** + * Retrieve the inference metadata associated with this mapper. + * + * @param sourcePaths The source path that populates the input for the field (before inference) + */ + InferenceFieldMetadata getMetadata(Set sourcePaths); +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 673593cc6e240..bf879f30e5a29 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -10,9 +10,11 @@ import org.apache.lucene.codecs.PostingsFormat; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; +import org.elasticsearch.inference.InferenceService; import java.util.ArrayList; import java.util.Collection; @@ -47,6 +49,7 @@ private CacheKey() {} /** Full field name to mapper */ private final Map fieldMappers; private final Map objectMappers; + private final Map inferenceFields; private final int runtimeFieldMappersCount; private final NestedLookup nestedLookup; private final FieldTypeLookup fieldTypeLookup; @@ -84,12 +87,12 @@ private static void collect( Collection fieldMappers, Collection fieldAliasMappers ) { - if (mapper instanceof ObjectMapper) { - objectMappers.add((ObjectMapper) mapper); - } else if (mapper instanceof FieldMapper) { - fieldMappers.add((FieldMapper) mapper); - } else if (mapper instanceof FieldAliasMapper) { - fieldAliasMappers.add((FieldAliasMapper) mapper); + if (mapper instanceof ObjectMapper objectMapper) { + objectMappers.add(objectMapper); + } else if (mapper instanceof FieldMapper fieldMapper) { + fieldMappers.add(fieldMapper); + } else if (mapper instanceof FieldAliasMapper fieldAliasMapper) { + fieldAliasMappers.add(fieldAliasMapper); } else { throw new IllegalStateException("Unrecognized mapper type [" + mapper.getClass().getSimpleName() + "]."); } @@ -174,6 +177,15 @@ private MappingLookup( final Collection runtimeFields = mapping.getRoot().runtimeFields(); this.fieldTypeLookup = new FieldTypeLookup(mappers, aliasMappers, runtimeFields); + + Map inferenceFields = new HashMap<>(); + for (FieldMapper mapper : mappers) { + if (mapper instanceof InferenceFieldMapper inferenceFieldMapper) { + inferenceFields.put(mapper.name(), inferenceFieldMapper.getMetadata(fieldTypeLookup.sourcePaths(mapper.name()))); + } + } + this.inferenceFields = Map.copyOf(inferenceFields); + if (runtimeFields.isEmpty()) { // without runtime fields this is the same as the field type lookup this.indexTimeLookup = fieldTypeLookup; @@ -360,6 +372,13 @@ public Map objectMappers() { return objectMappers; } + /** + * Returns a map containing all fields that require to run inference (through the {@link InferenceService} prior to indexation. + */ + public Map inferenceFields() { + return inferenceFields; + } + public NestedLookup nestedLookup() { return nestedLookup; } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 5cc1a7206e7e4..45ffba25eb558 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -83,6 +83,8 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; + Map dynamicFields = randomInferenceFields(); + IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) .creationDate(randomLong()) @@ -107,6 +109,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) + .putInferenceFields(dynamicFields) .build(); assertEquals(system, metadata.isSystem()); @@ -141,6 +144,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getInferenceFields(), fromXContentMeta.getInferenceFields()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -162,8 +166,9 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getCustomData(), deserialized.getCustomData()); assertEquals(metadata.isSystem(), deserialized.isSystem()); assertEquals(metadata.getStats(), deserialized.getStats()); - assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); - assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); + assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); + assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); } } @@ -547,10 +552,34 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } + public void testInferenceFieldMetadata() { + Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); + assertTrue(idxMeta1.getInferenceFields().isEmpty()); + + Map dynamicFields = randomInferenceFields(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).putInferenceFields(dynamicFields).build(); + assertThat(idxMeta2.getInferenceFields(), equalTo(dynamicFields)); + } + private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } + public static Map randomInferenceFields() { + Map map = new HashMap<>(); + int numFields = randomIntBetween(0, 5); + for (int i = 0; i < numFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + map.put(field, randomInferenceFieldMetadata(field)); + } + return map; + } + + private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { + return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); + } + private IndexMetadataStats randomIndexStats(int numberOfShards) { IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards); int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java new file mode 100644 index 0000000000000..958d86535ae76 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.function.Predicate; + +import static org.hamcrest.Matchers.equalTo; + +public class InferenceFieldMetadataTests extends AbstractXContentTestCase { + + public void testSerialization() throws IOException { + final InferenceFieldMetadata before = createTestItem(); + final BytesStreamOutput out = new BytesStreamOutput(); + before.writeTo(out); + + final StreamInput in = out.bytes().streamInput(); + final InferenceFieldMetadata after = new InferenceFieldMetadata(in); + + assertThat(after, equalTo(before)); + } + + @Override + protected InferenceFieldMetadata createTestInstance() { + return createTestItem(); + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field + } + + @Override + protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + assertEquals(XContentParser.Token.FIELD_NAME, parser.currentToken()); + InferenceFieldMetadata inferenceMetadata = InferenceFieldMetadata.fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return inferenceMetadata; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + private static InferenceFieldMetadata createTestItem() { + String name = randomAlphaOfLengthBetween(3, 10); + String inferenceId = randomIdentifier(); + String[] inputFields = generateRandomStringArray(5, 10, false, false); + return new InferenceFieldMetadata(name, inferenceId, inputFields); + } +} From 4025d2ccddafdaa16fd3f2b2bbca191e5b08b2f0 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 5 Apr 2024 12:28:43 +0200 Subject: [PATCH 15/29] Add javadoc --- .../cluster/metadata/InferenceFieldMetadata.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 0cd3f05f250a3..23608f075f0b5 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -22,6 +22,13 @@ import java.util.List; import java.util.Objects; +/** + * Contains inference field data for fields. + * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need + * to perform inference for specific fields in an index. + * Given that the coordinator node does not necessarily have mapping information for all indices (only for those that have shards + * in the node), the field inference information must be stored in the IndexMetadata and broadcasted to all nodes. + */ public final class InferenceFieldMetadata implements SimpleDiffable, ToXContentFragment { private static final String INFERENCE_ID_FIELD = "inference_id"; private static final String SOURCE_FIELDS_FIELD = "source_fields"; From d78acc3d2d00a930361e472c37ab42504c917370 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 5 Apr 2024 12:59:14 +0200 Subject: [PATCH 16/29] Fix test helper --- .../org/elasticsearch/cluster/metadata/DataStreamTestHelper.java | 1 + 1 file changed, 1 insertion(+) diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index 2980b8a48636a..d91cf625ab944 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -726,6 +726,7 @@ public static IndicesService mockIndicesServices(MappingLookup mappingLookup) th Mapping mapping = new Mapping(root, new MetadataFieldMapper[0], null); DocumentMapper documentMapper = mock(DocumentMapper.class); when(documentMapper.mapping()).thenReturn(mapping); + when(documentMapper.mappers()).thenReturn(MappingLookup.EMPTY); when(documentMapper.mappingSource()).thenReturn(mapping.toCompressedXContent()); RoutingFieldMapper routingFieldMapper = mock(RoutingFieldMapper.class); when(routingFieldMapper.required()).thenReturn(false); From 7a2b70b9c33b064feedbce2aa5b7131b9f8a8b83 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 5 Apr 2024 14:24:14 +0200 Subject: [PATCH 17/29] PR Review comments --- .../java/org/elasticsearch/TransportVersions.java | 2 +- .../cluster/metadata/IndexMetadata.java | 14 +++++++------- .../index/mapper/InferenceFieldMapper.java | 1 - 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 95e3353de4120..10f80ae739fdf 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -163,7 +163,7 @@ static TransportVersion def(int id) { public static final TransportVersion CCR_STATS_API_TIMEOUT_PARAM = def(8_622_00_0); public static final TransportVersion ESQL_ORDINAL_BLOCK = def(8_623_00_0); public static final TransportVersion ML_INFERENCE_COHERE_RERANK = def(8_624_00_0); - public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_625_00_0); + public static final TransportVersion INFERENCE_FIELDS_METADATA = def(8_625_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index b66da654f8a1c..529814e83ba38 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELD_INFERENCE = "field_inference"; + public static final String KEY_INFERENCE_FIELDS = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -1569,7 +1569,7 @@ private static class IndexMetadataDiff implements Diff { } primaryTerms = in.readVLongArray(); mappings = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), MAPPING_DIFF_VALUE_READER); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { inferenceFields = DiffableUtils.readImmutableOpenMapDiff( in, DiffableUtils.getStringKeySerializer(), @@ -1627,7 +1627,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeVLongArray(primaryTerms); mappings.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { inferenceFields.writeTo(out); } aliases.writeTo(out); @@ -1709,7 +1709,7 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function builder.putInferenceField(f)); } @@ -1773,7 +1773,7 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException mapping.writeTo(out); } } - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_FIELDS_METADATA)) { out.writeCollection(inferenceFields.values()); } out.writeCollection(aliases.values()); @@ -2437,7 +2437,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build } if (indexMetadata.getInferenceFields().isEmpty() == false) { - builder.startObject(KEY_FIELD_INFERENCE); + builder.startObject(KEY_INFERENCE_FIELDS); for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { field.toXContent(builder, params); } @@ -2521,7 +2521,7 @@ public static IndexMetadata fromXContent(XContentParser parser, Map Date: Mon, 8 Apr 2024 10:31:50 +0200 Subject: [PATCH 18/29] [feature/semantic-text] Handle chunked error (#107192) This PR handles inference inputs that return an error and marks the whole document as a failure when it happens. --- .../ShardBulkInferenceActionFilter.java | 32 ++++++--- .../ShardBulkInferenceActionFilterTests.java | 71 ++++++++++++++++--- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index eaa62a3aa743a..ddb11613fa5ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -37,6 +37,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -282,16 +283,27 @@ public void onResponse(List results) { var request = requests.get(i); var result = results.get(i); var acc = inferenceResults.get(request.id); - acc.addOrUpdateResponse( - new FieldInferenceResponse( - request.field(), - request.input(), - request.inputOrder(), - request.isOriginalFieldInput(), - inferenceProvider.model, - result - ) - ); + if (result instanceof ErrorChunkedInferenceResults error) { + acc.addFailure( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + error.getException(), + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } else { + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.input(), + request.inputOrder(), + request.isOriginalFieldInput(), + inferenceProvider.model, + result + ) + ); + } } } finally { onFinish(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7cfaeaae4c3a4..f2f9a2229f8c3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceService; @@ -33,7 +34,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -54,7 +55,9 @@ import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -144,6 +147,60 @@ public void testInferenceNotFound() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testItemFailures() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); + model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(3)); + + // item 0 is a failure + assertNotNull(bulkShardRequest.items()[0].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[0].getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + + // item 1 is a success + assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); + IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); + assertThat(XContentMapValues.extractValue("field1.text", actualRequest.sourceAsMap()), equalTo("I am a success")); + + // item 2 is a failure + assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); + failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[3]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { Map inferenceModelMap = new HashMap<>(); @@ -282,7 +339,7 @@ private static BulkItemRequest[] randomBulkItemRequest( continue; } var result = randomSemanticText(field, model, List.of(text), requestContentType); - model.putResult(text, result); + model.putResult(text, toChunkedResult(result)); expectedDocMap.put(field, result); } return new BulkItemRequest[] { @@ -304,7 +361,7 @@ private static StaticModel randomStaticModel() { } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -319,14 +376,10 @@ private static class StaticModel extends TestModel { } ChunkedInferenceServiceResults getResults(String text) { - SemanticTextField result = resultMap.get(text); - if (result == null) { - return new ChunkedSparseEmbeddingResults(List.of()); - } - return toChunkedResult(result); + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); } - void putResult(String text, SemanticTextField result) { + void putResult(String text, ChunkedInferenceServiceResults result) { resultMap.put(text, result); } } From dc46e88381eac36eb62da901729c1acd1eb376af Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 8 Apr 2024 19:50:18 +0200 Subject: [PATCH 19/29] Add test coverage for null constructor args --- .../cluster/metadata/InferenceFieldMetadataTests.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 958d86535ae76..bd4c87be51157 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -63,4 +63,10 @@ private static InferenceFieldMetadata createTestItem() { String[] inputFields = generateRandomStringArray(5, 10, false, false); return new InferenceFieldMetadata(name, inferenceId, inputFields); } + + public void testNullCtorArgsThrowException() { + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null)); + } } From 937572d00699223e7e5c936d1024ec2daba0e286 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 10:11:40 +0200 Subject: [PATCH 20/29] Add first query tests --- x-pack/plugin/inference/build.gradle | 2 +- .../20_semantic_text_field_mapper.yml | 59 ++++++++++++++++--- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 781261c330e78..a2811ffefca91 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -10,7 +10,7 @@ apply plugin: 'elasticsearch.internal-yaml-rest-test' restResources { restApi { - include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex', 'search' } } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 27f233436b925..b574aeec632ae 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -17,6 +17,7 @@ setup: "task_settings": { } } + - do: inference.put_model: task_type: text_embedding @@ -50,19 +51,61 @@ setup: type: text --- -"Sparse vector results format": +"Dense vector results are indexed as nested chunks and searchable": + - do: + bulk: + index: test-index + refresh: true + body: | + {"index":{}} + {"dense_field": "you know, for testing"} + {"index":{}} + {"dense_field": "some more tests"} + - do: - index: + search: index: test-index - id: doc_1 body: - sparse_field: "you know, for testing" + query: + nested: + path: dense_field.inference.chunks + query: + knn: + field: dense_field.inference.chunks.embeddings + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } --- -"Dense vector results format": +"Sparse vector results are indexed as nested chunks and searchable": - do: - index: + bulk: + index: test-index + refresh: true + body: | + {"index":{}} + {"sparse_field": "you know, for testing"} + {"index":{}} + {"sparse_field": "some more tests"} + + - do: + search: index: test-index - id: doc_1 body: - dense_field: "you know, for testing" + query: + nested: + path: sparse_field.inference.chunks + query: + text_expansion: + sparse_field.inference.chunks.embeddings: + model_id: sparse-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } + + From bef2214c4eb65429711708e5137fe65182fc2b1e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 10:57:00 +0200 Subject: [PATCH 21/29] Add inner_hits tests --- .../20_semantic_text_field_mapper.yml | 56 +++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index b574aeec632ae..df5073cfed525 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -58,9 +58,9 @@ setup: refresh: true body: | {"index":{}} - {"dense_field": "you know, for testing"} + {"dense_field": ["you know, for testing", "now with chunks"]} {"index":{}} - {"dense_field": "some more tests"} + {"dense_field": ["some more tests", "that include chunks"]} - do: search: @@ -80,6 +80,33 @@ setup: - match: { hits.total.value: 2 } - match: { hits.total.relation: eq } + # Search with inner hits + - do: + search: + _source: false + index: test-index + body: + query: + nested: + path: dense_field.inference.chunks + inner_hits: + _source: false + fields: [dense_field.inference.chunks.text] + query: + knn: + field: dense_field.inference.chunks.embeddings + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: "you know, for testing" + + - match: { hits.total.value: 2 } + - match: { hits.total.relation: eq } + - match: { hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.total.value: 2 } + - exists: hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.0.fields.dense_field\.inference\.chunks.0.text + - exists: hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.1.fields.dense_field\.inference\.chunks.0.text + + --- "Sparse vector results are indexed as nested chunks and searchable": - do: @@ -88,17 +115,35 @@ setup: refresh: true body: | {"index":{}} - {"sparse_field": "you know, for testing"} + {"sparse_field": ["you know, for testing", "now with chunks"]} {"index":{}} - {"sparse_field": "some more tests"} + {"sparse_field": ["some more tests", "that include chunks"]} + + - do: + search: + index: test-index + body: + query: + nested: + path: sparse_field.inference.chunks + query: + text_expansion: + sparse_field.inference.chunks.embeddings: + model_id: sparse-inference-id + model_text: "you know, for testing" + # Search with inner hits - do: search: + _source: false index: test-index body: query: nested: path: sparse_field.inference.chunks + inner_hits: + _source: false + fields: [sparse_field.inference.chunks.text] query: text_expansion: sparse_field.inference.chunks.embeddings: @@ -107,5 +152,8 @@ setup: - match: { hits.total.value: 2 } - match: { hits.total.relation: eq } + - match: { hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.total.value: 2 } + - exists: hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.0.fields.sparse_field\.inference\.chunks.0.text + - exists: hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.1.fields.sparse_field\.inference\.chunks.0.text From 84a2735740edbafe95020f07c1f0b1940f9fa71c Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 12:23:55 +0200 Subject: [PATCH 22/29] Add mapping incompatibility tests --- .../mapper/SemanticTextFieldMapper.java | 10 +- ...tic_text_inference_incompatible_models.yml | 190 ++++++++++++++++++ 2 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_incompatible_models.yml diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 2536825a9e0b7..08d11f7bd41f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -217,7 +217,15 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio try { conflicts.check(); } catch (Exception exc) { - throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); + throw new DocumentParsingException( + xContentLocation, + "Incompatible model settings for field [" + + name() + + "]. Check that the " + + INFERENCE_ID_FIELD.getPreferredName() + + " is not using different model settings", + exc + ); } mapper = this; } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_incompatible_models.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_incompatible_models.yml new file mode 100644 index 0000000000000..48a73a02ef645 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/30_semantic_text_inference_incompatible_models.yml @@ -0,0 +1,190 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + inference_id: sparse-inference-id + dense_field: + type: semantic_text + inference_id: dense-inference-id + + # Index a doc to set mappings internally + - do: + index: + index: test-index + id: doc_1 + body: + dense_field: "inference test" + sparse_field: "another inference test" + +--- +"Fails for non-compatible dimensions": + + - do: + inference.delete_model: + task_type: text_embedding + inference_id: dense-inference-id + + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 20, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: "some other test" + +--- +"Fails for non-compatible similarity": + + - do: + inference.delete_model: + task_type: text_embedding + inference_id: dense-inference-id + + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "dot_product", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: "some other test" + +--- +"Fails for non-compatible task type for dense vectors": + + - do: + inference.delete_model: + task_type: text_embedding + inference_id: dense-inference-id + + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: dense-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: "some other test" + +--- +"Fails for non-compatible task type for sparse vectors": + + - do: + inference.delete_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + + - do: + inference.put_model: + task_type: text_embedding + inference_id: sparse-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + catch: /Incompatible model settings for field \[sparse_field\].+/ + index: + index: test-index + id: doc_2 + body: + sparse_field: "some other test" + + From b0e6d430f3dbbb03b4187993db9fa3e3107a8be1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 13:13:45 +0200 Subject: [PATCH 23/29] Add semantic_text field mapper and inference generation --- .../org/elasticsearch/TransportVersions.java | 2 +- .../action/bulk/BulkOperation.java | 4 + .../action/bulk/BulkShardRequest.java | 28 + .../cluster/metadata/IndexMetadata.java | 6 +- .../metadata/InferenceFieldMetadata.java | 7 - .../index/mapper/DocumentParser.java | 4 + .../index/mapper/FieldMapper.java | 4 +- .../index/mapper/InferenceModelFieldType.java | 21 - .../index/mapper/MapperMergeContext.java | 4 +- .../vectors/DenseVectorFieldMapper.java | 16 +- .../vectors/SparseVectorFieldMapper.java | 7 +- .../inference/InferenceServiceResults.java | 2 + .../inference/SemanticTextModelSettings.java | 91 --- .../metadata/InferenceFieldMetadataTests.java | 6 - .../index/mapper/CopyToMapperTests.java | 7 + .../index/mapper/MultiFieldTests.java | 3 + .../metadata/DataStreamTestHelper.java | 4 +- .../index/mapper/MapperTestCase.java | 2 +- x-pack/plugin/inference/build.gradle | 12 + .../mock/AbstractTestInferenceService.java | 5 - .../TestDenseInferenceServiceExtension.java | 2 +- .../TestSparseInferenceServiceExtension.java | 12 +- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferencePlugin.java | 30 +- .../xpack/inference}/SemanticTextFeature.java | 2 +- .../ShardBulkInferenceActionFilter.java | 530 ++++++++++++++++++ .../inference/mapper/SemanticTextField.java | 328 +++++++++++ .../mapper/SemanticTextFieldMapper.java | 389 +++++++++++++ .../SemanticTextClusterMetadataTests.java | 104 ++++ .../ShardBulkInferenceActionFilterTests.java | 386 +++++++++++++ .../mapper/SemanticTextFieldMapperTests.java | 520 +++++++++++++++++ .../mapper/SemanticTextFieldTests.java | 219 ++++++++ .../xpack/inference/model/TestModel.java | 11 + .../xpack/inference/InferenceRestIT.java | 41 ++ .../inference/10_semantic_text_inference.yml | 444 +++++++++++++++ .../20_semantic_text_field_mapper.yml | 68 +++ .../CoordinatedInferenceIngestIT.java | 4 +- .../xpack/ml/MachineLearning.java | 14 +- .../ml/mapper/SemanticTextFieldMapper.java | 130 ----- .../mapper/SemanticTextFieldMapperTests.java | 118 ---- 40 files changed, 3173 insertions(+), 415 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/SemanticTextFeature.java (93%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 4bb00826bd4c3..c9d730138494e 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -164,7 +164,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_ORDINAL_BLOCK = def(8_623_00_0); public static final TransportVersion ML_INFERENCE_COHERE_RERANK = def(8_624_00_0); public static final TransportVersion INDEXING_PRESSURE_DOCUMENT_REJECTIONS_COUNT = def(8_625_00_0); - public static final TransportVersion INFERENCE_FIELDS_METADATA = def(8_625_00_0); + public static final TransportVersion INFERENCE_FIELDS_METADATA = def(8_626_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 412e4f3c875e8..6df34c6430d0c 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -294,6 +294,10 @@ private void executeBulkRequestsByShard( bulkRequest.getRefreshPolicy(), requests.toArray(new BulkItemRequest[0]) ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getInferenceFields().isEmpty() == false) { + bulkShardRequest.setInferenceFieldMap(indexMetadata.getInferenceFields()); + } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..8d1618b443ace 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -22,6 +23,7 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; +import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -33,6 +35,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest inferenceFieldMap = null; + public BulkShardRequest(StreamInput in) throws IOException { super(in); items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +48,30 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. + */ + public void setInferenceFieldMap(Map fieldInferenceMap) { + this.inferenceFieldMap = fieldInferenceMap; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public Map consumeInferenceFieldMap() { + Map ret = inferenceFieldMap; + inferenceFieldMap = null; + return ret; + } + + /** + * Public for test + */ + public Map getInferenceFieldMap() { + return inferenceFieldMap; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 529814e83ba38..3a852f20a761e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_INFERENCE_FIELDS = "field_inference"; + public static final String KEY_FIELD_INFERENCE = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -2437,7 +2437,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build } if (indexMetadata.getInferenceFields().isEmpty() == false) { - builder.startObject(KEY_INFERENCE_FIELDS); + builder.startObject(KEY_FIELD_INFERENCE); for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { field.toXContent(builder, params); } @@ -2521,7 +2521,7 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { private static final String INFERENCE_ID_FIELD = "inference_id"; private static final String SOURCE_FIELDS_FIELD = "source_fields"; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index 1fda9ababfabd..7357f6f4bdfc6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -696,6 +696,10 @@ private static void failIfMatchesRoutingPath(DocumentParserContext context, Stri */ private static void parseCopyFields(DocumentParserContext context, List copyToFields) throws IOException { for (String field : copyToFields) { + if (context.mappingLookup().getMapper(field) instanceof InferenceFieldMapper) { + // ignore copy_to that targets inference fields, values are already extracted in the coordinating node to perform inference. + continue; + } // In case of a hierarchy of nested documents, we need to figure out // which document the field should go to LuceneDocument targetDoc = null; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index fe9bdd73cfa10..5eddfb7d91df2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1199,7 +1199,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1211,7 +1211,7 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java deleted file mode 100644 index 490d7f36219cf..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -/** - * Field type that uses an inference model. - */ -public interface InferenceModelFieldType { - /** - * Retrieve inference model used by the field type. - * - * @return model id used by the field type - */ - String getInferenceModel(); -} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java index 1e3f69baf86dd..48e04a938d2b2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java @@ -55,7 +55,7 @@ public static MapperMergeContext from(MapperBuilderContext mapperBuilderContext, * @param name the name of the child context * @return a new {@link MapperMergeContext} with this context as its parent */ - MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { + public MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { return createChildContext(mapperBuilderContext.createChildContext(name, dynamic)); } @@ -69,7 +69,7 @@ MapperMergeContext createChildContext(MapperBuilderContext childContext) { return new MapperMergeContext(childContext, newFieldsBudget); } - MapperBuilderContext getMapperBuilderContext() { + public MapperBuilderContext getMapperBuilderContext() { return mapperBuilderContext; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index db958dc8a8acb..3bb82bea58acf 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -230,6 +230,16 @@ protected Parameter[] getParameters() { return new Parameter[] { elementType, dims, indexed, similarity, indexOptions, meta }; } + public Builder similarity(VectorSimilarity vectorSimilarity) { + similarity.setValue(vectorSimilarity); + return this; + } + + public Builder dimensions(int dimensions) { + this.dims.setValue(dimensions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { return new DenseVectorFieldMapper( @@ -754,7 +764,7 @@ public static ElementType fromString(String name) { ElementType.FLOAT ); - enum VectorSimilarity { + public enum VectorSimilarity { L2_NORM { @Override float score(float similarity, ElementType elementType, int dim) { @@ -1122,6 +1132,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 62166115820f5..14cfeacf76139 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -35,6 +35,8 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragm /** * Convert the result to a map to aid with test assertions + * + * @return a map */ Map asMap(); } diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java deleted file mode 100644 index 78773bfb72a95..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -/** - * Model settings that are interesting for semantic_text inference fields. This class is used to serialize common - * ServiceSettings methods when building inference for semantic_text fields. - * - * @param taskType task type - * @param inferenceId inference id - * @param dimensions number of dimensions. May be null if not applicable - * @param similarity similarity used by the service. May be null if not applicable - */ -public record SemanticTextModelSettings( - TaskType taskType, - String inferenceId, - @Nullable Integer dimensions, - @Nullable SimilarityMeasure similarity -) { - - public static final String NAME = "model_settings"; - private static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - private static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); - private static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - private static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); - - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { - Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); - this.taskType = taskType; - this.inferenceId = inferenceId; - this.dimensions = dimensions; - this.similarity = similarity; - } - - public SemanticTextModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); - } - - public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { - return PARSER.apply(parser, null); - } - - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[2]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); - }); - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); - } - - public Map asMap() { - Map attrsMap = new HashMap<>(); - attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); - if (dimensions != null) { - attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return Map.of(NAME, attrsMap); - } -} diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index bd4c87be51157..958d86535ae76 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -63,10 +63,4 @@ private static InferenceFieldMetadata createTestItem() { String[] inputFields = generateRandomStringArray(5, 10, false, false); return new InferenceFieldMetadata(name, inferenceId, inputFields); } - - public void testNullCtorArgsThrowException() { - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null)); - } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java index 5eacfe6f2e3ab..33341e6b36987 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; @@ -106,6 +107,12 @@ public void testCopyToFieldsParsing() throws Exception { fieldMapper = mapperService.documentMapper().mappers().getMapper("new_field"); assertThat(fieldMapper.typeName(), equalTo("long")); + + MappingLookup mappingLookup = mapperService.mappingLookup(); + assertThat(mappingLookup.sourcePaths("another_field"), equalTo(Set.of("copy_test", "int_to_str_test", "another_field"))); + assertThat(mappingLookup.sourcePaths("new_field"), equalTo(Set.of("new_field", "int_to_str_test"))); + assertThat(mappingLookup.sourcePaths("copy_test"), equalTo(Set.of("copy_test", "cyclic_test"))); + assertThat(mappingLookup.sourcePaths("cyclic_test"), equalTo(Set.of("cyclic_test", "copy_test"))); } public void testCopyToFieldsInnerObjectParsing() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java index d7df41131414e..6446033c07c5b 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java @@ -224,6 +224,9 @@ public void testSourcePathFields() throws IOException { final Set fieldsUsingSourcePath = new HashSet<>(); ((FieldMapper) mapper).sourcePathUsedBy().forEachRemaining(mapper1 -> fieldsUsingSourcePath.add(mapper1.name())); assertThat(fieldsUsingSourcePath, equalTo(Set.of("field.subfield1", "field.subfield2"))); + + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield1"), equalTo(Set.of("field"))); + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield2"), equalTo(Set.of("field"))); } public void testUnknownLegacyFieldsUnderKnownRootField() throws Exception { diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index e2b03c6b81af3..d5cc1b137c456 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -652,7 +652,7 @@ public static MetadataRolloverService getMetadataRolloverService( AllocationService allocationService = mock(AllocationService.class); when(allocationService.reroute(any(ClusterState.class), any(String.class), any())).then(i -> i.getArguments()[0]); when(allocationService.getShardRoutingRoleStrategy()).thenReturn(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY); - MappingLookup mappingLookup = null; + MappingLookup mappingLookup = MappingLookup.EMPTY; if (dataStream != null) { RootObjectMapper.Builder root = new RootObjectMapper.Builder("_doc", ObjectMapper.Defaults.SUBOBJECTS); root.add( @@ -729,8 +729,8 @@ public static IndicesService mockIndicesServices(MappingLookup mappingLookup) th Mapping mapping = new Mapping(root, new MetadataFieldMapper[0], null); DocumentMapper documentMapper = mock(DocumentMapper.class); when(documentMapper.mapping()).thenReturn(mapping); - when(documentMapper.mappers()).thenReturn(MappingLookup.EMPTY); when(documentMapper.mappingSource()).thenReturn(mapping.toCompressedXContent()); + when(documentMapper.mappers()).thenReturn(mappingLookup); RoutingFieldMapper routingFieldMapper = mock(RoutingFieldMapper.class); when(routingFieldMapper.required()).thenReturn(false); when(documentMapper.routingFieldMapper()).thenReturn(routingFieldMapper); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index fa0f0e1b95f54..34ccc4599811b 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -1030,7 +1030,7 @@ public final void testMinimalIsInvalidInRoutingPath() throws IOException { } } - private String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { if (mapper instanceof FieldMapper fieldMapper && fieldMapper.fieldType().isDimension() == false) { return "All fields that match routing_path must be configured with [time_series_dimension: true] " + "or flattened fields with a list of dimensions in [time_series_dimensions] and " diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index e4f4de0027073..781261c330e78 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -6,6 +6,13 @@ */ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' + +restResources { + restApi { + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' + } +} esplugin { name 'x-pack-inference' @@ -24,4 +31,9 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(':modules:reindex') + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') +} + +tasks.named('yamlRestTest') { + usesDefaultDistribution() } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 99dfc9582eb05..a65b8e43e6adf 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -101,11 +101,6 @@ public TestServiceModel( super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings)); } - @Override - public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() { - return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings(); - } - @Override public TestTaskSettings getTaskSettings() { return (TestTaskSettings) super.getTaskSettings(); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index c53ed82b9fe50..39d8c9c9c3c7f 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -169,7 +169,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 30977c23ef5aa..dc7bfe14bada1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -118,7 +118,7 @@ private SparseEmbeddingResults makeResults(List input) { for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new SparseEmbeddingResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F)); } embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); } @@ -126,15 +126,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j)); + tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 2d25a48117778..ddd56c758d67c 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -17,6 +17,7 @@ requires org.apache.httpcomponents.httpasyncclient; requires org.apache.httpcomponents.httpcore.nio; requires org.apache.lucene.core; + requires org.elasticsearch.logging; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c707f99e7eb65..666e7a3bd2043 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -21,11 +22,13 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -43,6 +46,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -50,6 +54,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -66,12 +71,15 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { +import static java.util.Collections.singletonList; + +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -96,6 +104,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -161,6 +170,9 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); + return List.of(modelRegistry, registry); } @@ -259,4 +271,20 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } + + @Override + public Map getMappers() { + if (SemanticTextFeature.isEnabled()) { + return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + } + return Map.of(); + } + + @Override + public Collection getActionFilters() { + if (SemanticTextFeature.isEnabled()) { + return singletonList(shardBulkInferenceActionFilter.get()); + } + return List.of(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java similarity index 93% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java index f861760803e56..4f2c5c564bcb8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml; +package org.elasticsearch.xpack.inference; import org.elasticsearch.common.util.FeatureFlag; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..92749ebaf9d4f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,530 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; + +/** + * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified + * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in + * the request source, we generate embeddings and include the results in the source under the semantic text field + * name as a {@link SemanticTextField}. + * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the + * results during indexing on the shard. + * + * TODO: batchSize should be configurable via a cluster setting + */ +public class ShardBulkInferenceActionFilter implements MappedActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + protected static final int DEFAULT_BATCH_SIZE = 512; + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + private final int batchSize; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + } + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + this.batchSize = batchSize; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public String actionName() { + return TransportShardBulkAction.ACTION_NAME; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + Map fieldInferenceMap, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + /** + * A field inference request on a single input. + * @param id The id of the request in the original bulk request. + * @param field The target field. + * @param input The input to run inference on. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + */ + private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isOriginalFieldInput) {} + + /** + * The field inference response. + * @param field The target field. + * @param input The input that was used to run inference. + * @param inputOrder The original order of the input. + * @param isOriginalFieldInput Whether the input is part of the original values of the field. + * @param model The model used to run inference. + * @param chunkedResults The actual results. + */ + private record FieldInferenceResponse( + String field, + String input, + int inputOrder, + boolean isOriginalFieldInput, + Model model, + ChunkedInferenceServiceResults chunkedResults + ) {} + + private record FieldInferenceResponseAccumulator( + int id, + Map> responses, + List failures + ) { + void addOrUpdateResponse(FieldInferenceResponse response) { + synchronized (this) { + var list = responses.computeIfAbsent(response.field, k -> new ArrayList<>()); + list.add(response); + } + } + + void addFailure(Exception exc) { + synchronized (this) { + failures.add(exc); + } + } + } + + private class AsyncBulkShardInferenceAction implements Runnable { + private final Map fieldInferenceMap; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + Map fieldInferenceMap, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMap = fieldInferenceMap; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException( + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + int currentBatchSize = Math.min(requests.size(), batchSize); + final List currentBatch = requests.subList(0, currentBatchSize); + final List nextBatch = requests.subList(currentBatchSize, requests.size()); + final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + try { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + if (result instanceof ErrorChunkedInferenceResults error) { + acc.addFailure( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + error.getException(), + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } else { + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.input(), + request.inputOrder(), + request.isOriginalFieldInput(), + inferenceProvider.model, + result + ) + ); + } + } + } finally { + onFinish(); + } + } + + @Override + public void onFailure(Exception exc) { + try { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + addInferenceResponseFailure( + request.id, + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } finally { + onFinish(); + } + } + + private void onFinish() { + if (nextBatch.isEmpty()) { + onFinish.close(); + } else { + executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + null, + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + completionListener + ); + } + + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + FieldInferenceResponseAccumulator acc = inferenceResults.get(id); + if (acc == null) { + acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); + inferenceResults.set(id, acc); + } + return acc; + } + + private void addInferenceResponseFailure(int id, Exception failure) { + var acc = ensureResponseAccumulatorSlot(id); + acc.addFailure(failure); + } + + /** + * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is marked as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results under the + * {@link SemanticTextField#INFERENCE_FIELD} field. + */ + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } + return; + } + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : response.responses.entrySet()) { + var fieldName = entry.getKey(); + var responses = entry.getValue(); + var model = responses.get(0).model(); + // ensure that the order in the original field is consistent in case of multiple inputs + Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); + List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + var result = new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), results, indexRequest.getContentType()) + ), + indexRequest.getContentType() + ); + newDocMap.put(fieldName, result); + } + indexRequest.source(newDocMap, indexRequest.getContentType()); + } + + /** + * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. + * If results are already populated for fields in the original index request, the inference request for this specific + * field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, + * where an error will be thrown if they mismatch or if the content is malformed. + *

+ * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? + */ + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security) + continue; + } + boolean isUpdateRequest = false; + final IndexRequest indexRequest; + if (item.request() instanceof IndexRequest ir) { + indexRequest = ir; + } else if (item.request() instanceof UpdateRequest updateRequest) { + isUpdateRequest = true; + if (updateRequest.script() != null) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Cannot apply update with a script on indices that contain [{}] field(s)", + RestStatus.BAD_REQUEST, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + continue; + } + indexRequest = updateRequest.doc(); + } else { + // ignore delete request + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + String inferenceId = entry.getInferenceId(); + var originalFieldValue = XContentMapValues.extractValue(field, docMap); + if (originalFieldValue instanceof Map) { + continue; + } + int order = 0; + for (var sourceField : entry.getSourceFields()) { + boolean isOriginalFieldInput = sourceField.equals(field); + var valueObj = XContentMapValues.extractValue(sourceField, docMap); + if (valueObj == null) { + if (isUpdateRequest) { + addInferenceResponseFailure( + item.id(), + new ElasticsearchStatusException( + "Field [{}] must be specified on an update request to calculate inference for field [{}]", + RestStatus.BAD_REQUEST, + sourceField, + field + ) + ); + break; + } + continue; + } + ensureResponseAccumulatorSlot(item.id()); + final List values; + try { + values = nodeStringValues(field, valueObj); + } catch (Exception exc) { + addInferenceResponseFailure(item.id(), exc); + break; + } + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + for (var v : values) { + fieldRequests.add(new FieldInferenceRequest(item.id(), field, v, order++, isOriginalFieldInput)); + } + } + } + } + return fieldRequestsMap; + } + } + + /** + * This method converts the given {@code valueObj} into a list of strings. + * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. + */ + private static List nodeStringValues(String field, Object valueObj) { + if (valueObj instanceof String value) { + return List.of(value); + } else if (valueObj instanceof Collection values) { + List valuesString = new ArrayList<>(); + for (var v : values) { + if (v instanceof String value) { + valuesString.add(value); + } else { + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + } + return valuesString; + } + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java new file mode 100644 index 0000000000000..f0267d144b7b8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. + * The resulting object preserves the original input under the {@link SemanticTextField#TEXT_FIELD} and exposes + * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. + * + * @param fieldName The original field name. + * @param originalValues The original values associated with the field name. + * @param inference The inference result. + * @param contentType The {@link XContentType} used to store the embeddings chunks. + */ +public record SemanticTextField(String fieldName, List originalValues, InferenceResult inference, XContentType contentType) + implements + ToXContentObject { + + static final ParseField TEXT_FIELD = new ParseField("text"); + static final ParseField INFERENCE_FIELD = new ParseField("inference"); + static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + static final ParseField CHUNKS_FIELD = new ParseField("chunks"); + static final ParseField CHUNKED_EMBEDDINGS_FIELD = new ParseField("embeddings"); + static final ParseField CHUNKED_TEXT_FIELD = new ParseField("text"); + static final ParseField MODEL_SETTINGS_FIELD = new ParseField("model_settings"); + static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + + public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} + + public record Chunk(String text, BytesReference rawEmbeddings) {} + + public record ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) implements ToXContentObject { + public ModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { + this.taskType = Objects.requireNonNull(taskType, "task type must not be null"); + this.dimensions = dimensions; + this.similarity = similarity; + validate(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + private void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + } + + public static String getOriginalTextFieldName(String fieldName) { + return fieldName + "." + TEXT_FIELD.getPreferredName(); + } + + public static String getInferenceFieldName(String fieldName) { + return fieldName + "." + INFERENCE_FIELD.getPreferredName(); + } + + public static String getChunksFieldName(String fieldName) { + return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD.getPreferredName(); + } + + public static String getEmbeddingsFieldName(String fieldName) { + return getChunksFieldName(fieldName) + "." + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(); + } + + static SemanticTextField parse(XContentParser parser, Tuple context) throws IOException { + return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); + } + + static ModelSettings parseModelSettings(XContentParser parser) throws IOException { + return MODEL_SETTINGS_PARSER.parse(parser, null); + } + + static ModelSettings parseModelSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, MODEL_SETTINGS_FIELD.getPreferredName()); + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return parseModelSettings(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (originalValues.isEmpty() == false) { + builder.field(TEXT_FIELD.getPreferredName(), originalValues.size() == 1 ? originalValues.get(0) : originalValues); + } + builder.startObject(INFERENCE_FIELD.getPreferredName()); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inference.inferenceId); + builder.field(MODEL_SETTINGS_FIELD.getPreferredName(), inference.modelSettings); + builder.startArray(CHUNKS_FIELD.getPreferredName()); + for (var chunk : inference.chunks) { + builder.startObject(); + builder.field(CHUNKED_TEXT_FIELD.getPreferredName(), chunk.text); + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings, + contentType + ); + builder.field(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()).copyCurrentStructure(parser); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + builder.endObject(); + return builder; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser> SEMANTIC_TEXT_FIELD_PARSER = + new ConstructingObjectParser<>( + SemanticTextFieldMapper.CONTENT_TYPE, + true, + (args, context) -> new SemanticTextField( + context.v1(), + (List) (args[0] == null ? List.of() : args[0]), + (InferenceResult) args[1], + context.v2() + ) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( + INFERENCE_FIELD.getPreferredName(), + true, + args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( + CHUNKS_FIELD.getPreferredName(), + true, + args -> new Chunk((String) args[0], (BytesReference) args[1]) + ); + + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( + MODEL_SETTINGS_FIELD.getPreferredName(), + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new ModelSettings(taskType, dimensions, similarity); + } + ); + + static { + SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), TEXT_FIELD); + SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), INFERENCE_FIELD); + + INFERENCE_RESULT_PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD); + INFERENCE_RESULT_PARSER.declareObject(constructorArg(), (p, c) -> MODEL_SETTINGS_PARSER.parse(p, c), MODEL_SETTINGS_FIELD); + INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), (p, c) -> CHUNKS_PARSER.parse(p, c), CHUNKS_FIELD); + + CHUNKS_PARSER.declareString(constructorArg(), CHUNKED_TEXT_FIELD); + CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { + XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); + b.copyCurrentStructure(p); + return BytesReference.bytes(b); + }, CHUNKED_EMBEDDINGS_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); + + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); + MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); + } + + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + */ + public static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java new file mode 100644 index 0000000000000..2536825a9e0b7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -0,0 +1,389 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Explicit; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; + +/** + * A {@link FieldMapper} for semantic text fields. + */ +public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + public static final String CONTENT_TYPE = "semantic_text"; + + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); + + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); + + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; + + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); + + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings, + XContentBuilder::field, + (m) -> m == null ? "null" : Strings.toString(m) + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + + private final Parameter> meta = Parameter.metaParam(); + + private Function inferenceFieldBuilder; + + public Builder(String name, IndexVersion indexVersionCreated) { + super(name); + this.indexVersionCreated = indexVersionCreated; + this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get()); + } + + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextField.ModelSettings value) { + this.modelSettings.setValue(value); + return this; + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { inferenceId, modelSettings, meta }; + } + + @Override + protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { + super.merge(mergeWith, conflicts, mapperMergeContext); + conflicts.check(); + var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE); + var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); + var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE); + var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), childContext); + inferenceFieldBuilder = c -> mergedInferenceField; + } + + @Override + public SemanticTextFieldMapper build(MapperBuilderContext context) { + final String fullName = context.buildFullName(name()); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType( + fullName, + inferenceId.getValue(), + modelSettings.getValue(), + inferenceField, + indexVersionCreated, + meta.getValue() + ), + copyTo + ); + } + } + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(fieldType().getInferenceField()); + return subIterators.iterator(); + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return; + } + XContentLocation xContentLocation = parser.getTokenLocation(); + final SemanticTextField field; + boolean isWithinLeaf = context.path().isWithinLeafObject(); + try { + context.path().setWithinLeafObject(true); + field = SemanticTextField.parse(parser, new Tuple<>(name(), context.parser().contentType())); + } finally { + context.path().setWithinLeafObject(isWithinLeaf); + } + final String fullFieldName = fieldType().name(); + if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID_FIELD.getPreferredName(), + field.inference().inferenceId(), + fullFieldName, + INFERENCE_ID_FIELD.getPreferredName(), + fieldType().getInferenceId() + ) + ); + } + final SemanticTextFieldMapper mapper; + if (fieldType().getModelSettings() == null) { + context.path().remove(); + Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + try { + mapper = builder.setModelSettings(field.inference().modelSettings()) + .setInferenceId(field.inference().inferenceId()) + .build(context.createDynamicMapperBuilderContext()); + context.addDynamicMapper(mapper); + } finally { + context.path().add(simpleName()); + } + } else { + Conflicts conflicts = new Conflicts(fullFieldName); + canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); + } + mapper = this; + } + var chunksField = mapper.fieldType().getChunksField(); + var embeddingsField = mapper.fieldType().getEmbeddingsField(); + for (var chunk : field.inference().chunks()) { + XContentParser subParser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings(), + context.parser().contentType() + ); + DocumentParserContext subContext = context.createNestedContext(chunksField).switchParser(subParser); + subParser.nextToken(); + embeddingsField.parse(subContext); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } + + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + String[] copyFields = sourcePaths.toArray(String[]::new); + // ensure consistent order + Arrays.sort(copyFields); + return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields); + } + + public static class SemanticTextFieldType extends SimpleMappedFieldType { + private final String inferenceId; + private final SemanticTextField.ModelSettings modelSettings; + private final ObjectMapper inferenceField; + private final IndexVersion indexVersionCreated; + + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextField.ModelSettings modelSettings, + ObjectMapper inferenceField, + IndexVersion indexVersionCreated, + Map meta + ) { + super(name, false, false, false, TextSearchInfo.NONE, meta); + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.inferenceField = inferenceField; + this.indexVersionCreated = indexVersionCreated; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + public String getInferenceId() { + return inferenceId; + } + + public SemanticTextField.ModelSettings getModelSettings() { + return modelSettings; + } + + public ObjectMapper getInferenceField() { + return inferenceField; + } + + public NestedObjectMapper getChunksField() { + return (NestedObjectMapper) inferenceField.getMapper(CHUNKS_FIELD.getPreferredName()); + } + + public FieldMapper getEmbeddingsField() { + return (FieldMapper) getChunksField().getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException(CONTENT_TYPE + " fields do not support term query"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + // Redirect the fetcher to load the original values of the field + return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { + throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); + } + } + + private static ObjectMapper createInferenceField( + MapperBuilderContext context, + IndexVersion indexVersionCreated, + @Nullable SemanticTextField.ModelSettings modelSettings + ) { + return new ObjectMapper.Builder(INFERENCE_FIELD.getPreferredName(), Explicit.EXPLICIT_TRUE).dynamic(ObjectMapper.Dynamic.FALSE) + .add(createChunksField(indexVersionCreated, modelSettings)) + .build(context); + } + + private static NestedObjectMapper.Builder createChunksField( + IndexVersion indexVersionCreated, + SemanticTextField.ModelSettings modelSettings + ) { + NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD.getPreferredName(), indexVersionCreated); + chunksField.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder( + CHUNKED_TEXT_FIELD.getPreferredName(), + indexVersionCreated + ).indexed(false).docValues(false); + if (modelSettings != null) { + chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings)); + } + chunksField.add(chunkTextField); + return chunksField; + } + + private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(), + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure in model_settings [" + similarity.name() + "]" + ); + } + } + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); + }; + } + + private static boolean canMergeModelSettings( + SemanticTextField.ModelSettings previous, + SemanticTextField.ModelSettings current, + Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + if (current == null) { + conflicts.addConflict("model_settings", ""); + return false; + } + conflicts.addConflict("model_settings", ""); + return false; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java new file mode 100644 index 0000000000000..1c4a2f561ad4a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingClusterStateUpdateRequest; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.hamcrest.Matchers; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return List.of(InferencePlugin.class); + } + + public void testCreateIndexWithSemanticTextField() { + final IndexService indexService = createIndex( + "test", + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") + ); + assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); + } + + public void testSingleSourceSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); + } + + public void testCopyToSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "semantic": { + "type": "semantic_text", + "inference_id": "test_model" + }, + "copy_origin_1": { + "type": "text", + "copy_to": "semantic" + }, + "copy_origin_2": { + "type": "text", + "copy_to": "semantic" + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get("semantic"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo("test_model")); + assertThat( + Arrays.asList(inferenceFieldMetadata.getSourceFields()), + Matchers.containsInAnyOrder("semantic", "copy_origin_1", "copy_origin_2") + ); + } + + private static List singleTask(PutMappingClusterStateUpdateRequest request) { + return Collections.singletonList(new MetadataMappingService.PutMappingClusterStateUpdateTask(request, ActionListener.running(() -> { + throw new AssertionError("task should not complete publication"); + }))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..f2f9a2229f8c3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,386 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; +import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getInferenceFieldMap()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) + ); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + "field2", + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + "field3", + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFieldMap)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testItemFailures() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + randomIntBetween(1, 10) + ); + model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(3)); + + // item 0 is a failure + assertNotNull(bulkShardRequest.items()[0].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[0].getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + + // item 1 is a success + assertNull(bulkShardRequest.items()[1].getPrimaryResponse()); + IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request()); + assertThat(XContentMapValues.extractValue("field1.text", actualRequest.sourceAsMap()), equalTo("I am a success")); + + // item 2 is a failure + assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse()); + assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed()); + failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure(); + assertThat(failure.getCause().getCause().getMessage(), containsString("boom")); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "field1", + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[3]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success")); + items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map inferenceFieldMap = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); + } + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFieldMap); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap, int batchSize) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[2]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[6]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + Map fieldInferenceMap + ) { + Map docMap = new LinkedHashMap<>(); + Map expectedDocMap = new LinkedHashMap<>(); + XContentType requestContentType = randomFrom(XContentType.values()); + for (var entry : fieldInferenceMap.values()) { + String field = entry.getName(); + var model = modelMap.get(entry.getInferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + expectedDocMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + var result = randomSemanticText(field, model, List.of(text), requestContentType); + model.putResult(text, toChunkedResult(result)); + expectedDocMap.put(field, result); + } + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults result) { + resultMap.put(text, result); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java new file mode 100644 index 0000000000000..a6f0fa83eab37 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -0,0 +1,520 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.LuceneDocument; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperParsingException; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedLookup; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.LeafNestedDocuments; +import org.elasticsearch.search.NestedDocuments; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.junit.AssumptionViolatedException; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomModel; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class SemanticTextFieldMapperTests extends MapperTestCase { + @Override + protected Collection getPlugins() { + return singletonList(new InferencePlugin(Settings.EMPTY)); + } + + @Override + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "semantic_text").field("inference_id", "test_model"); + } + + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; + } + + @Override + protected Object getSampleValueForDocument() { + return null; + } + + @Override + protected boolean supportsIgnoreMalformed() { + return false; + } + + @Override + protected boolean supportsStoredFields() { + return false; + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException {} + + @Override + protected Object generateRandomInputValue(MappedFieldType ft) { + assumeFalse("doc_values are not supported in semantic_text", true); + return null; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { + throw new AssumptionViolatedException("not supported"); + } + + @Override + protected IngestScriptSupport ingestScriptSupport() { + throw new AssumptionViolatedException("not supported"); + } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.field("inference_id", "my_inference_id"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Required [task_type]")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .field("similarity", "cosine") + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\"}] " + + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + + NestedObjectMapper chunksMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(getChunksFieldName(fieldName)); + assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); + Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD.getPreferredName()); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); + Mapper inferenceMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + assertNotNull(inferenceMapper); + switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.fieldType().getModelSettings()); + } + } + + public void testSuccessfulParse() throws IOException { + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticText(fieldName1, model1, List.of("a b", "c"), XContentType.JSON), + randomSemanticText(fieldName2, model2, List.of("d e f"), XContentType.JSON) + ) + ) + ) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + } + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); + assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); + assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) + ); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } + } + + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field( + MODEL_SETTINGS_FIELD.getPreferredName(), + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null) + ) + .field(CHUNKS_FIELD.getPreferredName(), List.of()) + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .startObject(MODEL_SETTINGS_FIELD.getPreferredName()) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("failed to parse field [model_settings]")); + } + + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + mappingBuilder.startObject(fieldName); + mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mappingBuilder.field("inference_id", modelId); + mappingBuilder.endObject(); + } + + private static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List semanticTextInferenceResults) + throws IOException { + for (var field : semanticTextInferenceResults) { + sourceBuilder.field(field.fieldName()); + sourceBuilder.value(field); + } + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); + } + + private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String fieldName, List tokens) { + NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(getChunksFieldName(fieldName)); + assertNotNull(mapper); + + BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (String token : tokens) { + queryBuilder.add( + new BooleanClause(new TermQuery(new Term(getEmbeddingsFieldName(fieldName), token)), BooleanClause.Occur.MUST) + ); + } + queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); + + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + } + + private static void assertChildLeafNestedDocument( + LeafNestedDocuments leaf, + int advanceToDoc, + int expectedRootDoc, + Set visitedNestedIdentities + ) throws IOException { + + assertNotNull(leaf.advance(advanceToDoc)); + assertEquals(advanceToDoc, leaf.doc()); + assertEquals(expectedRootDoc, leaf.rootDoc()); + assertNotNull(leaf.nestedIdentity()); + visitedNestedIdentities.add(leaf.nestedIdentity()); + } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java new file mode 100644 index 0000000000000..c5ab6eb1f9e15 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.model.TestModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextFieldTests extends AbstractXContentTestCase { + private static final String NAME = "field"; + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + } + + @Override + protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { + assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); + assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); + assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); + assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); + for (int i = 0; i < newInstance.inference().chunks().size(); i++) { + assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); + switch (modelSettings.taskType()) { + case TEXT_EMBEDDING -> { + double[] expectedVector = parseDenseVector( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + expectedInstance.contentType() + ); + double[] newVector = parseDenseVector( + newInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + newInstance.contentType() + ); + assertArrayEquals(expectedVector, newVector, 0f); + } + case SPARSE_EMBEDDING -> { + List expectedTokens = parseWeightedTokens( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + expectedInstance.contentType() + ); + List newTokens = parseWeightedTokens( + newInstance.inference().chunks().get(i).rawEmbeddings(), + newInstance.contentType() + ); + assertThat(newTokens, equalTo(expectedTokens)); + } + default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); + } + } + } + + @Override + protected SemanticTextField createTestInstance() { + List rawValues = randomList(1, 5, () -> randomAlphaOfLengthBetween(10, 20)); + return randomSemanticText( + NAME, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + rawValues, + randomFrom(XContentType.values()) + ); + } + + @Override + protected SemanticTextField doParseInstance(XContentParser parser) throws IOException { + return SemanticTextField.parse(parser, new Tuple<>(NAME, parser.contentType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[model.getServiceSettings().dimensions()]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } + + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + + public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) { + ChunkedInferenceServiceResults results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, inputs); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); + }; + return new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), List.of(results), contentType) + ), + contentType + ); + } + + public static Model randomModel(TaskType taskType) { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + taskType, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) { + switch (field.inference().modelSettings().taskType()) { + case SPARSE_EMBEDDING -> { + List chunks = new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + case TEXT_EMBEDDING -> { + List chunks = + new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + double[] values = parseDenseVector( + chunk.rawEmbeddings(), + field.inference().modelSettings().dimensions(), + field.contentType() + ); + chunks.add( + new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk( + chunk.text(), + values + ) + ); + } + return new ChunkedTextEmbeddingResults(chunks); + } + default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); + } + } + + private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + parser.nextToken(); + assertThat(parser.currentToken(), equalTo(XContentParser.Token.START_ARRAY)); + double[] values = new double[numDims]; + for (int i = 0; i < numDims; i++) { + assertThat(parser.nextToken(), equalTo(XContentParser.Token.VALUE_NUMBER)); + values[i] = parser.doubleValue(); + } + assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_ARRAY)); + return values; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static List parseWeightedTokens(BytesReference value, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + Map map = parser.map(); + List weightedTokens = new ArrayList<>(); + for (var entry : map.entrySet()) { + weightedTokens.add(new TextExpansionResults.WeightedToken(entry.getKey(), ((Number) entry.getValue()).floatValue())); + } + return weightedTokens; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..b64485a3d3fb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -121,6 +122,16 @@ public void writeTo(StreamOutput out) throws IOException { public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.COSINE; + } + + @Override + public Integer dimensions() { + return 100; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java new file mode 100644 index 0000000000000..a397d9864d23d --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +public class InferenceRestIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") + .plugin("inference-service-test") + .distribution(DistributionType.DEFAULT) + .build(); + + public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(); + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml new file mode 100644 index 0000000000000..e567b2103e527 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -0,0 +1,444 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + another_inference_field: + type: semantic_text + inference_id: sparse-inference-id + non_inference_field: + type: text + + - do: + indices.create: + index: test-dense-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + another_inference_field: + type: semantic_text + inference_id: dense-inference-id + non_inference_field: + type: text + +--- +"Calculates text expansion results for new documents": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + +--- +"text expansion documents do not create new mappings": + - do: + indices.get_mapping: + index: test-sparse-index + + - match: {test-sparse-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-sparse-index.mappings.properties.non_inference_field.type: text} + - length: {test-sparse-index.mappings.properties: 3} + +--- +"Calculates text embeddings results for new documents": + - do: + index: + index: test-dense-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-dense-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.non_inference_field: "non inference test" } + + +--- +"text embeddings documents do not create new mappings": + - do: + indices.get_mapping: + index: test-dense-index + + - match: {test-dense-index.mappings.properties.inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.another_inference_field.type: semantic_text} + - match: {test-dense-index.mappings.properties.non_inference_field.type: text} + - length: {test-dense-index.mappings.properties: 3} + +--- +"Updating non semantic_text fields does not recalculate embeddings": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } + + - do: + update: + index: test-sparse-index + id: doc_1 + body: + doc: + non_inference_field: "another non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } + - match: { _source.non_inference_field: "another non inference test" } + +--- +"Updating semantic_text fields recalculates embeddings": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "I am a test" } + - match: { _source.inference_field.inference.chunks.0.text: "I am a test" } + - match: { _source.another_inference_field.text: "I am a teapot" } + - match: { _source.another_inference_field.inference.chunks.0.text: "I am a teapot" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + update: + index: test-sparse-index + id: doc_1 + body: + doc: + inference_field: "updated inference test" + another_inference_field: "another updated inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "updated inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "updated inference test" } + - match: { _source.another_inference_field.text: "another updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "bulk inference test", "another_inference_field": "bulk updated inference test"}}' + + - do: + get: + index: test-sparse-index + id: doc_1 + + - match: { _source.inference_field.text: "bulk inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "bulk inference test" } + - match: { _source.another_inference_field.text: "bulk updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "bulk updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + +--- +"Reindex works for semantic_text fields": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-sparse-index + id: doc_1 + + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } + + - do: + indices.refresh: { } + + - do: + indices.create: + index: destination-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + another_inference_field: + type: semantic_text + inference_id: sparse-inference-id + non_inference_field: + type: text + + - do: + reindex: + wait_for_completion: true + body: + source: + index: test-sparse-index + dest: + index: destination-index + - do: + get: + index: destination-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } + - match: { _source.non_inference_field: "non inference test" } + +--- +"Fails for non-existent inference": + - do: + indices.create: + index: incorrect-test-sparse-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: non-existing-inference-id + non_inference_field: + type: text + + - do: + catch: missing + index: + index: incorrect-test-sparse-index + id: doc_1 + body: + inference_field: "inference test" + non_inference_field: "non inference test" + + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } + + # Succeeds when semantic_text field is not used + - do: + index: + index: incorrect-test-sparse-index + id: doc_1 + body: + non_inference_field: "non inference test" + +--- +"Updates with script are not allowed": + - do: + bulk: + index: test-sparse-index + body: + - '{"index": {"_id": "doc_1"}}' + - '{"doc":{"inference_field": "I am a test", "another_inference_field": "I am a teapot"}}' + + - do: + bulk: + index: test-sparse-index + body: + - '{"update": {"_id": "doc_1"}}' + - '{"script": "ctx._source.new_field = \"hello\"", "scripted_upsert": true}' + + - match: { errors: true } + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } + +--- +"semantic_text copy_to calculate inference for source fields": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "copy_to inference test" + inference_field: "inference test" + another_source_field: "another copy_to inference test" + + - do: + get: + index: test-copy-to-index + id: doc_1 + + - match: { _source.inference_field.text: "inference test" } + - length: { _source.inference_field.inference.chunks: 3 } + - match: { _source.inference_field.inference.chunks.0.text: "another copy_to inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.1.text: "inference test" } + - exists: _source.inference_field.inference.chunks.1.embeddings + - match: { _source.inference_field.inference.chunks.2.text: "copy_to inference test" } + - exists: _source.inference_field.inference.chunks.2.embeddings + + +--- +"semantic_text copy_to needs values for every source field for updates": + - do: + indices.create: + index: test-copy-to-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + source_field: + type: text + copy_to: inference_field + another_source_field: + type: text + copy_to: inference_field + + # Not every source field needed on creation + - do: + index: + index: test-copy-to-index + id: doc_1 + body: + source_field: "a single source field provided" + inference_field: "inference test" + + # Every source field needed on bulk updates + - do: + bulk: + body: + - '{"update": {"_index": "test-copy-to-index", "_id": "doc_1"}}' + - '{"doc": {"source_field": "a single source field is kept as provided via bulk", "inference_field": "updated inference test" }}' + + - match: { items.0.update.status: 400 } + - match: { items.0.update.error.reason: "Field [another_source_field] must be specified on an update request to calculate inference for field [inference_field]" } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml new file mode 100644 index 0000000000000..27f233436b925 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -0,0 +1,68 @@ +setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + + - do: + inference.put_model: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + - do: + inference.put_model: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64", + "similarity": "cosine" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + inference_id: sparse-inference-id + dense_field: + type: semantic_text + inference_id: dense-inference-id + non_inference_field: + type: text + +--- +"Sparse vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + sparse_field: "you know, for testing" + +--- +"Dense vector results format": + - do: + index: + index: test-index + id: doc_1 + body: + dense_field: "you know, for testing" diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java index 4d90d2a186858..d8c9dc2efd927 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/CoordinatedInferenceIngestIT.java @@ -59,10 +59,10 @@ public void testIngestWithMultipleModelTypes() throws IOException { assertThat(simulatedDocs, hasSize(2)); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0))); var sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1))); sparseEmbedding = (Map) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)); - assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1")); + assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1")); } { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5ef7311179e4f..7fa2bcca952bf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,7 +49,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -69,7 +68,6 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.IngestPlugin; -import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Platforms; import org.elasticsearch.plugins.Plugin; @@ -365,7 +363,6 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; -import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -485,8 +482,7 @@ public class MachineLearning extends Plugin PersistentTaskPlugin, SearchPlugin, ShutdownAwarePlugin, - ExtensiblePlugin, - MapperPlugin { + ExtensiblePlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; // Endpoints that were deprecated in 7.x can still be called in 8.x using the REST compatibility layer @@ -2308,12 +2304,4 @@ public void signalShutdown(Collection shutdownNodeIds) { mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds); } } - - @Override - public Map getMappers() { - if (SemanticTextFeature.isEnabled()) { - return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); - } - return Map.of(); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java deleted file mode 100644 index cf713546a071a..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.fielddata.FieldDataContext; -import org.elasticsearch.index.fielddata.IndexFieldData; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceModelFieldType; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.SimpleMappedFieldType; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.io.IOException; -import java.util.Map; - -/** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. - * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using a different field mapper. - */ -public class SemanticTextFieldMapper extends FieldMapper { - - public static final String CONTENT_TYPE = "semantic_text"; - - private static SemanticTextFieldMapper toType(FieldMapper in) { - return (SemanticTextFieldMapper) in; - } - - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); - - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { - super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); - } - - @Override - public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - // Just parses text - no indexing is performed - context.parser().textOrNull(); - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SemanticTextFieldType fieldType() { - return (SemanticTextFieldType) super.fieldType(); - } - - public static class Builder extends FieldMapper.Builder { - - private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) - .addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); - - private final Parameter> meta = Parameter.metaParam(); - - public Builder(String name) { - super(name); - } - - @Override - protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; - } - - @Override - public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); - } - } - - public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - - private final String modelId; - - public SemanticTextFieldType(String name, String modelId, Map meta) { - super(name, false, false, false, TextSearchInfo.NONE, meta); - this.modelId = modelId; - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public String getInferenceModel() { - return modelId; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented yet"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { - throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); - } - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java deleted file mode 100644 index ccb8f106e4945..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.mapper; - -import org.apache.lucene.index.IndexableField; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.MapperParsingException; -import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MapperTestCase; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.MachineLearning; -import org.junit.AssumptionViolatedException; - -import java.io.IOException; -import java.util.Collection; -import java.util.List; - -import static java.util.Collections.singletonList; -import static org.hamcrest.Matchers.containsString; - -public class SemanticTextFieldMapperTests extends MapperTestCase { - - public void testDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - public void testModelIdNotPresent() throws IOException { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) - ); - assertThat(e.getMessage(), containsString("field [model_id] must be specified")); - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToModelIdNotSupported() throws IOException { - MapperService mapperService = createMapperService( - fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) - ); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); - } - - @Override - protected Collection getPlugins() { - return singletonList(new MachineLearning(Settings.EMPTY)); - } - - @Override - protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("model_id", "test_model"); - } - - @Override - protected Object getSampleValueForDocument() { - return "value"; - } - - @Override - protected boolean supportsIgnoreMalformed() { - return false; - } - - @Override - protected boolean supportsStoredFields() { - return false; - } - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException {} - - @Override - protected Object generateRandomInputValue(MappedFieldType ft) { - assumeFalse("doc_values are not supported in semantic_text", true); - return null; - } - - @Override - protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { - throw new AssumptionViolatedException("not supported"); - } - - @Override - protected IngestScriptSupport ingestScriptSupport() { - throw new AssumptionViolatedException("not supported"); - } -} From 3bce501a5e9fa5ed729f61fd0ebfe73ab0b88cd8 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 15:21:30 +0200 Subject: [PATCH 24/29] Add tests pending from #107256 --- .../inference/20_semantic_text_field_mapper.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index df5073cfed525..0a8e7d7c5f4a6 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -103,8 +103,11 @@ setup: - match: { hits.total.value: 2 } - match: { hits.total.relation: eq } - match: { hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.total.value: 2 } - - exists: hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.0.fields.dense_field\.inference\.chunks.0.text - - exists: hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.1.fields.dense_field\.inference\.chunks.0.text + - match: { hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.0.fields.dense_field\.inference\.chunks.0.text: ["you know, for testing"] } + - match: { hits.hits.0.inner_hits.dense_field\.inference\.chunks.hits.hits.1.fields.dense_field\.inference\.chunks.0.text: ["now with chunks"] } + - match: { hits.hits.1.inner_hits.dense_field\.inference\.chunks.hits.total.value: 2 } + - match: { hits.hits.1.inner_hits.dense_field\.inference\.chunks.hits.hits.0.fields.dense_field\.inference\.chunks.0.text: ["some more tests"] } + - match: { hits.hits.1.inner_hits.dense_field\.inference\.chunks.hits.hits.1.fields.dense_field\.inference\.chunks.0.text: ["that include chunks"] } --- @@ -153,7 +156,10 @@ setup: - match: { hits.total.value: 2 } - match: { hits.total.relation: eq } - match: { hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.total.value: 2 } - - exists: hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.0.fields.sparse_field\.inference\.chunks.0.text - - exists: hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.1.fields.sparse_field\.inference\.chunks.0.text + - match: { hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.0.fields.sparse_field\.inference\.chunks.0.text: ["you know, for testing"] } + - match: { hits.hits.0.inner_hits.sparse_field\.inference\.chunks.hits.hits.1.fields.sparse_field\.inference\.chunks.0.text: ["now with chunks"] } + - match: { hits.hits.1.inner_hits.sparse_field\.inference\.chunks.hits.total.value: 2 } + - match: { hits.hits.1.inner_hits.sparse_field\.inference\.chunks.hits.hits.0.fields.sparse_field\.inference\.chunks.0.text: ["some more tests"] } + - match: { hits.hits.1.inner_hits.sparse_field\.inference\.chunks.hits.hits.1.fields.sparse_field\.inference\.chunks.0.text: ["that include chunks"] } From 0f57a5b65dea675662c7512fbcc721131a63b4c9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 17:26:48 +0200 Subject: [PATCH 25/29] Fix merge --- .../test/nodes.stats/11_indices_metrics.yml | 12 ++++-------- .../cluster/metadata/IndexMetadata.java | 6 +++--- .../cluster/metadata/InferenceFieldMetadata.java | 7 +++++++ .../index/mapper/vectors/DenseVectorFieldMapper.java | 4 ---- .../mapper/vectors/SparseVectorFieldMapper.java | 7 ++----- .../metadata/InferenceFieldMetadataTests.java | 6 ++++++ 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml index 146f0e5c62bc9..b119a1a1d94f3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/nodes.stats/11_indices_metrics.yml @@ -413,7 +413,7 @@ - match: { nodes.$node_id.indices.mappings.total_estimated_overhead_in_bytes: 0 } --- -"indices mappings count test for indices level": +"indices mappings exact count test for indices level": - skip: features: [arbitrary_key] @@ -468,7 +468,7 @@ - do: nodes.stats: { metric: _all, level: "indices", human: true } - # In the below assertions, we expect a field count of at least 26 because the above mapping expects the following: + # In the below assertions, we expect a field count of 26 because the above mapping expects the following: # Field mappers (incl. alias fields and object mappers' flattened leaves): # 1. _data_stream_timestamp # 2. _doc_count @@ -498,17 +498,13 @@ # 25. authors.name # Runtime field mappers: # 26. a_source_field - # - # Plugins (which may or may not be loaded depending on the context in which this test is executed) may add additional - # field mappers: - # 27. _semantic_text_inference (from ML plugin) - gte: { nodes.$node_id.indices.mappings.total_count: 26 } - is_true: nodes.$node_id.indices.mappings.total_estimated_overhead - gte: { nodes.$node_id.indices.mappings.total_estimated_overhead_in_bytes: 26624 } - - gte: { nodes.$node_id.indices.indices.index1.mappings.total_count: 26 } + - match: { nodes.$node_id.indices.indices.index1.mappings.total_count: 26 } - is_true: nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead - - gte: { nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead_in_bytes: 26624 } + - match: { nodes.$node_id.indices.indices.index1.mappings.total_estimated_overhead_in_bytes: 26624 } --- "indices mappings does not exist in shards level": diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 3a852f20a761e..529814e83ba38 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -540,7 +540,7 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; - public static final String KEY_FIELD_INFERENCE = "field_inference"; + public static final String KEY_INFERENCE_FIELDS = "field_inference"; public static final String INDEX_STATE_FILE_PREFIX = "state-"; @@ -2437,7 +2437,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build } if (indexMetadata.getInferenceFields().isEmpty() == false) { - builder.startObject(KEY_FIELD_INFERENCE); + builder.startObject(KEY_INFERENCE_FIELDS); for (InferenceFieldMetadata field : indexMetadata.getInferenceFields().values()) { field.toXContent(builder, params); } @@ -2521,7 +2521,7 @@ public static IndexMetadata fromXContent(XContentParser parser, Map, ToXContentFragment { private static final String INFERENCE_ID_FIELD = "inference_id"; private static final String SOURCE_FIELDS_FIELD = "source_fields"; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 3bb82bea58acf..3f83a8819b4c1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1132,10 +1132,6 @@ public String typeName() { return CONTENT_TYPE; } - public Integer getDims() { - return dims; - } - @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 58286d34dada1..6532abed19044 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,12 +171,9 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; - boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - if (context.path().isWithinLeafObject() == false) { - context.path().setWithinLeafObject(true); - } + context.path().setWithinLeafObject(true); for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -210,7 +207,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(origIsWithLeafObject); + context.path().setWithinLeafObject(false); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 958d86535ae76..bd4c87be51157 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -63,4 +63,10 @@ private static InferenceFieldMetadata createTestItem() { String[] inputFields = generateRandomStringArray(5, 10, false, false); return new InferenceFieldMetadata(name, inferenceId, inputFields); } + + public void testNullCtorArgsThrowException() { + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, new String[0])); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null)); + } } From 3e847f4bc90acb8e74f74adfe9f54b8d90986402 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 17:40:20 +0200 Subject: [PATCH 26/29] Fix merge --- .../org/elasticsearch/inference/InferenceServiceResults.java | 2 -- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 1 + .../xpack/inference/mapper/SemanticTextFieldMapperTests.java | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 14cfeacf76139..62166115820f5 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -35,8 +35,6 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragm /** * Convert the result to a map to aid with test assertions - * - * @return a map */ Map asMap(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 08d11f7bd41f2..c577de4fa5ee4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -62,6 +62,7 @@ /** * A {@link FieldMapper} for semantic text fields. + * */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { public static final String CONTENT_TYPE = "semantic_text"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a6f0fa83eab37..e26d8016bc541 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -71,6 +71,7 @@ import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { + @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); From c0960fac4aba2c4d8cb807d39aaa2a1444695bc8 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Apr 2024 17:40:29 +0200 Subject: [PATCH 27/29] Fix merge --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 1 - .../xpack/inference/mapper/SemanticTextFieldMapperTests.java | 1 - 2 files changed, 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index c577de4fa5ee4..08d11f7bd41f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -62,7 +62,6 @@ /** * A {@link FieldMapper} for semantic text fields. - * */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { public static final String CONTENT_TYPE = "semantic_text"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index e26d8016bc541..a6f0fa83eab37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -71,7 +71,6 @@ import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { - @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); From ff8365a2a078f10b4ab3d3498743fd588d5a60a9 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Wed, 10 Apr 2024 17:01:01 +0200 Subject: [PATCH 28/29] Update docs/changelog/107262.yaml --- docs/changelog/107262.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/107262.yaml diff --git a/docs/changelog/107262.yaml b/docs/changelog/107262.yaml new file mode 100644 index 0000000000000..ced725d3b7dd4 --- /dev/null +++ b/docs/changelog/107262.yaml @@ -0,0 +1,5 @@ +pr: 107262 +summary: Semantic_text field mapper and inference +area: Mapping +type: feature +issues: [] From 7dbb53b6f81868e9454229074452cbb347a3ba32 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Apr 2024 17:04:17 +0200 Subject: [PATCH 29/29] Update changelog --- docs/changelog/107262.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/107262.yaml b/docs/changelog/107262.yaml index ced725d3b7dd4..20bb2d12fae1d 100644 --- a/docs/changelog/107262.yaml +++ b/docs/changelog/107262.yaml @@ -1,5 +1,5 @@ pr: 107262 -summary: Semantic_text field mapper and inference +summary: semantic_text field mapper and inference generation area: Mapping type: feature issues: []