From 941f960b9f1cc3977e3df0878657b1032dfaff18 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 8 Dec 2023 18:10:26 -0500 Subject: [PATCH 001/106] Added fieldsForModels to IndexMetadata & MappingMetadata --- .../org/elasticsearch/TransportVersions.java | 1 + .../cluster/metadata/IndexMetadata.java | 83 +++++++++++++++++-- .../cluster/metadata/MappingMetadata.java | 21 ++++- 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ad29384b16f45..855ac7a2a400f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -192,6 +192,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_SERVICE_EMBEDDING_SIZE_ADDED = def(8_559_00_0); public static final TransportVersion ENRICH_ELASTICSEARCH_VERSION_REMOVED = def(8_560_00_0); public static final TransportVersion NODE_STATS_REQUEST_SIMPLIFIED = def(8_561_00_0); + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_562_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 742b52365c8d7..48f7f152d4cb2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -79,6 +79,7 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -546,6 +547,8 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; + public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String INDEX_STATE_FILE_PREFIX = "state-"; static final TransportVersion SYSTEM_INDEX_FLAG_ADDED = TransportVersions.V_7_10_0; @@ -635,6 +638,8 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; + // Key: model ID, Value: Fields that use model + private final Map> fieldsForModels; private IndexMetadata( final Index index, @@ -680,7 +685,8 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast + @Nullable Long shardSizeInBytesForecast, + final Map> fieldsForModels ) { this.index = index; this.version = version; @@ -736,6 +742,8 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; + this.fieldsForModels = fieldsForModels; + assert fieldsForModels != null; } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -786,7 +794,8 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -844,7 +853,8 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -900,7 +910,8 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -956,7 +967,8 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1008,7 +1020,8 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1212,6 +1225,10 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1476,6 +1493,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; + private final Diff>> fieldsForModels; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1512,6 +1530,12 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; + fieldsForModels = DiffableUtils.diff( + before.fieldsForModels, + after.fieldsForModels, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1571,6 +1595,15 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = DiffableUtils.readJdkMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } else { + fieldsForModels = DiffableUtils.emptyDiff(); + } } @Override @@ -1606,6 +1639,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels.writeTo(out); + } } @Override @@ -1797,6 +1833,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; + private Map> fieldsForModels = Map.of(); public Builder(String index) { this.index = index; @@ -1828,6 +1865,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; + this.fieldsForModels = indexMetadata.fieldsForModels; } public Builder index(String index) { @@ -1909,6 +1947,11 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; + Map> fieldsForModels = mappingMd.getFieldsForModels(); + // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? + if (fieldsForModels != null) { + this.fieldsForModels = fieldsForModels; + } return this; } @@ -2057,6 +2100,12 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder fieldsForModels(Map> fieldsForModels) { + // TODO: How to handle null value? Clear this.fieldsForModels? + this.fieldsForModels = fieldsForModels; + return this; + } + public IndexMetadata build() { return build(false); } @@ -2251,7 +2300,8 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast + shardSizeInBytesForecast, + fieldsForModels ); } @@ -2377,6 +2427,12 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } + // TODO: Need null check? + Map> fieldsForModels = indexMetadata.getFieldsForModels(); + if (fieldsForModels != null && !fieldsForModels.isEmpty()) { + builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); + } + builder.endObject(); } @@ -2454,6 +2510,19 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) + .entrySet() + .stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) + ) + ); + builder.fieldsForModels(fieldsForModels); + break; default: // assume it's custom index metadata builder.putCustom(currentFieldName, parser.mapStrings()); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index b629ab5d5f710..3db9ad884ceff 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -23,6 +23,7 @@ import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -42,10 +43,13 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; + private final Map> fieldsForModels; + public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -57,6 +61,7 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -72,6 +77,7 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -158,12 +164,19 @@ public String getSha256() { return source.getSha256(); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -176,19 +189,25 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; + if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired); + return Objects.hash(type, source, routingRequired, fieldsForModels); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); + } else { + fieldsForModels = Map.of(); + } } public static Diff readDiffFrom(StreamInput in) throws IOException { From 46f1f2eeb4327e0cfc9a10341b840f8f34833be6 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 11 Dec 2023 17:01:11 -0500 Subject: [PATCH 002/106] Updated IndexMetadata tests --- .../cluster/metadata/IndexMetadata.java | 17 +++++++++++++++-- .../cluster/metadata/IndexMetadataTests.java | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 48f7f152d4cb2..3b942f3944934 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1671,6 +1671,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); + builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); return builder.build(true); } } @@ -1738,6 +1739,11 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) + ); + } return builder.build(true); } @@ -1784,6 +1790,9 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalDouble(writeLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -2101,8 +2110,12 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { } public Builder fieldsForModels(Map> fieldsForModels) { - // TODO: How to handle null value? Clear this.fieldsForModels? - this.fieldsForModels = fieldsForModels; + if (fieldsForModels != null) { + this.fieldsForModels = fieldsForModels; + } else { + this.fieldsForModels.clear(); + } + return this; } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 0680392ffb3f0..8911754a42242 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -82,6 +83,15 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; + Map> fieldsForModels = new HashMap<>(); + for (int i = 0; i < randomIntBetween(0, 5); i++) { + Set fields = new HashSet<>(); + for (int j = 0; j < randomIntBetween(1, 4); j++) { + fields.add(randomAlphaOfLengthBetween(4, 10)); + } + fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); + } + IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) .creationDate(randomLong()) @@ -105,6 +115,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) + .fieldsForModels(fieldsForModels) .build(); assertEquals(system, metadata.isSystem()); @@ -136,6 +147,7 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -157,8 +169,9 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getCustomData(), deserialized.getCustomData()); assertEquals(metadata.isSystem(), deserialized.isSystem()); assertEquals(metadata.getStats(), deserialized.getStats()); - assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); - assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); + assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); + assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); + assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); } } From 02555e383d5975a14c9d5037b51b2e1826c008dc Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 11 Dec 2023 17:17:26 -0500 Subject: [PATCH 003/106] Randomize when fieldsForModels is set --- .../org/elasticsearch/cluster/metadata/IndexMetadata.java | 4 ++-- .../elasticsearch/cluster/metadata/IndexMetadataTests.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 3b942f3944934..6f7a503207508 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -2112,8 +2112,8 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { public Builder fieldsForModels(Map> fieldsForModels) { if (fieldsForModels != null) { this.fieldsForModels = fieldsForModels; - } else { - this.fieldsForModels.clear(); + } else if (!this.fieldsForModels.isEmpty()) { + this.fieldsForModels = Map.of(); } return this; diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 8911754a42242..88ca1a110a676 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -115,7 +115,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(fieldsForModels) + .fieldsForModels(randomBoolean() ? fieldsForModels : null) .build(); assertEquals(system, metadata.isSystem()); From 3208f7457d10526974aedc6cd9e930bab32a8131 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 12 Dec 2023 14:52:41 -0500 Subject: [PATCH 004/106] Added fieldsForModels to FieldTypeLookup --- .../index/mapper/FieldTypeLookup.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 2b4eec2bdd565..cda44d252a72a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,6 +36,11 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; + /** + * A map from inference model ID to all fields that use the model to generate embeddings. + */ + private final Map> fieldsForModels; + private final int maxParentPathDots; FieldTypeLookup( @@ -48,6 +53,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); + final Map> fieldsForModels = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -65,6 +71,13 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } + if (fieldType instanceof InferenceModelFieldType) { + String inferenceModel = ((InferenceModelFieldType) fieldType).getInferenceModel(); + if (inferenceModel != null) { + Set fields = fieldsForModels.computeIfAbsent(inferenceModel, v -> new HashSet<>()); + fields.add(fieldName); + } + } } int maxParentPathDots = 0; @@ -97,6 +110,8 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); + fieldsForModels.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); + this.fieldsForModels = Map.copyOf(fieldsForModels); } public static int dotCount(String path) { @@ -205,6 +220,15 @@ Set sourcePaths(String field) { return fieldToCopiedFields.containsKey(resolvedField) ? fieldToCopiedFields.get(resolvedField) : Set.of(resolvedField); } + Set getFieldsForModel(String model) { + Set fields = fieldsForModels.get(model); + return fields != null ? fields : Collections.emptySet(); + } + + Map> getFieldsForModels() { + return fieldsForModels; + } + /** * If field is a leaf multi-field return the path to the parent field. Otherwise, return null. */ From 9a7513eb7457d73f66d4bbf2adcf32c116317513 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 12 Dec 2023 15:01:18 -0500 Subject: [PATCH 005/106] Updated MappingLookup to add getFieldsForModels --- .../java/org/elasticsearch/index/mapper/MappingLookup.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 4880ce5edc204..e89ac5e050ddd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -140,7 +140,7 @@ private MappingLookup( nestedMappers.add((NestedObjectMapper) mapper); } } - this.nestedLookup = NestedLookup.build(nestedMappers); + this.nestedLookup = NestedLookup.build(nestedMappers); // TODO: Update to handle models in nested mappings final Map indexAnalyzersMap = new HashMap<>(); final Set completionFields = new HashSet<>(); @@ -498,4 +498,8 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } + + public Map> getFieldsForModels() { + return fieldTypeLookup.getFieldsForModels(); + } } From 409c8d5d9405e66e04d70279ff3b30495b740c4e Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 12 Dec 2023 15:15:26 -0500 Subject: [PATCH 006/106] Updated MappingMetadata to set fieldsForModels --- .../org/elasticsearch/cluster/metadata/MappingMetadata.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 3db9ad884ceff..a7c3867772c3c 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -49,7 +49,7 @@ public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - this.fieldsForModels = null; // TODO: Set fieldsForModels here + this.fieldsForModels = docMapper.mappers().getFieldsForModels(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -61,7 +61,7 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = null; // TODO: Set fieldsForModels here + this.fieldsForModels = null; } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -77,7 +77,7 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = null; // TODO: Set fieldsForModels here + this.fieldsForModels = null; } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { From c5748f70ea4ff132567b5cb1b41a0402bd065b45 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Tue, 12 Dec 2023 16:28:58 -0500 Subject: [PATCH 007/106] Ensure that fieldsForModels is immutable --- .../cluster/metadata/IndexMetadata.java | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 6f7a503207508..eaa1d27dc6b1a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -795,7 +795,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels + this.fieldsForModels // TODO: Need to take fieldsForModels from MappingMetadata in this constructor? ); } @@ -1842,7 +1842,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private Map> fieldsForModels = Map.of(); + private final ImmutableOpenMap.Builder> fieldsForModels; public Builder(String index) { this.index = index; @@ -1850,6 +1850,7 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); + this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1874,7 +1875,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = indexMetadata.fieldsForModels; + this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); } public Builder index(String index) { @@ -1959,7 +1960,7 @@ public Builder putMapping(MappingMetadata mappingMd) { Map> fieldsForModels = mappingMd.getFieldsForModels(); // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? if (fieldsForModels != null) { - this.fieldsForModels = fieldsForModels; + processFieldsForModels(this.fieldsForModels, fieldsForModels); } return this; } @@ -2110,12 +2111,7 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { } public Builder fieldsForModels(Map> fieldsForModels) { - if (fieldsForModels != null) { - this.fieldsForModels = fieldsForModels; - } else if (!this.fieldsForModels.isEmpty()) { - this.fieldsForModels = Map.of(); - } - + processFieldsForModels(this.fieldsForModels, fieldsForModels); return this; } @@ -2314,7 +2310,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels + fieldsForModels.build() ); } @@ -2732,6 +2728,17 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } + + private static void processFieldsForModels( + ImmutableOpenMap.Builder> builder, + Map> fieldsForModels + ) { + builder.clear(); + if (fieldsForModels != null) { + // Ensure that all field sets contained in the processed map are immutable + fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); + } + } } /** From 83319f5bbf220c490fde0618641badddab2e236e Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 08:47:52 -0500 Subject: [PATCH 008/106] Fix NPE --- .../elasticsearch/cluster/metadata/IndexMetadata.java | 10 ++++++---- .../cluster/metadata/MappingMetadata.java | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index eaa1d27dc6b1a..6b4174eccb682 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1957,10 +1957,12 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; - Map> fieldsForModels = mappingMd.getFieldsForModels(); - // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? - if (fieldsForModels != null) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); + if (mappingMd != null) { + Map> fieldsForModels = mappingMd.getFieldsForModels(); + // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? + if (fieldsForModels != null) { + processFieldsForModels(this.fieldsForModels, fieldsForModels); + } } return this; } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index a7c3867772c3c..6c0b9ec6d5634 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -206,7 +206,7 @@ public MappingMetadata(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); } else { - fieldsForModels = Map.of(); + fieldsForModels = null; } } From b4a6f6e37ce4e53ea6f00a6e6da8ba20b905df5e Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 08:49:31 -0500 Subject: [PATCH 009/106] Update docs/changelog/103319.yaml --- docs/changelog/103319.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/103319.yaml diff --git a/docs/changelog/103319.yaml b/docs/changelog/103319.yaml new file mode 100644 index 0000000000000..3bfced8300e49 --- /dev/null +++ b/docs/changelog/103319.yaml @@ -0,0 +1,5 @@ +pr: 103319 +summary: Store `semantic_text` model info in mappings +area: Mapping +type: feature +issues: [] From 31642b85325c40dd417bcf3b1bd79784d01b1cd3 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 08:57:40 -0500 Subject: [PATCH 010/106] Update IndexMetadata equals & hashCode --- .../org/elasticsearch/cluster/metadata/IndexMetadata.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 6b4174eccb682..a1e045a68077b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1418,6 +1418,9 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } + if (fieldsForModels.equals(that.fieldsForModels) == false) { + return false; + } if (isSystem != that.isSystem) { return false; } @@ -1438,6 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); + result = 31 * result + fieldsForModels.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } From 206ddb90d7fe9cebfe3ac22b4fc777b6675deaa7 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 09:16:38 -0500 Subject: [PATCH 011/106] Fix NPE --- .../org/elasticsearch/cluster/metadata/MappingMetadata.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 6c0b9ec6d5634..63699a139091a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MappingLookup; import java.io.IOException; import java.io.UncheckedIOException; @@ -49,7 +50,9 @@ public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - this.fieldsForModels = docMapper.mappers().getFieldsForModels(); + + MappingLookup mappingLookup = docMapper.mappers(); + this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : null; } @SuppressWarnings({ "this-escape", "unchecked" }) From f2503ea6ad4a2e5ac35f01214d32b6d6c6d822df Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 09:19:10 -0500 Subject: [PATCH 012/106] Fix checkstyle error --- .../java/org/elasticsearch/cluster/metadata/IndexMetadata.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index a1e045a68077b..955ef5f570a98 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -2444,7 +2444,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build // TODO: Need null check? Map> fieldsForModels = indexMetadata.getFieldsForModels(); - if (fieldsForModels != null && !fieldsForModels.isEmpty()) { + if (fieldsForModels != null && fieldsForModels.isEmpty() == false) { builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); } From 6c5d54189b7324253099933b18a9133d25adc140 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 10:42:25 -0500 Subject: [PATCH 013/106] Fix NPE --- .../org/elasticsearch/cluster/metadata/MappingMetadata.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 63699a139091a..bf7ad43da7200 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -177,7 +177,7 @@ public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED) && fieldsForModels != null) { out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); } } From d78af4cfbc9ff7b2f041f4971babfdc354625225 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 11:31:18 -0500 Subject: [PATCH 014/106] Update MappingMetadata to ensure that fieldsForModels is always non-null --- .../elasticsearch/cluster/metadata/IndexMetadata.java | 7 ++----- .../cluster/metadata/MappingMetadata.java | 10 +++++----- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 955ef5f570a98..2472126cb9800 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -795,7 +795,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.stats, this.writeLoadForecast, this.shardSizeInBytesForecast, - this.fieldsForModels // TODO: Need to take fieldsForModels from MappingMetadata in this constructor? + this.fieldsForModels ); } @@ -1963,10 +1963,7 @@ public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; if (mappingMd != null) { Map> fieldsForModels = mappingMd.getFieldsForModels(); - // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? - if (fieldsForModels != null) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); - } + processFieldsForModels(this.fieldsForModels, fieldsForModels); } return this; } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index bf7ad43da7200..64a61f854b9da 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -52,7 +52,7 @@ public MappingMetadata(DocumentMapper docMapper) { this.routingRequired = docMapper.routingFieldMapper().required(); MappingLookup mappingLookup = docMapper.mappers(); - this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : null; + this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -64,7 +64,7 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = null; + this.fieldsForModels = Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -80,7 +80,7 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = null; + this.fieldsForModels = Map.of(); } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -177,7 +177,7 @@ public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED) && fieldsForModels != null) { + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); } } @@ -209,7 +209,7 @@ public MappingMetadata(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); } else { - fieldsForModels = null; + fieldsForModels = Map.of(); } } From aa5b80055a4d41be73b28934541108d5967c36eb Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 12:07:34 -0500 Subject: [PATCH 015/106] Resolved TODOs --- .../org/elasticsearch/cluster/metadata/IndexMetadata.java | 6 ++---- .../java/org/elasticsearch/index/mapper/MappingLookup.java | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 2472126cb9800..e67eca4c4c97b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -2439,10 +2439,8 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - // TODO: Need null check? - Map> fieldsForModels = indexMetadata.getFieldsForModels(); - if (fieldsForModels != null && fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); + if (indexMetadata.fieldsForModels.isEmpty() == false) { + builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); } builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index e89ac5e050ddd..2c16a0fda9e60 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -140,7 +140,7 @@ private MappingLookup( nestedMappers.add((NestedObjectMapper) mapper); } } - this.nestedLookup = NestedLookup.build(nestedMappers); // TODO: Update to handle models in nested mappings + this.nestedLookup = NestedLookup.build(nestedMappers); final Map indexAnalyzersMap = new HashMap<>(); final Set completionFields = new HashSet<>(); From 04112d179870c701920db69d198d061adad1af6a Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 14:41:47 -0500 Subject: [PATCH 016/106] Adjusted cluster state diff tests --- .../cluster/ClusterStateDiffIT.java | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index b869b3a90fbce..433b4bdaf5d98 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -55,6 +55,7 @@ import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -571,7 +572,7 @@ public IndexMetadata randomCreate(String name) { @Override public IndexMetadata randomChange(IndexMetadata part) { IndexMetadata.Builder builder = IndexMetadata.builder(part); - switch (randomIntBetween(0, 2)) { + switch (randomIntBetween(0, 3)) { case 0: builder.settings(Settings.builder().put(part.getSettings()).put(randomSettings(Settings.EMPTY))); break; @@ -585,11 +586,34 @@ public IndexMetadata randomChange(IndexMetadata part) { case 2: builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; + case 3: + builder.fieldsForModels(randomFieldsForModels()); + break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } + + /** + * Generates a random fieldsForModels map + */ + private Map> randomFieldsForModels() { + if (randomBoolean()) { + return null; + } + + Map> fieldsForModels = new HashMap<>(); + for (int i = 0; i < randomIntBetween(0, 5); i++) { + Set fields = new HashSet<>(); + for (int j = 0; j < randomIntBetween(1, 4); j++) { + fields.add(randomAlphaOfLengthBetween(4, 10)); + } + fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); + } + + return fieldsForModels; + } }); } From a66f69bda0fb83f8daef85e67774dd8b2394e8d1 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 15:34:18 -0500 Subject: [PATCH 017/106] IndexMetadata test updates --- .../cluster/metadata/IndexMetadataTests.java | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 88ca1a110a676..a41d068385c7a 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -83,14 +83,7 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } + Map> fieldsForModels = randomFieldsForModels(true); IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) @@ -115,7 +108,7 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(randomBoolean() ? fieldsForModels : null) + .fieldsForModels(fieldsForModels) .build(); assertEquals(system, metadata.isSystem()); @@ -555,10 +548,39 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } + public void testFieldsForModels() { + Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); + IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); + assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); + + Map> fieldsForModels = randomFieldsForModels(false); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1) + .fieldsForModels(fieldsForModels) + .build(); + assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); + } + private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } + private static Map> randomFieldsForModels(boolean allowNull) { + if (allowNull && randomBoolean()) { + return null; + } + + Map> fieldsForModels = new HashMap<>(); + for (int i = 0; i < randomIntBetween(0, 5); i++) { + Set fields = new HashSet<>(); + for (int j = 0; j < randomIntBetween(1, 4); j++) { + fields.add(randomAlphaOfLengthBetween(4, 10)); + } + fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); + } + + return fieldsForModels; + } + private IndexMetadataStats randomIndexStats(int numberOfShards) { IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards); int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards); From 6feacd7ee604b631311e7daed7f76950462b8b65 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 16:44:50 -0500 Subject: [PATCH 018/106] Added/updated FieldTypeLookup tests --- .../index/mapper/FieldTypeLookupTests.java | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 3f50b9fdf6621..728096cb75d0a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -8,7 +8,9 @@ package org.elasticsearch.index.mapper; +import org.apache.lucene.search.Query; import org.elasticsearch.index.mapper.flattened.FlattenedFieldMapper; +import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -16,6 +18,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -35,6 +38,10 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -42,6 +49,10 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -421,7 +432,51 @@ public void testRuntimeFieldNameOutsideContext() { } } + public void testInferenceModelFieldType() { + MockFieldMapper f = new MockFieldMapper(new MockInferenceModelFieldType("foo", "bar")); + FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); + assertEquals(f.fieldType(), lookup.get("foo")); + assertEquals(Collections.emptySet(), lookup.getFieldsForModel("baz")); + assertEquals(Collections.singleton("foo"), lookup.getFieldsForModel("bar")); + + Map> fieldsForModels = lookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertEquals(1, fieldsForModels.size()); + assertEquals(Collections.singleton("foo"), fieldsForModels.get("bar")); + } + private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } + + private static class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private static final String TYPE_NAME = "mock_inference_model_field_type"; + + private final String modelId; + + MockInferenceModelFieldType(String name, String modelId) { + super(name, false, false, false, TextSearchInfo.NONE, Map.of()); + this.modelId = modelId; + } + + @Override + public String typeName() { + return TYPE_NAME; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException("termQuery not implemented"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.toString(name(), context, format); + } + + @Override + public String getInferenceModel() { + return modelId; + } + } } From 7be2f4b3b3c686a1e5b75b7c797a808d937383c3 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 13 Dec 2023 16:49:50 -0500 Subject: [PATCH 019/106] Fix spotless violations --- .../elasticsearch/cluster/metadata/IndexMetadataTests.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index a41d068385c7a..1925c869cdb81 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -554,9 +554,7 @@ public void testFieldsForModels() { assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1) - .fieldsForModels(fieldsForModels) - .build(); + IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); } From fa678a3c33505d25c9e94927933a917308e14630 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 14 Dec 2023 08:18:55 -0500 Subject: [PATCH 020/106] Added/updated MappingLookup tests --- .../index/mapper/FieldTypeLookupTests.java | 33 -------------- .../index/mapper/MappingLookupTests.java | 19 ++++++++ .../mapper/MockInferenceModelFieldType.java | 45 +++++++++++++++++++ 3 files changed, 64 insertions(+), 33 deletions(-) create mode 100644 test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 728096cb75d0a..8db9c09f0d098 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -8,9 +8,7 @@ package org.elasticsearch.index.mapper; -import org.apache.lucene.search.Query; import org.elasticsearch.index.mapper.flattened.FlattenedFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -448,35 +446,4 @@ public void testInferenceModelFieldType() { private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } - - private static class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private static final String TYPE_NAME = "mock_inference_model_field_type"; - - private final String modelId; - - MockInferenceModelFieldType(String name, String modelId) { - super(name, false, false, false, TextSearchInfo.NONE, Map.of()); - this.modelId = modelId; - } - - @Override - public String typeName() { - return TYPE_NAME; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public String getInferenceModel() { - return modelId; - } - } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index 0308dac5fa216..f512f5d352a43 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -121,6 +122,8 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); + assertNotNull(mappingLookup.getFieldsForModels()); + assertTrue(mappingLookup.getFieldsForModels().isEmpty()); } public void testValidateDoesNotShadow() { @@ -188,6 +191,22 @@ public MetricType getMetricType() { ); } + public void testFieldsForModels() { + MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); + MappingLookup mappingLookup = createMappingLookup( + Collections.singletonList(new MockFieldMapper(fieldType)), + emptyList(), + emptyList() + ); + assertEquals(1, size(mappingLookup.fieldMappers())); + assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); + + Map> fieldsForModels = mappingLookup.getFieldsForModels(); + assertNotNull(fieldsForModels); + assertEquals(1, fieldsForModels.size()); + assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); + } + private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java new file mode 100644 index 0000000000000..854749d6308db --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.index.query.SearchExecutionContext; + +import java.util.Map; + +public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private static final String TYPE_NAME = "mock_inference_model_field_type"; + + private final String modelId; + + public MockInferenceModelFieldType(String name, String modelId) { + super(name, false, false, false, TextSearchInfo.NONE, Map.of()); + this.modelId = modelId; + } + + @Override + public String typeName() { + return TYPE_NAME; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException("termQuery not implemented"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.toString(name(), context, format); + } + + @Override + public String getInferenceModel() { + return modelId; + } +} From 981ac8f89db131deae4c6406014ad7a61ca49eb5 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 14 Dec 2023 11:25:21 -0500 Subject: [PATCH 021/106] Delete docs/changelog/103319.yaml --- docs/changelog/103319.yaml | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 docs/changelog/103319.yaml diff --git a/docs/changelog/103319.yaml b/docs/changelog/103319.yaml deleted file mode 100644 index 3bfced8300e49..0000000000000 --- a/docs/changelog/103319.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 103319 -summary: Store `semantic_text` model info in mappings -area: Mapping -type: feature -issues: [] From f45af491aa5ec6b9aa1d552da32556b396db873f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Dec 2023 13:25:17 +0100 Subject: [PATCH 022/106] Refactored into separate methods --- .../action/bulk/TransportBulkAction.java | 81 +++++++++++-------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index b89b5e2de7924..be976af717cd3 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -598,6 +598,11 @@ protected void doRun() { if (handleBlockExceptions(clusterState)) { return; } + Map> requestsByShard = groupRequestsByShards(clusterState); + executeBulkRequestsByShard(requestsByShard, clusterState); + } + + private Map> groupRequestsByShards(ClusterState clusterState) { final ConcreteIndices concreteIndices = new ConcreteIndices(clusterState, indexNameExpressionResolver); Metadata metadata = clusterState.metadata(); // Group the requests by ShardId -> Operations mapping @@ -616,7 +621,7 @@ protected void doRun() { continue; } IndexAbstraction ia = null; - boolean includeDataStreams = docWriteRequest.opType() == DocWriteRequest.OpType.CREATE; + boolean includeDataStreams = docWriteRequest.opType() == OpType.CREATE; try { ia = concreteIndices.resolveIfAbsent(docWriteRequest); if (ia.isDataStreamRelated() && includeDataStreams == false) { @@ -629,7 +634,7 @@ protected void doRun() { // avoid valid cases when directly indexing into a backing index // (for example when directly indexing into .ds-logs-foobar-000001) ia.getName().equals(docWriteRequest.index()) == false - && docWriteRequest.opType() != DocWriteRequest.OpType.CREATE) { + && docWriteRequest.opType() != OpType.CREATE) { throw new IllegalArgumentException("only write ops with an op_type of create are allowed in data streams"); } @@ -658,7 +663,10 @@ protected void doRun() { bulkRequest.requests.set(i, null); } } + return requestsByShard; + } + private void executeBulkRequestsByShard(Map> requestsByShard, ClusterState clusterState) { if (requestsByShard.isEmpty()) { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -682,44 +690,49 @@ protected void doRun() { if (task != null) { bulkShardRequest.setParentTask(nodeId, task.getId()); } - client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { - @Override - public void onResponse(BulkShardResponse bulkShardResponse) { - for (BulkItemResponse bulkItemResponse : bulkShardResponse.getResponses()) { - // we may have no response if item failed - if (bulkItemResponse.getResponse() != null) { - bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo()); - } - responses.set(bulkItemResponse.getItemId(), bulkItemResponse); + + executeBulkShardRequest(bulkShardRequest, requests, counter); + } + bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed + } + + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { + client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { + @Override + public void onResponse(BulkShardResponse bulkShardResponse) { + for (BulkItemResponse bulkItemResponse : bulkShardResponse.getResponses()) { + // we may have no response if item failed + if (bulkItemResponse.getResponse() != null) { + bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo()); } - maybeFinishHim(); + responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } + maybeFinishHim(); + } - @Override - public void onFailure(Exception e) { - // create failures for all relevant requests - for (BulkItemRequest request : requests) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - } - maybeFinishHim(); + @Override + public void onFailure(Exception e) { + // create failures for all relevant requests + for (BulkItemRequest request : requests) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } + maybeFinishHim(); + } - private void maybeFinishHim() { - if (counter.decrementAndGet() == 0) { - listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) - ); - } + private void maybeFinishHim() { + if (counter.decrementAndGet() == 0) { + listener.onResponse( + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos) + ) + ); } - }); - } - bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed + } + }); } private boolean handleBlockExceptions(ClusterState state) { From 4a93be8732847c62f02c6cd201ded18584aee564 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Sat, 9 Dec 2023 00:10:26 +0100 Subject: [PATCH 023/106] Added fieldsForModels to IndexMetadata & MappingMetadata --- .../org/elasticsearch/TransportVersions.java | 1 + .../cluster/metadata/IndexMetadata.java | 83 +++++++++++++++++-- 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 625871d25734b..226d82a9d5b73 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -193,6 +193,7 @@ static TransportVersion def(int id) { public static final TransportVersion ENRICH_ELASTICSEARCH_VERSION_REMOVED = def(8_560_00_0); public static final TransportVersion NODE_STATS_REQUEST_SIMPLIFIED = def(8_561_00_0); public static final TransportVersion TEXT_EXPANSION_TOKEN_PRUNING_CONFIG_ADDED = def(8_562_00_0); + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_563_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 742b52365c8d7..48f7f152d4cb2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -79,6 +79,7 @@ import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; @@ -546,6 +547,8 @@ public Iterator> settings() { public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast"; + public static final String KEY_FIELDS_FOR_MODELS = "fields_for_models"; + public static final String INDEX_STATE_FILE_PREFIX = "state-"; static final TransportVersion SYSTEM_INDEX_FLAG_ADDED = TransportVersions.V_7_10_0; @@ -635,6 +638,8 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; + // Key: model ID, Value: Fields that use model + private final Map> fieldsForModels; private IndexMetadata( final Index index, @@ -680,7 +685,8 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast + @Nullable Long shardSizeInBytesForecast, + final Map> fieldsForModels ) { this.index = index; this.version = version; @@ -736,6 +742,8 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; + this.fieldsForModels = fieldsForModels; + assert fieldsForModels != null; } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -786,7 +794,8 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -844,7 +853,8 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -900,7 +910,8 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -956,7 +967,8 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1008,7 +1020,8 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.fieldsForModels ); } @@ -1212,6 +1225,10 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1476,6 +1493,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; + private final Diff>> fieldsForModels; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1512,6 +1530,12 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; + fieldsForModels = DiffableUtils.diff( + before.fieldsForModels, + after.fieldsForModels, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1571,6 +1595,15 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = DiffableUtils.readJdkMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } else { + fieldsForModels = DiffableUtils.emptyDiff(); + } } @Override @@ -1606,6 +1639,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels.writeTo(out); + } } @Override @@ -1797,6 +1833,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; + private Map> fieldsForModels = Map.of(); public Builder(String index) { this.index = index; @@ -1828,6 +1865,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; + this.fieldsForModels = indexMetadata.fieldsForModels; } public Builder index(String index) { @@ -1909,6 +1947,11 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; + Map> fieldsForModels = mappingMd.getFieldsForModels(); + // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? + if (fieldsForModels != null) { + this.fieldsForModels = fieldsForModels; + } return this; } @@ -2057,6 +2100,12 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder fieldsForModels(Map> fieldsForModels) { + // TODO: How to handle null value? Clear this.fieldsForModels? + this.fieldsForModels = fieldsForModels; + return this; + } + public IndexMetadata build() { return build(false); } @@ -2251,7 +2300,8 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast + shardSizeInBytesForecast, + fieldsForModels ); } @@ -2377,6 +2427,12 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } + // TODO: Need null check? + Map> fieldsForModels = indexMetadata.getFieldsForModels(); + if (fieldsForModels != null && !fieldsForModels.isEmpty()) { + builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); + } + builder.endObject(); } @@ -2454,6 +2510,19 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> fieldsForModels = parser.map(HashMap::new, XContentParser::list) + .entrySet() + .stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + v -> v.getValue().stream().map(Object::toString).collect(Collectors.toUnmodifiableSet()) + ) + ); + builder.fieldsForModels(fieldsForModels); + break; default: // assume it's custom index metadata builder.putCustom(currentFieldName, parser.mapStrings()); From 1b8226146e856fd06ea3eb203a94f5f8bda2ee05 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Dec 2023 13:46:33 +0100 Subject: [PATCH 024/106] Moved InferenceAction and result classes to server --- .../action/bulk/TransportBulkAction.java | 13 +++++++++++- .../action/inference}/InferenceAction.java | 8 +++---- .../results/LegacyTextEmbeddingResults.java | 2 +- .../results/NlpInferenceResults.java | 10 ++++++++- .../results/SparseEmbeddingResults.java | 12 +++++++++-- .../results/TextEmbeddingResults.java | 11 +++++++++- .../results/TextExpansionResults.java | 13 ++++++++++-- .../cluster/metadata/MappingMetadata.java | 21 ++++++++++++++++++- .../MlInferenceNamedXContentProvider.java | 2 +- .../core/ml/inference/results/NerResults.java | 1 + .../NlpClassificationInferenceResults.java | 1 + .../results/PyTorchPassThroughResults.java | 1 + .../QuestionAnsweringInferenceResults.java | 1 + .../results/TextEmbeddingResults.java | 1 + .../TextSimilarityInferenceResults.java | 1 + .../action/InferModelActionResponseTests.java | 2 +- .../results/InferenceResultsTestCase.java | 1 + .../results/TextExpansionResultsTests.java | 1 + .../mock/TestInferenceServiceExtension.java | 2 +- .../InferenceNamedWriteablesProvider.java | 6 +++--- .../xpack/inference/InferencePlugin.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../HuggingFaceElserResponseEntity.java | 2 +- .../HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../OpenAiEmbeddingsResponseEntity.java | 2 +- .../inference/rest/RestInferenceAction.java | 2 +- .../inference/services/ServiceUtils.java | 2 +- .../services/elser/ElserMlNodeService.java | 2 +- .../action/InferenceActionRequestTests.java | 2 +- .../action/InferenceActionResponseTests.java | 4 ++-- .../HuggingFaceElserResponseEntityTests.java | 2 +- ...gingFaceEmbeddingsResponseEntityTests.java | 2 +- .../OpenAiEmbeddingsResponseEntityTests.java | 2 +- .../LegacyTextEmbeddingResultsTests.java | 2 +- .../results/SparseEmbeddingResultsTests.java | 4 ++-- .../results/TextEmbeddingResultsTests.java | 2 +- .../TransportCoordinatedInferenceAction.java | 2 +- .../inference/nlp/TextExpansionProcessor.java | 2 +- .../ml/queries/TextExpansionQueryBuilder.java | 2 +- .../queries/WeightedTokensQueryBuilder.java | 2 +- .../nlp/TextExpansionProcessorTests.java | 2 +- .../TextExpansionQueryBuilderTests.java | 2 +- .../WeightedTokensQueryBuilderTests.java | 4 ++-- 43 files changed, 117 insertions(+), 45 deletions(-) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action => server/src/main/java/org/elasticsearch/action/inference}/InferenceAction.java (97%) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core => server/src/main/java/org/elasticsearch/action}/inference/results/LegacyTextEmbeddingResults.java (98%) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml => server/src/main/java/org/elasticsearch/action}/inference/results/NlpInferenceResults.java (86%) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core => server/src/main/java/org/elasticsearch/action}/inference/results/SparseEmbeddingResults.java (93%) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core => server/src/main/java/org/elasticsearch/action}/inference/results/TextEmbeddingResults.java (90%) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml => server/src/main/java/org/elasticsearch/action}/inference/results/TextExpansionResults.java (87%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index be976af717cd3..54ab93e35a250 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -691,11 +691,22 @@ private void executeBulkRequestsByShard(Map> requ bulkShardRequest.setParentTask(nodeId, task.getId()); } - executeBulkShardRequest(bulkShardRequest, requests, counter); + IndexMetadata indexMetadata = clusterState.metadata().index(shardId.getIndex()); + if (indexMetadata.getFieldsForModels().isEmpty()) { + executeBulkShardRequest(bulkShardRequest, requests, counter); + } else { + performInference(bulkShardRequest, requests, clusterState); + } + + } bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed } + private void performInference(BulkShardRequest bulkShardRequest, List requests, ClusterState clusterState) { + + } + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java similarity index 97% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java rename to server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java index 53391aca84622..f924f2537ec6b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.inference.action; +package org.elasticsearch.action.inference; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersions; @@ -13,6 +13,9 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.InferenceResults; @@ -24,9 +27,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java similarity index 98% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java rename to server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java index 8f03a75c61c11..3cbbf0c6863cc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.inference.results; +package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java similarity index 86% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java rename to server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java index 4efb719137c65..fee00551bb528 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java @@ -1,3 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -5,7 +13,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.ml.inference.results; +package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java similarity index 93% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java rename to server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java index 910ea5cab214d..85fecdae6f39c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java @@ -1,3 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -5,7 +13,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.inference.results; +package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -91,7 +99,7 @@ public List transformToLegacyFormat() { return embeddings.stream() .map( embedding -> new TextExpansionResults( - DEFAULT_RESULTS_FIELD, + InferenceConfig.DEFAULT_RESULTS_FIELD, embedding.tokens() .stream() .map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token, weightedToken.weight)) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java similarity index 90% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java rename to server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java index ace5974866038..49674f85b90df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java @@ -1,3 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -5,8 +13,9 @@ * 2.0. */ -package org.elasticsearch.xpack.core.inference.results; +package org.elasticsearch.action.inference.results; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java similarity index 87% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java rename to server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java index 45aa4d51e0ad6..e54089187e121 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java @@ -1,3 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -5,12 +13,13 @@ * 2.0. */ -package org.elasticsearch.xpack.core.ml.inference.results; +package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; @@ -87,7 +96,7 @@ public Object predictedValue() { } @Override - void doXContentBody(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(resultsField); for (var weightedToken : weightedTokens) { weightedToken.toXContent(builder, params); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index b629ab5d5f710..3db9ad884ceff 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -23,6 +23,7 @@ import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -42,10 +43,13 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; + private final Map> fieldsForModels; + public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -57,6 +61,7 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -72,6 +77,7 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); + this.fieldsForModels = null; // TODO: Set fieldsForModels here } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -158,12 +164,19 @@ public String getSha256() { return source.getSha256(); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -176,19 +189,25 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; + if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired); + return Objects.hash(type, source, routingRequired, fieldsForModels); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); + } else { + fieldsForModels = Map.of(); + } } public static Diff readDiffFrom(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 00587936848f8..9f08667f85572 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -29,7 +29,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index b077c93c141a5..ba5ba3fcb7a5c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java index a49e81e40a7a6..cf6e29be1746c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index de49fb2252ad0..83d4f204cd174 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index e9e41ce963bec..9f2c8f2b77a70 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java index 526c2ec7b7aaa..a92c53378a5dc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java index b8b75e2bf7eb4..1fe69a16b41b3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 4d8035864729a..0d85b173be27b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java index bda9eed40659c..8db86bbb2e2af 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.TestIngestDocument; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java index 82487960dfe8f..84b14e4122d6e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.ingest.IngestDocument; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java index eee6f68c20ff7..e30294b50b780 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java @@ -25,7 +25,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index c632c568fea16..2cf7dabdbb2ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -14,9 +14,9 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 33d71c65ed643..c8b67a2cdc59d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -35,7 +35,7 @@ import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index db98aeccc556b..a80156f065e45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java index 247537b9958d0..1008dd3118874 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index b74b03891034f..4259c139e98e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 4926ba3f0ef6b..e000390bb8575 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -13,7 +13,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index beecf75da38ab..272544e6c1c06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1686cd32d4a6b..ddfa310a6a71b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -16,7 +16,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.net.URI; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 01fe828d723d2..de6710c67c9a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -21,7 +21,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index aa540694ba564..e5d1cbcc7034e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 759411cec1212..8b79aa5f90f83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; @@ -26,7 +26,7 @@ import static org.elasticsearch.TransportVersions.INFERENCE_SERVICE_RESULTS_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_OPENAI_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED; -import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Response.transformToServiceResults; +import static org.elasticsearch.action.inference.InferenceAction.Response.transformToServiceResults; public class InferenceActionResponseTests extends AbstractBWCWireSerializationTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index bdb8e38fa8228..35cc3f5243ae5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index 2b6e11fdfafa7..237d4145866bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 010e990a3ce80..910ee3eed00c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -10,7 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java index 605411343533f..ace4bb9c77729 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 6f8fa0c453d09..8fecf975aedb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -10,8 +10,8 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 09d9894d98853..1ece52cd247f4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 7442f1db0a662..57c325c30937a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -25,7 +25,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 6483b9d9b3da9..57709bdb33e42 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 24383e51b0ed2..f8b06d168464d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index a09bcadaacfc0..835fd611b5a7d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -27,7 +27,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults.WeightedToken; +import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java index c94775b1785c9..22dab2e3801d2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 13f12f3cdc1e1..d7c1204249ff0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -33,7 +33,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java index 4d91c66de4b9e..4d05f02f946c4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java @@ -33,7 +33,7 @@ import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; @@ -41,7 +41,7 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults.WeightedToken; +import static org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.TOKENS_FIELD; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; From 39cbdffe9171c3fdd29a622aeff979f4a266e8c5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Dec 2023 19:08:00 +0100 Subject: [PATCH 025/106] Remove unneeded code for retrieving IndexMetadata.fieldsForModels --- .../cluster/metadata/IndexMetadata.java | 7 +------ .../cluster/metadata/MappingMetadata.java | 21 +------------------ 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 48f7f152d4cb2..7dc99d8bc527e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1947,11 +1947,6 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; - Map> fieldsForModels = mappingMd.getFieldsForModels(); - // TODO: Need to clear fieldsForModels if the version from MappingMetadata is null? - if (fieldsForModels != null) { - this.fieldsForModels = fieldsForModels; - } return this; } @@ -2429,7 +2424,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build // TODO: Need null check? Map> fieldsForModels = indexMetadata.getFieldsForModels(); - if (fieldsForModels != null && !fieldsForModels.isEmpty()) { + if (fieldsForModels != null && fieldsForModels.isEmpty() == false) { builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 3db9ad884ceff..b629ab5d5f710 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -23,7 +23,6 @@ import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -43,13 +42,10 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; - private final Map> fieldsForModels; - public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -61,7 +57,6 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = null; // TODO: Set fieldsForModels here } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -77,7 +72,6 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = null; // TODO: Set fieldsForModels here } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -164,19 +158,12 @@ public String getSha256() { return source.getSha256(); } - public Map> getFieldsForModels() { - return fieldsForModels; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); - } } @Override @@ -189,25 +176,19 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; - if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired, fieldsForModels); + return Objects.hash(type, source, routingRequired); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); - } else { - fieldsForModels = Map.of(); - } } public static Diff readDiffFrom(StreamInput in) throws IOException { From 0e493ee4696aae0f1b4bf81fef9a61c986025a0b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Dec 2023 19:09:01 +0100 Subject: [PATCH 026/106] Change inference result classes to server --- server/src/main/java/module-info.java | 3 +++ .../results/NlpInferenceResults.java | 19 ++++++------------- .../results/SparseEmbeddingResults.java | 12 ++---------- .../results/TextEmbeddingResults.java | 3 +-- .../results/TextExpansionResults.java | 13 +++---------- .../results/TextEmbeddingResults.java | 13 +++++++------ .../core/src/main/java/module-info.java | 1 - .../MlInferenceNamedXContentProvider.java | 2 +- .../ml/inference/results/FillMaskResults.java | 6 +++--- .../core/ml/inference/results/NerResults.java | 6 +++--- .../NlpClassificationInferenceResults.java | 6 +++--- .../results/PyTorchPassThroughResults.java | 6 +++--- .../QuestionAnsweringInferenceResults.java | 6 +++--- .../TextSimilarityInferenceResults.java | 6 +++--- .../action/InferModelActionResponseTests.java | 2 +- .../results/FillMaskResultsTests.java | 2 +- .../ml/inference/results/NerResultsTests.java | 2 +- .../PyTorchPassThroughResultsTests.java | 4 ++-- .../results/TextEmbeddingResultsTests.java | 5 +++-- .../results/TextEmbeddingResultsTests.java | 4 ++-- x-pack/plugin/ml/build.gradle | 1 + .../TransportCoordinatedInferenceAction.java | 2 +- .../inference/nlp/TextEmbeddingProcessor.java | 2 +- .../inference/nlp/TextExpansionProcessor.java | 2 +- .../ml/queries/TextExpansionQueryBuilder.java | 2 +- .../queries/WeightedTokensQueryBuilder.java | 2 +- .../TextEmbeddingQueryVectorBuilder.java | 2 +- .../nlp/TextExpansionProcessorTests.java | 2 +- .../TextExpansionQueryBuilderTests.java | 2 +- .../WeightedTokensQueryBuilderTests.java | 2 +- .../TextEmbeddingQueryVectorBuilderTests.java | 2 +- 31 files changed, 62 insertions(+), 80 deletions(-) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core => server/src/main/java/org/elasticsearch/action}/ml/inference/results/TextEmbeddingResults.java (85%) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 613e6868b8e9f..904876984ad91 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -135,6 +135,9 @@ exports org.elasticsearch.action.get; exports org.elasticsearch.action.index; exports org.elasticsearch.action.ingest; + exports org.elasticsearch.action.inference; + exports org.elasticsearch.action.inference.results; + exports org.elasticsearch.action.ml.inference.results; exports org.elasticsearch.action.resync; exports org.elasticsearch.action.search; exports org.elasticsearch.action.support; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java index fee00551bb528..98f13b6711a06 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java @@ -6,13 +6,6 @@ * Side Public License, v 1. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -26,23 +19,23 @@ import java.util.Map; import java.util.Objects; -abstract class NlpInferenceResults implements InferenceResults { +public abstract class NlpInferenceResults implements InferenceResults { protected final boolean isTruncated; - NlpInferenceResults(boolean isTruncated) { + public NlpInferenceResults(boolean isTruncated) { this.isTruncated = isTruncated; } - NlpInferenceResults(StreamInput in) throws IOException { + public NlpInferenceResults(StreamInput in) throws IOException { this.isTruncated = in.readBoolean(); } - abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; + protected abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; - abstract void doWriteTo(StreamOutput out) throws IOException; + protected abstract void doWriteTo(StreamOutput out) throws IOException; - abstract void addMapFields(Map map); + protected abstract void addMapFields(Map map); public boolean isTruncated() { return isTruncated; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java index 85fecdae6f39c..0a90df280a6c6 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java @@ -6,13 +6,6 @@ * Side Public License, v 1. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -25,7 +18,6 @@ import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; @@ -34,12 +26,12 @@ import java.util.Map; import java.util.stream.Collectors; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; public record SparseEmbeddingResults(List embeddings) implements InferenceServiceResults { public static final String NAME = "sparse_embedding_results"; public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString(); + public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public SparseEmbeddingResults(StreamInput in) throws IOException { this(in.readCollectionAsList(Embedding::new)); @@ -99,7 +91,7 @@ public List transformToLegacyFormat() { return embeddings.stream() .map( embedding -> new TextExpansionResults( - InferenceConfig.DEFAULT_RESULTS_FIELD, + DEFAULT_RESULTS_FIELD, embedding.tokens() .stream() .map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token, weightedToken.weight)) diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java index 49674f85b90df..fc6f30d0f661a 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java @@ -15,7 +15,6 @@ package org.elasticsearch.action.inference.results; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -91,7 +90,7 @@ public String getWriteableName() { public List transformToCoordinationFormat() { return embeddings.stream() .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) - .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .map(values -> new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) .toList(); } diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java index e54089187e121..f5d8f68d144e5 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java @@ -6,13 +6,6 @@ * Side Public License, v 1. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -96,7 +89,7 @@ public Object predictedValue() { } @Override - void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(resultsField); for (var weightedToken : weightedTokens) { weightedToken.toXContent(builder, params); @@ -119,13 +112,13 @@ public int hashCode() { } @Override - void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(resultsField); out.writeCollection(weightedTokens); } @Override - void addMapFields(Map map) { + public void addMapFields(Map map) { map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java similarity index 85% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java rename to server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java index a92c53378a5dc..632e03f69bcd7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.core.ml.inference.results; +package org.elasticsearch.action.ml.inference.results; import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; @@ -53,7 +54,7 @@ public float[] getInferenceAsFloat() { } @Override - void doXContentBody(XContentBuilder builder, Params params) throws IOException { + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, inference); } @@ -63,13 +64,13 @@ public String getWriteableName() { } @Override - void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeDoubleArray(inference); out.writeString(resultsField); } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, inference); } diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index f747d07224454..38913adbd0f28 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -74,7 +74,6 @@ exports org.elasticsearch.xpack.core.ilm; exports org.elasticsearch.xpack.core.indexing; exports org.elasticsearch.xpack.core.inference.action; - exports org.elasticsearch.xpack.core.inference.results; exports org.elasticsearch.xpack.core.inference; exports org.elasticsearch.xpack.core.logstash; exports org.elasticsearch.xpack.core.ml.action; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 9f08667f85572..d65ceea656079 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -28,7 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java index 4fad9b535e4e1..ded62c3e12bfb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java @@ -40,7 +40,7 @@ public FillMaskResults(StreamInput in) throws IOException { } @Override - public void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { super.doWriteTo(out); out.writeString(predictedSequence); } @@ -50,7 +50,7 @@ public String getPredictedSequence() { } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { super.addMapFields(map); map.put(resultsField + "_sequence", predictedSequence); } @@ -68,7 +68,7 @@ public String getWriteableName() { } @Override - public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { super.doXContentBody(builder, params); builder.field(resultsField + "_sequence", predictedSequence); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index ba5ba3fcb7a5c..b95642ecda753 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -47,7 +47,7 @@ public NerResults(StreamInput in) throws IOException { } @Override - void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, annotatedResult); builder.startArray("entities"); for (EntityGroup entity : entityGroups) { @@ -62,14 +62,14 @@ public String getWriteableName() { } @Override - void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeCollection(entityGroups); out.writeString(resultsField); out.writeString(annotatedResult); } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, annotatedResult); map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList())); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java index cf6e29be1746c..8a7be15294548 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -60,7 +60,7 @@ public List getTopClasses() { } @Override - public void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(classificationLabel); out.writeCollection(topClasses); out.writeString(resultsField); @@ -99,7 +99,7 @@ public Object predictedValue() { } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, classificationLabel); if (topClasses.isEmpty() == false) { map.put( @@ -118,7 +118,7 @@ public String getWriteableName() { } @Override - public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, classificationLabel); if (topClasses.size() > 0) { builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index 83d4f204cd174..550e2e73ae0d2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -41,7 +41,7 @@ public double[][] getInference() { } @Override - void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, inference); } @@ -51,7 +51,7 @@ public String getWriteableName() { } @Override - public void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeDoubleArray, inference); out.writeString(resultsField); } @@ -62,7 +62,7 @@ public String getResultsField() { } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, inference); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index 9f2c8f2b77a70..87f781e41c8b0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -77,7 +77,7 @@ public List getTopClasses() { } @Override - public void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(answer); out.writeVInt(startOffset); out.writeVInt(endOffset); @@ -120,7 +120,7 @@ public String predictedValue() { } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, answer); addSupportingFieldsToMap(map); } @@ -151,7 +151,7 @@ public String getWriteableName() { } @Override - public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, answer); builder.field(START_OFFSET.getPreferredName(), startOffset); builder.field(END_OFFSET.getPreferredName(), endOffset); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java index 1fe69a16b41b3..c390bb6c19323 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java @@ -35,7 +35,7 @@ public TextSimilarityInferenceResults(StreamInput in) throws IOException { } @Override - public void doWriteTo(StreamOutput out) throws IOException { + protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(resultsField); out.writeDouble(score); } @@ -65,7 +65,7 @@ public Double predictedValue() { } @Override - void addMapFields(Map map) { + protected void addMapFields(Map map) { map.put(resultsField, score); } @@ -82,7 +82,7 @@ public String getWriteableName() { } @Override - public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, score); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 0d85b173be27b..76b96135e897d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java index 432e05b9cc680..7ca86dbaf195a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -65,7 +65,7 @@ public void testAsMap() { assertThat(asMap.get(PREDICTION_PROBABILITY), equalTo(testInstance.getPredictionProbability())); assertThat(asMap.get(DEFAULT_RESULTS_FIELD + "_sequence"), equalTo(testInstance.getPredictedSequence())); List> resultList = (List>) asMap.get(DEFAULT_TOP_CLASSES_RESULTS_FIELD); - if (testInstance.isTruncated) { + if (testInstance.isTruncated()) { assertThat(asMap.get("is_truncated"), is(true)); } else { assertThat(asMap, not(hasKey("is_truncated"))); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java index 4be49807d27b0..252e0491314a8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -68,7 +68,7 @@ public void testAsMap() { } assertThat(resultList, hasSize(testInstance.getEntityGroups().size())); assertThat(asMap.get(testInstance.getResultsField()), equalTo(testInstance.getAnnotatedResult())); - if (testInstance.isTruncated) { + if (testInstance.isTruncated()) { assertThat(asMap.get("is_truncated"), is(true)); } else { assertThat(asMap, not(hasKey("is_truncated"))); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java index e6b38a08a75ba..898f7a9f39916 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java @@ -49,10 +49,10 @@ protected PyTorchPassThroughResults mutateInstance(PyTorchPassThroughResults ins public void testAsMap() { PyTorchPassThroughResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - int size = testInstance.isTruncated ? 2 : 1; + int size = testInstance.isTruncated() ? 2 : 1; assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(DEFAULT_RESULTS_FIELD)); - if (testInstance.isTruncated) { + if (testInstance.isTruncated()) { assertThat(asMap.get("is_truncated"), is(true)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java index fd3ac7f8c0d12..4f8f433c1d2cd 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.ingest.IngestDocument; @@ -46,10 +47,10 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) { public void testAsMap() { TextEmbeddingResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - int size = testInstance.isTruncated ? 2 : 1; + int size = testInstance.isTruncated() ? 2 : 1; assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(DEFAULT_RESULTS_FIELD), 1e-10); - if (testInstance.isTruncated) { + if (testInstance.isTruncated()) { assertThat(asMap.get("is_truncated"), is(true)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 1ece52cd247f4..55c5d01530e0b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -109,12 +109,12 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults( TextEmbeddingResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false ), - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults( TextEmbeddingResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index 22cdb752d1e8d..d4c53f3f58c31 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -73,6 +73,7 @@ esplugin.bundleSpec.exclude 'platform/licenses/**' } dependencies { + compileOnly project(":server") compileOnly project(':modules:lang-painless:spi') compileOnly project(path: xpackModule('core')) compileOnly project(path: xpackModule('autoscaling')) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 57c325c30937a..478b80e6e58be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; @@ -25,7 +26,6 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index 453b689d59cc0..7b98be7974b87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 57709bdb33e42..6dd9249d44be0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index f8b06d168464d..91f57637ab092 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -27,7 +28,6 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index 835fd611b5a7d..ed57ef157978a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -17,6 +17,7 @@ import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -27,7 +28,6 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java index bd0916065ec5f..a3570bebd8055 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -10,6 +10,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,7 +23,6 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java index 22dab2e3801d2..0380c5e37863b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index d7c1204249ff0..217475de84659 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; @@ -33,7 +34,6 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java index 4d05f02f946c4..14880a5cf85e7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; @@ -33,7 +34,6 @@ import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 8575c7e1f4bf3..8bbc3351e3f33 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.SearchPlugin; @@ -17,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.ml.MachineLearningTests; import java.io.IOException; From 90457b2f54b6a2e74d25555574cc4aaa3e09ceb1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 11:02:25 +0100 Subject: [PATCH 027/106] First version of TransportBulkAction --- .../action/bulk/TransportBulkAction.java | 102 +++++++++++++++++- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 54ab93e35a250..4d16e04a5b29e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -26,9 +26,11 @@ import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.ingest.IngestActionForwarder; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.WriteResponse; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.update.UpdateRequest; @@ -60,6 +62,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -76,6 +80,7 @@ import java.util.Objects; import java.util.Set; import java.util.SortedMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; @@ -106,6 +111,10 @@ public class TransportBulkAction extends HandledTransportAction> requ for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); + BulkShardRequest bulkShardRequest = new BulkShardRequest( shardId, bulkRequest.getRefreshPolicy(), @@ -692,19 +702,101 @@ private void executeBulkRequestsByShard(Map> requ } IndexMetadata indexMetadata = clusterState.metadata().index(shardId.getIndex()); - if (indexMetadata.getFieldsForModels().isEmpty()) { - executeBulkShardRequest(bulkShardRequest, requests, counter); + if (indexMetadata.getFieldsForModels().isEmpty() == false) { + performInference(requests, indexMetadata.getModelsForFields(), () -> executeBulkShardRequest(bulkShardRequest, requests, counter)); } else { - performInference(bulkShardRequest, requests, clusterState); + executeBulkShardRequest(bulkShardRequest, requests, counter); } - } bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed } - private void performInference(BulkShardRequest bulkShardRequest, List requests, ClusterState clusterState) { + private void performInference(List requests, Map modelsForFields, Runnable onCompletion) { + + // TODO Should we create a specific ThreadPool? + try (var bulkItemReqRef = new RefCountingRunnable(() -> { + onCompletion.run(); + })) { + + for (BulkItemRequest request : requests) { + DocWriteRequest docWriteRequest = request.request(); + Map sourceMap = null; + if (docWriteRequest instanceof IndexRequest indexRequest) { + sourceMap = indexRequest.sourceAsMap(); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + sourceMap = updateRequest.doc().sourceAsMap(); + } + if (sourceMap == null || sourceMap.isEmpty()) { + continue; + } + bulkItemReqRef.acquire(); + final Map docMap = new ConcurrentHashMap<>(sourceMap); + + Set inferenceFields = new HashSet<>(sourceMap.keySet()); + inferenceFields.retainAll(modelsForFields.keySet()); + + // When a document completes processing, update the source with the inference + try (var docRef = new RefCountingRunnable(() -> { + if (docWriteRequest instanceof IndexRequest indexRequest) { + indexRequest.source(docMap); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + updateRequest.doc().source(docMap); + } + })) { + + for (String inferenceField : inferenceFields) { + Object fieldValue = sourceMap.get(inferenceField); + if (fieldValue instanceof String == false) { + continue; + } + + docRef.acquire(); + + // TODO batch by model id, and multiple docs + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified + modelsForFields.get(inferenceField), + List.of((String) fieldValue), + Map.of() + ); + + client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { + @Override + public void onResponse(InferenceAction.Response response) { + // Transform into two subfields, one with the actual text and other with the inference + InferenceServiceResults results = response.getResults(); + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for field " + inferenceField + " in document " + docWriteRequest.id() + ); + } + + Object inferenceResult = results.transformToLegacyFormat().get(0).predictedValue(); + + docMap.put(ROOT_RESULT_FIELD + "." + inferenceField + "." + INFERENCE_FIELD, inferenceResult); + docMap.put(ROOT_RESULT_FIELD + "." + inferenceField + "." + TEXT_FIELD, fieldValue); + + docRef.close(); + } + + @Override + public void onFailure(Exception e) { + + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + // make sure the request gets never processed again + bulkRequest.requests.set(request.id(), null); + docRef.close(); + } + }); + } + } + } + } } private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { From 41a527420e90eea0fbb70d1c0482380d380a3aa3 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 16:54:18 +0100 Subject: [PATCH 028/106] Working version, no threading yet --- .../action/bulk/TransportBulkAction.java | 112 ++++++++++++++---- 1 file changed, 92 insertions(+), 20 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 4d16e04a5b29e..02a71a30534f5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -62,6 +62,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.ingest.IngestService; @@ -72,6 +73,7 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -111,7 +113,7 @@ public class TransportBulkAction extends HandledTransportAction> requ IndexMetadata indexMetadata = clusterState.metadata().index(shardId.getIndex()); if (indexMetadata.getFieldsForModels().isEmpty() == false) { - performInference(requests, indexMetadata.getModelsForFields(), () -> executeBulkShardRequest(bulkShardRequest, requests, counter)); + performInference( + bulkShardRequest, + indexMetadata.getFieldsForModels(), + () -> { + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + executeBulkShardRequest(errorsFilteredShardRequest, requests, counter); + } + ); } else { executeBulkShardRequest(bulkShardRequest, requests, counter); } - } - bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed } - private void performInference(List requests, Map modelsForFields, Runnable onCompletion) { + private void performInference(BulkShardRequest bulkShardRequest, Map> fieldsForModels, Runnable onCompletion) { // TODO Should we create a specific ThreadPool? try (var bulkItemReqRef = new RefCountingRunnable(() -> { onCompletion.run(); })) { - for (BulkItemRequest request : requests) { + for (BulkItemRequest request : bulkShardRequest.items()) { DocWriteRequest docWriteRequest = request.request(); Map sourceMap = null; if (docWriteRequest instanceof IndexRequest indexRequest) { @@ -733,9 +744,6 @@ private void performInference(List requests, Map docMap = new ConcurrentHashMap<>(sourceMap); - Set inferenceFields = new HashSet<>(sourceMap.keySet()); - inferenceFields.retainAll(modelsForFields.keySet()); - // When a document completes processing, update the source with the inference try (var docRef = new RefCountingRunnable(() -> { if (docWriteRequest instanceof IndexRequest indexRequest) { @@ -743,11 +751,34 @@ private void performInference(List requests, Map> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); + + @SuppressWarnings("unchecked") + Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + ROOT_RESULT_FIELD, + k -> new HashMap() + ); + + List inferenceFieldNames = new ArrayList<>(); + for (String inferenceField : fieldModelsEntrySet.getValue()) { + Object fieldValue = docMap.get(inferenceField); + + // Perform inference on string, non-null values + if (fieldValue instanceof String fieldStringValue) { + + // Only do inference if the previous text value doesn't match the new one + String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); + if (fieldStringValue.equals(previousValue) == false) { + inferenceFieldNames.add(inferenceField); + } + } + } + + if (inferenceFieldNames.isEmpty()) { continue; } @@ -756,8 +787,8 @@ private void performInference(List requests, Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() + ); - docMap.put(ROOT_RESULT_FIELD + "." + inferenceField + "." + INFERENCE_FIELD, inferenceResult); - docMap.put(ROOT_RESULT_FIELD + "." + inferenceField + "." + TEXT_FIELD, fieldValue); + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); + } docRef.close(); } @@ -785,10 +824,11 @@ public void onFailure(Exception e) { final String indexName = request.index(); DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e)); responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); // make sure the request gets never processed again - bulkRequest.requests.set(request.id(), null); + bulkShardRequest.items()[request.id()] = null; docRef.close(); } @@ -799,7 +839,37 @@ public void onFailure(Exception e) { } } + @SuppressWarnings("unchecked") + private static String findMapValue(Map map, String... path) { + Map currentMap = map; + for (int i = 0; i < path.length - 1; i++) { + Object value = currentMap.get(path[i]); + + if (value instanceof Map) { + currentMap = (Map) value; + } else { + // Invalid path or non-Map value encountered + return null; + } + } + + // Retrieve the final value in the map, if it's a String + Object finalValue = currentMap.get(path[path.length - 1]); + + return (finalValue instanceof String) ? (String) finalValue : null; + } + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { + if (bulkShardRequest.items().length == 0) { + // No requests to execute due to previous errors, terminate early + listener.onResponse( + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos) + ) + ); + return; + } client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -833,6 +903,8 @@ private void maybeFinishHim() { buildTookInMillis(startTimeNanos) ) ); + // Allow memory for bulk shard request items to be reclaimed before all items have been completed + bulkRequest = null; } } }); From 49793f13b7fe16ccf5fea06decaebcd6d3f22eed Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 17:26:10 +0100 Subject: [PATCH 029/106] Refactoring - used RefCountingRunnable --- .../action/bulk/TransportBulkAction.java | 104 +++++++++--------- 1 file changed, 53 insertions(+), 51 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 02a71a30534f5..e4b13ea3452ce 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -685,49 +685,62 @@ private void executeBulkRequestsByShard(Map> requ return; } - final AtomicInteger counter = new AtomicInteger(requestsByShard.size()); String nodeId = clusterService.localNode().getId(); - for (Map.Entry> entry : requestsByShard.entrySet()) { - final ShardId shardId = entry.getKey(); - final List requests = entry.getValue(); - - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); - } - - IndexMetadata indexMetadata = clusterState.metadata().index(shardId.getIndex()); - if (indexMetadata.getFieldsForModels().isEmpty() == false) { - performInference( - bulkShardRequest, - indexMetadata.getFieldsForModels(), - () -> { - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) - ); - executeBulkShardRequest(errorsFilteredShardRequest, requests, counter); - } + try(RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(() -> sendResponse())) { + for (Map.Entry> entry : requestsByShard.entrySet()) { + final ShardId shardId = entry.getKey(); + final List requests = entry.getValue(); + + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) ); - } else { - executeBulkShardRequest(bulkShardRequest, requests, counter); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + + performInferenceAndExecute(bulkShardRequest, clusterState, bulkItemRequestCompleteRefCount.acquire()); } } } - private void performInference(BulkShardRequest bulkShardRequest, Map> fieldsForModels, Runnable onCompletion) { + private void sendResponse() { + listener.onResponse( + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos) + ) + ); + // Allow memory for bulk shard request items to be reclaimed before all items have been completed + bulkRequest = null; + } + + private void performInferenceAndExecute( + BulkShardRequest bulkShardRequest, + ClusterState clusterState, + Releasable releaseOnFinish + ) { + Map> fieldsForModels = clusterState.metadata() + .index(bulkShardRequest.shardId().getIndex()) + .getFieldsForModels(); + if (fieldsForModels.isEmpty()) { + executeBulkShardRequest(bulkShardRequest, releaseOnFinish); + } + + AtomicArray newBulkItemsRequest = new AtomicArray<>(bulkShardRequest.items().length); // TODO Should we create a specific ThreadPool? try (var bulkItemReqRef = new RefCountingRunnable(() -> { - onCompletion.run(); + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish); })) { for (BulkItemRequest request : bulkShardRequest.items()) { @@ -859,7 +872,7 @@ private static String findMapValue(Map map, String... path) { return (finalValue instanceof String) ? (String) finalValue : null; } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early listener.onResponse( @@ -868,8 +881,10 @@ private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -880,32 +895,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - maybeFinishHim(); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - for (BulkItemRequest request : requests) { + for (BulkItemRequest request : bulkShardRequest.items()) { final String indexName = request.index(); DocWriteRequest docWriteRequest = request.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - maybeFinishHim(); - } - - private void maybeFinishHim() { - if (counter.decrementAndGet() == 0) { - listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) - ); - // Allow memory for bulk shard request items to be reclaimed before all items have been completed - bulkRequest = null; - } + releaseOnFinish.close(); } }); } From 1d1893667fc89ae76450d5cc6d5277f3b6c46290 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 17:56:19 +0100 Subject: [PATCH 030/106] More refactoring --- .../action/bulk/TransportBulkAction.java | 229 ++++++++++-------- 1 file changed, 124 insertions(+), 105 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index e4b13ea3452ce..ca1dec311e5a7 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -686,7 +686,16 @@ private void executeBulkRequestsByShard(Map> requ } String nodeId = clusterService.localNode().getId(); - try(RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(() -> sendResponse())) { + try(RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(() -> { + listener.onResponse( + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos) + ) + ); + // Allow memory for bulk shard request items to be reclaimed before all items have been completed + bulkRequest = null; + })) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); @@ -708,17 +717,6 @@ private void executeBulkRequestsByShard(Map> requ } } - private void sendResponse() { - listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) - ); - // Allow memory for bulk shard request items to be reclaimed before all items have been completed - bulkRequest = null; - } - private void performInferenceAndExecute( BulkShardRequest bulkShardRequest, ClusterState clusterState, @@ -732,7 +730,6 @@ private void performInferenceAndExecute( executeBulkShardRequest(bulkShardRequest, releaseOnFinish); } - AtomicArray newBulkItemsRequest = new AtomicArray<>(bulkShardRequest.items().length); // TODO Should we create a specific ThreadPool? try (var bulkItemReqRef = new RefCountingRunnable(() -> { BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( @@ -744,112 +741,134 @@ private void performInferenceAndExecute( })) { for (BulkItemRequest request : bulkShardRequest.items()) { - DocWriteRequest docWriteRequest = request.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.doc().sourceAsMap(); + performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef); + } + } + } + + private void performInferenceOnBulkItemRequest( + BulkShardRequest bulkShardRequest, + BulkItemRequest request, + Map> fieldsForModels, + RefCountingRunnable bulkItemReqRef + ) { + DocWriteRequest docWriteRequest = request.request(); + Map sourceMap = null; + if (docWriteRequest instanceof IndexRequest indexRequest) { + sourceMap = indexRequest.sourceAsMap(); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + sourceMap = updateRequest.docAsUpsert() + ? updateRequest.upsertRequest().sourceAsMap() + : updateRequest.doc().sourceAsMap(); + } + if (sourceMap == null || sourceMap.isEmpty()) { + return; + } + bulkItemReqRef.acquire(); + final Map docMap = new ConcurrentHashMap<>(sourceMap); + + // When a document completes processing, update the source with the inference + try (var docRef = new RefCountingRunnable(() -> { + if (docWriteRequest instanceof IndexRequest indexRequest) { + indexRequest.source(docMap); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + if (updateRequest.docAsUpsert()) { + updateRequest.upsertRequest().source(docMap); + } else { + updateRequest.doc().source(docMap); } - if (sourceMap == null || sourceMap.isEmpty()) { + } + bulkItemReqRef.close(); + })) { + + for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); + + @SuppressWarnings("unchecked") + Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + ROOT_RESULT_FIELD, + k -> new HashMap() + ); + + List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); + + if (inferenceFieldNames.isEmpty()) { continue; } - bulkItemReqRef.acquire(); - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - updateRequest.doc().source(docMap); - } - bulkItemReqRef.close(); - })) { - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - - @SuppressWarnings("unchecked") - Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_RESULT_FIELD, - k -> new HashMap() - ); - - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String fieldStringValue) { - - // Only do inference if the previous text value doesn't match the new one - String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); - if (fieldStringValue.equals(previousValue) == false) { - inferenceFieldNames.add(inferenceField); - } - } + + docRef.acquire(); + + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified + modelId, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + Map.of() + ); + + client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { + @Override + public void onResponse(InferenceAction.Response response) { + // Transform into two subfields, one with the actual text and other with the inference + InferenceServiceResults results = response.getResults(); + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ); } - if (inferenceFieldNames.isEmpty()) { - continue; + int i = 0; + for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() + ); + + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); } - docRef.acquire(); - - // TODO batch by model id, and multiple docs - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified - modelId, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - Map.of() - ); - - client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { - @Override - public void onResponse(InferenceAction.Response response) { - // Transform into two subfields, one with the actual text and other with the inference - InferenceServiceResults results = response.getResults(); - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() - ); - - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); - } - - docRef.close(); - } + docRef.close(); + } - @Override - public void onFailure(Exception e) { + @Override + public void onFailure(Exception e) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e)); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e)); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; - docRef.close(); - } - }); + docRef.close(); } + }); + } + } + } + + private static List getFieldNamesForInference( + Map.Entry> fieldModelsEntrySet, + Map docMap + ) { + List inferenceFieldNames = new ArrayList<>(); + for (String inferenceField : fieldModelsEntrySet.getValue()) { + Object fieldValue = docMap.get(inferenceField); + + // Perform inference on string, non-null values + if (fieldValue instanceof String fieldStringValue) { + + // Only do inference if the previous text value doesn't match the new one + String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); + if (fieldStringValue.equals(previousValue) == false) { + inferenceFieldNames.add(inferenceField); } } } + return inferenceFieldNames; } @SuppressWarnings("unchecked") From 89249ec6c458c8d17a5b3276b1d294db88655473 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 18:28:06 +0100 Subject: [PATCH 031/106] License headers --- .../org/elasticsearch/action/inference/InferenceAction.java | 5 +++-- .../action/inference/results/LegacyTextEmbeddingResults.java | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java index f924f2537ec6b..7d472b0d47813 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java +++ b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java @@ -1,8 +1,9 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ package org.elasticsearch.action.inference; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java index 3cbbf0c6863cc..da3e34c79e9fd 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java @@ -1,8 +1,9 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ package org.elasticsearch.action.inference.results; From 4802b2398d7d6e9b2db0c1342602f4d0df574bf4 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 18:39:40 +0100 Subject: [PATCH 032/106] License headers --- .../action/inference/results/SparseEmbeddingResults.java | 1 - .../action/inference/results/TextEmbeddingResults.java | 7 ------- 2 files changed, 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java index 0a90df280a6c6..4dc9c8cb88a77 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java @@ -26,7 +26,6 @@ import java.util.Map; import java.util.stream.Collectors; - public record SparseEmbeddingResults(List embeddings) implements InferenceServiceResults { public static final String NAME = "sparse_embedding_results"; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java index fc6f30d0f661a..6597c8d36a92c 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java @@ -6,13 +6,6 @@ * Side Public License, v 1. */ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; From 9bb66e18b942e7f4f30b07f55b0a10b8dadab63b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 18:39:55 +0100 Subject: [PATCH 033/106] More refactoring around runnables --- .../action/bulk/TransportBulkAction.java | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index ca1dec311e5a7..9662c4ca0fd2a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -644,8 +644,7 @@ private Map> groupRequestsByShards(ClusterState c if (ia.getParentDataStream() != null && // avoid valid cases when directly indexing into a backing index // (for example when directly indexing into .ds-logs-foobar-000001) - ia.getName().equals(docWriteRequest.index()) == false - && docWriteRequest.opType() != OpType.CREATE) { + ia.getName().equals(docWriteRequest.index()) == false && docWriteRequest.opType() != OpType.CREATE) { throw new IllegalArgumentException("only write ops with an op_type of create are allowed in data streams"); } @@ -686,16 +685,15 @@ private void executeBulkRequestsByShard(Map> requ } String nodeId = clusterService.localNode().getId(); - try(RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(() -> { + Runnable onBulkItemsComplete = () -> { listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) + new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) ); // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; - })) { + }; + + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); @@ -717,11 +715,7 @@ private void executeBulkRequestsByShard(Map> requ } } - private void performInferenceAndExecute( - BulkShardRequest bulkShardRequest, - ClusterState clusterState, - Releasable releaseOnFinish - ) { + private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) { Map> fieldsForModels = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) @@ -730,18 +724,18 @@ private void performInferenceAndExecute( executeBulkShardRequest(bulkShardRequest, releaseOnFinish); } - // TODO Should we create a specific ThreadPool? - try (var bulkItemReqRef = new RefCountingRunnable(() -> { + Runnable onInferenceComplete = () -> { BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( bulkShardRequest.shardId(), bulkRequest.getRefreshPolicy(), Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) ); executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish); - })) { + }; + try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { for (BulkItemRequest request : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef); + performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire()); } } } @@ -750,21 +744,19 @@ private void performInferenceOnBulkItemRequest( BulkShardRequest bulkShardRequest, BulkItemRequest request, Map> fieldsForModels, - RefCountingRunnable bulkItemReqRef + Releasable releaseOnFinish ) { DocWriteRequest docWriteRequest = request.request(); Map sourceMap = null; if (docWriteRequest instanceof IndexRequest indexRequest) { sourceMap = indexRequest.sourceAsMap(); } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() - ? updateRequest.upsertRequest().sourceAsMap() - : updateRequest.doc().sourceAsMap(); + sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); } if (sourceMap == null || sourceMap.isEmpty()) { + releaseOnFinish.close(); return; } - bulkItemReqRef.acquire(); final Map docMap = new ConcurrentHashMap<>(sourceMap); // When a document completes processing, update the source with the inference @@ -778,7 +770,7 @@ private void performInferenceOnBulkItemRequest( updateRequest.doc().source(docMap); } } - bulkItemReqRef.close(); + releaseOnFinish.close(); })) { for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { @@ -837,8 +829,11 @@ public void onFailure(Exception e) { final String indexName = request.index(); DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e)); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure( + indexName, + docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) + ); responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); // make sure the request gets never processed again bulkShardRequest.items()[request.id()] = null; @@ -895,10 +890,7 @@ private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasab if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) + new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) ); releaseOnFinish.close(); return; From 7f28aba480a3532b05be8fc1e85546df05580caa Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 18:58:09 +0100 Subject: [PATCH 034/106] Spotless --- .../xpack/inference/mock/TestInferenceServiceExtension.java | 2 +- .../xpack/inference/InferenceNamedWriteablesProvider.java | 6 +++--- .../org/elasticsearch/xpack/inference/InferencePlugin.java | 2 +- .../xpack/inference/action/TransportInferenceAction.java | 2 +- .../huggingface/HuggingFaceElserResponseEntity.java | 2 +- .../huggingface/HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../response/openai/OpenAiEmbeddingsResponseEntity.java | 2 +- .../xpack/inference/rest/RestInferenceAction.java | 2 +- .../xpack/inference/services/ServiceUtils.java | 2 +- .../xpack/inference/services/elser/ElserMlNodeService.java | 2 +- .../xpack/inference/action/InferenceActionRequestTests.java | 2 +- .../inference/action/InferenceActionResponseTests.java | 2 +- .../huggingface/HuggingFaceElserResponseEntityTests.java | 2 +- .../HuggingFaceEmbeddingsResponseEntityTests.java | 2 +- .../openai/OpenAiEmbeddingsResponseEntityTests.java | 2 +- .../inference/results/LegacyTextEmbeddingResultsTests.java | 2 +- .../inference/results/SparseEmbeddingResultsTests.java | 4 ++-- .../xpack/inference/results/TextEmbeddingResultsTests.java | 2 +- 18 files changed, 21 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java index e30294b50b780..0a56cb7d7d7b2 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -25,7 +26,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 2cf7dabdbb2ab..a39f6378af7ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -14,9 +17,6 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index c8b67a2cdc59d..b3726abc48591 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -35,7 +36,6 @@ import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index a80156f065e45..cd60e3017c93c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.inject.Inject; @@ -19,7 +20,6 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java index 1008dd3118874..f3836d4d7528f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index 4259c139e98e0..69f156509b803 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index e000390bb8575..70d0031730c48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.openai; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index 272544e6c1c06..761c0b1f069a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.inference.rest; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.action.inference.InferenceAction; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index ddfa310a6a71b..7dfaff5dc8a9c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Strings; @@ -16,7 +17,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.net.URI; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index de6710c67c9a1..2da6669f05c4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; @@ -21,7 +22,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index e5d1cbcc7034e..205ab2a9688dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -7,12 +7,12 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.action.inference.InferenceAction; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 8b79aa5f90f83..3b6be0f5aa9b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index 35cc3f5243ae5..bcb4468d92d12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -8,11 +8,11 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; import org.apache.http.HttpResponse; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentParseException; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index 237d4145866bc..54a469cfdb517 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; import org.apache.http.HttpResponse; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 910ee3eed00c8..6bba5f37f4afd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.openai; import org.apache.http.HttpResponse; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java index ace4bb9c77729..1ad94598e2ef9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.results; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -14,7 +15,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 8fecf975aedb2..0f725a5c91e8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.inference.results; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 55c5d01530e0b..acd8697de450b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.inference.results; +import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; From a0a7b58d213f5ccfebb5601e3af0b7c58c19b402 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 19:11:00 +0100 Subject: [PATCH 035/106] Add comments --- .../org/elasticsearch/action/bulk/TransportBulkAction.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 9662c4ca0fd2a..39cc586118eb5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -720,11 +720,14 @@ private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, Clust Map> fieldsForModels = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) .getFieldsForModels(); + // No inference fields? Just execute the request if (fieldsForModels.isEmpty()) { executeBulkShardRequest(bulkShardRequest, releaseOnFinish); } Runnable onInferenceComplete = () -> { + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( bulkShardRequest.shardId(), bulkRequest.getRefreshPolicy(), From 607f005ab1d179e8ff5a5db707798ff725070781 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 22 Dec 2023 19:11:07 +0100 Subject: [PATCH 036/106] Spotless --- .../core/ml/inference/MlInferenceNamedXContentProvider.java | 4 ++-- .../xpack/core/ml/action/InferModelActionResponseTests.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index d65ceea656079..18cadafb9a6a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.spi.NamedXContentProvider; @@ -28,8 +30,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 76b96135e897d..ca661da89dcb8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; @@ -25,9 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; From 6d781bca19808748a6c849a87c93df2dc6c563e1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 8 Jan 2024 20:03:14 +0100 Subject: [PATCH 037/106] First working test version --- .../TransportBulkActionInferenceTests.java | 192 ++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java new file mode 100644 index 0000000000000..b00df3861a090 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -0,0 +1,192 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.inference.InferenceAction; +import org.elasticsearch.action.inference.results.SparseEmbeddingResults; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.AutoCreateIndex; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexingPressure; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.indices.TestIndexNameExpressionResolver; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.test.ClusterServiceUtils; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.CapturingTransport; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.Before; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportBulkActionInferenceTests extends ESTestCase { + + public static final String INDEX_NAME = "index"; + public static final String INFERENCE_FIELD = "inference_field"; + public static final String MODEL_ID = "model_id"; + private TransportService transportService; + private CapturingTransport capturingTransport; + private ClusterService clusterService; + private ThreadPool threadPool; + private NodeClient nodeClient; + private TransportBulkAction transportBulkAction; + + + @Before + public void setup() { + threadPool = new TestThreadPool(getClass().getName()); + + nodeClient = mock(NodeClient.class); + + DiscoveryNodes nodes = mock(DiscoveryNodes.class); + DiscoveryNode remoteNode = mock(DiscoveryNode.class); + Map ingestNodes = Map.of("node", remoteNode); + when(nodes.getIngestNodes()).thenReturn(ingestNodes); + Metadata metadata = Metadata.builder() + .indices( + Map.of( + INDEX_NAME, + IndexMetadata.builder(INDEX_NAME) + .settings(settings(IndexVersion.current())) + .fieldsForModels(Map.of(MODEL_ID, Set.of(INFERENCE_FIELD))) + .numberOfShards(1) + .numberOfReplicas(1) + .build() + )) + .build(); + + DiscoveryNode masterNode = DiscoveryNodeUtils.create(UUIDs.base64UUID()); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT) + .metadata(metadata) + .nodes(DiscoveryNodes.builder().add(masterNode).localNodeId(masterNode.getId()).masterNodeId(masterNode.getId())) + .build(); + + clusterService = ClusterServiceUtils.createClusterService(state, threadPool); + + capturingTransport = new CapturingTransport(); + transportService = capturingTransport.createTransportService( + clusterService.getSettings(), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> clusterService.localNode(), + null, + Collections.emptySet() + ); + transportService.start(); + transportService.acceptIncomingRequests(); + + IngestService ingestService = mock(IngestService.class); + transportBulkAction = new TransportBulkAction( + threadPool, + transportService, + clusterService, + ingestService, + nodeClient, + new ActionFilters(Collections.emptySet()), + TestIndexNameExpressionResolver.newInstance(), + new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), + EmptySystemIndices.INSTANCE + ); + } + + @After + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + threadPool = null; + clusterService.close(); + super.tearDown(); + } + + + public void testBulkRequestWithInference() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); + indexRequest.source(INFERENCE_FIELD, "some text"); + bulkRequest.add(indexRequest); + + doAnswer(invocation -> { + + InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; + assertThat(request.getModelId(), equalTo(MODEL_ID)); + assertThat(request.getInput(), equalTo(List.of("some text"))); + + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse( + new InferenceAction.Response( + new SparseEmbeddingResults( + List.of( + new SparseEmbeddingResults.Embedding( + List.of( + new SparseEmbeddingResults.WeightedToken("some", 0.5f), + new SparseEmbeddingResults.WeightedToken("text", 0.5f) + ), false) + ) + ) + )); + return Void.TYPE; + } + ).when(nodeClient).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + ShardId shardId = new ShardId(INDEX_NAME, "UUID", 0); + BulkItemResponse successResponse = BulkItemResponse.success( + 0, + DocWriteRequest.OpType.INDEX, + new IndexResponse(shardId, "id", 0, 0, 0, true) + ); + listener.onResponse(new BulkShardResponse(shardId, new BulkItemResponse[] { successResponse })); + return null; + }).when(nodeClient).executeLocally(eq(TransportShardBulkAction.TYPE), any(BulkShardRequest.class), any()); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + + assertEquals(1, response.getItems().length); + verify(nodeClient).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + + } +} From b609472891af4f5c99bd244abee0236d6ffa41c6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Jan 2024 10:43:42 +0100 Subject: [PATCH 038/106] Multiple inference fields test --- .../TransportBulkActionInferenceTests.java | 149 ++++++++++++++---- 1 file changed, 114 insertions(+), 35 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index b00df3861a090..37889ea6a6739 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; @@ -43,6 +44,7 @@ import org.elasticsearch.transport.TransportService; import org.junit.After; import org.junit.Before; +import org.mockito.verification.VerificationMode; import java.util.Collections; import java.util.List; @@ -50,19 +52,26 @@ import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class TransportBulkActionInferenceTests extends ESTestCase { +public class TransportBulkActionInferenceTests extends ESTestCase { public static final String INDEX_NAME = "index"; - public static final String INFERENCE_FIELD = "inference_field"; - public static final String MODEL_ID = "model_id"; + public static final String INFERENCE_FIELD_1_MODEL_A = "inference_field_1_model_a"; + public static final String MODEL_A_ID = "model_a_id"; + private static final String INFERENCE_FIELD_2_MODEL_A = "inference_field_2_model_a"; + public static final String MODEL_B_ID = "model_b_id"; + private static final String INFERENCE_FIELD_MODEL_B = "inference_field_model_b"; private TransportService transportService; private CapturingTransport capturingTransport; private ClusterService clusterService; @@ -70,7 +79,6 @@ public class TransportBulkActionInferenceTests extends ESTestCase { private NodeClient nodeClient; private TransportBulkAction transportBulkAction; - @Before public void setup() { threadPool = new TestThreadPool(getClass().getName()); @@ -87,11 +95,19 @@ public void setup() { INDEX_NAME, IndexMetadata.builder(INDEX_NAME) .settings(settings(IndexVersion.current())) - .fieldsForModels(Map.of(MODEL_ID, Set.of(INFERENCE_FIELD))) + .fieldsForModels( + Map.of( + MODEL_A_ID, + Set.of(INFERENCE_FIELD_1_MODEL_A, INFERENCE_FIELD_2_MODEL_A), + MODEL_B_ID, + Set.of(INFERENCE_FIELD_MODEL_B) + ) + ) .numberOfShards(1) .numberOfReplicas(1) .build() - )) + ) + ) .build(); DiscoveryNode masterNode = DiscoveryNodeUtils.create(UUIDs.base64UUID()); @@ -136,37 +152,78 @@ public void tearDown() throws Exception { super.tearDown(); } + public void testBulkRequestWithoutInference() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); + indexRequest.source("non_inference_field", "text", "another_non_inference_field", "other text"); + bulkRequest.add(indexRequest); + + expectTransportShardBulkActionRequest(); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertEquals(1, response.getItems().length); + verifyInferenceExecuted(never()); + } public void testBulkRequestWithInference() { BulkRequest bulkRequest = new BulkRequest(); IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); - indexRequest.source(INFERENCE_FIELD, "some text"); + String inferenceFieldText = "some text"; + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldText, "non_inference_field", "other text"); bulkRequest.add(indexRequest); - doAnswer(invocation -> { + expectInferenceRequest(Map.of(MODEL_A_ID, Set.of(inferenceFieldText))); - InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; - assertThat(request.getModelId(), equalTo(MODEL_ID)); - assertThat(request.getInput(), equalTo(List.of("some text"))); + expectTransportShardBulkActionRequest(); - @SuppressWarnings("unchecked") - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse( - new InferenceAction.Response( - new SparseEmbeddingResults( - List.of( - new SparseEmbeddingResults.Embedding( - List.of( - new SparseEmbeddingResults.WeightedToken("some", 0.5f), - new SparseEmbeddingResults.WeightedToken("text", 0.5f) - ), false) - ) - ) - )); - return Void.TYPE; - } - ).when(nodeClient).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertEquals(1, response.getItems().length); + verifyInferenceExecuted(times(1)); + } + + public void testBulkRequestWithMultipleFieldsInference() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); + String inferenceField1Text = "some text"; + String inferenceField2Text = "some other text"; + String inferenceField3Text = "more inference text"; + indexRequest.source( + INFERENCE_FIELD_1_MODEL_A, + inferenceField1Text, + INFERENCE_FIELD_2_MODEL_A, + inferenceField2Text, + INFERENCE_FIELD_MODEL_B, + inferenceField3Text, + "non_inference_field", + "other text" + ); + bulkRequest.add(indexRequest); + + expectInferenceRequest( + Map.of(MODEL_A_ID, Set.of(inferenceField1Text, inferenceField2Text), MODEL_B_ID, Set.of(inferenceField3Text)) + ); + expectTransportShardBulkActionRequest(); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertEquals(1, response.getItems().length); + verifyInferenceExecuted(times(2)); + } + + private void verifyInferenceExecuted(VerificationMode times) { + verify(nodeClient, times).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + } + + private void expectTransportShardBulkActionRequest() { doAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; @@ -179,14 +236,36 @@ public void testBulkRequestWithInference() { listener.onResponse(new BulkShardResponse(shardId, new BulkItemResponse[] { successResponse })); return null; }).when(nodeClient).executeLocally(eq(TransportShardBulkAction.TYPE), any(BulkShardRequest.class), any()); + } - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - + private void expectInferenceRequest(Map> modelsAndInferenceTextMap) { + doAnswer(invocation -> { + InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; + Set textsForModel = modelsAndInferenceTextMap.get(request.getModelId()); + assertThat("model is not expected", textsForModel, notNullValue()); + assertThat("unexpected inference field values", request.getInput(), containsInAnyOrder(textsForModel.toArray())); - assertEquals(1, response.getItems().length); - verify(nodeClient).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse( + new InferenceAction.Response( + new SparseEmbeddingResults( + request.getInput().stream() + .map( + text -> new SparseEmbeddingResults.Embedding( + List.of(new SparseEmbeddingResults.WeightedToken(text.toString(), 1.0f)), + false + ) + ) + .toList() + ) + ) + ); + return Void.TYPE; + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelsAndInferenceTextMap.keySet())), any()); + } + private boolean inferenceRequestMatches(ActionRequest request, Set models) { + return request instanceof InferenceAction.Request && models.contains(((InferenceAction.Request) request).getModelId()); } } From e4a6a678717cbabc8d79850871b460c70e1049b7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Jan 2024 12:02:21 +0100 Subject: [PATCH 039/106] Added test cases --- .../action/bulk/TransportBulkAction.java | 4 +- .../TransportBulkActionInferenceTests.java | 173 ++++++++++++------ 2 files changed, 119 insertions(+), 58 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 39cc586118eb5..39b45f92dc6da 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -723,6 +723,7 @@ private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, Clust // No inference fields? Just execute the request if (fieldsForModels.isEmpty()) { executeBulkShardRequest(bulkShardRequest, releaseOnFinish); + return; } Runnable onInferenceComplete = () -> { @@ -892,9 +893,6 @@ private static String findMapValue(Map map, String... path) { private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early - listener.onResponse( - new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) - ); releaseOnFinish.close(); return; } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 37889ea6a6739..71fd62df1620d 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -38,7 +38,6 @@ import org.elasticsearch.ingest.IngestService; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -46,14 +45,14 @@ import org.junit.Before; import org.mockito.verification.VerificationMode; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -72,8 +71,6 @@ public class TransportBulkActionInferenceTests extends ESTestCase { private static final String INFERENCE_FIELD_2_MODEL_A = "inference_field_2_model_a"; public static final String MODEL_B_ID = "model_b_id"; private static final String INFERENCE_FIELD_MODEL_B = "inference_field_model_b"; - private TransportService transportService; - private CapturingTransport capturingTransport; private ClusterService clusterService; private ThreadPool threadPool; private NodeClient nodeClient; @@ -82,13 +79,9 @@ public class TransportBulkActionInferenceTests extends ESTestCase { @Before public void setup() { threadPool = new TestThreadPool(getClass().getName()); - nodeClient = mock(NodeClient.class); - DiscoveryNodes nodes = mock(DiscoveryNodes.class); - DiscoveryNode remoteNode = mock(DiscoveryNode.class); - Map ingestNodes = Map.of("node", remoteNode); - when(nodes.getIngestNodes()).thenReturn(ingestNodes); + // Contains the fields for models for the index Metadata metadata = Metadata.builder() .indices( Map.of( @@ -118,30 +111,31 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(state, threadPool); - capturingTransport = new CapturingTransport(); - transportService = capturingTransport.createTransportService( - clusterService.getSettings(), - threadPool, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - boundAddress -> clusterService.localNode(), - null, - Collections.emptySet() - ); - transportService.start(); - transportService.acceptIncomingRequests(); - - IngestService ingestService = mock(IngestService.class); transportBulkAction = new TransportBulkAction( threadPool, - transportService, + mock(TransportService.class), clusterService, - ingestService, + mock(IngestService.class), nodeClient, new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), EmptySystemIndices.INSTANCE ); + + // Default answers to avoid hanging tests due to unexpected invocations + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unexpected invocation")); + return Void.TYPE; + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), any(), any()); + when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unexpected invocation")); + return null; + }); } @After @@ -158,13 +152,14 @@ public void testBulkRequestWithoutInference() { indexRequest.source("non_inference_field", "text", "another_non_inference_field", "other text"); bulkRequest.add(indexRequest); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(never()); } @@ -175,15 +170,16 @@ public void testBulkRequestWithInference() { indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldText, "non_inference_field", "other text"); bulkRequest.add(indexRequest); - expectInferenceRequest(Map.of(MODEL_A_ID, Set.of(inferenceFieldText))); + expectInferenceRequest(MODEL_A_ID, inferenceFieldText); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(times(1)); } @@ -205,47 +201,102 @@ public void testBulkRequestWithMultipleFieldsInference() { ); bulkRequest.add(indexRequest); - expectInferenceRequest( - Map.of(MODEL_A_ID, Set.of(inferenceField1Text, inferenceField2Text), MODEL_B_ID, Set.of(inferenceField3Text)) - ); + expectInferenceRequest(MODEL_A_ID, inferenceField1Text, inferenceField2Text); + expectInferenceRequest(MODEL_B_ID, inferenceField3Text); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(times(2)); } - private void verifyInferenceExecuted(VerificationMode times) { - verify(nodeClient, times).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + public void testBulkRequestWithMultipleDocs() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc1 = "some text"; + bulkRequest.add(indexRequest); + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "other text"); + indexRequest = new IndexRequest(INDEX_NAME).id("id2"); + String inferenceFieldTextDoc2 = "some other text"; + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc2, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + + expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc1); + expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc2); + + expectTransportShardBulkActionRequest(2); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertThat(response.getItems().length, equalTo(2)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); + verifyInferenceExecuted(times(2)); } - private void expectTransportShardBulkActionRequest() { - doAnswer(invocation -> { + public void testFailingInference() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc1 = "some text"; + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc2 = "some text"; + indexRequest.source(INFERENCE_FIELD_MODEL_B, inferenceFieldTextDoc2, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + + expectInferenceRequestFails(MODEL_A_ID, inferenceFieldTextDoc1); + expectInferenceRequest(MODEL_B_ID, inferenceFieldTextDoc2); + + // Only non-failing inference requests will be executed + expectTransportShardBulkActionRequest(1); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertThat(response.getItems().length, equalTo(2)); + assertTrue(response.getItems()[0].isFailed()); + assertFalse(response.getItems()[1].isFailed()); + verifyInferenceExecuted(times(2)); + } + + private void verifyInferenceExecuted(VerificationMode verificationMode) { + verify(nodeClient, verificationMode).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + } + + private void expectTransportShardBulkActionRequest(int requestSize) { + when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), argThat(r -> matchBulkShardRequest(r, requestSize)), any())) + .thenAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; + var bulkShardRequest = (BulkShardRequest) invocation.getArguments()[1]; ShardId shardId = new ShardId(INDEX_NAME, "UUID", 0); - BulkItemResponse successResponse = BulkItemResponse.success( - 0, - DocWriteRequest.OpType.INDEX, - new IndexResponse(shardId, "id", 0, 0, 0, true) - ); - listener.onResponse(new BulkShardResponse(shardId, new BulkItemResponse[] { successResponse })); + BulkItemResponse[] bulkItemResponses = Arrays.stream(bulkShardRequest.items()).map(item -> BulkItemResponse.success(item.id(), DocWriteRequest.OpType.INDEX, new IndexResponse( + shardId, + "id", + 0, 0, 0, true + ))).toArray(BulkItemResponse[]::new); + + listener.onResponse(new BulkShardResponse(shardId, bulkItemResponses)); return null; - }).when(nodeClient).executeLocally(eq(TransportShardBulkAction.TYPE), any(BulkShardRequest.class), any()); + }); } - private void expectInferenceRequest(Map> modelsAndInferenceTextMap) { + private boolean matchBulkShardRequest(ActionRequest request, int requestSize) { + return (request instanceof BulkShardRequest) && ((BulkShardRequest) request).items().length == requestSize; + } + + @SuppressWarnings("unchecked") + private void expectInferenceRequest(String modelId, String... inferenceTexts) { doAnswer(invocation -> { InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; - Set textsForModel = modelsAndInferenceTextMap.get(request.getModelId()); - assertThat("model is not expected", textsForModel, notNullValue()); - assertThat("unexpected inference field values", request.getInput(), containsInAnyOrder(textsForModel.toArray())); - - @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; listener.onResponse( new InferenceAction.Response( @@ -262,10 +313,22 @@ private void expectInferenceRequest(Map> modelsAndInferenceT ) ); return Void.TYPE; - }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelsAndInferenceTextMap.keySet())), any()); + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); + } + + private void expectInferenceRequestFails(String modelId, String... inferenceTexts) { + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Inference failed")); + return Void.TYPE; + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); } - private boolean inferenceRequestMatches(ActionRequest request, Set models) { - return request instanceof InferenceAction.Request && models.contains(((InferenceAction.Request) request).getModelId()); + private boolean inferenceRequestMatches(ActionRequest request, String modelId, String[] inferenceTexts) { + if (request instanceof InferenceAction.Request inferenceRequest) { + return inferenceRequest.getModelId().equals(modelId) && inferenceRequest.getInput().containsAll(List.of(inferenceTexts)); + } + return false; } } From 477b89c47a220823e94addc758e42a6285cd6d16 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Jan 2024 12:07:55 +0100 Subject: [PATCH 040/106] Style fixes --- .../TransportBulkActionInferenceTests.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 71fd62df1620d..c10687055c632 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -278,11 +278,16 @@ private void expectTransportShardBulkActionRequest(int requestSize) { var listener = (ActionListener) invocation.getArguments()[2]; var bulkShardRequest = (BulkShardRequest) invocation.getArguments()[1]; ShardId shardId = new ShardId(INDEX_NAME, "UUID", 0); - BulkItemResponse[] bulkItemResponses = Arrays.stream(bulkShardRequest.items()).map(item -> BulkItemResponse.success(item.id(), DocWriteRequest.OpType.INDEX, new IndexResponse( - shardId, - "id", - 0, 0, 0, true - ))).toArray(BulkItemResponse[]::new); + BulkItemResponse[] bulkItemResponses = Arrays.stream(bulkShardRequest.items()) + .map(item -> BulkItemResponse.success( + item.id(), + DocWriteRequest.OpType.INDEX, + new IndexResponse( + shardId, + "id", + 0, 0, 0, true) + ) + ).toArray(BulkItemResponse[]::new); listener.onResponse(new BulkShardResponse(shardId, bulkItemResponses)); return null; @@ -301,7 +306,8 @@ private void expectInferenceRequest(String modelId, String... inferenceTexts) { listener.onResponse( new InferenceAction.Response( new SparseEmbeddingResults( - request.getInput().stream() + request.getInput() + .stream() .map( text -> new SparseEmbeddingResults.Embedding( List.of(new SparseEmbeddingResults.WeightedToken(text.toString(), 1.0f)), From fa552988624a7874a1a9f6151ede83f08309ac5a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Jan 2024 17:28:52 +0100 Subject: [PATCH 041/106] Remove unused import --- .../java/org/elasticsearch/action/bulk/TransportBulkAction.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 75b45cecfadf2..11b44667556f8 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -87,7 +87,6 @@ import java.util.SortedMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.function.LongSupplier; import java.util.stream.Collectors; From 39061b4cdc02f644496e08c888f62bc402815881 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 14:07:56 +0100 Subject: [PATCH 042/106] First attempt on creating an IT test --- .../ml-inference-service-tests/build.gradle | 4 + .../ml/integration/SemanticTextFieldIT.java | 162 ++++++++++++++++++ .../qa/native-multi-node-tests/build.gradle | 1 + 3 files changed, 167 insertions(+) create mode 100644 x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle index 83226acb383c7..922492cde3e5d 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle @@ -4,6 +4,10 @@ dependencies { javaRestTestImplementation(testArtifact(project(xpackModule('core')))) javaRestTestImplementation(testArtifact(project(xpackModule('ml')))) javaRestTestImplementation project(path: xpackModule('inference')) + javaRestTestImplementation(testArtifact(project(':x-pack:plugin:ml:qa:native-multi-node-tests'), "javaRestTest")) + javaRestTestImplementation(testArtifact(project(":qa:full-cluster-restart"), "javaRestTest")) + + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java new file mode 100644 index 0000000000000..b20614191b13d --- /dev/null +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.Assert; + +import java.io.IOException; +import java.util.Base64; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextFieldIT extends PyTorchModelRestTestCase { + + private static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA" + + "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG" + + "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29" + + "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl" + + "bW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWRT4+cMAzF7" + + "/spfASJomF3e0Ga3nrrn8vcELIyxAzRhAQlpjvbT19DWDrdquqBA/bvPT87nVUxwsm41xPd+PNtUi4a77" + + "KvXs+W8voBAHFSQY3EFCIiHKFp1+p57vs/ShyUccZdoIaz93aBTMR+thbPqru+qKBx8P4q/e8TyxRlmwVc" + + "tJp66H1YmCyS7WsZwD50A2L5V7pCBADGTTOj0bGGE7noQyqzv5JDfp0o9fZRCWqP37yjhE4+mqX5X3AdF" + + "ZHGM/2TzOHDpy1IvQWR+OWo3KwsRiKdpcqg4pBFDtm+QJ7nqwIPckrlnGfFJG0uNhOl38Sjut3pCqg26Qu" + + "Zy8BR9In7ScHHrKkKMW0TIucFrGQXCMpdaDO05O6DpOiy8e4kr0Ed/2YKOIhplW8gPr4ntygrd9ixpx3j9" + + "UZZVRagl2c6+imWUzBjuf5m+Ch7afphuvvW+r/0dsfn+2N9MZGb9+/SFtCYdhd83CMYp+mGy0LiKNs8y/e" + + "UuEA8B/d2z4dfUEsHCFSE3IaCAQAAIAMAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwApAHNpbXBsZ" + + "W1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCJQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlp" + + "aWlpaWlpaWlpaWlpahZHLbtNAFIZtp03rSVIuLRKXjdk5ojitKJsiFq24lem0KKSqpRIZt55gE9/GM+lNL" + + "Fgx4i1Ys2aHhIBXgAVICNggHgNm6rqJN2BZGv36/v/MOWeea/Z5RVHurLfRUsfZXOnccx522itrd53O0vL" + + "qbaKYtsAKUe1pcege7hm9JNtzM8+kOOzNApIX0A3xBXE6YE7g0UWjg2OaZAJXbKvALOnj2GEHKc496ykLkt" + + "gNt3Jz17hprCUxFqExe7YIpQkNpO1/kfHhPUdtUAdH2/gfmeYiIFW7IkM6IBP2wrDNbMe3Mjf2ksiK3Hjg" + + "hg7F2DN9l/omZZl5Mmez2QRk0q4WUUB0+1oh9nDwxGdUXJdXPMRZQs352eGaRPV9s2lcMeZFGWBfKJJiw0Y" + + "gbCMLBaRmXyy4flx6a667Fch55q05QOq2Jg2ANOyZwplhNsjiohVApo7aa21QnNGW5+4GXv8gxK1beBeHSR" + + "rhmLXWVh+0aBhErZ7bx1ejxMOhlR6QU4ycNqGyk8/yNGCWkwY7/RCD7UEQek4QszCgDJAzZtfErA0VqHBy9" + + "ugQP9pUfUmgCjVYgWNwHFbhBJyEOgSwBuuwARWZmoI6J9PwLfzEocpRpPrT8DP8wqHG0b4UX+E3DiscvRgl" + + "XIoi81KKPwioHI5x9EooNKWiy0KOc/T6WF4SssrRuzJ9L2VNRXUhJzj6UKYfS4W/q/5wuh/l4M9R9qsU+y2" + + "dpoo2hJzkaEET8r6KRONicnRdK9EbUi6raFVIwNGjsrlbpk6ZPi7TbS3fv3LyNjPiEKzG0aG0tvNb6xw90/" + + "whe6ONjnJcUxobHDUqQ8bIOW79BVBLBwhfSmPKdAIAAE4EAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAA" + + "BkABQBzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsRkIBAFqAAikuUEsHCG0vCVcEAAAABAAAAFBLAwQAAAgI" + + "AAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25GQjcAWlpaWlpaWlpaWlpaWlpaWlp" + + "aWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAA" + + "AACAgAAAAAAAAhOZuwWAAAAFgAAAAUAAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLA" + + "QIAABQACAgIAAAAAABUhNyGggEAACADAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19f" + + "dG9yY2hfXy5weVBLAQIAABQACAgIAAAAAABfSmPKdAIAAE4EAAAnAAAAAAAAAAAAAAAAAJICAABzaW1wbGVt" + + "b2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAAAAAAbS8JVwQAAAAEAAAAGQAA" + + "AAAAAAAAAAAAAACEBQAAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAA" + + "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA" + + "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA="; + + private static final long RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + public static final List VOCABULARY = List.of( + "these", + "are", + "my", + "words", + "the", + "washing", + "machine", + "is", + "leaking", + "octopus", + "comforter", + "smells" + ); + + public void testSemanticTextInference() throws IOException { + String modelId = "semantic-text-model"; + + createTextExpansionModel(modelId); + putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); + putVocabulary( + VOCABULARY, + modelId + ); + startDeployment(modelId); + + String indexName = modelId + "-index"; + createSemanticTextIndex(indexName); + + int numDocs = ESTestCase.randomIntBetween(1, 10); + bulkIndexDocs(indexName, numDocs); + + for (int i = 0; i < numDocs; i++) { + Request getRequest = new Request("GET", "/" + indexName + "/_doc/" + i); + Response response = ESRestTestCase.client().performRequest(getRequest); + Assert.assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); + } + } + + private void createTextExpansionModel(String modelId) throws IOException { + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + request.setJsonEntity(""" + { + "description": "a text expansion model", + "model_type": "pytorch", + "inference_config": { + "text_expansion": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }"""); + ESRestTestCase.client().performRequest(request); + } + + private void createSemanticTextIndex(String indexName) throws IOException { + Request createIndex = new Request("PUT", "/" + indexName); + createIndex.setJsonEntity(""" + { + "mappings": { + "properties": { + "text_field": { + "type": "text" + }, + "inference_field": { + "type": "semantic_text", + "model_id": "semantic-text-model" + } + } + } + }"""); + var response = ESRestTestCase.client().performRequest(createIndex); + assertOkWithErrorMessage(response); + } + + private void bulkIndexDocs(String indexName, int numDocs) throws IOException { + + StringBuilder bulkBuilder = new StringBuilder(); + + for (int i = 0; i < numDocs; i++) { + String createAction = "{\"create\": {\"_index\": \"" + indexName + "\" \"_id\":\"" + i + "\"}}\n"; + bulkBuilder.append(createAction); + bulkBuilder.append("{\"text_field\": \"").append(ESTestCase.randomAlphaOfLengthBetween(1, 100)).append("\","); + + bulkBuilder.append("{\"inference_field\": \""); + bulkBuilder.append(String.join(" ", ESTestCase.randomSubsetOf(ESTestCase.randomIntBetween(1, 10), VOCABULARY))); + bulkBuilder.append("\""); + + bulkBuilder.append("}}\n"); + } + + Request bulkRequest = new Request("POST", "/_bulk"); + + bulkRequest.setJsonEntity(bulkBuilder.toString()); + bulkRequest.addParameter("refresh", "true"); + var bulkResponse = ESRestTestCase.client().performRequest(bulkRequest); + assertOkWithErrorMessage(bulkResponse); + } + +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle index db53b9aec7f1f..5022354c641e6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle @@ -1,4 +1,5 @@ apply plugin: 'elasticsearch.legacy-java-rest-test' +apply plugin: 'elasticsearch.internal-test-artifact' dependencies { javaRestTestImplementation(testArtifact(project(xpackModule('core')))) From 91d77718fbb4f75fb5a28d728f70e44d7c1d3e0f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 14:13:54 +0100 Subject: [PATCH 043/106] Revert "Change inference result classes to server" This reverts commit 0e493ee4 --- server/src/main/java/module-info.java | 3 --- .../results/NlpInferenceResults.java | 19 +++++++++++++------ .../results/SparseEmbeddingResults.java | 13 +++++++++++-- .../results/TextEmbeddingResults.java | 3 ++- .../results/TextExpansionResults.java | 13 ++++++++++--- .../core/src/main/java/module-info.java | 1 + .../MlInferenceNamedXContentProvider.java | 4 ++-- .../ml/inference/results/FillMaskResults.java | 6 +++--- .../core/ml/inference/results/NerResults.java | 6 +++--- .../NlpClassificationInferenceResults.java | 6 +++--- .../results/PyTorchPassThroughResults.java | 6 +++--- .../QuestionAnsweringInferenceResults.java | 6 +++--- .../results/TextEmbeddingResults.java | 13 ++++++------- .../TextSimilarityInferenceResults.java | 6 +++--- .../action/InferModelActionResponseTests.java | 4 ++-- .../results/FillMaskResultsTests.java | 2 +- .../ml/inference/results/NerResultsTests.java | 2 +- .../PyTorchPassThroughResultsTests.java | 4 ++-- .../results/TextEmbeddingResultsTests.java | 5 ++--- .../results/TextEmbeddingResultsTests.java | 4 ++-- x-pack/plugin/ml/build.gradle | 9 --------- .../TransportCoordinatedInferenceAction.java | 2 +- .../inference/nlp/TextEmbeddingProcessor.java | 2 +- .../inference/nlp/TextExpansionProcessor.java | 2 +- .../ml/queries/TextExpansionQueryBuilder.java | 2 +- .../queries/WeightedTokensQueryBuilder.java | 2 +- .../TextEmbeddingQueryVectorBuilder.java | 2 +- .../nlp/TextExpansionProcessorTests.java | 2 +- .../TextExpansionQueryBuilderTests.java | 2 +- .../WeightedTokensQueryBuilderTests.java | 2 +- .../TextEmbeddingQueryVectorBuilderTests.java | 2 +- 31 files changed, 83 insertions(+), 72 deletions(-) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core}/ml/inference/results/TextEmbeddingResults.java (85%) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 0f58703af7750..e72cb6c53e8e5 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -135,9 +135,6 @@ exports org.elasticsearch.action.get; exports org.elasticsearch.action.index; exports org.elasticsearch.action.ingest; - exports org.elasticsearch.action.inference; - exports org.elasticsearch.action.inference.results; - exports org.elasticsearch.action.ml.inference.results; exports org.elasticsearch.action.resync; exports org.elasticsearch.action.search; exports org.elasticsearch.action.support; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java index 98f13b6711a06..fee00551bb528 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java @@ -6,6 +6,13 @@ * Side Public License, v 1. */ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -19,23 +26,23 @@ import java.util.Map; import java.util.Objects; -public abstract class NlpInferenceResults implements InferenceResults { +abstract class NlpInferenceResults implements InferenceResults { protected final boolean isTruncated; - public NlpInferenceResults(boolean isTruncated) { + NlpInferenceResults(boolean isTruncated) { this.isTruncated = isTruncated; } - public NlpInferenceResults(StreamInput in) throws IOException { + NlpInferenceResults(StreamInput in) throws IOException { this.isTruncated = in.readBoolean(); } - protected abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; + abstract void doXContentBody(XContentBuilder builder, Params params) throws IOException; - protected abstract void doWriteTo(StreamOutput out) throws IOException; + abstract void doWriteTo(StreamOutput out) throws IOException; - protected abstract void addMapFields(Map map); + abstract void addMapFields(Map map); public boolean isTruncated() { return isTruncated; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java index 4dc9c8cb88a77..85fecdae6f39c 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java @@ -6,6 +6,13 @@ * Side Public License, v 1. */ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -18,6 +25,7 @@ import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; @@ -26,11 +34,12 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; + public record SparseEmbeddingResults(List embeddings) implements InferenceServiceResults { public static final String NAME = "sparse_embedding_results"; public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString(); - public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public SparseEmbeddingResults(StreamInput in) throws IOException { this(in.readCollectionAsList(Embedding::new)); @@ -90,7 +99,7 @@ public List transformToLegacyFormat() { return embeddings.stream() .map( embedding -> new TextExpansionResults( - DEFAULT_RESULTS_FIELD, + InferenceConfig.DEFAULT_RESULTS_FIELD, embedding.tokens() .stream() .map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token, weightedToken.weight)) diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java index 6597c8d36a92c..5e757f1401f74 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java @@ -8,6 +8,7 @@ package org.elasticsearch.action.inference.results; +import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -83,7 +84,7 @@ public String getWriteableName() { public List transformToCoordinationFormat() { return embeddings.stream() .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) - .map(values -> new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) .toList(); } diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java index f5d8f68d144e5..e54089187e121 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java +++ b/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java @@ -6,6 +6,13 @@ * Side Public License, v 1. */ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + package org.elasticsearch.action.inference.results; import org.elasticsearch.common.Strings; @@ -89,7 +96,7 @@ public Object predictedValue() { } @Override - protected void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { + void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(resultsField); for (var weightedToken : weightedTokens) { weightedToken.toXContent(builder, params); @@ -112,13 +119,13 @@ public int hashCode() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + void doWriteTo(StreamOutput out) throws IOException { out.writeString(resultsField); out.writeCollection(weightedTokens); } @Override - public void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); } diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index 38913adbd0f28..f747d07224454 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -74,6 +74,7 @@ exports org.elasticsearch.xpack.core.ilm; exports org.elasticsearch.xpack.core.indexing; exports org.elasticsearch.xpack.core.inference.action; + exports org.elasticsearch.xpack.core.inference.results; exports org.elasticsearch.xpack.core.inference; exports org.elasticsearch.xpack.core.logstash; exports org.elasticsearch.xpack.core.ml.action; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 18cadafb9a6a5..9f08667f85572 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -6,8 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.inference; -import org.elasticsearch.action.inference.results.TextExpansionResults; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.spi.NamedXContentProvider; @@ -30,6 +28,8 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java index ded62c3e12bfb..4fad9b535e4e1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java @@ -40,7 +40,7 @@ public FillMaskResults(StreamInput in) throws IOException { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { super.doWriteTo(out); out.writeString(predictedSequence); } @@ -50,7 +50,7 @@ public String getPredictedSequence() { } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { super.addMapFields(map); map.put(resultsField + "_sequence", predictedSequence); } @@ -68,7 +68,7 @@ public String getWriteableName() { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { super.doXContentBody(builder, params); builder.field(resultsField + "_sequence", predictedSequence); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index b95642ecda753..ba5ba3fcb7a5c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -47,7 +47,7 @@ public NerResults(StreamInput in) throws IOException { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, annotatedResult); builder.startArray("entities"); for (EntityGroup entity : entityGroups) { @@ -62,14 +62,14 @@ public String getWriteableName() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + void doWriteTo(StreamOutput out) throws IOException { out.writeCollection(entityGroups); out.writeString(resultsField); out.writeString(annotatedResult); } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, annotatedResult); map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList())); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java index 8a7be15294548..cf6e29be1746c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -60,7 +60,7 @@ public List getTopClasses() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { out.writeString(classificationLabel); out.writeCollection(topClasses); out.writeString(resultsField); @@ -99,7 +99,7 @@ public Object predictedValue() { } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, classificationLabel); if (topClasses.isEmpty() == false) { map.put( @@ -118,7 +118,7 @@ public String getWriteableName() { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, classificationLabel); if (topClasses.size() > 0) { builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index 550e2e73ae0d2..83d4f204cd174 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -41,7 +41,7 @@ public double[][] getInference() { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, inference); } @@ -51,7 +51,7 @@ public String getWriteableName() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeDoubleArray, inference); out.writeString(resultsField); } @@ -62,7 +62,7 @@ public String getResultsField() { } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, inference); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index 87f781e41c8b0..9f2c8f2b77a70 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -77,7 +77,7 @@ public List getTopClasses() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { out.writeString(answer); out.writeVInt(startOffset); out.writeVInt(endOffset); @@ -120,7 +120,7 @@ public String predictedValue() { } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, answer); addSupportingFieldsToMap(map); } @@ -151,7 +151,7 @@ public String getWriteableName() { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, answer); builder.field(START_OFFSET.getPreferredName(), startOffset); builder.field(END_OFFSET.getPreferredName(), endOffset); diff --git a/server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java similarity index 85% rename from server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java index 632e03f69bcd7..a92c53378a5dc 100644 --- a/server/src/main/java/org/elasticsearch/action/ml/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java @@ -1,12 +1,11 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.action.ml.inference.results; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; @@ -54,7 +53,7 @@ public float[] getInferenceAsFloat() { } @Override - public void doXContentBody(XContentBuilder builder, Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, inference); } @@ -64,13 +63,13 @@ public String getWriteableName() { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + void doWriteTo(StreamOutput out) throws IOException { out.writeDoubleArray(inference); out.writeString(resultsField); } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, inference); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java index c390bb6c19323..1fe69a16b41b3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java @@ -35,7 +35,7 @@ public TextSimilarityInferenceResults(StreamInput in) throws IOException { } @Override - protected void doWriteTo(StreamOutput out) throws IOException { + public void doWriteTo(StreamOutput out) throws IOException { out.writeString(resultsField); out.writeDouble(score); } @@ -65,7 +65,7 @@ public Double predictedValue() { } @Override - protected void addMapFields(Map map) { + void addMapFields(Map map) { map.put(resultsField, score); } @@ -82,7 +82,7 @@ public String getWriteableName() { } @Override - protected void doXContentBody(XContentBuilder builder, Params params) throws IOException { + public void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, score); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index ca661da89dcb8..0d85b173be27b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -6,8 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.action; -import org.elasticsearch.action.inference.results.TextExpansionResults; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; @@ -27,7 +25,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java index 7ca86dbaf195a..432e05b9cc680 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -65,7 +65,7 @@ public void testAsMap() { assertThat(asMap.get(PREDICTION_PROBABILITY), equalTo(testInstance.getPredictionProbability())); assertThat(asMap.get(DEFAULT_RESULTS_FIELD + "_sequence"), equalTo(testInstance.getPredictedSequence())); List> resultList = (List>) asMap.get(DEFAULT_TOP_CLASSES_RESULTS_FIELD); - if (testInstance.isTruncated()) { + if (testInstance.isTruncated) { assertThat(asMap.get("is_truncated"), is(true)); } else { assertThat(asMap, not(hasKey("is_truncated"))); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java index 252e0491314a8..4be49807d27b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -68,7 +68,7 @@ public void testAsMap() { } assertThat(resultList, hasSize(testInstance.getEntityGroups().size())); assertThat(asMap.get(testInstance.getResultsField()), equalTo(testInstance.getAnnotatedResult())); - if (testInstance.isTruncated()) { + if (testInstance.isTruncated) { assertThat(asMap.get("is_truncated"), is(true)); } else { assertThat(asMap, not(hasKey("is_truncated"))); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java index 898f7a9f39916..e6b38a08a75ba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java @@ -49,10 +49,10 @@ protected PyTorchPassThroughResults mutateInstance(PyTorchPassThroughResults ins public void testAsMap() { PyTorchPassThroughResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - int size = testInstance.isTruncated() ? 2 : 1; + int size = testInstance.isTruncated ? 2 : 1; assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(DEFAULT_RESULTS_FIELD)); - if (testInstance.isTruncated()) { + if (testInstance.isTruncated) { assertThat(asMap.get("is_truncated"), is(true)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java index 4f8f433c1d2cd..fd3ac7f8c0d12 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.ingest.IngestDocument; @@ -47,10 +46,10 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) { public void testAsMap() { TextEmbeddingResults testInstance = createTestInstance(); Map asMap = testInstance.asMap(); - int size = testInstance.isTruncated() ? 2 : 1; + int size = testInstance.isTruncated ? 2 : 1; assertThat(asMap.keySet(), hasSize(size)); assertArrayEquals(testInstance.getInference(), (double[]) asMap.get(DEFAULT_RESULTS_FIELD), 1e-10); - if (testInstance.isTruncated()) { + if (testInstance.isTruncated) { assertThat(asMap.get("is_truncated"), is(true)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index acd8697de450b..fcbbc5f08a49e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -109,12 +109,12 @@ public void testTransformToCoordinationFormat() { results, is( List.of( - new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( TextEmbeddingResults.TEXT_EMBEDDING, new double[] { 0.1F, 0.2F }, false ), - new org.elasticsearch.action.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( TextEmbeddingResults.TEXT_EMBEDDING, new double[] { 0.3F, 0.4F }, false diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index 3e716f93ac949..22cdb752d1e8d 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -1,6 +1,5 @@ import org.elasticsearch.gradle.VersionProperties import org.elasticsearch.gradle.internal.dra.DraResolvePlugin -import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' @@ -74,8 +73,6 @@ esplugin.bundleSpec.exclude 'platform/licenses/**' } dependencies { - compileOnly project(":server") - testImplementation project(path: ':x-pack:plugin:inference') compileOnly project(':modules:lang-painless:spi') compileOnly project(path: xpackModule('core')) compileOnly project(path: xpackModule('autoscaling')) @@ -116,12 +113,6 @@ artifacts { archives tasks.named("jar") } -if (BuildParams.isSnapshotBuild() == false) { - tasks.named("test").configure { - systemProperty 'es.semantic_text_feature_flag_enabled', 'true' - } -} - tasks.register("extractNativeLicenses", Copy) { dependsOn configurations.nativeBundle into "${buildDir}/extractedNativeLicenses" diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 4562f664cbc1a..3c37eda58ba43 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; @@ -26,6 +25,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index 7b98be7974b87..453b689d59cc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 6dd9249d44be0..57709bdb33e42 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 197132084b855..4c065d9195af5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -12,7 +12,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -28,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index ed57ef157978a..835fd611b5a7d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -17,7 +17,6 @@ import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -28,6 +27,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java index a3570bebd8055..bd0916065ec5f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilder.java @@ -10,7 +10,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -23,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java index 0380c5e37863b..22dab2e3801d2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 217475de84659..d7c1204249ff0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; @@ -34,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java index 2576f3f802b98..f9cd283e93b1f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.compress.CompressedXContent; @@ -34,6 +33,7 @@ import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java index 8bbc3351e3f33..8575c7e1f4bf3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/vectors/TextEmbeddingQueryVectorBuilderTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.SearchPlugin; @@ -18,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.ml.MachineLearningTests; import java.io.IOException; From 36650d3950c1436d94b9d89b57e5bf4df81e6735 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 14:15:10 +0100 Subject: [PATCH 044/106] Revert "Moved InferenceAction and result classes to server" This reverts commit 1b822614 --- .../action/bulk/TransportBulkAction.java | 355 ++++-------------- .../cluster/metadata/MappingMetadata.java | 24 +- .../inference/action}/InferenceAction.java | 8 +- .../results/LegacyTextEmbeddingResults.java | 2 +- .../results/SparseEmbeddingResults.java | 12 +- .../results/TextEmbeddingResults.java | 8 +- .../MlInferenceNamedXContentProvider.java | 2 +- .../core/ml/inference/results/NerResults.java | 1 - .../NlpClassificationInferenceResults.java | 1 - .../results/NlpInferenceResults.java | 10 +- .../results/PyTorchPassThroughResults.java | 1 - .../QuestionAnsweringInferenceResults.java | 1 - .../results/TextEmbeddingResults.java | 1 - .../results/TextExpansionResults.java | 13 +- .../TextSimilarityInferenceResults.java | 1 - .../action/InferModelActionResponseTests.java | 2 +- .../results/InferenceResultsTestCase.java | 1 - .../results/TextExpansionResultsTests.java | 1 - .../mock/TestInferenceServiceExtension.java | 2 +- .../InferenceNamedWriteablesProvider.java | 6 +- .../xpack/inference/InferencePlugin.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../HuggingFaceElserResponseEntity.java | 2 +- .../HuggingFaceEmbeddingsResponseEntity.java | 2 +- .../OpenAiEmbeddingsResponseEntity.java | 2 +- .../inference/rest/RestInferenceAction.java | 9 +- .../inference/services/ServiceUtils.java | 2 +- .../services/elser/ElserMlNodeService.java | 2 +- .../action/InferenceActionRequestTests.java | 2 +- .../action/InferenceActionResponseTests.java | 4 +- .../HuggingFaceElserResponseEntityTests.java | 2 +- ...gingFaceEmbeddingsResponseEntityTests.java | 2 +- .../OpenAiEmbeddingsResponseEntityTests.java | 2 +- .../LegacyTextEmbeddingResultsTests.java | 2 +- .../results/SparseEmbeddingResultsTests.java | 4 +- .../results/TextEmbeddingResultsTests.java | 2 +- .../TransportCoordinatedInferenceAction.java | 2 +- .../inference/nlp/TextExpansionProcessor.java | 2 +- .../ml/queries/TextExpansionQueryBuilder.java | 2 +- .../queries/WeightedTokensQueryBuilder.java | 2 +- .../nlp/TextExpansionProcessorTests.java | 2 +- .../TextExpansionQueryBuilderTests.java | 2 +- .../WeightedTokensQueryBuilderTests.java | 4 +- 43 files changed, 113 insertions(+), 398 deletions(-) rename {server/src/main/java/org/elasticsearch/action/inference => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action}/InferenceAction.java (97%) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core}/inference/results/LegacyTextEmbeddingResults.java (98%) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core}/inference/results/SparseEmbeddingResults.java (93%) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core}/inference/results/TextEmbeddingResults.java (93%) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml}/inference/results/NlpInferenceResults.java (86%) rename {server/src/main/java/org/elasticsearch/action => x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml}/inference/results/TextExpansionResults.java (87%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 11b44667556f8..be976af717cd3 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -25,15 +25,10 @@ import org.elasticsearch.action.admin.indices.create.AutoCreateAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; -import org.elasticsearch.action.admin.indices.rollover.RolloverAction; -import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; -import org.elasticsearch.action.admin.indices.rollover.RolloverResponse; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.ingest.IngestActionForwarder; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.WriteResponse; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.update.UpdateRequest; @@ -65,9 +60,6 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -76,7 +68,6 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -85,8 +76,8 @@ import java.util.Objects; import java.util.Set; import java.util.SortedMap; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.function.LongSupplier; import java.util.stream.Collectors; @@ -115,10 +106,6 @@ public class TransportBulkAction extends HandledTransportAction indices = bulkRequest.requests.stream() // delete requests should not attempt to create the index (if the index does not - // exist), unless an external versioning is used + // exists), unless an external versioning is used .filter( request -> request.opType() != DocWriteRequest.OpType.DELETE || request.versionType() == VersionType.EXTERNAL @@ -378,28 +365,20 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec } } - // Step 3: Collect all the data streams that need to be rolled over before writing - Set dataStreamsToBeRolledOver = indices.keySet().stream().filter(target -> { - DataStream dataStream = state.metadata().dataStreams().get(target); - return dataStream != null && dataStream.rolloverOnWrite(); - }).collect(Collectors.toSet()); - - // Step 4: create all the indices that are missing, if there are any missing. start the bulk after all the creates come back. + // Step 3: create all the indices that are missing, if there are any missing. start the bulk after all the creates come back. createMissingIndicesAndIndexData( task, bulkRequest, executorName, listener, autoCreateIndices, - dataStreamsToBeRolledOver, indicesThatCannotBeCreated, startTime ); } /* - * This method is responsible for creating any missing indices, rolling over a data stream when needed and then - * indexing the data in the BulkRequest + * This method is responsible for creating any missing indices and indexing the data in the BulkRequest */ protected void createMissingIndicesAndIndexData( Task task, @@ -407,27 +386,22 @@ protected void createMissingIndicesAndIndexData( String executorName, ActionListener listener, Set autoCreateIndices, - Set dataStreamsToBeRolledOver, Map indicesThatCannotBeCreated, long startTime ) { final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); - // Optimizing when there are no prerequisite actions - if (autoCreateIndices.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { + if (autoCreateIndices.isEmpty()) { executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); - return; - } - Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { - @Override - protected void doRun() { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); - } - }); - try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { + } else { + final AtomicInteger counter = new AtomicInteger(autoCreateIndices.size()); for (String index : autoCreateIndices) { - createIndex(index, bulkRequest.timeout(), ActionListener.releaseAfter(new ActionListener<>() { + createIndex(index, bulkRequest.timeout(), new ActionListener<>() { @Override - public void onResponse(CreateIndexResponse createIndexResponse) {} + public void onResponse(CreateIndexResponse result) { + if (counter.decrementAndGet() == 0) { + forkExecuteBulk(listener); + } + } @Override public void onFailure(Exception e) { @@ -438,47 +412,30 @@ public void onFailure(Exception e) { } } else if ((cause instanceof ResourceAlreadyExistsException) == false) { // fail all requests involving this index, if create didn't work - failRequestsWhenPrerequisiteActionFailed(index, bulkRequest, responses, e); + for (int i = 0; i < bulkRequest.requests.size(); i++) { + DocWriteRequest request = bulkRequest.requests.get(i); + if (request != null && setResponseFailureIfIndexMatches(responses, i, request, index, e)) { + bulkRequest.requests.set(i, null); + } + } + } + if (counter.decrementAndGet() == 0) { + forkExecuteBulk(ActionListener.wrap(listener::onResponse, inner -> { + inner.addSuppressed(e); + listener.onFailure(inner); + })); } - } - }, refs.acquire())); - } - for (String dataStream : dataStreamsToBeRolledOver) { - rolloverDataStream(dataStream, bulkRequest.timeout(), ActionListener.releaseAfter(new ActionListener<>() { - - @Override - public void onResponse(RolloverResponse result) { - // A successful response has rolled_over false when in the following cases: - // - A request had the parameter lazy or dry_run enabled - // - A request had conditions that were not met - // Since none of the above apply, getting a response with rolled_over false is considered a bug - // that should be caught here and inform the developer. - assert result.isRolledOver() - : "An successful unconditional rollover should always result in a rolled over data stream"; } - @Override - public void onFailure(Exception e) { - failRequestsWhenPrerequisiteActionFailed(dataStream, bulkRequest, responses, e); + private void forkExecuteBulk(ActionListener finalListener) { + threadPool.executor(executorName).execute(new ActionRunnable<>(finalListener) { + @Override + protected void doRun() { + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + } + }); } - }, refs.acquire())); - } - } - } - - /** - * Fails all requests involving this index or data stream because the prerequisite action failed too. - */ - private static void failRequestsWhenPrerequisiteActionFailed( - String target, - BulkRequest bulkRequest, - AtomicArray responses, - Exception error - ) { - for (int i = 0; i < bulkRequest.requests.size(); i++) { - DocWriteRequest request = bulkRequest.requests.get(i); - if (request != null && setResponseFailureIfIndexMatches(responses, i, request, target, error)) { - bulkRequest.requests.set(i, null); + }); } } } @@ -581,12 +538,6 @@ void createIndex(String index, TimeValue timeout, ActionListener listener) { - RolloverRequest rolloverRequest = new RolloverRequest(dataStream, null); - rolloverRequest.masterNodeTimeout(timeout); - client.execute(RolloverAction.INSTANCE, rolloverRequest, listener); - } - private static boolean setResponseFailureIfIndexMatches( AtomicArray responses, int idx, @@ -682,7 +633,8 @@ private Map> groupRequestsByShards(ClusterState c if (ia.getParentDataStream() != null && // avoid valid cases when directly indexing into a backing index // (for example when directly indexing into .ds-logs-foobar-000001) - ia.getName().equals(docWriteRequest.index()) == false && docWriteRequest.opType() != OpType.CREATE) { + ia.getName().equals(docWriteRequest.index()) == false + && docWriteRequest.opType() != OpType.CREATE) { throw new IllegalArgumentException("only write ops with an op_type of create are allowed in data streams"); } @@ -722,219 +674,29 @@ private void executeBulkRequestsByShard(Map> requ return; } + final AtomicInteger counter = new AtomicInteger(requestsByShard.size()); String nodeId = clusterService.localNode().getId(); - Runnable onBulkItemsComplete = () -> { - listener.onResponse( - new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) - ); - // Allow memory for bulk shard request items to be reclaimed before all items have been completed - bulkRequest = null; - }; - - try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { - for (Map.Entry> entry : requestsByShard.entrySet()) { - final ShardId shardId = entry.getKey(); - final List requests = entry.getValue(); - - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); - } - - performInferenceAndExecute(bulkShardRequest, clusterState, bulkItemRequestCompleteRefCount.acquire()); - } - } - } - - private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) { - - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); - // No inference fields? Just execute the request - if (fieldsForModels.isEmpty()) { - executeBulkShardRequest(bulkShardRequest, releaseOnFinish); - return; - } - - Runnable onInferenceComplete = () -> { - // We need to remove items that have had an inference error, as the response will have been updated already - // and we don't need to process them further - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), + for (Map.Entry> entry : requestsByShard.entrySet()) { + final ShardId shardId = entry.getKey(); + final List requests = entry.getValue(); + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, bulkRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + requests.toArray(new BulkItemRequest[0]) ); - executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish); - }; - - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - for (BulkItemRequest request : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire()); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); } - } - } - private void performInferenceOnBulkItemRequest( - BulkShardRequest bulkShardRequest, - BulkItemRequest request, - Map> fieldsForModels, - Releasable releaseOnFinish - ) { - DocWriteRequest docWriteRequest = request.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; + executeBulkShardRequest(bulkShardRequest, requests, counter); } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - - @SuppressWarnings("unchecked") - Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_RESULT_FIELD, - k -> new HashMap() - ); - - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - - if (inferenceFieldNames.isEmpty()) { - continue; - } - - docRef.acquire(); - - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified - modelId, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - Map.of() - ); - - client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { - @Override - public void onResponse(InferenceAction.Response response) { - // Transform into two subfields, one with the actual text and other with the inference - InferenceServiceResults results = response.getResults(); - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() - ); - - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); - } - - docRef.close(); - } - - @Override - public void onFailure(Exception e) { - - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure( - indexName, - docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) - ); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; - - docRef.close(); - } - }); - } - } - } - - private static List getFieldNamesForInference( - Map.Entry> fieldModelsEntrySet, - Map docMap - ) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String fieldStringValue) { - - // Only do inference if the previous text value doesn't match the new one - String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); - if (fieldStringValue.equals(previousValue) == false) { - inferenceFieldNames.add(inferenceField); - } - } - } - return inferenceFieldNames; - } - - @SuppressWarnings("unchecked") - private static String findMapValue(Map map, String... path) { - Map currentMap = map; - for (int i = 0; i < path.length - 1; i++) { - Object value = currentMap.get(path[i]); - - if (value instanceof Map) { - currentMap = (Map) value; - } else { - // Invalid path or non-Map value encountered - return null; - } - } - - // Retrieve the final value in the map, if it's a String - Object finalValue = currentMap.get(path[path.length - 1]); - - return (finalValue instanceof String) ? (String) finalValue : null; + bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { - if (bulkShardRequest.items().length == 0) { - // No requests to execute due to previous errors, terminate early - releaseOnFinish.close(); - return; - } - + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -945,19 +707,30 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - releaseOnFinish.close(); + maybeFinishHim(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - for (BulkItemRequest request : bulkShardRequest.items()) { + for (BulkItemRequest request : requests) { final String indexName = request.index(); DocWriteRequest docWriteRequest = request.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - releaseOnFinish.close(); + maybeFinishHim(); + } + + private void maybeFinishHim() { + if (counter.decrementAndGet() == 0) { + listener.onResponse( + new BulkResponse( + responses.toArray(new BulkItemResponse[responses.length()]), + buildTookInMillis(startTimeNanos) + ) + ); + } } }); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 64a61f854b9da..b629ab5d5f710 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -18,13 +18,11 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MappingLookup; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -44,15 +42,10 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; - private final Map> fieldsForModels; - public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - - MappingLookup mappingLookup = docMapper.mappers(); - this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -64,7 +57,6 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -80,7 +72,6 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = Map.of(); } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -167,19 +158,12 @@ public String getSha256() { return source.getSha256(); } - public Map> getFieldsForModels() { - return fieldsForModels; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); - } } @Override @@ -192,25 +176,19 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; - if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired, fieldsForModels); + return Objects.hash(type, source, routingRequired); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); - } else { - fieldsForModels = Map.of(); - } } public static Diff readDiffFrom(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java similarity index 97% rename from server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index f8a0291cef204..ffb6567009b79 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -6,7 +6,7 @@ * Side Public License, v 1. */ -package org.elasticsearch.action.inference; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersions; @@ -14,9 +14,6 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.InferenceResults; @@ -28,6 +25,9 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java similarity index 98% rename from server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index da3e34c79e9fd..73ba6544fc86e 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -6,7 +6,7 @@ * Side Public License, v 1. */ -package org.elasticsearch.action.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java similarity index 93% rename from server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index 85fecdae6f39c..910ea5cab214d 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -1,11 +1,3 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -13,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.action.inference.results; +package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -99,7 +91,7 @@ public List transformToLegacyFormat() { return embeddings.stream() .map( embedding -> new TextExpansionResults( - InferenceConfig.DEFAULT_RESULTS_FIELD, + DEFAULT_RESULTS_FIELD, embedding.tokens() .stream() .map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token, weightedToken.weight)) diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java similarity index 93% rename from server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 5e757f1401f74..ace5974866038 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -1,14 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.action.inference.results; +package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 9f08667f85572..00587936848f8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -29,7 +29,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java index ba5ba3fcb7a5c..b077c93c141a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NerResults.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java index cf6e29be1746c..a49e81e40a7a6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java similarity index 86% rename from server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java index fee00551bb528..4efb719137c65 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/NlpInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java @@ -1,11 +1,3 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -13,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.action.inference.results; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index 83d4f204cd174..de49fb2252ad0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index 9f2c8f2b77a70..e9e41ce963bec 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java index a92c53378a5dc..526c2ec7b7aaa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java similarity index 87% rename from server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java index e54089187e121..45aa4d51e0ad6 100644 --- a/server/src/main/java/org/elasticsearch/action/inference/results/TextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java @@ -1,11 +1,3 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -13,13 +5,12 @@ * 2.0. */ -package org.elasticsearch.action.inference.results; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; @@ -96,7 +87,7 @@ public Object predictedValue() { } @Override - void doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { + void doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.startObject(resultsField); for (var weightedToken : weightedTokens) { weightedToken.toXContent(builder, params); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java index 1fe69a16b41b3..b8b75e2bf7eb4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java @@ -6,7 +6,6 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 0d85b173be27b..4d8035864729a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java index 8db86bbb2e2af..bda9eed40659c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.inference.results.NlpInferenceResults; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.TestIngestDocument; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java index 84b14e4122d6e..82487960dfe8f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.ingest.IngestDocument; diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java index 0a56cb7d7d7b2..eee6f68c20ff7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -26,6 +25,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index a39f6378af7ce..c632c568fea16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -7,9 +7,6 @@ package org.elasticsearch.xpack.inference; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -17,6 +14,9 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index b3726abc48591..33d71c65ed643 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,7 +10,6 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -36,6 +35,7 @@ import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index cd60e3017c93c..db98aeccc556b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.inject.Inject; @@ -20,6 +19,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java index f3836d4d7528f..247537b9958d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index 69f156509b803..b74b03891034f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 70d0031730c48..4926ba3f0ef6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.inference.external.response.openai; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index f04230e7697dc..beecf75da38ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.inference.rest; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.io.IOException; import java.util.List; @@ -33,9 +33,8 @@ public List routes() { protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String taskType = restRequest.param("task_type"); String modelId = restRequest.param("model_id"); - try (var parser = restRequest.contentParser()) { - var request = InferenceAction.Request.parseRequest(modelId, taskType, parser); - return channel -> client.execute(InferenceAction.INSTANCE, request, new RestToXContentListener<>(channel)); - } + var request = InferenceAction.Request.parseRequest(modelId, taskType, restRequest.contentParser()); + + return channel -> client.execute(InferenceAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 7dfaff5dc8a9c..1686cd32d4a6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -9,7 +9,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Strings; @@ -17,6 +16,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.net.URI; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 2da6669f05c4e..01fe828d723d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -11,7 +11,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; @@ -22,6 +21,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index 205ab2a9688dc..aa540694ba564 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -7,12 +7,12 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 3b6be0f5aa9b3..759411cec1212 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.TransportVersion; -import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; @@ -26,7 +26,7 @@ import static org.elasticsearch.TransportVersions.INFERENCE_SERVICE_RESULTS_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_OPENAI_ADDED; import static org.elasticsearch.TransportVersions.ML_INFERENCE_TASK_SETTINGS_OPTIONAL_ADDED; -import static org.elasticsearch.action.inference.InferenceAction.Response.transformToServiceResults; +import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Response.transformToServiceResults; public class InferenceActionResponseTests extends AbstractBWCWireSerializationTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index bcb4468d92d12..bdb8e38fa8228 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -8,11 +8,11 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; import org.apache.http.HttpResponse; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentEOFException; import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index 54a469cfdb517..2b6e11fdfafa7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.huggingface; import org.apache.http.HttpResponse; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 6bba5f37f4afd..010e990a3ce80 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.external.response.openai; import org.apache.http.HttpResponse; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java index 1ad94598e2ef9..605411343533f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/LegacyTextEmbeddingResultsTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.results; -import org.elasticsearch.action.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -15,6 +14,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 0f725a5c91e8b..6f8fa0c453d09 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.inference.results; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index fcbbc5f08a49e..09d9894d98853 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.inference.results; -import org.elasticsearch.action.inference.results.TextEmbeddingResults; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 3c37eda58ba43..9c368c1a162a8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -25,7 +25,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.action.inference.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 57709bdb33e42..6483b9d9b3da9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 4c065d9195af5..a392996fbb448 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index 835fd611b5a7d..a09bcadaacfc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -27,7 +27,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java index 22dab2e3801d2..c94775b1785c9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index d7c1204249ff0..13f12f3cdc1e1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -33,7 +33,7 @@ import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java index f9cd283e93b1f..59d6db2c2ea4f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilderTests.java @@ -33,7 +33,7 @@ import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; -import org.elasticsearch.action.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; @@ -41,7 +41,7 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.action.inference.results.TextExpansionResults.WeightedToken; +import static org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults.WeightedToken; import static org.elasticsearch.xpack.ml.queries.WeightedTokensQueryBuilder.TOKENS_FIELD; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; From 360be07350aaa36f780c80f7c9d3ed76c68165ef Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 18:34:50 +0100 Subject: [PATCH 045/106] Add a new InferenceProvider interface to avoid moving InferenceAction to server --- .../inference/InferenceProvider.java | 28 ++++++++ .../inference/InferenceProviderException.java | 18 +++++ .../inference/TestInferenceResults.java | 67 +++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceProvider.java create mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java create mode 100644 server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java new file mode 100644 index 0000000000000..37fc934104f18 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.action.ActionListener; + +import java.util.List; + +/** + * Provides NLP text inference results. Plugins can implement this interface to provide their own inference results. + */ +public interface InferenceProvider { + /** + * Returns nferenceResults for a given model ID and list of texts. + * + * @param modelId model identifier + * @param texts texts to perform inference on + * @param listener listener to be called when inference is complete + */ + void textInference(String modelId, List texts, ActionListener> listener) + throws InferenceProviderException; +} diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java b/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java new file mode 100644 index 0000000000000..0d82bc800f414 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.ElasticsearchException; + +public class InferenceProviderException extends ElasticsearchException { + + public InferenceProviderException(String msg, Throwable cause, Object... args) { + super(msg, cause, args); + } +} diff --git a/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java b/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java new file mode 100644 index 0000000000000..7b1445ae7ad5f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.util.HashMap; +import java.util.Map; + +public class TestInferenceResults implements InferenceResults { + + private final String resultField; + private final Map inferenceResults; + + public TestInferenceResults(String resultField, Map inferenceResults) { + this.resultField = resultField; + this.inferenceResults = inferenceResults; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnsupportedEncodingException(); + } + + @Override + public String getWriteableName() { + return "test_inference_results"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public String getResultsField() { + return resultField; + } + + @Override + public Map asMap() { + Map result = new HashMap<>(); + result.put(resultField, inferenceResults); + return result; + } + + @Override + public Map asMap(String outputField) { + Map result = new HashMap<>(); + result.put(outputField, inferenceResults); + return result; + } + + @Override + public Object predictedValue() { + throw new UnsupportedOperationException(); + } +} From 25e4fc47f6880814a2c96452ae6903d637d5d4bf Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 18:35:34 +0100 Subject: [PATCH 046/106] Use InferenceProvider in TransportBulkAction instead of directly invoking the InferenceAction --- .../action/bulk/TransportBulkAction.java | 372 ++++++++++++++---- .../bulk/TransportSimulateBulkAction.java | 3 +- .../cluster/metadata/IndexMetadata.java | 44 +-- ...ActionIndicesThatCannotBeCreatedTests.java | 3 +- .../TransportBulkActionInferenceTests.java | 63 +-- .../bulk/TransportBulkActionIngestTests.java | 3 +- .../action/bulk/TransportBulkActionTests.java | 3 +- .../bulk/TransportBulkActionTookTests.java | 4 +- .../snapshots/SnapshotResiliencyTests.java | 4 +- 9 files changed, 357 insertions(+), 142 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index be976af717cd3..c8ace03dfe334 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -25,10 +25,14 @@ import org.elasticsearch.action.admin.indices.create.AutoCreateAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.rollover.RolloverAction; +import org.elasticsearch.action.admin.indices.rollover.RolloverRequest; +import org.elasticsearch.action.admin.indices.rollover.RolloverResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.IngestActionForwarder; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.WriteResponse; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.update.UpdateRequest; @@ -60,6 +64,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -68,6 +74,7 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -76,8 +83,8 @@ import java.util.Objects; import java.util.Set; import java.util.SortedMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.function.LongSupplier; import java.util.stream.Collectors; @@ -105,6 +112,11 @@ public class TransportBulkAction extends HandledTransportAction indices = bulkRequest.requests.stream() // delete requests should not attempt to create the index (if the index does not - // exists), unless an external versioning is used + // exist), unless an external versioning is used .filter( request -> request.opType() != DocWriteRequest.OpType.DELETE || request.versionType() == VersionType.EXTERNAL @@ -365,20 +382,28 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec } } - // Step 3: create all the indices that are missing, if there are any missing. start the bulk after all the creates come back. + // Step 3: Collect all the data streams that need to be rolled over before writing + Set dataStreamsToBeRolledOver = indices.keySet().stream().filter(target -> { + DataStream dataStream = state.metadata().dataStreams().get(target); + return dataStream != null && dataStream.rolloverOnWrite(); + }).collect(Collectors.toSet()); + + // Step 4: create all the indices that are missing, if there are any missing. start the bulk after all the creates come back. createMissingIndicesAndIndexData( task, bulkRequest, executorName, listener, autoCreateIndices, + dataStreamsToBeRolledOver, indicesThatCannotBeCreated, startTime ); } /* - * This method is responsible for creating any missing indices and indexing the data in the BulkRequest + * This method is responsible for creating any missing indices, rolling over a data stream when needed and then + * indexing the data in the BulkRequest */ protected void createMissingIndicesAndIndexData( Task task, @@ -386,22 +411,27 @@ protected void createMissingIndicesAndIndexData( String executorName, ActionListener listener, Set autoCreateIndices, + Set dataStreamsToBeRolledOver, Map indicesThatCannotBeCreated, long startTime ) { final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); - if (autoCreateIndices.isEmpty()) { + // Optimizing when there are no prerequisite actions + if (autoCreateIndices.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); - } else { - final AtomicInteger counter = new AtomicInteger(autoCreateIndices.size()); + return; + } + Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { + @Override + protected void doRun() { + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + } + }); + try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { for (String index : autoCreateIndices) { - createIndex(index, bulkRequest.timeout(), new ActionListener<>() { + createIndex(index, bulkRequest.timeout(), ActionListener.releaseAfter(new ActionListener<>() { @Override - public void onResponse(CreateIndexResponse result) { - if (counter.decrementAndGet() == 0) { - forkExecuteBulk(listener); - } - } + public void onResponse(CreateIndexResponse createIndexResponse) {} @Override public void onFailure(Exception e) { @@ -412,30 +442,47 @@ public void onFailure(Exception e) { } } else if ((cause instanceof ResourceAlreadyExistsException) == false) { // fail all requests involving this index, if create didn't work - for (int i = 0; i < bulkRequest.requests.size(); i++) { - DocWriteRequest request = bulkRequest.requests.get(i); - if (request != null && setResponseFailureIfIndexMatches(responses, i, request, index, e)) { - bulkRequest.requests.set(i, null); - } - } - } - if (counter.decrementAndGet() == 0) { - forkExecuteBulk(ActionListener.wrap(listener::onResponse, inner -> { - inner.addSuppressed(e); - listener.onFailure(inner); - })); + failRequestsWhenPrerequisiteActionFailed(index, bulkRequest, responses, e); } } + }, refs.acquire())); + } + for (String dataStream : dataStreamsToBeRolledOver) { + rolloverDataStream(dataStream, bulkRequest.timeout(), ActionListener.releaseAfter(new ActionListener<>() { - private void forkExecuteBulk(ActionListener finalListener) { - threadPool.executor(executorName).execute(new ActionRunnable<>(finalListener) { - @Override - protected void doRun() { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); - } - }); + @Override + public void onResponse(RolloverResponse result) { + // A successful response has rolled_over false when in the following cases: + // - A request had the parameter lazy or dry_run enabled + // - A request had conditions that were not met + // Since none of the above apply, getting a response with rolled_over false is considered a bug + // that should be caught here and inform the developer. + assert result.isRolledOver() + : "An successful unconditional rollover should always result in a rolled over data stream"; } - }); + + @Override + public void onFailure(Exception e) { + failRequestsWhenPrerequisiteActionFailed(dataStream, bulkRequest, responses, e); + } + }, refs.acquire())); + } + } + } + + /** + * Fails all requests involving this index or data stream because the prerequisite action failed too. + */ + private static void failRequestsWhenPrerequisiteActionFailed( + String target, + BulkRequest bulkRequest, + AtomicArray responses, + Exception error + ) { + for (int i = 0; i < bulkRequest.requests.size(); i++) { + DocWriteRequest request = bulkRequest.requests.get(i); + if (request != null && setResponseFailureIfIndexMatches(responses, i, request, target, error)) { + bulkRequest.requests.set(i, null); } } } @@ -538,6 +585,12 @@ void createIndex(String index, TimeValue timeout, ActionListener listener) { + RolloverRequest rolloverRequest = new RolloverRequest(dataStream, null); + rolloverRequest.masterNodeTimeout(timeout); + client.execute(RolloverAction.INSTANCE, rolloverRequest, listener); + } + private static boolean setResponseFailureIfIndexMatches( AtomicArray responses, int idx, @@ -633,8 +686,7 @@ private Map> groupRequestsByShards(ClusterState c if (ia.getParentDataStream() != null && // avoid valid cases when directly indexing into a backing index // (for example when directly indexing into .ds-logs-foobar-000001) - ia.getName().equals(docWriteRequest.index()) == false - && docWriteRequest.opType() != OpType.CREATE) { + ia.getName().equals(docWriteRequest.index()) == false && docWriteRequest.opType() != OpType.CREATE) { throw new IllegalArgumentException("only write ops with an op_type of create are allowed in data streams"); } @@ -674,29 +726,220 @@ private void executeBulkRequestsByShard(Map> requ return; } - final AtomicInteger counter = new AtomicInteger(requestsByShard.size()); String nodeId = clusterService.localNode().getId(); - for (Map.Entry> entry : requestsByShard.entrySet()) { - final ShardId shardId = entry.getKey(); - final List requests = entry.getValue(); - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, + Runnable onBulkItemsComplete = () -> { + listener.onResponse( + new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) + ); + // Allow memory for bulk shard request items to be reclaimed before all items have been completed + bulkRequest = null; + }; + + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { + for (Map.Entry> entry : requestsByShard.entrySet()) { + final ShardId shardId = entry.getKey(); + final List requests = entry.getValue(); + + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + + performInferenceAndExecute(bulkShardRequest, clusterState, bulkItemRequestCompleteRefCount.acquire()); + } + } + } + + private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) { + + Map> fieldsForModels = clusterState.metadata() + .index(bulkShardRequest.shardId().getIndex()) + .getFieldsForModels(); + // No inference fields? Just execute the request + if (fieldsForModels.isEmpty()) { + executeBulkShardRequest(bulkShardRequest, releaseOnFinish); + return; + } + + Runnable onInferenceComplete = () -> { + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); + executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish); + }; + + try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { + for (BulkItemRequest request : bulkShardRequest.items()) { + performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire()); } + } + } - executeBulkShardRequest(bulkShardRequest, requests, counter); + private void performInferenceOnBulkItemRequest( + BulkShardRequest bulkShardRequest, + BulkItemRequest request, + Map> fieldsForModels, + Releasable releaseOnFinish + ) { + if (inferenceProvider == null) { + releaseOnFinish.close(); + return; + } + + DocWriteRequest docWriteRequest = request.request(); + Map sourceMap = null; + if (docWriteRequest instanceof IndexRequest indexRequest) { + sourceMap = indexRequest.sourceAsMap(); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); + } + if (sourceMap == null || sourceMap.isEmpty()) { + releaseOnFinish.close(); + return; + } + final Map docMap = new ConcurrentHashMap<>(sourceMap); + + // When a document completes processing, update the source with the inference + try (var docRef = new RefCountingRunnable(() -> { + if (docWriteRequest instanceof IndexRequest indexRequest) { + indexRequest.source(docMap); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + if (updateRequest.docAsUpsert()) { + updateRequest.upsertRequest().source(docMap); + } else { + updateRequest.doc().source(docMap); + } + } + releaseOnFinish.close(); + })) { + + for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); + + @SuppressWarnings("unchecked") + Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + ROOT_RESULT_FIELD, + k -> new HashMap() + ); + + List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); + + if (inferenceFieldNames.isEmpty()) { + continue; + } + + docRef.acquire(); + + inferenceProvider.textInference( + modelId, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + new ActionListener<>() { + + @Override + public void onResponse(List results) { + + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ); + } + + int i = 0; + for (InferenceResults inferenceResults : results) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() + ); + + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); + } + + docRef.close(); + } + + @Override + public void onFailure(Exception e) { + + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure( + indexName, + docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) + ); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; + + docRef.close(); + } + }); + } + } + } + + private static List getFieldNamesForInference( + Map.Entry> fieldModelsEntrySet, + Map docMap + ) { + List inferenceFieldNames = new ArrayList<>(); + for (String inferenceField : fieldModelsEntrySet.getValue()) { + Object fieldValue = docMap.get(inferenceField); + + // Perform inference on string, non-null values + if (fieldValue instanceof String fieldStringValue) { + + // Only do inference if the previous text value doesn't match the new one + String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); + if (fieldStringValue.equals(previousValue) == false) { + inferenceFieldNames.add(inferenceField); + } + } + } + return inferenceFieldNames; + } + + @SuppressWarnings("unchecked") + private static String findMapValue(Map map, String... path) { + Map currentMap = map; + for (int i = 0; i < path.length - 1; i++) { + Object value = currentMap.get(path[i]); + + if (value instanceof Map) { + currentMap = (Map) value; + } else { + // Invalid path or non-Map value encountered + return null; + } } - bulkRequest = null; // allow memory for bulk request items to be reclaimed before all items have been completed + + // Retrieve the final value in the map, if it's a String + Object finalValue = currentMap.get(path[path.length - 1]); + + return (finalValue instanceof String) ? (String) finalValue : null; } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, List requests, AtomicInteger counter) { + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { + if (bulkShardRequest.items().length == 0) { + // No requests to execute due to previous errors, terminate early + releaseOnFinish.close(); + return; + } + client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -707,30 +950,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - maybeFinishHim(); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - for (BulkItemRequest request : requests) { + for (BulkItemRequest request : bulkShardRequest.items()) { final String indexName = request.index(); DocWriteRequest docWriteRequest = request.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - maybeFinishHim(); - } - - private void maybeFinishHim() { - if (counter.decrementAndGet() == 0) { - listener.onResponse( - new BulkResponse( - responses.toArray(new BulkItemResponse[responses.length()]), - buildTookInMillis(startTimeNanos) - ) - ); - } + releaseOnFinish.close(); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index a44c8091aaa2e..1261b716869de 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -55,7 +55,8 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime + System::nanoTime, + null ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index e67eca4c4c97b..7dc99d8bc527e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1418,9 +1418,6 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } - if (fieldsForModels.equals(that.fieldsForModels) == false) { - return false; - } if (isSystem != that.isSystem) { return false; } @@ -1441,7 +1438,6 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); - result = 31 * result + fieldsForModels.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1675,7 +1671,6 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); - builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); return builder.build(true); } } @@ -1743,11 +1738,6 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) - ); - } return builder.build(true); } @@ -1794,9 +1784,6 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalDouble(writeLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); - } } @Override @@ -1846,7 +1833,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private final ImmutableOpenMap.Builder> fieldsForModels; + private Map> fieldsForModels = Map.of(); public Builder(String index) { this.index = index; @@ -1854,7 +1841,6 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); - this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1879,7 +1865,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); + this.fieldsForModels = indexMetadata.fieldsForModels; } public Builder index(String index) { @@ -1961,10 +1947,6 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; - if (mappingMd != null) { - Map> fieldsForModels = mappingMd.getFieldsForModels(); - processFieldsForModels(this.fieldsForModels, fieldsForModels); - } return this; } @@ -2114,7 +2096,8 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { } public Builder fieldsForModels(Map> fieldsForModels) { - processFieldsForModels(this.fieldsForModels, fieldsForModels); + // TODO: How to handle null value? Clear this.fieldsForModels? + this.fieldsForModels = fieldsForModels; return this; } @@ -2313,7 +2296,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels.build() + fieldsForModels ); } @@ -2439,8 +2422,10 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - if (indexMetadata.fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); + // TODO: Need null check? + Map> fieldsForModels = indexMetadata.getFieldsForModels(); + if (fieldsForModels != null && fieldsForModels.isEmpty() == false) { + builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); } builder.endObject(); @@ -2729,17 +2714,6 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } - - private static void processFieldsForModels( - ImmutableOpenMap.Builder> builder, - Map> fieldsForModels - ) { - builder.clear(); - if (fieldsForModels != null) { - // Ensure that all field sets contained in the processed map are immutable - fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); - } - } } /** diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 1276f6c2db58b..49bd0f48d44b8 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -123,7 +123,8 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null ) { @Override void executeBulk( diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index c10687055c632..2f5fd465ce6e3 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -13,8 +13,6 @@ import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.inference.InferenceAction; -import org.elasticsearch.action.inference.results.SparseEmbeddingResults; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.AutoCreateIndex; @@ -30,11 +28,16 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; +import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.inference.InferenceProviderException; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TestInferenceResults; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; @@ -51,6 +54,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; @@ -76,6 +80,8 @@ public class TransportBulkActionInferenceTests extends ESTestCase { private NodeClient nodeClient; private TransportBulkAction transportBulkAction; + private InferenceProvider inferenceProvider; + @Before public void setup() { threadPool = new TestThreadPool(getClass().getName()); @@ -111,6 +117,8 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(state, threadPool); + inferenceProvider = mock(InferenceProvider.class); + transportBulkAction = new TransportBulkAction( threadPool, mock(TransportService.class), @@ -120,16 +128,17 @@ public void setup() { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + inferenceProvider ); // Default answers to avoid hanging tests due to unexpected invocations doAnswer(invocation -> { @SuppressWarnings("unchecked") - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onFailure(new Exception("Unexpected invocation")); + var listener = (ActionListener>) invocation.getArguments()[2]; + listener.onFailure(new InferenceProviderException("Unexpected invocation", null)); return Void.TYPE; - }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), any(), any()); + }).when(inferenceProvider).textInference(any(), any(), any()); when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; @@ -268,7 +277,7 @@ public void testFailingInference() { } private void verifyInferenceExecuted(VerificationMode verificationMode) { - verify(nodeClient, verificationMode).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + verify(inferenceProvider, verificationMode).textInference(any(), any(), any()); } private void expectTransportShardBulkActionRequest(int requestSize) { @@ -301,40 +310,34 @@ private boolean matchBulkShardRequest(ActionRequest request, int requestSize) { @SuppressWarnings("unchecked") private void expectInferenceRequest(String modelId, String... inferenceTexts) { doAnswer(invocation -> { - InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; - var listener = (ActionListener) invocation.getArguments()[2]; + List texts = (List) invocation.getArguments()[1]; + var listener = (ActionListener>) invocation.getArguments()[2]; listener.onResponse( - new InferenceAction.Response( - new SparseEmbeddingResults( - request.getInput() + texts .stream() .map( - text -> new SparseEmbeddingResults.Embedding( - List.of(new SparseEmbeddingResults.WeightedToken(text.toString(), 1.0f)), - false + text -> new TestInferenceResults( + "test_field", + randomMap( + 1, + 10, + () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomFloat()) + ) ) - ) - .toList() - ) - ) - ); + ).collect(Collectors.toList())); return Void.TYPE; - }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); + }).when(inferenceProvider) + .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); } private void expectInferenceRequestFails(String modelId, String... inferenceTexts) { doAnswer(invocation -> { @SuppressWarnings("unchecked") - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onFailure(new Exception("Inference failed")); + var listener = (ActionListener>) invocation.getArguments()[2]; + listener.onFailure(new InferenceProviderException("Inference failed", null)); return Void.TYPE; - }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); + }).when(inferenceProvider) + .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); } - private boolean inferenceRequestMatches(ActionRequest request, String modelId, String[] inferenceTexts) { - if (request instanceof InferenceAction.Request inferenceRequest) { - return inferenceRequest.getModelId().equals(modelId) && inferenceRequest.getInput().containsAll(List.of(inferenceTexts)); - } - return false; - } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index f30bceada65d9..0cb030624013b 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -135,7 +135,8 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index a2e164f6a242c..15f716d2eabf9 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -87,7 +87,8 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index d4c5fc09e821f..ce8e762f4e755 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -252,8 +252,8 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider - ); + relativeTimeProvider, + null); } } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 1df74c787eec4..4c863ab6e7eb0 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1936,12 +1936,14 @@ protected void assertSnapshotOrGenericThread() { client, null, () -> DocumentParsingObserver.EMPTY_INSTANCE + ), client, actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE + EmptySystemIndices.INSTANCE, + null ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( From fac291305ae211f6dc82fbfdb825fbca46a53865 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 19:26:27 +0100 Subject: [PATCH 047/106] First version for adding InferenceProviderPlugin to the InferencePlugin --- .../elasticsearch/node/NodeConstruction.java | 6 +++ .../plugins/InferenceProviderPlugin.java | 25 +++++++++ .../InferenceActionInferenceProvider.java | 54 +++++++++++++++++++ .../xpack/inference/InferencePlugin.java | 16 +++++- 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 018abebdb7709..f673ea4e4e209 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -122,6 +122,7 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -140,6 +141,7 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; +import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1077,6 +1079,10 @@ record PluginServiceInstances( ); } + getSinglePlugin(InferenceProviderPlugin.class).ifPresent(plugin -> { + modules.add(b -> b.bind(InferenceProvider.class).toInstance(plugin.getInferenceProvider())); + }); + injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java new file mode 100644 index 0000000000000..ebd307d3d02c0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.plugins; + +import org.elasticsearch.inference.InferenceProvider; + +/** + * An extension point for {@link Plugin} implementations to add inference plugins for use on document ingestion + */ +public interface InferenceProviderPlugin { + + /** + * Returns the inference provider added by this plugin. + * + * @return InferenceProvider added by the plugin + */ + InferenceProvider getInferenceProvider(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java new file mode 100644 index 0000000000000..1305bab3b4a38 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.inference.InferenceProviderException; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; + +public class InferenceActionInferenceProvider implements InferenceProvider { + + private final Client client; + + public InferenceActionInferenceProvider(Client client) { + this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); + } + + @Override + public void textInference(String modelId, List texts, ActionListener> listener) + throws InferenceProviderException { + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified + modelId, + texts, + Map.of() + ); + + client.execute(InferenceAction.INSTANCE, inferenceRequest, listener.delegateFailure((l, response) -> { + InferenceServiceResults results = response.getResults(); + if (results == null) { + throw new IllegalArgumentException("No inference retrieved for model ID " + modelId); + } + + @SuppressWarnings("unchecked") + List result = (List) results.transformToLegacyFormat(); + l.onResponse(result); + })); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 33d71c65ed643..6537156eaa1c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,10 +21,12 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.SystemIndexDescriptor; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; +import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -66,7 +68,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceProviderPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -77,6 +79,8 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); + + private final SetOnce inferenceProvider = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -141,7 +145,10 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); - return List.of(modelRegistry, registry); + var provider = new InferenceActionInferenceProvider(services.client()); + inferenceProvider.set(provider); + + return List.of(modelRegistry, registry, provider); } @Override @@ -235,4 +242,9 @@ public void close() { IOUtils.closeWhileHandlingException(httpManager.get(), throttlerToClose); } + + @Override + public InferenceProvider getInferenceProvider() { + return inferenceProvider.get(); + } } From fb02ae4415b9327aa2cf63ef7fa94e8e35ce90fe Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 19:36:13 +0100 Subject: [PATCH 048/106] Adding IndexMetadata support for fieldsForModels as it stands today --- .../cluster/metadata/IndexMetadata.java | 44 +++++++++++++++---- .../cluster/metadata/MappingMetadata.java | 24 +++++++++- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 7dc99d8bc527e..e67eca4c4c97b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -1418,6 +1418,9 @@ public boolean equals(Object o) { if (rolloverInfos.equals(that.rolloverInfos) == false) { return false; } + if (fieldsForModels.equals(that.fieldsForModels) == false) { + return false; + } if (isSystem != that.isSystem) { return false; } @@ -1438,6 +1441,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(primaryTerms); result = 31 * result + inSyncAllocationIds.hashCode(); result = 31 * result + rolloverInfos.hashCode(); + result = 31 * result + fieldsForModels.hashCode(); result = 31 * result + Boolean.hashCode(isSystem); return result; } @@ -1671,6 +1675,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); + builder.fieldsForModels(fieldsForModels.apply(part.fieldsForModels)); return builder.build(true); } } @@ -1738,6 +1743,11 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) + ); + } return builder.build(true); } @@ -1784,6 +1794,9 @@ public void writeTo(StreamOutput out, boolean mappingsAsHash) throws IOException out.writeOptionalDouble(writeLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -1833,7 +1846,7 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private Map> fieldsForModels = Map.of(); + private final ImmutableOpenMap.Builder> fieldsForModels; public Builder(String index) { this.index = index; @@ -1841,6 +1854,7 @@ public Builder(String index) { this.customMetadata = ImmutableOpenMap.builder(); this.inSyncAllocationIds = new HashMap<>(); this.rolloverInfos = ImmutableOpenMap.builder(); + this.fieldsForModels = ImmutableOpenMap.builder(); this.isSystem = false; } @@ -1865,7 +1879,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; - this.fieldsForModels = indexMetadata.fieldsForModels; + this.fieldsForModels = ImmutableOpenMap.builder(indexMetadata.fieldsForModels); } public Builder index(String index) { @@ -1947,6 +1961,10 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; + if (mappingMd != null) { + Map> fieldsForModels = mappingMd.getFieldsForModels(); + processFieldsForModels(this.fieldsForModels, fieldsForModels); + } return this; } @@ -2096,8 +2114,7 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { } public Builder fieldsForModels(Map> fieldsForModels) { - // TODO: How to handle null value? Clear this.fieldsForModels? - this.fieldsForModels = fieldsForModels; + processFieldsForModels(this.fieldsForModels, fieldsForModels); return this; } @@ -2296,7 +2313,7 @@ IndexMetadata build(boolean repair) { stats, indexWriteLoadForecast, shardSizeInBytesForecast, - fieldsForModels + fieldsForModels.build() ); } @@ -2422,10 +2439,8 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - // TODO: Need null check? - Map> fieldsForModels = indexMetadata.getFieldsForModels(); - if (fieldsForModels != null && fieldsForModels.isEmpty() == false) { - builder.field(KEY_FIELDS_FOR_MODELS, fieldsForModels); + if (indexMetadata.fieldsForModels.isEmpty() == false) { + builder.field(KEY_FIELDS_FOR_MODELS, indexMetadata.fieldsForModels); } builder.endObject(); @@ -2714,6 +2729,17 @@ private static void handleLegacyMapping(Builder builder, Map map builder.putMapping(new MappingMetadata(MapperService.SINGLE_MAPPING_NAME, mapping)); } } + + private static void processFieldsForModels( + ImmutableOpenMap.Builder> builder, + Map> fieldsForModels + ) { + builder.clear(); + if (fieldsForModels != null) { + // Ensure that all field sets contained in the processed map are immutable + fieldsForModels.forEach((k, v) -> builder.put(k, Set.copyOf(v))); + } + } } /** diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index b629ab5d5f710..64a61f854b9da 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -18,11 +18,13 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MappingLookup; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -42,10 +44,15 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; + private final Map> fieldsForModels; + public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); + + MappingLookup mappingLookup = docMapper.mappers(); + this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -57,6 +64,7 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); + this.fieldsForModels = Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -72,6 +80,7 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); + this.fieldsForModels = Map.of(); } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -158,12 +167,19 @@ public String getSha256() { return source.getSha256(); } + public Map> getFieldsForModels() { + return fieldsForModels; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); + if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); + } } @Override @@ -176,19 +192,25 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; + if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired); + return Objects.hash(type, source, routingRequired, fieldsForModels); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { + fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); + } else { + fieldsForModels = Map.of(); + } } public static Diff readDiffFrom(StreamInput in) throws IOException { From 60345794c33f665ccfdfa8fff1fcac53cd75b055 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 19:39:04 +0100 Subject: [PATCH 049/106] Revert "First attempt on creating an IT test" This reverts commit 39061b4cdc02f644496e08c888f62bc402815881. --- .../ml-inference-service-tests/build.gradle | 4 - .../ml/integration/SemanticTextFieldIT.java | 162 ------------------ .../qa/native-multi-node-tests/build.gradle | 1 - 3 files changed, 167 deletions(-) delete mode 100644 x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle index 922492cde3e5d..83226acb383c7 100644 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle +++ b/x-pack/plugin/ml/qa/ml-inference-service-tests/build.gradle @@ -4,10 +4,6 @@ dependencies { javaRestTestImplementation(testArtifact(project(xpackModule('core')))) javaRestTestImplementation(testArtifact(project(xpackModule('ml')))) javaRestTestImplementation project(path: xpackModule('inference')) - javaRestTestImplementation(testArtifact(project(':x-pack:plugin:ml:qa:native-multi-node-tests'), "javaRestTest")) - javaRestTestImplementation(testArtifact(project(":qa:full-cluster-restart"), "javaRestTest")) - - clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } diff --git a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java b/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java deleted file mode 100644 index b20614191b13d..0000000000000 --- a/x-pack/plugin/ml/qa/ml-inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticTextFieldIT.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.integration; - -import org.elasticsearch.client.Request; -import org.elasticsearch.client.Response; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.rest.ESRestTestCase; -import org.junit.Assert; - -import java.io.IOException; -import java.util.Base64; -import java.util.List; - -import static org.hamcrest.Matchers.equalTo; - -public class SemanticTextFieldIT extends PyTorchModelRestTestCase { - - private static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA" - + "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG" - + "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29" - + "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl" - + "bW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWRT4+cMAzF7" - + "/spfASJomF3e0Ga3nrrn8vcELIyxAzRhAQlpjvbT19DWDrdquqBA/bvPT87nVUxwsm41xPd+PNtUi4a77" - + "KvXs+W8voBAHFSQY3EFCIiHKFp1+p57vs/ShyUccZdoIaz93aBTMR+thbPqru+qKBx8P4q/e8TyxRlmwVc" - + "tJp66H1YmCyS7WsZwD50A2L5V7pCBADGTTOj0bGGE7noQyqzv5JDfp0o9fZRCWqP37yjhE4+mqX5X3AdF" - + "ZHGM/2TzOHDpy1IvQWR+OWo3KwsRiKdpcqg4pBFDtm+QJ7nqwIPckrlnGfFJG0uNhOl38Sjut3pCqg26Qu" - + "Zy8BR9In7ScHHrKkKMW0TIucFrGQXCMpdaDO05O6DpOiy8e4kr0Ed/2YKOIhplW8gPr4ntygrd9ixpx3j9" - + "UZZVRagl2c6+imWUzBjuf5m+Ch7afphuvvW+r/0dsfn+2N9MZGb9+/SFtCYdhd83CMYp+mGy0LiKNs8y/e" - + "UuEA8B/d2z4dfUEsHCFSE3IaCAQAAIAMAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwApAHNpbXBsZ" - + "W1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCJQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlp" - + "aWlpaWlpaWlpaWlpahZHLbtNAFIZtp03rSVIuLRKXjdk5ojitKJsiFq24lem0KKSqpRIZt55gE9/GM+lNL" - + "Fgx4i1Ys2aHhIBXgAVICNggHgNm6rqJN2BZGv36/v/MOWeea/Z5RVHurLfRUsfZXOnccx522itrd53O0vL" - + "qbaKYtsAKUe1pcege7hm9JNtzM8+kOOzNApIX0A3xBXE6YE7g0UWjg2OaZAJXbKvALOnj2GEHKc496ykLkt" - + "gNt3Jz17hprCUxFqExe7YIpQkNpO1/kfHhPUdtUAdH2/gfmeYiIFW7IkM6IBP2wrDNbMe3Mjf2ksiK3Hjg" - + "hg7F2DN9l/omZZl5Mmez2QRk0q4WUUB0+1oh9nDwxGdUXJdXPMRZQs352eGaRPV9s2lcMeZFGWBfKJJiw0Y" - + "gbCMLBaRmXyy4flx6a667Fch55q05QOq2Jg2ANOyZwplhNsjiohVApo7aa21QnNGW5+4GXv8gxK1beBeHSR" - + "rhmLXWVh+0aBhErZ7bx1ejxMOhlR6QU4ycNqGyk8/yNGCWkwY7/RCD7UEQek4QszCgDJAzZtfErA0VqHBy9" - + "ugQP9pUfUmgCjVYgWNwHFbhBJyEOgSwBuuwARWZmoI6J9PwLfzEocpRpPrT8DP8wqHG0b4UX+E3DiscvRgl" - + "XIoi81KKPwioHI5x9EooNKWiy0KOc/T6WF4SssrRuzJ9L2VNRXUhJzj6UKYfS4W/q/5wuh/l4M9R9qsU+y2" - + "dpoo2hJzkaEET8r6KRONicnRdK9EbUi6raFVIwNGjsrlbpk6ZPi7TbS3fv3LyNjPiEKzG0aG0tvNb6xw90/" - + "whe6ONjnJcUxobHDUqQ8bIOW79BVBLBwhfSmPKdAIAAE4EAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAA" - + "BkABQBzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsRkIBAFqAAikuUEsHCG0vCVcEAAAABAAAAFBLAwQAAAgI" - + "AAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25GQjcAWlpaWlpaWlpaWlpaWlpaWlp" - + "aWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAA" - + "AACAgAAAAAAAAhOZuwWAAAAFgAAAAUAAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLA" - + "QIAABQACAgIAAAAAABUhNyGggEAACADAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19f" - + "dG9yY2hfXy5weVBLAQIAABQACAgIAAAAAABfSmPKdAIAAE4EAAAnAAAAAAAAAAAAAAAAAJICAABzaW1wbGVt" - + "b2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAAAAAAbS8JVwQAAAAEAAAAGQAA" - + "AAAAAAAAAAAAAACEBQAAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAA" - + "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA" - + "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA="; - - private static final long RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; - public static final List VOCABULARY = List.of( - "these", - "are", - "my", - "words", - "the", - "washing", - "machine", - "is", - "leaking", - "octopus", - "comforter", - "smells" - ); - - public void testSemanticTextInference() throws IOException { - String modelId = "semantic-text-model"; - - createTextExpansionModel(modelId); - putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); - putVocabulary( - VOCABULARY, - modelId - ); - startDeployment(modelId); - - String indexName = modelId + "-index"; - createSemanticTextIndex(indexName); - - int numDocs = ESTestCase.randomIntBetween(1, 10); - bulkIndexDocs(indexName, numDocs); - - for (int i = 0; i < numDocs; i++) { - Request getRequest = new Request("GET", "/" + indexName + "/_doc/" + i); - Response response = ESRestTestCase.client().performRequest(getRequest); - Assert.assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); - } - } - - private void createTextExpansionModel(String modelId) throws IOException { - Request request = new Request("PUT", "/_ml/trained_models/" + modelId); - request.setJsonEntity(""" - { - "description": "a text expansion model", - "model_type": "pytorch", - "inference_config": { - "text_expansion": { - "tokenization": { - "bert": { - "with_special_tokens": false - } - } - } - } - }"""); - ESRestTestCase.client().performRequest(request); - } - - private void createSemanticTextIndex(String indexName) throws IOException { - Request createIndex = new Request("PUT", "/" + indexName); - createIndex.setJsonEntity(""" - { - "mappings": { - "properties": { - "text_field": { - "type": "text" - }, - "inference_field": { - "type": "semantic_text", - "model_id": "semantic-text-model" - } - } - } - }"""); - var response = ESRestTestCase.client().performRequest(createIndex); - assertOkWithErrorMessage(response); - } - - private void bulkIndexDocs(String indexName, int numDocs) throws IOException { - - StringBuilder bulkBuilder = new StringBuilder(); - - for (int i = 0; i < numDocs; i++) { - String createAction = "{\"create\": {\"_index\": \"" + indexName + "\" \"_id\":\"" + i + "\"}}\n"; - bulkBuilder.append(createAction); - bulkBuilder.append("{\"text_field\": \"").append(ESTestCase.randomAlphaOfLengthBetween(1, 100)).append("\","); - - bulkBuilder.append("{\"inference_field\": \""); - bulkBuilder.append(String.join(" ", ESTestCase.randomSubsetOf(ESTestCase.randomIntBetween(1, 10), VOCABULARY))); - bulkBuilder.append("\""); - - bulkBuilder.append("}}\n"); - } - - Request bulkRequest = new Request("POST", "/_bulk"); - - bulkRequest.setJsonEntity(bulkBuilder.toString()); - bulkRequest.addParameter("refresh", "true"); - var bulkResponse = ESRestTestCase.client().performRequest(bulkRequest); - assertOkWithErrorMessage(bulkResponse); - } - -} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle index 5022354c641e6..db53b9aec7f1f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/build.gradle @@ -1,5 +1,4 @@ apply plugin: 'elasticsearch.legacy-java-rest-test' -apply plugin: 'elasticsearch.internal-test-artifact' dependencies { javaRestTestImplementation(testArtifact(project(xpackModule('core')))) From 6ec089e9a59ba4cd8797d17f7b907e507c0cb867 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 19:41:56 +0100 Subject: [PATCH 050/106] spotless --- .../action/bulk/TransportBulkAction.java | 70 ++++++++++--------- .../TransportBulkActionInferenceTests.java | 21 +++--- .../bulk/TransportBulkActionTookTests.java | 3 +- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index c8ace03dfe334..e63fa16cd0d60 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -157,7 +157,8 @@ public TransportBulkAction( IndexingPressure indexingPressure, SystemIndices systemIndices, LongSupplier relativeTimeProvider, - InferenceProvider inferenceProvider) { + InferenceProvider inferenceProvider + ) { this( BulkAction.INSTANCE, BulkRequest::new, @@ -849,45 +850,46 @@ private void performInferenceOnBulkItemRequest( @Override public void onResponse(List results) { - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ); + } + + int i = 0; + for (InferenceResults inferenceResults : results) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() + ); + + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); + } + + docRef.close(); } - int i = 0; - for (InferenceResults inferenceResults : results) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() + @Override + public void onFailure(Exception e) { + + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure( + indexName, + docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) ); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); + docRef.close(); } - - docRef.close(); - } - - @Override - public void onFailure(Exception e) { - - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure( - indexName, - docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) - ); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; - - docRef.close(); } - }); + ); } } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 2f5fd465ce6e3..3e881bb91594f 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -313,18 +313,15 @@ private void expectInferenceRequest(String modelId, String... inferenceTexts) { List texts = (List) invocation.getArguments()[1]; var listener = (ActionListener>) invocation.getArguments()[2]; listener.onResponse( - texts - .stream() - .map( - text -> new TestInferenceResults( - "test_field", - randomMap( - 1, - 10, - () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomFloat()) - ) - ) - ).collect(Collectors.toList())); + texts.stream() + .map( + text -> new TestInferenceResults( + "test_field", + randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomFloat())) + ) + ) + .collect(Collectors.toList()) + ); return Void.TYPE; }).when(inferenceProvider) .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index ce8e762f4e755..27cf5b101e276 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -253,7 +253,8 @@ static class TestTransportBulkAction extends TransportBulkAction { new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, relativeTimeProvider, - null); + null + ); } } } From 85eeec06a1b13f86628d32ed9ee580ecbac7ea6e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 20:01:16 +0100 Subject: [PATCH 051/106] Remove changes from other branches --- .../cluster/ClusterStateDiffIT.java | 26 +---------- .../cluster/metadata/IndexMetadataTests.java | 37 +-------------- .../index/mapper/FieldTypeLookupTests.java | 22 --------- .../index/mapper/MappingLookupTests.java | 19 -------- .../mapper/MockInferenceModelFieldType.java | 45 ------------------- .../inference/action/InferenceAction.java | 5 +-- .../results/LegacyTextEmbeddingResults.java | 5 +-- .../inference/rest/RestInferenceAction.java | 7 +-- 8 files changed, 11 insertions(+), 155 deletions(-) delete mode 100644 test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index 433b4bdaf5d98..b869b3a90fbce 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -55,7 +55,6 @@ import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -572,7 +571,7 @@ public IndexMetadata randomCreate(String name) { @Override public IndexMetadata randomChange(IndexMetadata part) { IndexMetadata.Builder builder = IndexMetadata.builder(part); - switch (randomIntBetween(0, 3)) { + switch (randomIntBetween(0, 2)) { case 0: builder.settings(Settings.builder().put(part.getSettings()).put(randomSettings(Settings.EMPTY))); break; @@ -586,34 +585,11 @@ public IndexMetadata randomChange(IndexMetadata part) { case 2: builder.settings(Settings.builder().put(part.getSettings()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)); break; - case 3: - builder.fieldsForModels(randomFieldsForModels()); - break; default: throw new IllegalArgumentException("Shouldn't be here"); } return builder.build(); } - - /** - * Generates a random fieldsForModels map - */ - private Map> randomFieldsForModels() { - if (randomBoolean()) { - return null; - } - - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } - - return fieldsForModels; - } }); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 58b8adcf53538..b4c9f670f66b6 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -40,7 +40,6 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,8 +82,6 @@ public void testIndexMetadataSerialization() throws IOException { IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; - Map> fieldsForModels = randomFieldsForModels(true); - IndexMetadata metadata = IndexMetadata.builder("foo") .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) .creationDate(randomLong()) @@ -108,7 +105,6 @@ public void testIndexMetadataSerialization() throws IOException { .stats(indexStats) .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) - .fieldsForModels(fieldsForModels) .build(); assertEquals(system, metadata.isSystem()); @@ -142,7 +138,6 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getStats(), fromXContentMeta.getStats()); assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), fromXContentMeta.getFieldsForModels()); final BytesStreamOutput out = new BytesStreamOutput(); metadata.writeTo(out); @@ -164,9 +159,8 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getCustomData(), deserialized.getCustomData()); assertEquals(metadata.isSystem(), deserialized.isSystem()); assertEquals(metadata.getStats(), deserialized.getStats()); - assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); - assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); - assertEquals(metadata.getFieldsForModels(), deserialized.getFieldsForModels()); + assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad()); + assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes()); } } @@ -550,37 +544,10 @@ public void testPartialIndexReceivesDataFrozenTierPreference() { } } - public void testFieldsForModels() { - Settings.Builder settings = indexSettings(IndexVersion.current(), randomIntBetween(1, 8), 0); - IndexMetadata idxMeta1 = IndexMetadata.builder("test").settings(settings).build(); - assertThat(idxMeta1.getFieldsForModels(), equalTo(Map.of())); - - Map> fieldsForModels = randomFieldsForModels(false); - IndexMetadata idxMeta2 = IndexMetadata.builder(idxMeta1).fieldsForModels(fieldsForModels).build(); - assertThat(idxMeta2.getFieldsForModels(), equalTo(fieldsForModels)); - } - private static Settings indexSettingsWithDataTier(String dataTier) { return indexSettings(IndexVersion.current(), 1, 0).put(DataTier.TIER_PREFERENCE, dataTier).build(); } - private static Map> randomFieldsForModels(boolean allowNull) { - if (allowNull && randomBoolean()) { - return null; - } - - Map> fieldsForModels = new HashMap<>(); - for (int i = 0; i < randomIntBetween(0, 5); i++) { - Set fields = new HashSet<>(); - for (int j = 0; j < randomIntBetween(1, 4); j++) { - fields.add(randomAlphaOfLengthBetween(4, 10)); - } - fieldsForModels.put(randomAlphaOfLengthBetween(4, 10), fields); - } - - return fieldsForModels; - } - private IndexMetadataStats randomIndexStats(int numberOfShards) { IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards); int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java index 8db9c09f0d098..3f50b9fdf6621 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/FieldTypeLookupTests.java @@ -16,7 +16,6 @@ import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Set; import static java.util.Collections.emptyList; @@ -36,10 +35,6 @@ public void testEmpty() { Collection names = lookup.getMatchingFieldNames("foo"); assertNotNull(names); assertThat(names, hasSize(0)); - - Map> fieldsForModels = lookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddNewField() { @@ -47,10 +42,6 @@ public void testAddNewField() { FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); assertNull(lookup.get("bar")); assertEquals(f.fieldType(), lookup.get("foo")); - - Map> fieldsForModels = lookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertTrue(fieldsForModels.isEmpty()); } public void testAddFieldAlias() { @@ -430,19 +421,6 @@ public void testRuntimeFieldNameOutsideContext() { } } - public void testInferenceModelFieldType() { - MockFieldMapper f = new MockFieldMapper(new MockInferenceModelFieldType("foo", "bar")); - FieldTypeLookup lookup = new FieldTypeLookup(Collections.singletonList(f), emptyList(), Collections.emptyList()); - assertEquals(f.fieldType(), lookup.get("foo")); - assertEquals(Collections.emptySet(), lookup.getFieldsForModel("baz")); - assertEquals(Collections.singleton("foo"), lookup.getFieldsForModel("bar")); - - Map> fieldsForModels = lookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(1, fieldsForModels.size()); - assertEquals(Collections.singleton("foo"), fieldsForModels.get("bar")); - } - private static FlattenedFieldMapper createFlattenedMapper(String fieldName) { return new FlattenedFieldMapper.Builder(fieldName).build(MapperBuilderContext.root(false, false)); } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java index f512f5d352a43..0308dac5fa216 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupTests.java @@ -26,7 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -122,8 +121,6 @@ public void testEmptyMappingLookup() { assertEquals(0, mappingLookup.getMapping().getMetadataMappersMap().size()); assertFalse(mappingLookup.fieldMappers().iterator().hasNext()); assertEquals(0, mappingLookup.getMatchingFieldNames("*").size()); - assertNotNull(mappingLookup.getFieldsForModels()); - assertTrue(mappingLookup.getFieldsForModels().isEmpty()); } public void testValidateDoesNotShadow() { @@ -191,22 +188,6 @@ public MetricType getMetricType() { ); } - public void testFieldsForModels() { - MockInferenceModelFieldType fieldType = new MockInferenceModelFieldType("test_field_name", "test_model_id"); - MappingLookup mappingLookup = createMappingLookup( - Collections.singletonList(new MockFieldMapper(fieldType)), - emptyList(), - emptyList() - ); - assertEquals(1, size(mappingLookup.fieldMappers())); - assertEquals(fieldType, mappingLookup.getFieldType("test_field_name")); - - Map> fieldsForModels = mappingLookup.getFieldsForModels(); - assertNotNull(fieldsForModels); - assertEquals(1, fieldsForModels.size()); - assertEquals(Collections.singleton("test_field_name"), fieldsForModels.get("test_model_id")); - } - private void assertAnalyzes(Analyzer analyzer, String field, String output) throws IOException { try (TokenStream tok = analyzer.tokenStream(field, new StringReader(""))) { CharTermAttribute term = tok.addAttribute(CharTermAttribute.class); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java deleted file mode 100644 index 854749d6308db..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MockInferenceModelFieldType.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.util.Map; - -public class MockInferenceModelFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private static final String TYPE_NAME = "mock_inference_model_field_type"; - - private final String modelId; - - public MockInferenceModelFieldType(String name, String modelId) { - super(name, false, false, false, TextSearchInfo.NONE, Map.of()); - this.modelId = modelId; - } - - @Override - public String typeName() { - return TYPE_NAME; - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented"); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); - } - - @Override - public String getInferenceModel() { - return modelId; - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index ffb6567009b79..a1eabb682c98f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -1,9 +1,8 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ package org.elasticsearch.xpack.core.inference.action; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java index 73ba6544fc86e..8f03a75c61c11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/LegacyTextEmbeddingResults.java @@ -1,9 +1,8 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ package org.elasticsearch.xpack.core.inference.results; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index beecf75da38ab..0286390a8a3ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -33,8 +33,9 @@ public List routes() { protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String taskType = restRequest.param("task_type"); String modelId = restRequest.param("model_id"); - var request = InferenceAction.Request.parseRequest(modelId, taskType, restRequest.contentParser()); - - return channel -> client.execute(InferenceAction.INSTANCE, request, new RestToXContentListener<>(channel)); + try (var parser = restRequest.contentParser()) { + var request = InferenceAction.Request.parseRequest(modelId, taskType, parser); + return channel -> client.execute(InferenceAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } } } From b90d6ad58efacaf12accd78293691827e191822a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 20:06:08 +0100 Subject: [PATCH 052/106] Add javadoc --- .../xpack/inference/InferenceActionInferenceProvider.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java index 1305bab3b4a38..20d12c8990507 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java @@ -22,6 +22,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +/** + * InferenceProvider implementation that uses the inference action to retrieve inference results. + */ public class InferenceActionInferenceProvider implements InferenceProvider { private final Client client; From 130cc827c6854537c70d4a4722711f8d1c604992 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 10 Jan 2024 20:13:52 +0100 Subject: [PATCH 053/106] Remove changes from other branches --- x-pack/plugin/ml/build.gradle | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index 22cdb752d1e8d..74600a072ea0d 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -1,5 +1,6 @@ import org.elasticsearch.gradle.VersionProperties import org.elasticsearch.gradle.internal.dra.DraResolvePlugin +import org.elasticsearch.gradle.internal.info.BuildParams apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' @@ -73,6 +74,7 @@ esplugin.bundleSpec.exclude 'platform/licenses/**' } dependencies { + testImplementation project(path: ':x-pack:plugin:inference') compileOnly project(':modules:lang-painless:spi') compileOnly project(path: xpackModule('core')) compileOnly project(path: xpackModule('autoscaling')) @@ -113,6 +115,12 @@ artifacts { archives tasks.named("jar") } +if (BuildParams.isSnapshotBuild() == false) { + tasks.named("test").configure { + systemProperty 'es.semantic_text_feature_flag_enabled', 'true' + } +} + tasks.register("extractNativeLicenses", Copy) { dependsOn configurations.nativeBundle into "${buildDir}/extractedNativeLicenses" From d66951bac781b06a238432095c39f7e1d665c4fb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 11 Jan 2024 18:37:44 +0100 Subject: [PATCH 054/106] Makes InferenceProvider non null to deal with injection --- .../action/bulk/TransportBulkAction.java | 4 +++- .../bulk/TransportSimulateBulkAction.java | 3 ++- .../inference/InferenceProvider.java | 21 +++++++++++++++++++ .../elasticsearch/node/NodeConstruction.java | 11 +++++++--- ...ActionIndicesThatCannotBeCreatedTests.java | 3 ++- .../bulk/TransportBulkActionIngestTests.java | 3 ++- .../action/bulk/TransportBulkActionTests.java | 3 ++- .../inference/TestInferenceResults.java | 4 ++-- .../snapshots/SnapshotResiliencyTests.java | 3 ++- 9 files changed, 44 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index e63fa16cd0d60..af1332924a5a5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -54,6 +54,7 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; @@ -203,6 +204,7 @@ public TransportBulkAction( this.indexNameExpressionResolver = indexNameExpressionResolver; this.indexingPressure = indexingPressure; this.systemIndices = systemIndices; + Objects.requireNonNull(inferenceProvider); this.inferenceProvider = inferenceProvider; clusterService.addStateApplier(this.ingestForwarder); } @@ -793,7 +795,7 @@ private void performInferenceOnBulkItemRequest( Map> fieldsForModels, Releasable releaseOnFinish ) { - if (inferenceProvider == null) { + if (inferenceProvider.performsInference() == false) { releaseOnFinish.close(); return; } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index 1261b716869de..868d3babd3edc 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -21,6 +21,7 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.SimulateIngestService; import org.elasticsearch.tasks.Task; @@ -56,7 +57,7 @@ public TransportSimulateBulkAction( indexingPressure, systemIndices, System::nanoTime, - null + new InferenceProvider.NoopInferenceProvider() ); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java index 37fc934104f18..96bf33fa3c1d8 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java @@ -25,4 +25,25 @@ public interface InferenceProvider { */ void textInference(String modelId, List texts, ActionListener> listener) throws InferenceProviderException; + + /** + * Returns true if this inference provider can perform inference + * + * @return true if this inference provider can perform inference + */ + boolean performsInference(); + + class NoopInferenceProvider implements InferenceProvider { + + @Override + public void textInference(String modelId, List texts, ActionListener> listener) + throws InferenceProviderException { + throw new InferenceProviderException("No inference provider has been registered", null); + } + + @Override + public boolean performsInference() { + return false; + } + } } diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index f673ea4e4e209..a0a91712ed6af 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1079,9 +1079,14 @@ record PluginServiceInstances( ); } - getSinglePlugin(InferenceProviderPlugin.class).ifPresent(plugin -> { - modules.add(b -> b.bind(InferenceProvider.class).toInstance(plugin.getInferenceProvider())); - }); + InferenceProvider inferenceProvider = null; + Optional inferenceProviderPlugin = getSinglePlugin(InferenceProviderPlugin.class); + if (inferenceProviderPlugin.isPresent()) { + inferenceProvider = inferenceProviderPlugin.get().getInferenceProvider(); + } else { + inferenceProvider = new InferenceProvider.NoopInferenceProvider(); + } + modules.bindToInstance(InferenceProvider.class, inferenceProvider); injector = modules.createInjector(); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 49bd0f48d44b8..cad87f52829cb 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockUtils; @@ -124,7 +125,7 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - null + new InferenceProvider.NoopInferenceProvider() ) { @Override void executeBulk( diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 0cb030624013b..eae8554f3f394 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -41,6 +41,7 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -136,7 +137,7 @@ class TestTransportBulkAction extends TransportBulkAction { TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), EmptySystemIndices.INSTANCE, - null + new InferenceProvider.NoopInferenceProvider() ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 15f716d2eabf9..01fbbff173cd5 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -40,6 +40,7 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.SystemIndexDescriptorUtils; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.index.IndexVersionUtils; @@ -88,7 +89,7 @@ class TestTransportBulkAction extends TransportBulkAction { new Resolver(), new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - null + new InferenceProvider.NoopInferenceProvider() ); } diff --git a/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java b/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java index 7b1445ae7ad5f..f24997fd6a328 100644 --- a/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java +++ b/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java @@ -19,9 +19,9 @@ public class TestInferenceResults implements InferenceResults { private final String resultField; - private final Map inferenceResults; + private final Map inferenceResults; - public TestInferenceResults(String resultField, Map inferenceResults) { + public TestInferenceResults(String resultField, Map inferenceResults) { this.resultField = resultField; this.inferenceResults = inferenceResults; } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 4c863ab6e7eb0..6be62be14fbfd 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -155,6 +155,7 @@ import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.indices.recovery.SnapshotFilesProvider; import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.StatusInfo; import org.elasticsearch.node.ResponseCollectorService; @@ -1943,7 +1944,7 @@ protected void assertSnapshotOrGenericThread() { indexNameExpressionResolver, new IndexingPressure(settings), EmptySystemIndices.INSTANCE, - null + new InferenceProvider.NoopInferenceProvider() ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( From d79839623d4a80f10799b58e931230c395fbfcfa Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 11 Jan 2024 18:39:00 +0100 Subject: [PATCH 055/106] Spotless --- .../java/org/elasticsearch/action/bulk/TransportBulkAction.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index af1332924a5a5..82080f00c8abd 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -54,7 +54,6 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Assertions; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; From 33600bf7114e3620caee441d08c75663ff80e7b5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 11 Jan 2024 18:47:13 +0100 Subject: [PATCH 056/106] Implement missing method --- .../xpack/inference/InferenceActionInferenceProvider.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java index 20d12c8990507..183032393ef1c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java @@ -54,4 +54,9 @@ public void textInference(String modelId, List texts, ActionListener Date: Thu, 11 Jan 2024 20:06:34 +0100 Subject: [PATCH 057/106] Fix tests and remove useless exception from interface --- .../inference/InferenceProvider.java | 8 +++----- .../inference/InferenceProviderException.java | 18 ------------------ .../TransportBulkActionInferenceTests.java | 1 + .../bulk/TransportBulkActionTookTests.java | 3 ++- .../InferenceActionInferenceProvider.java | 4 +--- 5 files changed, 7 insertions(+), 27 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java index 96bf33fa3c1d8..39af88d3194a2 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java @@ -23,8 +23,7 @@ public interface InferenceProvider { * @param texts texts to perform inference on * @param listener listener to be called when inference is complete */ - void textInference(String modelId, List texts, ActionListener> listener) - throws InferenceProviderException; + void textInference(String modelId, List texts, ActionListener> listener); /** * Returns true if this inference provider can perform inference @@ -36,9 +35,8 @@ void textInference(String modelId, List texts, ActionListener texts, ActionListener> listener) - throws InferenceProviderException { - throw new InferenceProviderException("No inference provider has been registered", null); + public void textInference(String modelId, List texts, ActionListener> listener) { + throw new UnsupportedOperationException("No inference provider has been registered"); } @Override diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java b/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java deleted file mode 100644 index 0d82bc800f414..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProviderException.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.ElasticsearchException; - -public class InferenceProviderException extends ElasticsearchException { - - public InferenceProviderException(String msg, Throwable cause, Object... args) { - super(msg, cause, args); - } -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 3e881bb91594f..16eecfe3a4e69 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -118,6 +118,7 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(state, threadPool); inferenceProvider = mock(InferenceProvider.class); + when(inferenceProvider.performsInference()).thenReturn(true); transportBulkAction = new TransportBulkAction( threadPool, diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index 27cf5b101e276..d94a1cb092bc4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; @@ -253,7 +254,7 @@ static class TestTransportBulkAction extends TransportBulkAction { new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, relativeTimeProvider, - null + new InferenceProvider.NoopInferenceProvider() ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java index 183032393ef1c..6d9740c407059 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java @@ -11,7 +11,6 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceProviderException; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; @@ -34,8 +33,7 @@ public InferenceActionInferenceProvider(Client client) { } @Override - public void textInference(String modelId, List texts, ActionListener> listener) - throws InferenceProviderException { + public void textInference(String modelId, List texts, ActionListener> listener) { InferenceAction.Request inferenceRequest = new InferenceAction.Request( TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified modelId, From 33b325b10fc24f6e963c467a346a862534b6ff46 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 11 Jan 2024 20:55:22 +0100 Subject: [PATCH 058/106] Remove references to the removed exception - I'm hopefully tired and not just stupid --- .../action/bulk/TransportBulkActionInferenceTests.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 16eecfe3a4e69..08948d76ed8a1 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -35,7 +35,6 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceProviderException; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.TestInferenceResults; import org.elasticsearch.ingest.IngestService; @@ -137,7 +136,7 @@ public void setup() { doAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new InferenceProviderException("Unexpected invocation", null)); + listener.onFailure(new Exception("Unexpected invocation")); return Void.TYPE; }).when(inferenceProvider).textInference(any(), any(), any()); when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { @@ -332,7 +331,7 @@ private void expectInferenceRequestFails(String modelId, String... inferenceText doAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new InferenceProviderException("Inference failed", null)); + listener.onFailure(new Exception("Inference failed")); return Void.TYPE; }).when(inferenceProvider) .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); From e14ef02e04c5a68159470e54b2298fe0af9c8021 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 12 Jan 2024 21:09:16 +0100 Subject: [PATCH 059/106] Add back TransportVersions for semantic text --- server/src/main/java/org/elasticsearch/TransportVersions.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5b2819f04ec24..fc5367d254b10 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -183,6 +183,7 @@ static TransportVersion def(int id) { public static final TransportVersion HOT_THREADS_AS_BYTES = def(8_571_00_0); public static final TransportVersion ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED = def(8_572_00_0); public static final TransportVersion ESQL_ENRICH_POLICY_CCQ_MODE = def(8_573_00_0); + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_574_00_0); /* * STOP! READ THIS FIRST! No, really, From d1bc78f24e7684a25ba6ec6c832d51d2f798d11a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 15 Jan 2024 10:04:54 +0100 Subject: [PATCH 060/106] Add inference service param needed --- .../xpack/inference/InferenceActionInferenceProvider.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java index 6d9740c407059..7886590c768cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; @@ -38,7 +39,8 @@ public void textInference(String modelId, List texts, ActionListener { From 10a0eda3fef6b8bf686b4b908b69742ee1b96e00 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 15 Jan 2024 19:18:03 +0100 Subject: [PATCH 061/106] Performs inference even if text value is the same as previous --- .../elasticsearch/action/bulk/TransportBulkAction.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 82080f00c8abd..a9c79708532df 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -904,13 +904,8 @@ private static List getFieldNamesForInference( Object fieldValue = docMap.get(inferenceField); // Perform inference on string, non-null values - if (fieldValue instanceof String fieldStringValue) { - - // Only do inference if the previous text value doesn't match the new one - String previousValue = findMapValue(docMap, ROOT_RESULT_FIELD, inferenceField, TEXT_FIELD); - if (fieldStringValue.equals(previousValue) == false) { - inferenceFieldNames.add(inferenceField); - } + if (fieldValue instanceof String) { + inferenceFieldNames.add(inferenceField); } } return inferenceFieldNames; From d39f2c965f7f5d7d93d4177cd11adc3f096ddc5b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 16 Jan 2024 11:54:38 +0100 Subject: [PATCH 062/106] Fix typo --- .../java/org/elasticsearch/inference/InferenceProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java index 39af88d3194a2..a0b282d327ae8 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java @@ -17,7 +17,7 @@ */ public interface InferenceProvider { /** - * Returns nferenceResults for a given model ID and list of texts. + * Returns InferenceResults for a given model ID and list of texts. * * @param modelId model identifier * @param texts texts to perform inference on From f3f008fa9921f77bf0604bf2de342201e941c8bc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 16 Jan 2024 11:59:48 +0100 Subject: [PATCH 063/106] Add warn when inference provider is not found --- .../src/main/java/org/elasticsearch/node/NodeConstruction.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 8a8c90b5110ee..c7dc21010357b 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -1085,6 +1085,7 @@ record PluginServiceInstances( if (inferenceProviderPlugin.isPresent()) { inferenceProvider = inferenceProviderPlugin.get().getInferenceProvider(); } else { + logger.warn("No inference provider found. Inference for semantic_text field types won't be available"); inferenceProvider = new InferenceProvider.NoopInferenceProvider(); } modules.bindToInstance(InferenceProvider.class, inferenceProvider); From b2aab09137dbabb1cbc940c62795214daca6bc18 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 18 Jan 2024 20:23:17 +0100 Subject: [PATCH 064/106] Removed changes from MappingMetadata --- .../cluster/metadata/MappingMetadata.java | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 64a61f854b9da..b629ab5d5f710 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -18,13 +18,11 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MappingLookup; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -44,15 +42,10 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; - private final Map> fieldsForModels; - public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - - MappingLookup mappingLookup = docMapper.mappers(); - this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -64,7 +57,6 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -80,7 +72,6 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = Map.of(); } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -167,19 +158,12 @@ public String getSha256() { return source.getSha256(); } - public Map> getFieldsForModels() { - return fieldsForModels; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); - } } @Override @@ -192,25 +176,19 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; - if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired, fieldsForModels); + return Objects.hash(type, source, routingRequired); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); - } else { - fieldsForModels = Map.of(); - } } public static Diff readDiffFrom(StreamInput in) throws IOException { From 4947b1a70e6896bc89091ab825a67a3feba27658 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 18 Jan 2024 20:49:45 +0100 Subject: [PATCH 065/106] Refactor inference into a separate class --- .../BulkShardRequestInferenceProvider.java | 215 ++++++++++++++++++ .../action/bulk/TransportBulkAction.java | 208 +++-------------- 2 files changed, 241 insertions(+), 182 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java new file mode 100644 index 0000000000000..7f10096b12828 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.TriConsumer; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.inference.InferenceResults; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +public class BulkShardRequestInferenceProvider { + + public static final String ROOT_RESULT_FIELD = "_ml_inference"; + public static final String INFERENCE_FIELD = "result"; + public static final String TEXT_FIELD = "text"; + + private final InferenceProvider inferenceProvider; + + public BulkShardRequestInferenceProvider(InferenceProvider inferenceProvider) { + this.inferenceProvider = inferenceProvider; + } + + public void processBulkShardRequest( + BulkShardRequest bulkShardRequest, + ClusterState clusterState, + TriConsumer onBulkItemFailure, + Consumer nextAction + ) { + + Map> fieldsForModels = clusterState.metadata() + .index(bulkShardRequest.shardId().getIndex()) + .getFieldsForModels(); + // No inference fields? Just execute the request + if (fieldsForModels.isEmpty()) { + nextAction.accept(bulkShardRequest); + return; + } + + Runnable onInferenceComplete = () -> { + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + nextAction.accept(errorsFilteredShardRequest); + }; + + try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { + for (BulkItemRequest bulkItemRequest : bulkShardRequest.items()) { + performInferenceOnBulkItemRequest( + bulkShardRequest, + bulkItemRequest, + fieldsForModels, + onBulkItemFailure, + bulkItemReqRef.acquire() + ); + } + } + } + + private void performInferenceOnBulkItemRequest( + BulkShardRequest bulkShardRequest, + BulkItemRequest bulkItemRequest, + Map> fieldsForModels, + TriConsumer onBulkItemFailure, + Releasable releaseOnFinish + ) { + if (inferenceProvider.performsInference() == false) { + releaseOnFinish.close(); + return; + } + + DocWriteRequest docWriteRequest = bulkItemRequest.request(); + Map sourceMap = null; + if (docWriteRequest instanceof IndexRequest indexRequest) { + sourceMap = indexRequest.sourceAsMap(); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); + } + if (sourceMap == null || sourceMap.isEmpty()) { + releaseOnFinish.close(); + return; + } + final Map docMap = new ConcurrentHashMap<>(sourceMap); + + // When a document completes processing, update the source with the inference + try (var docRef = new RefCountingRunnable(() -> { + if (docWriteRequest instanceof IndexRequest indexRequest) { + indexRequest.source(docMap); + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + if (updateRequest.docAsUpsert()) { + updateRequest.upsertRequest().source(docMap); + } else { + updateRequest.doc().source(docMap); + } + } + releaseOnFinish.close(); + })) { + + for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); + + @SuppressWarnings("unchecked") + Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + ROOT_RESULT_FIELD, + k -> new HashMap() + ); + + List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); + + if (inferenceFieldNames.isEmpty()) { + continue; + } + + docRef.acquire(); + + inferenceProvider.textInference( + modelId, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + new ActionListener<>() { + + @Override + public void onResponse(List results) { + + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ); + } + + int i = 0; + for (InferenceResults inferenceResults : results) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() + ); + + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); + } + + docRef.close(); + } + + @Override + public void onFailure(Exception e) { + onBulkItemFailure.apply(bulkShardRequest, bulkItemRequest, e); + docRef.close(); + } + } + ); + } + } + } + + private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { + List inferenceFieldNames = new ArrayList<>(); + for (String inferenceField : fieldModelsEntrySet.getValue()) { + Object fieldValue = docMap.get(inferenceField); + + // Perform inference on string, non-null values + if (fieldValue instanceof String) { + inferenceFieldNames.add(inferenceField); + } + } + return inferenceFieldNames; + } + + @SuppressWarnings("unchecked") + private static String findMapValue(Map map, String... path) { + Map currentMap = map; + for (int i = 0; i < path.length - 1; i++) { + Object value = currentMap.get(path[i]); + + if (value instanceof Map) { + currentMap = (Map) value; + } else { + // Invalid path or non-Map value encountered + return null; + } + } + + // Retrieve the final value in the map, if it's a String + Object finalValue = currentMap.get(path[path.length - 1]); + + return (finalValue instanceof String) ? (String) finalValue : null; + } + +} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index a9c79708532df..f3c144c0a0475 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -65,7 +65,6 @@ import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -74,7 +73,6 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -83,9 +81,9 @@ import java.util.Objects; import java.util.Set; import java.util.SortedMap; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerArray; +import java.util.function.Consumer; import java.util.function.LongSupplier; import java.util.stream.Collectors; @@ -112,11 +110,7 @@ public class TransportBulkAction extends HandledTransportAction> requ // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; - try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); @@ -754,187 +747,41 @@ private void executeBulkRequestsByShard(Map> requ bulkShardRequest.setParentTask(nodeId, task.getId()); } - performInferenceAndExecute(bulkShardRequest, clusterState, bulkItemRequestCompleteRefCount.acquire()); - } - } - } - - private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) { - - Map> fieldsForModels = clusterState.metadata() - .index(bulkShardRequest.shardId().getIndex()) - .getFieldsForModels(); - // No inference fields? Just execute the request - if (fieldsForModels.isEmpty()) { - executeBulkShardRequest(bulkShardRequest, releaseOnFinish); - return; - } - - Runnable onInferenceComplete = () -> { - // We need to remove items that have had an inference error, as the response will have been updated already - // and we don't need to process them further - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) - ); - executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish); - }; + Releasable ref = bulkItemRequestCompleteRefCount.acquire(); + bulkShardRequestInferenceProvider.processBulkShardRequest( + bulkShardRequest, + clusterState, + (bulkShardRequest1, request, e) -> onBulkItemInferenceFailure(bulkShardRequest1, request, e, ref), + bsr -> executeBulkShardRequest(bsr, b -> ref.close()) + ); - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - for (BulkItemRequest request : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire()); } } } - private void performInferenceOnBulkItemRequest( + private void onBulkItemInferenceFailure( BulkShardRequest bulkShardRequest, BulkItemRequest request, - Map> fieldsForModels, - Releasable releaseOnFinish + Exception e, + Releasable refCount ) { - if (inferenceProvider.performsInference() == false) { - releaseOnFinish.close(); - return; - } - - DocWriteRequest docWriteRequest = request.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; - } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - - @SuppressWarnings("unchecked") - Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_RESULT_FIELD, - k -> new HashMap() - ); - - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - - if (inferenceFieldNames.isEmpty()) { - continue; - } - - docRef.acquire(); - - inferenceProvider.textInference( - modelId, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - new ActionListener<>() { - - @Override - public void onResponse(List results) { - - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() - ); - - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); - } - - docRef.close(); - } - - @Override - public void onFailure(Exception e) { - - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure( - indexName, - docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) - ); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; - - docRef.close(); - } - } - ); - } - } - } - - private static List getFieldNamesForInference( - Map.Entry> fieldModelsEntrySet, - Map docMap - ) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String) { - inferenceFieldNames.add(inferenceField); - } - } - return inferenceFieldNames; + markBulkItemRequestFailed(request, new IllegalArgumentException("Inference failed: " + e.getMessage(), e)); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; + refCount.close(); } - @SuppressWarnings("unchecked") - private static String findMapValue(Map map, String... path) { - Map currentMap = map; - for (int i = 0; i < path.length - 1; i++) { - Object value = currentMap.get(path[i]); + private void markBulkItemRequestFailed(BulkItemRequest request, Exception e) { + final String indexName = request.index(); - if (value instanceof Map) { - currentMap = (Map) value; - } else { - // Invalid path or non-Map value encountered - return null; - } - } - - // Retrieve the final value in the map, if it's a String - Object finalValue = currentMap.get(path[path.length - 1]); - - return (finalValue instanceof String) ? (String) finalValue : null; + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Consumer onComplete) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early - releaseOnFinish.close(); return; } @@ -948,19 +795,16 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - releaseOnFinish.close(); + onComplete.accept(bulkShardRequest); } @Override public void onFailure(Exception e) { // create failures for all relevant requests for (BulkItemRequest request : bulkShardRequest.items()) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + markBulkItemRequestFailed(request, e); } - releaseOnFinish.close(); + onComplete.accept(bulkShardRequest); } }); } From daf0bfc0a44d5a868928abd87009a824c94f2275 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 30 Jan 2024 15:15:12 +0100 Subject: [PATCH 066/106] Move ModelRegistry as an interface to server --- .../inference/ModelRegistry.java | 32 +++++++++++ .../integration/ModelRegistryIT.java | 16 +++--- .../xpack/inference/InferencePlugin.java | 5 +- .../TransportDeleteInferenceModelAction.java | 5 +- .../TransportGetInferenceModelAction.java | 5 +- .../action/TransportInferenceAction.java | 8 +-- .../TransportPutInferenceModelAction.java | 2 +- ...elRegistry.java => ModelRegistryImpl.java} | 57 ++++++++----------- .../registry/ModelRegistryTests.java | 32 +++++------ 9 files changed, 95 insertions(+), 67 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/inference/ModelRegistry.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/{ModelRegistry.java => ModelRegistryImpl.java} (90%) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java new file mode 100644 index 0000000000000..dc7ee7cabf079 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.action.ActionListener; + +import java.util.List; +import java.util.Map; + +public interface ModelRegistry { + void getModel(String modelId, ActionListener listener); + + void getModelsByTaskType(TaskType taskType, ActionListener> listener); + + void getAllModels(ActionListener> listener); + + void storeModel(Model model, ActionListener listener); + + void deleteModel(String modelId, ActionListener listener); + + /** + * Semi parsed model where model id, task type and service + * are known but the settings are not parsed. + */ + record UnparsedModel(String modelId, TaskType taskType, String service, Map settings, Map secrets) {} +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 50647ca328b23..3e46643b90c1c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -25,7 +25,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests; @@ -56,11 +56,11 @@ public class ModelRegistryIT extends ESSingleNodeTestCase { - private ModelRegistry modelRegistry; + private ModelRegistryImpl modelRegistry; @Before public void createComponents() { - modelRegistry = new ModelRegistry(client()); + modelRegistry = new ModelRegistryImpl(client()); } @Override @@ -109,7 +109,7 @@ public void testGetModel() throws Exception { assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(modelId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -165,7 +165,7 @@ public void testDeleteModel() throws Exception { // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); @@ -191,7 +191,7 @@ public void testGetModelsByTaskType() throws InterruptedException { } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() @@ -232,7 +232,7 @@ public void testGetAllModels() throws InterruptedException { assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -261,7 +261,7 @@ public void testGetModelWithSecrets() throws InterruptedException { assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(modelId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3c99b9caac221..5ff82cf438d8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -50,7 +50,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -133,7 +134,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelRegistry = new ModelRegistry(services.client()); + ModelRegistry modelRegistry = new ModelRegistryImpl(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index cb728120d2f0b..b9b833dcec8f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -28,7 +28,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { @@ -68,7 +69,7 @@ protected void masterOperation( ClusterState state, ActionListener listener ) { - SubscribableListener.newForked(modelConfigListener -> { + SubscribableListener.newForked(modelConfigListener -> { modelRegistry.getModel(request.getModelId(), modelConfigListener); }).andThen((l1, unparsedModel) -> { var service = serviceRegistry.getService(unparsedModel.service()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index a7f5fb6c6c9a0..b28e4c8f5102e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -24,7 +24,8 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import java.util.ArrayList; import java.util.List; @@ -120,7 +121,7 @@ private void getModelsByTaskType(TaskType taskType, ActionListener unparsedModels) { + private GetInferenceModelAction.Response parseModels(List unparsedModels) { var parsedModels = new ArrayList(); for (var unparsedModel : unparsedModels) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index db98aeccc556b..038c62fb76cf7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -20,18 +20,18 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; public class TransportInferenceAction extends HandledTransportAction { - private final ModelRegistry modelRegistry; + private final ModelRegistryImpl modelRegistry; private final InferenceServiceRegistry serviceRegistry; @Inject public TransportInferenceAction( TransportService transportService, ActionFilters actionFilters, - ModelRegistry modelRegistry, + ModelRegistryImpl modelRegistry, InferenceServiceRegistry serviceRegistry ) { super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); @@ -42,7 +42,7 @@ public TransportInferenceAction( @Override protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { + ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { delegate.onFailure( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 8bcc07a6322bc..f7a5341c8218b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -42,7 +42,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.inference.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java similarity index 90% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java index 3cc83f2f4ddc5..ef33175152b67 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -55,41 +56,16 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistry { +public class ModelRegistryImpl implements ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} - /** - * Semi parsed model where model id, task type and service - * are known but the settings are not parsed. - */ - public record UnparsedModel( - String modelId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) { - - public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } - } - private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; - private static final Logger logger = LogManager.getLogger(ModelRegistry.class); + private static final Logger logger = LogManager.getLogger(ModelRegistryImpl.class); private final OriginSettingClient client; - public ModelRegistry(Client client) { + public ModelRegistryImpl(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } @@ -106,7 +82,7 @@ public void getModelWithSecrets(String modelId, ActionListener li return; } - delegate.onResponse(UnparsedModel.unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), modelId))); + delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), modelId))); }); QueryBuilder queryBuilder = documentIdQuery(modelId); @@ -124,6 +100,7 @@ public void getModelWithSecrets(String modelId, ActionListener li * @param modelId Model to get * @param listener Model listener */ + @Override public void getModel(String modelId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -132,7 +109,7 @@ public void getModel(String modelId, ActionListener listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -153,6 +130,7 @@ public void getModel(String modelId, ActionListener listener) { * @param taskType The task type * @param listener Models listener */ + @Override public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -161,7 +139,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -190,7 +169,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -252,6 +231,7 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String modelId) { ); } + @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -348,6 +328,7 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } + @Override public void deleteModel(String modelId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -372,4 +353,16 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String modelId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(modelId))); } + + private static UnparsedModel unparsedModelFromMap(ModelRegistryImpl.ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 67e7f3fa68b8c..71a89865fd6c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -65,9 +65,9 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -79,9 +79,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -96,9 +96,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -113,9 +113,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -147,9 +147,9 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -176,9 +176,9 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -201,7 +201,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +218,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +249,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -272,7 +272,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); From c4e66cf99a2ace90ffc1657112aa904a10973651 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 30 Jan 2024 20:50:56 +0100 Subject: [PATCH 067/106] Take back InferenceProviderPlugin changes --- .../elasticsearch/node/NodeConstruction.java | 12 ---- .../plugins/InferenceProviderPlugin.java | 25 -------- .../InferenceActionInferenceProvider.java | 62 ------------------- .../xpack/inference/InferencePlugin.java | 19 ++---- .../inference/registry/ModelRegistryImpl.java | 2 + 5 files changed, 7 insertions(+), 113 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 70aa196aa77ac..1dae328752bdc 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -123,7 +123,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -142,7 +141,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1089,16 +1087,6 @@ record PluginServiceInstances( ); } - InferenceProvider inferenceProvider = null; - Optional inferenceProviderPlugin = getSinglePlugin(InferenceProviderPlugin.class); - if (inferenceProviderPlugin.isPresent()) { - inferenceProvider = inferenceProviderPlugin.get().getInferenceProvider(); - } else { - logger.warn("No inference provider found. Inference for semantic_text field types won't be available"); - inferenceProvider = new InferenceProvider.NoopInferenceProvider(); - } - modules.bindToInstance(InferenceProvider.class, inferenceProvider); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java deleted file mode 100644 index ebd307d3d02c0..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceProvider; - -/** - * An extension point for {@link Plugin} implementations to add inference plugins for use on document ingestion - */ -public interface InferenceProviderPlugin { - - /** - * Returns the inference provider added by this plugin. - * - * @return InferenceProvider added by the plugin - */ - InferenceProvider getInferenceProvider(); - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java deleted file mode 100644 index 7886590c768cb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; - -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; - -/** - * InferenceProvider implementation that uses the inference action to retrieve inference results. - */ -public class InferenceActionInferenceProvider implements InferenceProvider { - - private final Client client; - - public InferenceActionInferenceProvider(Client client) { - this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); - } - - @Override - public void textInference(String modelId, List texts, ActionListener> listener) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified - modelId, - texts, - Map.of(), - InputType.INGEST - ); - - client.execute(InferenceAction.INSTANCE, inferenceRequest, listener.delegateFailure((l, response) -> { - InferenceServiceResults results = response.getResults(); - if (results == null) { - throw new IllegalArgumentException("No inference retrieved for model ID " + modelId); - } - - @SuppressWarnings("unchecked") - List result = (List) results.transformToLegacyFormat(); - l.onResponse(result); - })); - } - - @Override - public boolean performsInference() { - return true; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 5ff82cf438d8f..2a05d9f3c7fb3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,12 +21,12 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.SystemIndexDescriptor; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.node.PluginComponentBinding; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -50,7 +50,6 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -69,7 +68,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceProviderPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -81,7 +80,6 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce inferenceProvider = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -147,15 +145,13 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); - var provider = new InferenceActionInferenceProvider(services.client()); - inferenceProvider.set(provider); - - return List.of(modelRegistry, registry, provider); + return List.of(new PluginComponentBinding<>(ModelRegistry.class, modelRegistry), registry); } @Override public void loadExtensions(ExtensionLoader loader) { inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class); + loader.loadExtensions(ModelRegistry.class); } public List getInferenceServiceFactories() { @@ -244,9 +240,4 @@ public void close() { IOUtils.closeWhileHandlingException(httpManager.get(), throttlerToClose); } - - @Override - public InferenceProvider getInferenceProvider() { - return inferenceProvider.get(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java index ef33175152b67..7b7f3848b0abf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -65,6 +66,7 @@ public record ModelConfigMap(Map config, Map sec private final OriginSettingClient client; + @Inject public ModelRegistryImpl(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } From 5397da7ca085a23d567d6ca911d46eb1fe6a2fd5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 30 Jan 2024 20:51:57 +0100 Subject: [PATCH 068/106] First version with ModelRegistry / InferenceServiceRegistry --- .../BulkShardRequestInferenceProvider.java | 110 ++++++++++++------ .../action/bulk/TransportBulkAction.java | 46 ++++++-- .../bulk/TransportSimulateBulkAction.java | 4 +- .../inference/InferenceProvider.java | 47 -------- 4 files changed, 113 insertions(+), 94 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceProvider.java diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 7f10096b12828..9bbdfb54f0be4 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,14 +14,21 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; -import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -36,10 +43,60 @@ public class BulkShardRequestInferenceProvider { public static final String INFERENCE_FIELD = "result"; public static final String TEXT_FIELD = "text"; - private final InferenceProvider inferenceProvider; + private final Map inferenceProvidersMap; - public BulkShardRequestInferenceProvider(InferenceProvider inferenceProvider) { - this.inferenceProvider = inferenceProvider; + private record InferenceProvider (Model model, InferenceService service) { + private InferenceProvider { + Objects.requireNonNull(model); + Objects.requireNonNull(service); + } + } + + private BulkShardRequestInferenceProvider(Map inferenceProvidersMap) { + this.inferenceProvidersMap = inferenceProvidersMap; + } + + public static void executeWithInferenceProvider( + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry, + Metadata clusterMetadata, + Set shardIds, + Consumer action + ) { + Set inferenceIds = new HashSet<>(); + shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { + var fieldsForModels = clusterMetadata.index(index).getFieldsForModels(); + inferenceIds.addAll(fieldsForModels.keySet()); + }); + final Map inferenceProviderMap = new ConcurrentHashMap<>(); + Runnable onModelLoadingComplete = () -> action.accept(new BulkShardRequestInferenceProvider(inferenceProviderMap)); + try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { + for (var inferenceId : inferenceIds) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + InferenceProvider inferenceProvider = new InferenceProvider( + service.get().parsePersistedConfig( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings()), + service.get() + ); + inferenceProviderMap.put(inferenceId, inferenceProvider); + } + } + + @Override + public void onFailure(Exception e) { + // Do nothing - let it fail afterwards + } + }; + + modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); + } + } } public void processBulkShardRequest( @@ -89,10 +146,6 @@ private void performInferenceOnBulkItemRequest( TriConsumer onBulkItemFailure, Releasable releaseOnFinish ) { - if (inferenceProvider.performsInference() == false) { - releaseOnFinish.close(); - return; - } DocWriteRequest docWriteRequest = bulkItemRequest.request(); Map sourceMap = null; @@ -138,13 +191,25 @@ private void performInferenceOnBulkItemRequest( docRef.acquire(); - inferenceProvider.textInference( - modelId, + InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); + if (inferenceProvider == null) { + onBulkItemFailure.apply( + bulkShardRequest, + bulkItemRequest, + new IllegalArgumentException("No inference provider found for model ID " + modelId) + ); + docRef.close(); + continue; + } + inferenceProvider.service().infer( + inferenceProvider.model, inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + // TODO check for additional settings needed + Map.of(), new ActionListener<>() { @Override - public void onResponse(List results) { + public void onResponse(InferenceServiceResults results) { if (results == null) { throw new IllegalArgumentException( @@ -153,7 +218,7 @@ public void onResponse(List results) { } int i = 0; - for (InferenceResults inferenceResults : results) { + for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { String fieldName = inferenceFieldNames.get(i++); @SuppressWarnings("unchecked") Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( @@ -191,25 +256,4 @@ private static List getFieldNamesForInference(Map.Entry map, String... path) { - Map currentMap = map; - for (int i = 0; i < path.length - 1; i++) { - Object value = currentMap.get(path[i]); - - if (value instanceof Map) { - currentMap = (Map) value; - } else { - // Invalid path or non-Map value encountered - return null; - } - } - - // Retrieve the final value in the map, if it's a String - Object finalValue = currentMap.get(path[path.length - 1]); - - return (finalValue instanceof String) ? (String) finalValue : null; - } - } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 86d0b96754dfa..21896afca09a3 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -55,6 +55,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; @@ -65,7 +66,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -111,7 +113,8 @@ public class TransportBulkAction extends HandledTransportAction> requ return; } - String nodeId = clusterService.localNode().getId(); + BulkShardRequestInferenceProvider.executeWithInferenceProvider(inferenceServiceRegistry, modelRegistry, clusterState.metadata(), + requestsByShard.keySet(), + bulkShardRequestInferenceProvider -> { + processBulkItemRequests(requestsByShard, clusterState, bulkShardRequestInferenceProvider); + }); + + } + + private void processBulkItemRequests(Map> requestsByShard, ClusterState clusterState, + BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -762,7 +785,7 @@ private void executeBulkRequestsByShard(Map> requ bulkShardRequest.timeout(bulkRequest.timeout()); bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); + bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); } Releasable ref = bulkItemRequestCompleteRefCount.acquire(); @@ -772,7 +795,6 @@ private void executeBulkRequestsByShard(Map> requ (bulkShardRequest1, request, e) -> onBulkItemInferenceFailure(bulkShardRequest1, request, e, ref), bsr -> executeBulkShardRequest(bsr, b -> ref.close()) ); - } } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index bc9b1556e36e8..bddd12b7d9238 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -21,7 +21,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.SimulateIngestService; import org.elasticsearch.tasks.Task; @@ -57,7 +56,8 @@ public TransportSimulateBulkAction( indexingPressure, systemIndices, System::nanoTime, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java deleted file mode 100644 index a0b282d327ae8..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; - -/** - * Provides NLP text inference results. Plugins can implement this interface to provide their own inference results. - */ -public interface InferenceProvider { - /** - * Returns InferenceResults for a given model ID and list of texts. - * - * @param modelId model identifier - * @param texts texts to perform inference on - * @param listener listener to be called when inference is complete - */ - void textInference(String modelId, List texts, ActionListener> listener); - - /** - * Returns true if this inference provider can perform inference - * - * @return true if this inference provider can perform inference - */ - boolean performsInference(); - - class NoopInferenceProvider implements InferenceProvider { - - @Override - public void textInference(String modelId, List texts, ActionListener> listener) { - throw new UnsupportedOperationException("No inference provider has been registered"); - } - - @Override - public boolean performsInference() { - return false; - } - } -} From 106e8b79ac3ff9e92bf9567809270024de8f55a9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 31 Jan 2024 10:21:43 +0100 Subject: [PATCH 069/106] Move required string constants to server, adjust inference results indexing --- .../BulkShardRequestInferenceProvider.java | 77 +++++++++---------- ...emanticTextInferenceResultFieldMapper.java | 9 +-- ...icTextInferenceResultFieldMapperTests.java | 40 +++------- 3 files changed, 53 insertions(+), 73 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 9bbdfb54f0be4..7418d56f47b9e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -39,9 +39,9 @@ public class BulkShardRequestInferenceProvider { - public static final String ROOT_RESULT_FIELD = "_ml_inference"; - public static final String INFERENCE_FIELD = "result"; - public static final String TEXT_FIELD = "text"; + public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; + public static final String TEXT_SUBFIELD_NAME = "text"; + public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; private final Map inferenceProvidersMap; @@ -179,7 +179,7 @@ private void performInferenceOnBulkItemRequest( @SuppressWarnings("unchecked") Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_RESULT_FIELD, + ROOT_INFERENCE_FIELD, k -> new HashMap() ); @@ -189,8 +189,6 @@ private void performInferenceOnBulkItemRequest( continue; } - docRef.acquire(); - InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { onBulkItemFailure.apply( @@ -198,48 +196,47 @@ private void performInferenceOnBulkItemRequest( bulkItemRequest, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); - docRef.close(); continue; } - inferenceProvider.service().infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - new ActionListener<>() { - - @Override - public void onResponse(InferenceServiceResults results) { - - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); - } + ActionListener inferenceResultsListener = new ActionListener<>() { + @Override + public void onResponse(InferenceServiceResults results) { - int i = 0; - for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() - ); + if (results == null) { + throw new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ); + } - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); - } + int i = 0; + for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + List> inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new ArrayList<>() + ); - docRef.close(); + // TODO Check inference result type to change subfield name + var inferenceFieldMap = Map.of( + SPARSE_VECTOR_SUBFIELD_NAME, inferenceResults.asMap("output").get("output"), + TEXT_SUBFIELD_NAME, docMap.get(fieldName) + ); + inferenceFieldResultList.add(inferenceFieldMap); } + } - @Override - public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkShardRequest, bulkItemRequest, e); - docRef.close(); - } + @Override + public void onFailure(Exception e) { + onBulkItemFailure.apply(bulkShardRequest, bulkItemRequest, e); } - ); + }; + inferenceProvider.service().infer( + inferenceProvider.model, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + // TODO check for additional settings needed + Map.of(), + ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire())); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java index ff224522034bf..7f338f50f38be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java @@ -42,6 +42,9 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; + /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -107,15 +110,11 @@ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { public static final String CONTENT_TYPE = "_semantic_text_inference"; public static final String NAME = "_semantic_text_inference"; - public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; - public static final String TEXT_SUBFIELD_NAME = "text"; public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( List.of(), - Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME), - List.of(SPARSE_VECTOR_SUBFIELD_NAME), - Set.of(SparseEmbeddingResults.Embedding.EMBEDDING) + Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME) ); private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java index bde6da7fe8277..0c405ee4e821a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -51,6 +51,8 @@ import java.util.Set; import java.util.function.Consumer; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; import static org.hamcrest.Matchers.containsString; public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { @@ -214,7 +216,7 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + "]") + containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]") ); } { @@ -235,7 +237,7 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString("Missing required subfields: [" + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME + "]") + containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]") ); } { @@ -258,31 +260,13 @@ public void testMissingSubfields() throws IOException { ex.getMessage(), containsString( "Missing required subfields: [" - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + SPARSE_VECTOR_SUBFIELD_NAME + ", " - + SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME + + TEXT_SUBFIELD_NAME + "]" ) ); } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), - new SparseVectorSubfieldOptions(true, false, false), - false, - null - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SparseEmbeddingResults.Embedding.EMBEDDING + "]")); - } } public void testExtraSubfields() throws IOException { @@ -460,10 +444,10 @@ private static void addSemanticTextInferenceResults( if (sparseVectorSubfieldOptions.includeEmbedding() == false) { embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING); } - subfieldMap.put(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); + subfieldMap.put(SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap); } if (includeTextSubfield) { - subfieldMap.put(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME, text); + subfieldMap.put(TEXT_SUBFIELD_NAME, text); } if (extraSubfields != null) { subfieldMap.putAll(extraSubfields); @@ -482,14 +466,14 @@ private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuil mappingBuilder.startObject(semanticTextFieldName); mappingBuilder.field("type", "nested"); mappingBuilder.startObject("properties"); - mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME); + mappingBuilder.startObject(SPARSE_VECTOR_SUBFIELD_NAME); mappingBuilder.startObject("properties"); mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING); mappingBuilder.field("type", "sparse_vector"); mappingBuilder.endObject(); mappingBuilder.endObject(); mappingBuilder.endObject(); - mappingBuilder.startObject(SemanticTextInferenceResultFieldMapper.TEXT_SUBFIELD_NAME); + mappingBuilder.startObject(TEXT_SUBFIELD_NAME); mappingBuilder.field("type", "text"); mappingBuilder.field("index", false); mappingBuilder.endObject(); @@ -510,7 +494,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook new Term( path + "." - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token @@ -537,7 +521,7 @@ private static void assertValidChildDoc( childDoc.getFields( childDoc.getPath() + "." - + SemanticTextInferenceResultFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME + + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING ).size() From e19c4df06085d73b89e52a80d3f428a0476a5b41 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 31 Jan 2024 19:15:03 +0100 Subject: [PATCH 070/106] Replace consumers with listener constructs --- .../action/bulk/TransportBulkAction.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 21896afca09a3..f015856d2fcf2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.DocWriteRequest; @@ -85,7 +86,7 @@ import java.util.SortedMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerArray; -import java.util.function.Consumer; +import java.util.function.BiConsumer; import java.util.function.LongSupplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -793,7 +794,11 @@ private void processBulkItemRequests(Map> request bulkShardRequest, clusterState, (bulkShardRequest1, request, e) -> onBulkItemInferenceFailure(bulkShardRequest1, request, e, ref), - bsr -> executeBulkShardRequest(bsr, b -> ref.close()) + bsr -> executeBulkShardRequest(bsr, ActionListener.releaseAfter(ActionListener.noop(), ref), + (request, e) -> { + markBulkItemRequestFailed(request, e); + ref.close(); + }) ); } } @@ -819,7 +824,8 @@ private void markBulkItemRequestFailed(BulkItemRequest request, Exception e) { responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Consumer onComplete) { + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, ActionListener listener, + BiConsumer bulkItemErrorListener) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early return; @@ -835,16 +841,16 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - onComplete.accept(bulkShardRequest); + listener.onResponse(bulkShardRequest); } @Override public void onFailure(Exception e) { // create failures for all relevant requests for (BulkItemRequest request : bulkShardRequest.items()) { - markBulkItemRequestFailed(request, e); + bulkItemErrorListener.accept(request, e); } - onComplete.accept(bulkShardRequest); + listener.onResponse(bulkShardRequest); } }); } From e54265510cc4c606a93d0e137c5969e484a46d4a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 31 Jan 2024 19:55:40 +0100 Subject: [PATCH 071/106] Baby steps for replacing custom listeners with ActionListeners --- .../action/bulk/TransportBulkAction.java | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index f015856d2fcf2..8d70774d36f4e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -793,10 +793,14 @@ private void processBulkItemRequests(Map> request bulkShardRequestInferenceProvider.processBulkShardRequest( bulkShardRequest, clusterState, - (bulkShardRequest1, request, e) -> onBulkItemInferenceFailure(bulkShardRequest1, request, e, ref), - bsr -> executeBulkShardRequest(bsr, ActionListener.releaseAfter(ActionListener.noop(), ref), - (request, e) -> { - markBulkItemRequestFailed(request, e); + (bsr, request, e) -> { + markBulkItemRequestFailed(bsr, request, e); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; + }, + shardReq -> executeBulkShardRequest(shardReq, ActionListener.releaseAfter(ActionListener.noop(), ref), + (itemReq, e) -> { + markBulkItemRequestFailed(bulkShardRequest, itemReq, e); ref.close(); }) ); @@ -804,30 +808,22 @@ private void processBulkItemRequests(Map> request } } - private void onBulkItemInferenceFailure( - BulkShardRequest bulkShardRequest, - BulkItemRequest request, - Exception e, - Releasable refCount - ) { - markBulkItemRequestFailed(request, new IllegalArgumentException("Inference failed: " + e.getMessage(), e)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; - refCount.close(); - } - - private void markBulkItemRequestFailed(BulkItemRequest request, Exception e) { - final String indexName = request.index(); + private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, Exception e) { + final String indexName = itemRequest.index(); - DocWriteRequest docWriteRequest = request.request(); + DocWriteRequest docWriteRequest = itemRequest.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); + + // make sure the request gets never processed again + shardRequest.items()[itemRequest.id()] = null; } private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, ActionListener listener, BiConsumer bulkItemErrorListener) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early + listener.onResponse(bulkShardRequest); return; } @@ -850,7 +846,7 @@ public void onFailure(Exception e) { for (BulkItemRequest request : bulkShardRequest.items()) { bulkItemErrorListener.accept(request, e); } - listener.onResponse(bulkShardRequest); + listener.onFailure(e); } }); } From 3293be418782a573fd2ba809ac1b142f725e4d22 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 31 Jan 2024 20:00:14 +0100 Subject: [PATCH 072/106] Get more similar interfaces for item processors --- .../bulk/BulkShardRequestInferenceProvider.java | 17 +++++++---------- .../action/bulk/TransportBulkAction.java | 7 +++---- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 7418d56f47b9e..d08e321e4f35b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -34,6 +33,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -102,7 +102,7 @@ public void onFailure(Exception e) { public void processBulkShardRequest( BulkShardRequest bulkShardRequest, ClusterState clusterState, - TriConsumer onBulkItemFailure, + BiConsumer onBulkItemFailure, Consumer nextAction ) { @@ -143,7 +143,7 @@ private void performInferenceOnBulkItemRequest( BulkShardRequest bulkShardRequest, BulkItemRequest bulkItemRequest, Map> fieldsForModels, - TriConsumer onBulkItemFailure, + BiConsumer onBulkItemFailure, Releasable releaseOnFinish ) { @@ -191,8 +191,7 @@ private void performInferenceOnBulkItemRequest( InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { - onBulkItemFailure.apply( - bulkShardRequest, + onBulkItemFailure.accept( bulkItemRequest, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); @@ -212,10 +211,8 @@ public void onResponse(InferenceServiceResults results) { for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { String fieldName = inferenceFieldNames.get(i++); @SuppressWarnings("unchecked") - List> inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new ArrayList<>() - ); + List> inferenceFieldResultList = (List>) rootInferenceFieldMap + .computeIfAbsent(fieldName, k -> new ArrayList<>()); // TODO Check inference result type to change subfield name var inferenceFieldMap = Map.of( @@ -228,7 +225,7 @@ public void onResponse(InferenceServiceResults results) { @Override public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkShardRequest, bulkItemRequest, e); + onBulkItemFailure.accept(bulkItemRequest, e); } }; inferenceProvider.service().infer( diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 8d70774d36f4e..b7b5bfc8b8a3f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.DocWriteRequest; @@ -793,10 +792,10 @@ private void processBulkItemRequests(Map> request bulkShardRequestInferenceProvider.processBulkShardRequest( bulkShardRequest, clusterState, - (bsr, request, e) -> { - markBulkItemRequestFailed(bsr, request, e); + (itemReq, e) -> { + markBulkItemRequestFailed(bulkShardRequest, itemReq, e); // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; + bulkShardRequest.items()[itemReq.id()] = null; }, shardReq -> executeBulkShardRequest(shardReq, ActionListener.releaseAfter(ActionListener.noop(), ref), (itemReq, e) -> { From d12e31ec813311a1810f60831fbdfe6dc202e42e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 31 Jan 2024 20:10:53 +0100 Subject: [PATCH 073/106] More changes to listeners --- .../BulkShardRequestInferenceProvider.java | 12 ++++------ .../action/bulk/TransportBulkAction.java | 24 +++++++++++++------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index d08e321e4f35b..43aa00c49649b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -102,16 +102,16 @@ public void onFailure(Exception e) { public void processBulkShardRequest( BulkShardRequest bulkShardRequest, ClusterState clusterState, - BiConsumer onBulkItemFailure, - Consumer nextAction + ActionListener listener, + BiConsumer onBulkItemFailure ) { Map> fieldsForModels = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) .getFieldsForModels(); - // No inference fields? Just execute the request + // No inference fields? Terminate early if (fieldsForModels.isEmpty()) { - nextAction.accept(bulkShardRequest); + listener.onResponse(bulkShardRequest); return; } @@ -123,13 +123,12 @@ public void processBulkShardRequest( bulkShardRequest.getRefreshPolicy(), Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) ); - nextAction.accept(errorsFilteredShardRequest); + listener.onResponse(errorsFilteredShardRequest); }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { for (BulkItemRequest bulkItemRequest : bulkShardRequest.items()) { performInferenceOnBulkItemRequest( - bulkShardRequest, bulkItemRequest, fieldsForModels, onBulkItemFailure, @@ -140,7 +139,6 @@ public void processBulkShardRequest( } private void performInferenceOnBulkItemRequest( - BulkShardRequest bulkShardRequest, BulkItemRequest bulkItemRequest, Map> fieldsForModels, BiConsumer onBulkItemFailure, diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index b7b5bfc8b8a3f..93c7f2e6f0e44 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -792,16 +792,26 @@ private void processBulkItemRequests(Map> request bulkShardRequestInferenceProvider.processBulkShardRequest( bulkShardRequest, clusterState, - (itemReq, e) -> { + new ActionListener() { + @Override + public void onResponse(BulkShardRequest bulkShardRequest) { + executeBulkShardRequest(bulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), + (itemReq, e) -> { + markBulkItemRequestFailed(bulkShardRequest, itemReq, e); + ref.close(); + }); + } + + @Override + public void onFailure(Exception e) { + ref.close(); + } + }, + (itemReq, e) -> { markBulkItemRequestFailed(bulkShardRequest, itemReq, e); // make sure the request gets never processed again bulkShardRequest.items()[itemReq.id()] = null; - }, - shardReq -> executeBulkShardRequest(shardReq, ActionListener.releaseAfter(ActionListener.noop(), ref), - (itemReq, e) -> { - markBulkItemRequestFailed(bulkShardRequest, itemReq, e); - ref.close(); - }) + } ); } } From cedca0738fee15850bd0abc97aa32177807f422f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 10:27:34 +0100 Subject: [PATCH 074/106] Refactorings to create the inference provider --- .../BulkShardRequestInferenceProvider.java | 21 +++++----- .../action/bulk/TransportBulkAction.java | 40 +++++++++++-------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 43aa00c49649b..a494d17b339cd 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -34,7 +33,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiConsumer; -import java.util.function.Consumer; import java.util.stream.Collectors; public class BulkShardRequestInferenceProvider { @@ -43,6 +41,7 @@ public class BulkShardRequestInferenceProvider { public static final String TEXT_SUBFIELD_NAME = "text"; public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; + private final ClusterState clusterState; private final Map inferenceProvidersMap; private record InferenceProvider (Model model, InferenceService service) { @@ -52,24 +51,27 @@ private record InferenceProvider (Model model, InferenceService service) { } } - private BulkShardRequestInferenceProvider(Map inferenceProvidersMap) { + private BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { + this.clusterState = clusterState; this.inferenceProvidersMap = inferenceProvidersMap; } - public static void executeWithInferenceProvider( + public static void getInferenceProvider( InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, - Metadata clusterMetadata, + ClusterState clusterState, Set shardIds, - Consumer action + ActionListener listener ) { Set inferenceIds = new HashSet<>(); shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { - var fieldsForModels = clusterMetadata.index(index).getFieldsForModels(); + var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); inferenceIds.addAll(fieldsForModels.keySet()); }); final Map inferenceProviderMap = new ConcurrentHashMap<>(); - Runnable onModelLoadingComplete = () -> action.accept(new BulkShardRequestInferenceProvider(inferenceProviderMap)); + Runnable onModelLoadingComplete = () -> listener.onResponse( + new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) + ); try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { for (var inferenceId : inferenceIds) { ActionListener modelLoadingListener = new ActionListener<>() { @@ -90,7 +92,7 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { @Override public void onFailure(Exception e) { - // Do nothing - let it fail afterwards + // Do nothing - let it fail afterwards when model is retrieved } }; @@ -101,7 +103,6 @@ public void onFailure(Exception e) { public void processBulkShardRequest( BulkShardRequest bulkShardRequest, - ClusterState clusterState, ActionListener listener, BiConsumer onBulkItemFailure ) { diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 93c7f2e6f0e44..9938031117597 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SparseFixedBitSet; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceAlreadyExistsException; @@ -754,15 +755,23 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.executeWithInferenceProvider(inferenceServiceRegistry, modelRegistry, clusterState.metadata(), + BulkShardRequestInferenceProvider.getInferenceProvider(inferenceServiceRegistry, modelRegistry, clusterState, requestsByShard.keySet(), - bulkShardRequestInferenceProvider -> { - processBulkItemRequests(requestsByShard, clusterState, bulkShardRequestInferenceProvider); - }); + new ActionListener() { + @Override + public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { + processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); + } + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error loading inference models", e); + } + } + ); } - private void processBulkItemRequests(Map> requestsByShard, ClusterState clusterState, + private void processRequestsByShards(Map> requestsByShard, ClusterState clusterState, BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { Runnable onBulkItemsComplete = () -> { listener.onResponse( @@ -789,29 +798,26 @@ private void processBulkItemRequests(Map> request } Releasable ref = bulkItemRequestCompleteRefCount.acquire(); + final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed( + bulkShardRequest, + itemReq, + e + ); bulkShardRequestInferenceProvider.processBulkShardRequest( bulkShardRequest, - clusterState, - new ActionListener() { + new ActionListener<>() { @Override public void onResponse(BulkShardRequest bulkShardRequest) { executeBulkShardRequest(bulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), - (itemReq, e) -> { - markBulkItemRequestFailed(bulkShardRequest, itemReq, e); - ref.close(); - }); + bulkItemFailedListener); } @Override public void onFailure(Exception e) { - ref.close(); + throw new ElasticsearchException("Error performing inference", e); } }, - (itemReq, e) -> { - markBulkItemRequestFailed(bulkShardRequest, itemReq, e); - // make sure the request gets never processed again - bulkShardRequest.items()[itemReq.id()] = null; - } + bulkItemFailedListener ); } } From 134dd001bebe70501d42521a5a2327937b5fef58 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 10:41:59 +0100 Subject: [PATCH 075/106] Minor refactorings --- .../BulkShardRequestInferenceProvider.java | 21 ++++++++++--------- .../action/bulk/TransportBulkAction.java | 15 ++++++++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index a494d17b339cd..5a59921672b3d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -35,10 +35,19 @@ import java.util.function.BiConsumer; import java.util.stream.Collectors; +/** + * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. + */ public class BulkShardRequestInferenceProvider { + + // Root field name for storing inference results public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; + + // Contains the original text for the field public static final String TEXT_SUBFIELD_NAME = "text"; + + // Contains the inference result when it's a sparse vector public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; private final ClusterState clusterState; @@ -56,7 +65,7 @@ private BulkShardRequestInferenceProvider(ClusterState clusterState, Map { - // We need to remove items that have had an inference error, as the response will have been updated already - // and we don't need to process them further - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) - ); - listener.onResponse(errorsFilteredShardRequest); + listener.onResponse(bulkShardRequest); }; - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { for (BulkItemRequest bulkItemRequest : bulkShardRequest.items()) { performInferenceOnBulkItemRequest( diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 9938031117597..b86bf631e6c9a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -77,6 +77,7 @@ import org.elasticsearch.transport.TransportService; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -755,7 +756,7 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInferenceProvider(inferenceServiceRegistry, modelRegistry, clusterState, + BulkShardRequestInferenceProvider.getInstance(inferenceServiceRegistry, modelRegistry, clusterState, requestsByShard.keySet(), new ActionListener() { @Override @@ -808,7 +809,14 @@ private void processRequestsByShards(Map> request new ActionListener<>() { @Override public void onResponse(BulkShardRequest bulkShardRequest) { - executeBulkShardRequest(bulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + executeBulkShardRequest(errorsFilteredShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), bulkItemFailedListener); } @@ -823,6 +831,7 @@ public void onFailure(Exception e) { } } + // When an item fails, store the failure in the responses array private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, Exception e) { final String indexName = itemRequest.index(); @@ -830,7 +839,7 @@ private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRe BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again + // make sure the request gets never processed again, removing the item from the shard request shardRequest.items()[itemRequest.id()] = null; } From 3d0b537072098d41b6985236b8cb7ea40c0215ac Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 11:09:39 +0100 Subject: [PATCH 076/106] Minor refactorings --- .../BulkShardRequestInferenceProvider.java | 48 ++++---- .../action/bulk/TransportBulkAction.java | 108 +++++++++--------- 2 files changed, 77 insertions(+), 79 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 5a59921672b3d..95e3f66e95213 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -24,7 +24,6 @@ import org.elasticsearch.inference.ModelRegistry; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -40,27 +39,26 @@ */ public class BulkShardRequestInferenceProvider { - - // Root field name for storing inference results + // Root field name for storing inference results public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; - // Contains the original text for the field + // Contains the original text for the field public static final String TEXT_SUBFIELD_NAME = "text"; - // Contains the inference result when it's a sparse vector + // Contains the inference result when it's a sparse vector public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding"; private final ClusterState clusterState; private final Map inferenceProvidersMap; - private record InferenceProvider (Model model, InferenceService service) { + private record InferenceProvider(Model model, InferenceService service) { private InferenceProvider { Objects.requireNonNull(model); Objects.requireNonNull(service); } } - private BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { + BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { this.clusterState = clusterState; this.inferenceProvidersMap = inferenceProvidersMap; } @@ -89,10 +87,7 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { InferenceProvider inferenceProvider = new InferenceProvider( - service.get().parsePersistedConfig( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings()), + service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()), service.get() ); inferenceProviderMap.put(inferenceId, inferenceProvider); @@ -125,17 +120,10 @@ public void processBulkShardRequest( return; } - Runnable onInferenceComplete = () -> { - listener.onResponse(bulkShardRequest); - }; + Runnable onInferenceComplete = () -> { listener.onResponse(bulkShardRequest); }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { for (BulkItemRequest bulkItemRequest : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest( - bulkItemRequest, - fieldsForModels, - onBulkItemFailure, - bulkItemReqRef.acquire() - ); + performInferenceOnBulkItemRequest(bulkItemRequest, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire()); } } } @@ -216,8 +204,10 @@ public void onResponse(InferenceServiceResults results) { // TODO Check inference result type to change subfield name var inferenceFieldMap = Map.of( - SPARSE_VECTOR_SUBFIELD_NAME, inferenceResults.asMap("output").get("output"), - TEXT_SUBFIELD_NAME, docMap.get(fieldName) + SPARSE_VECTOR_SUBFIELD_NAME, + inferenceResults.asMap("output").get("output"), + TEXT_SUBFIELD_NAME, + docMap.get(fieldName) ); inferenceFieldResultList.add(inferenceFieldMap); } @@ -228,12 +218,14 @@ public void onFailure(Exception e) { onBulkItemFailure.accept(bulkItemRequest, e); } }; - inferenceProvider.service().infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire())); + inferenceProvider.service() + .infer( + inferenceProvider.model, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + // TODO check for additional settings needed + Map.of(), + ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) + ); } } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index b86bf631e6c9a..15f157de0a945 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -129,10 +129,8 @@ public TransportBulkAction( IndexNameExpressionResolver indexNameExpressionResolver, IndexingPressure indexingPressure, SystemIndices systemIndices, - @Nullable - InferenceServiceRegistry inferenceServiceRegistry, - @Nullable - ModelRegistry modelRegistry + @Nullable InferenceServiceRegistry inferenceServiceRegistry, + @Nullable ModelRegistry modelRegistry ) { this( threadPool, @@ -161,10 +159,8 @@ public TransportBulkAction( IndexingPressure indexingPressure, SystemIndices systemIndices, LongSupplier relativeTimeProvider, - @Nullable - InferenceServiceRegistry inferenceServiceRegistry, - @Nullable - ModelRegistry modelRegistry + @Nullable InferenceServiceRegistry inferenceServiceRegistry, + @Nullable ModelRegistry modelRegistry ) { this( BulkAction.INSTANCE, @@ -197,10 +193,8 @@ public TransportBulkAction( IndexingPressure indexingPressure, SystemIndices systemIndices, LongSupplier relativeTimeProvider, - @Nullable - InferenceServiceRegistry inferenceServiceRegistry, - @Nullable - ModelRegistry modelRegistry + @Nullable InferenceServiceRegistry inferenceServiceRegistry, + @Nullable ModelRegistry modelRegistry ) { super(bulkAction.name(), transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); Objects.requireNonNull(relativeTimeProvider); @@ -756,7 +750,10 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInstance(inferenceServiceRegistry, modelRegistry, clusterState, + BulkShardRequestInferenceProvider.getInstance( + inferenceServiceRegistry, + modelRegistry, + clusterState, requestsByShard.keySet(), new ActionListener() { @Override @@ -772,8 +769,11 @@ public void onFailure(Exception e) { ); } - private void processRequestsByShards(Map> requestsByShard, ClusterState clusterState, - BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { + private void processRequestsByShards( + Map> requestsByShard, + ClusterState clusterState, + BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider + ) { Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -785,18 +785,7 @@ private void processRequestsByShards(Map> request for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); - - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } + BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); Releasable ref = bulkItemRequestCompleteRefCount.acquire(); final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed( @@ -804,33 +793,47 @@ private void processRequestsByShards(Map> request itemReq, e ); - bulkShardRequestInferenceProvider.processBulkShardRequest( - bulkShardRequest, - new ActionListener<>() { - @Override - public void onResponse(BulkShardRequest bulkShardRequest) { - // We need to remove items that have had an inference error, as the response will have been updated already - // and we don't need to process them further - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) - ); - executeBulkShardRequest(errorsFilteredShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), - bulkItemFailedListener); - } + bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { + @Override + public void onResponse(BulkShardRequest bulkShardRequest) { + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + executeBulkShardRequest( + errorsFilteredShardRequest, + ActionListener.releaseAfter(ActionListener.noop(), ref), + bulkItemFailedListener + ); + } - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error performing inference", e); - } - }, - bulkItemFailedListener - ); + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error performing inference", e); + } + }, bulkItemFailedListener); } } } + private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } + return bulkShardRequest; + } + // When an item fails, store the failure in the responses array private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, Exception e) { final String indexName = itemRequest.index(); @@ -843,8 +846,11 @@ private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRe shardRequest.items()[itemRequest.id()] = null; } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, ActionListener listener, - BiConsumer bulkItemErrorListener) { + private void executeBulkShardRequest( + BulkShardRequest bulkShardRequest, + ActionListener listener, + BiConsumer bulkItemErrorListener + ) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early listener.onResponse(bulkShardRequest); From dd0f2e9baa8cb25a79056212351df82fc3f50ee8 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 11:10:04 +0100 Subject: [PATCH 077/106] Remove references from InferenceProvider from tests, remove current inference tests to be redone --- ...ActionIndicesThatCannotBeCreatedTests.java | 4 +- .../TransportBulkActionInferenceTests.java | 340 ------------------ .../bulk/TransportBulkActionIngestTests.java | 4 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 4 +- .../snapshots/SnapshotResiliencyTests.java | 4 +- 6 files changed, 10 insertions(+), 350 deletions(-) delete mode 100644 server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 5a0bc1e422ddb..5581086a94187 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.EmptySystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockUtils; @@ -125,7 +124,8 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ) { @Override void executeBulk( diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java deleted file mode 100644 index 08948d76ed8a1..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ /dev/null @@ -1,340 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.ActionTestUtils; -import org.elasticsearch.action.support.AutoCreateIndex; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.node.DiscoveryNodes; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.UUIDs; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexingPressure; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.indices.EmptySystemIndices; -import org.elasticsearch.indices.TestIndexNameExpressionResolver; -import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.TestInferenceResults; -import org.elasticsearch.ingest.IngestService; -import org.elasticsearch.test.ClusterServiceUtils; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.TransportService; -import org.junit.After; -import org.junit.Before; -import org.mockito.verification.VerificationMode; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - -import static org.hamcrest.Matchers.equalTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class TransportBulkActionInferenceTests extends ESTestCase { - - public static final String INDEX_NAME = "index"; - public static final String INFERENCE_FIELD_1_MODEL_A = "inference_field_1_model_a"; - public static final String MODEL_A_ID = "model_a_id"; - private static final String INFERENCE_FIELD_2_MODEL_A = "inference_field_2_model_a"; - public static final String MODEL_B_ID = "model_b_id"; - private static final String INFERENCE_FIELD_MODEL_B = "inference_field_model_b"; - private ClusterService clusterService; - private ThreadPool threadPool; - private NodeClient nodeClient; - private TransportBulkAction transportBulkAction; - - private InferenceProvider inferenceProvider; - - @Before - public void setup() { - threadPool = new TestThreadPool(getClass().getName()); - nodeClient = mock(NodeClient.class); - - // Contains the fields for models for the index - Metadata metadata = Metadata.builder() - .indices( - Map.of( - INDEX_NAME, - IndexMetadata.builder(INDEX_NAME) - .settings(settings(IndexVersion.current())) - .fieldsForModels( - Map.of( - MODEL_A_ID, - Set.of(INFERENCE_FIELD_1_MODEL_A, INFERENCE_FIELD_2_MODEL_A), - MODEL_B_ID, - Set.of(INFERENCE_FIELD_MODEL_B) - ) - ) - .numberOfShards(1) - .numberOfReplicas(1) - .build() - ) - ) - .build(); - - DiscoveryNode masterNode = DiscoveryNodeUtils.create(UUIDs.base64UUID()); - ClusterState state = ClusterState.builder(ClusterName.DEFAULT) - .metadata(metadata) - .nodes(DiscoveryNodes.builder().add(masterNode).localNodeId(masterNode.getId()).masterNodeId(masterNode.getId())) - .build(); - - clusterService = ClusterServiceUtils.createClusterService(state, threadPool); - - inferenceProvider = mock(InferenceProvider.class); - when(inferenceProvider.performsInference()).thenReturn(true); - - transportBulkAction = new TransportBulkAction( - threadPool, - mock(TransportService.class), - clusterService, - mock(IngestService.class), - nodeClient, - new ActionFilters(Collections.emptySet()), - TestIndexNameExpressionResolver.newInstance(), - new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), - EmptySystemIndices.INSTANCE, - inferenceProvider - ); - - // Default answers to avoid hanging tests due to unexpected invocations - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new Exception("Unexpected invocation")); - return Void.TYPE; - }).when(inferenceProvider).textInference(any(), any(), any()); - when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onFailure(new Exception("Unexpected invocation")); - return null; - }); - } - - @After - public void tearDown() throws Exception { - ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - threadPool = null; - clusterService.close(); - super.tearDown(); - } - - public void testBulkRequestWithoutInference() { - BulkRequest bulkRequest = new BulkRequest(); - IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); - indexRequest.source("non_inference_field", "text", "another_non_inference_field", "other text"); - bulkRequest.add(indexRequest); - - expectTransportShardBulkActionRequest(1); - - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - - assertThat(response.getItems().length, equalTo(1)); - assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(never()); - } - - public void testBulkRequestWithInference() { - BulkRequest bulkRequest = new BulkRequest(); - IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); - String inferenceFieldText = "some text"; - indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldText, "non_inference_field", "other text"); - bulkRequest.add(indexRequest); - - expectInferenceRequest(MODEL_A_ID, inferenceFieldText); - - expectTransportShardBulkActionRequest(1); - - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - - assertThat(response.getItems().length, equalTo(1)); - assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(times(1)); - } - - public void testBulkRequestWithMultipleFieldsInference() { - BulkRequest bulkRequest = new BulkRequest(); - IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id"); - String inferenceField1Text = "some text"; - String inferenceField2Text = "some other text"; - String inferenceField3Text = "more inference text"; - indexRequest.source( - INFERENCE_FIELD_1_MODEL_A, - inferenceField1Text, - INFERENCE_FIELD_2_MODEL_A, - inferenceField2Text, - INFERENCE_FIELD_MODEL_B, - inferenceField3Text, - "non_inference_field", - "other text" - ); - bulkRequest.add(indexRequest); - - expectInferenceRequest(MODEL_A_ID, inferenceField1Text, inferenceField2Text); - expectInferenceRequest(MODEL_B_ID, inferenceField3Text); - - expectTransportShardBulkActionRequest(1); - - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - - assertThat(response.getItems().length, equalTo(1)); - assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(times(2)); - } - - public void testBulkRequestWithMultipleDocs() { - BulkRequest bulkRequest = new BulkRequest(); - IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); - String inferenceFieldTextDoc1 = "some text"; - bulkRequest.add(indexRequest); - indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "other text"); - indexRequest = new IndexRequest(INDEX_NAME).id("id2"); - String inferenceFieldTextDoc2 = "some other text"; - indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc2, "non_inference_field", "more text"); - bulkRequest.add(indexRequest); - - expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc1); - expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc2); - - expectTransportShardBulkActionRequest(2); - - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - - assertThat(response.getItems().length, equalTo(2)); - assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(times(2)); - } - - public void testFailingInference() { - BulkRequest bulkRequest = new BulkRequest(); - IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); - String inferenceFieldTextDoc1 = "some text"; - indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "more text"); - bulkRequest.add(indexRequest); - indexRequest = new IndexRequest(INDEX_NAME).id("id1"); - String inferenceFieldTextDoc2 = "some text"; - indexRequest.source(INFERENCE_FIELD_MODEL_B, inferenceFieldTextDoc2, "non_inference_field", "more text"); - bulkRequest.add(indexRequest); - - expectInferenceRequestFails(MODEL_A_ID, inferenceFieldTextDoc1); - expectInferenceRequest(MODEL_B_ID, inferenceFieldTextDoc2); - - // Only non-failing inference requests will be executed - expectTransportShardBulkActionRequest(1); - - PlainActionFuture future = new PlainActionFuture<>(); - ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); - BulkResponse response = future.actionGet(); - - assertThat(response.getItems().length, equalTo(2)); - assertTrue(response.getItems()[0].isFailed()); - assertFalse(response.getItems()[1].isFailed()); - verifyInferenceExecuted(times(2)); - } - - private void verifyInferenceExecuted(VerificationMode verificationMode) { - verify(inferenceProvider, verificationMode).textInference(any(), any(), any()); - } - - private void expectTransportShardBulkActionRequest(int requestSize) { - when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), argThat(r -> matchBulkShardRequest(r, requestSize)), any())) - .thenAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener) invocation.getArguments()[2]; - var bulkShardRequest = (BulkShardRequest) invocation.getArguments()[1]; - ShardId shardId = new ShardId(INDEX_NAME, "UUID", 0); - BulkItemResponse[] bulkItemResponses = Arrays.stream(bulkShardRequest.items()) - .map(item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse( - shardId, - "id", - 0, 0, 0, true) - ) - ).toArray(BulkItemResponse[]::new); - - listener.onResponse(new BulkShardResponse(shardId, bulkItemResponses)); - return null; - }); - } - - private boolean matchBulkShardRequest(ActionRequest request, int requestSize) { - return (request instanceof BulkShardRequest) && ((BulkShardRequest) request).items().length == requestSize; - } - - @SuppressWarnings("unchecked") - private void expectInferenceRequest(String modelId, String... inferenceTexts) { - doAnswer(invocation -> { - List texts = (List) invocation.getArguments()[1]; - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onResponse( - texts.stream() - .map( - text -> new TestInferenceResults( - "test_field", - randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomFloat())) - ) - ) - .collect(Collectors.toList()) - ); - return Void.TYPE; - }).when(inferenceProvider) - .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); - } - - private void expectInferenceRequestFails(String modelId, String... inferenceTexts) { - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new Exception("Inference failed")); - return Void.TYPE; - }).when(inferenceProvider) - .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); - } - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index efd75f3c4d82e..9a04e41957272 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -41,7 +41,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -137,7 +136,8 @@ class TestTransportBulkAction extends TransportBulkAction { TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 44973831817ed..1af1ef32aa8b1 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -40,7 +40,6 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.SystemIndexDescriptorUtils; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.index.IndexVersionUtils; @@ -89,7 +88,8 @@ class TestTransportBulkAction extends TransportBulkAction { new Resolver(), new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index d94a1cb092bc4..d577a7dcc0313 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -32,7 +32,6 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; @@ -254,7 +253,8 @@ static class TestTransportBulkAction extends TransportBulkAction { new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, relativeTimeProvider, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index f8fdcbd09ce78..1ecc2782ca858 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -155,7 +155,6 @@ import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.indices.recovery.SnapshotFilesProvider; import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.StatusInfo; import org.elasticsearch.node.ResponseCollectorService; @@ -1944,7 +1943,8 @@ protected void assertSnapshotOrGenericThread() { indexNameExpressionResolver, new IndexingPressure(settings), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( From 2386a3b3ec9d89a435cbb5023144ea29768a89ff Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 12:42:33 +0100 Subject: [PATCH 078/106] Fix existing semantic text fields --- .../SemanticTextClusterMetadataTests.java | 12 +++++++ ...icTextInferenceResultFieldMapperTests.java | 33 +++---------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 47cae14003c70..2a47fc46311de 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -13,14 +13,26 @@ import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; public class SemanticTextClusterMetadataTests extends MlSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + Collection> plugins = new ArrayList<>(super.getPlugins()); + plugins.add(InferencePlugin.class); + return Collections.unmodifiableCollection(plugins); + } + public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java index 0c405ee4e821a..f44be37906c6d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -214,10 +214,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]") - ); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]")); } { DocumentParsingException ex = expectThrows( @@ -235,10 +232,7 @@ public void testMissingSubfields() throws IOException { ) ) ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]") - ); + assertThat(ex.getMessage(), containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]")); } { DocumentParsingException ex = expectThrows( @@ -258,13 +252,7 @@ public void testMissingSubfields() throws IOException { ); assertThat( ex.getMessage(), - containsString( - "Missing required subfields: [" - + SPARSE_VECTOR_SUBFIELD_NAME - + ", " - + TEXT_SUBFIELD_NAME - + "]" - ) + containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + ", " + TEXT_SUBFIELD_NAME + "]") ); } } @@ -491,14 +479,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook queryBuilder.add( new BooleanClause( new TermQuery( - new Term( - path - + "." - + SPARSE_VECTOR_SUBFIELD_NAME - + "." - + SparseEmbeddingResults.Embedding.EMBEDDING, - token - ) + new Term(path + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token) ), BooleanClause.Occur.MUST ) @@ -519,11 +500,7 @@ private static void assertValidChildDoc( new VisitedChildDocInfo( childDoc.getPath(), childDoc.getFields( - childDoc.getPath() - + "." - + SPARSE_VECTOR_SUBFIELD_NAME - + "." - + SparseEmbeddingResults.Embedding.EMBEDDING + childDoc.getPath() + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING ).size() ) ); From 3c9c32fc28645f5c0f29cb2cb47be765c15c6f36 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 12:56:01 +0100 Subject: [PATCH 079/106] Include changes from main in ModelRegistry --- .../inference/ModelRegistry.java | 10 +++- ...gistryIT.java => ModelRegistryImplIT.java} | 52 +++++++++---------- ...Tests.java => ModelRegistryImplTests.java} | 2 +- 3 files changed, 35 insertions(+), 29 deletions(-) rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryIT.java => ModelRegistryImplIT.java} (86%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryTests.java => ModelRegistryImplTests.java} (99%) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java index dc7ee7cabf079..4d0f3a8f2ea9b 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java @@ -25,8 +25,14 @@ public interface ModelRegistry { void deleteModel(String modelId, ActionListener listener); /** - * Semi parsed model where model id, task type and service + * Semi parsed model where inference entity id, task type and service * are known but the settings are not parsed. */ - record UnparsedModel(String modelId, TaskType taskType, String service, Map settings, Map secrets) {} + record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) {} } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java similarity index 86% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java index d6d0eb0bbbf21..614ebee99ae4f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java @@ -25,7 +25,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests; @@ -54,13 +54,13 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class ModelRegistryIT extends ESSingleNodeTestCase { +public class ModelRegistryImplIT extends ESSingleNodeTestCase { - private ModelRegistry modelRegistry; + private ModelRegistryImpl ModelRegistryImpl; @Before public void createComponents() { - modelRegistry = new ModelRegistry(client()); + ModelRegistryImpl = new ModelRegistryImpl(client()); } @Override @@ -74,7 +74,7 @@ public void testStoreModel() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); assertThat(storeModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); @@ -86,7 +86,7 @@ public void testStoreModelWithUnknownFields() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); assertNull(storeModelHolder.get()); assertNotNull(exceptionHolder.get()); @@ -105,12 +105,12 @@ public void testGetModel() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -132,13 +132,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); putModelHolder.set(false); // an model with the same id exists - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(false)); assertThat(exceptionHolder.get(), not(nullValue())); assertThat( @@ -153,20 +153,20 @@ public void testDeleteModel() throws Exception { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference deleteResponseHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertTrue(deleteResponseHolder.get()); // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); @@ -186,13 +186,13 @@ public void testGetModelsByTaskType() throws InterruptedException { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) @@ -203,7 +203,7 @@ public void testGetModelsByTaskType() throws InterruptedException { assertThat(m.secrets().keySet(), empty()); }); - blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(2)); var denseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) @@ -227,13 +227,13 @@ public void testGetAllModels() throws InterruptedException { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -257,18 +257,18 @@ public void testGetModelWithSecrets() throws InterruptedException { AtomicReference exceptionHolder = new AtomicReference<>(); var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); // get model without secrets - blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); + blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java similarity index 99% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java index c7810667bff34..fd6a203450c12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ModelRegistryTests extends ESTestCase { +public class ModelRegistryImplTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); From fb9b9e53efeca35846351986d240c328c6c0a66a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 12:59:54 +0100 Subject: [PATCH 080/106] Fix merge from main --- .../action/bulk/BulkShardRequestInferenceProvider.java | 2 ++ .../inference/action/TransportGetInferenceModelAction.java | 2 +- .../inference/action/TransportPutInferenceModelAction.java | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 95e3f66e95213..7ccf3d6baaac2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelRegistry; @@ -224,6 +225,7 @@ public void onFailure(Exception e) { inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), // TODO check for additional settings needed Map.of(), + InputType.INGEST, ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 58f406a8fe554..ac43b256b0770 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -24,7 +25,6 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 44a024b306f57..f94da64558132 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -42,7 +43,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.inference.ModelRegistry; import java.io.IOException; import java.util.Map; From 7a47fd73b21da0a71389f0085741c6146ce64b62 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 1 Feb 2024 13:09:01 +0100 Subject: [PATCH 081/106] Fix merge from main --- .../inference/action/TransportDeleteInferenceModelAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index 6a5a3f5a137e1..9b110f7b8e7a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { From 71002539a4dfd76bc93207a4d0aec6ae592e6d22 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 2 Feb 2024 10:13:58 +0100 Subject: [PATCH 082/106] Fix merge from main --- .../action/bulk/BulkOperation.java | 125 +++++++++++++++--- .../action/bulk/TransportBulkAction.java | 10 +- ...ActionIndicesThatCannotBeCreatedTests.java | 4 +- .../bulk/TransportBulkActionIngestTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 12 +- 5 files changed, 121 insertions(+), 34 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 1d95f430d5c7e..66f9a5a8bebff 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -35,15 +36,20 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -69,6 +75,8 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -82,6 +90,8 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -97,6 +107,8 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -189,7 +201,30 @@ private void executeBulkRequestsByShard(Map> requ return; } - String nodeId = clusterService.localNode().getId(); + BulkShardRequestInferenceProvider.getInstance( + inferenceServiceRegistry, + modelRegistry, + clusterState, + requestsByShard.keySet(), + new ActionListener() { + @Override + public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { + processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); + } + + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error loading inference models", e); + } + } + ); + } + + private void processRequestsByShards( + Map> requestsByShard, + ClusterState clusterState, + BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider + ) { Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -197,29 +232,82 @@ private void executeBulkRequestsByShard(Map> requ // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; - try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); + BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) + Releasable ref = bulkItemRequestCompleteRefCount.acquire(); + final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed( + bulkShardRequest, + itemReq, + e ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); - } - executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); + bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { + @Override + public void onResponse(BulkShardRequest bulkShardRequest) { + // We need to remove items that have had an inference error, as the response will have been updated already + // and we don't need to process them further + BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) + ); + executeBulkShardRequest( + errorsFilteredShardRequest, + ActionListener.releaseAfter(ActionListener.noop(), ref), + bulkItemFailedListener + ); + } + + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException("Error performing inference", e); + } + }, bulkItemFailedListener); } } } - private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { + private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } + return bulkShardRequest; + } + + // When an item fails, store the failure in the responses array + private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, Exception e) { + final String indexName = itemRequest.index(); + + DocWriteRequest docWriteRequest = itemRequest.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); + + // make sure the request gets never processed again, removing the item from the shard request + shardRequest.items()[itemRequest.id()] = null; + } + + private void executeBulkShardRequest( + BulkShardRequest bulkShardRequest, + ActionListener listener, + BiConsumer bulkItemErrorListener + ) { + if (bulkShardRequest.items().length == 0) { + // No requests to execute due to previous errors, terminate early + listener.onResponse(bulkShardRequest); + return; + } + client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -230,19 +318,16 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - releaseOnFinish.close(); + listener.onResponse(bulkShardRequest); } @Override public void onFailure(Exception e) { // create failures for all relevant requests for (BulkItemRequest request : bulkShardRequest.items()) { - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + bulkItemErrorListener.accept(request, e); } - releaseOnFinish.close(); + listener.onFailure(e); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index d74301f6efdbe..ddec60a6c3fa8 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -405,13 +405,13 @@ protected void createMissingIndicesAndIndexData( final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -612,10 +612,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { new BulkOperation( task, @@ -629,6 +629,8 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, + modelRegistry, + inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 5581086a94187..f90289c26e3a2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -132,10 +132,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 9a04e41957272..43eadbc873012 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -146,10 +146,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index d577a7dcc0313..db2e5ca02c0ae 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); + super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); + super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); } }; } From ecd9cf6d6db6b608d43fe55c644c046ad2481259 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 2 Feb 2024 10:15:38 +0100 Subject: [PATCH 083/106] Add InferencePlugin changes for providing ModelRegistry and InferenceServiceRegistry to server --- .../inference/InferenceServiceRegistry.java | 62 ++++++++---------- .../InferenceServiceRegistryImpl.java | 64 +++++++++++++++++++ .../inference/ModelRegistry.java | 63 +++++++++++++++++- .../elasticsearch/node/NodeConstruction.java | 18 ++++++ .../xpack/inference/InferencePlugin.java | 8 ++- 5 files changed, 177 insertions(+), 38 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index d5973807d9d78..ce6f1b21b734c 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -13,49 +13,41 @@ import java.io.Closeable; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class InferenceServiceRegistry implements Closeable { - - private final Map services; - private final List namedWriteables = new ArrayList<>(); - - public InferenceServiceRegistry( - List inferenceServicePlugins, - InferenceServiceExtension.InferenceServiceFactoryContext factoryContext - ) { - // TODO check names are unique - services = inferenceServicePlugins.stream() - .flatMap(r -> r.getInferenceServiceFactories().stream()) - .map(factory -> factory.create(factoryContext)) - .collect(Collectors.toMap(InferenceService::name, Function.identity())); - } - public void init(Client client) { - services.values().forEach(s -> s.init(client)); - } +public interface InferenceServiceRegistry extends Closeable { + void init(Client client); - public Map getServices() { - return services; - } + Map getServices(); - public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); - } + Optional getService(String serviceName); - public List getNamedWriteables() { - return namedWriteables; - } + List getNamedWriteables(); + + class NoopInferenceServiceRegistry implements InferenceServiceRegistry { + public NoopInferenceServiceRegistry() {} - @Override - public void close() throws IOException { - for (var service : services.values()) { - service.close(); + @Override + public void init(Client client) {} + + @Override + public Map getServices() { + return Map.of(); + } + + @Override + public Optional getService(String serviceName) { + return Optional.empty(); } + + @Override + public List getNamedWriteables() { + return List.of(); + } + + @Override + public void close() throws IOException {} } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java new file mode 100644 index 0000000000000..f0a990ded98ce --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistryImpl implements InferenceServiceRegistry { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistryImpl( + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } + + @Override + public void init(Client client) { + services.values().forEach(s -> s.init(client)); + } + + @Override + public Map getServices() { + return services; + } + + @Override + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } + + @Override + public List getNamedWriteables() { + return namedWriteables; + } + + @Override + public void close() throws IOException { + for (var service : services.values()) { + service.close(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java index 4d0f3a8f2ea9b..fa90d5ba6f756 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java @@ -14,10 +14,35 @@ import java.util.Map; public interface ModelRegistry { - void getModel(String modelId, ActionListener listener); + /** + * Get a model. + * Secret settings are not included + * @param inferenceEntityId Model to get + * @param listener Model listener + */ + void getModel(String inferenceEntityId, ActionListener listener); + + /** + * Get a model with its secret settings + * @param inferenceEntityId Model to get + * @param listener Model listener + */ + void getModelWithSecrets(String inferenceEntityId, ActionListener listener); + + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ void getModelsByTaskType(TaskType taskType, ActionListener> listener); + /** + * Get all models. + * Secret settings are not included + * @param listener Models listener + */ void getAllModels(ActionListener> listener); void storeModel(Model model, ActionListener listener); @@ -35,4 +60,40 @@ record UnparsedModel( Map settings, Map secrets ) {} + + class NoopModelRegistry implements ModelRegistry { + @Override + public void getModel(String modelId, ActionListener listener) { + fail(listener); + } + + @Override + public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { + listener.onResponse(List.of()); + } + + @Override + public void getAllModels(ActionListener> listener) { + listener.onResponse(List.of()); + } + + @Override + public void storeModel(Model model, ActionListener listener) { + fail(listener); + } + + @Override + public void deleteModel(String modelId, ActionListener listener) { + fail(listener); + } + + @Override + public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + fail(listener); + } + + private static void fail(ActionListener listener) { + listener.onFailure(new IllegalArgumentException("No model registry configured")); + } + } } diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 0795fef891f91..0b52a4eeeab2d 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -123,6 +123,8 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -1087,11 +1089,27 @@ record PluginServiceInstances( ); } + // Register noop versions of inference services if Inference plugin is not available + if (isPluginComponentDefined(pluginComponents, InferenceServiceRegistry.class) == false) { + logger.warn("Inference service is not available"); + modules.bindToInstance(InferenceServiceRegistry.class, new InferenceServiceRegistry.NoopInferenceServiceRegistry()); + } + if (isPluginComponentDefined(pluginComponents, ModelRegistry.class) == false) { + logger.warn("Model registry is not available"); + modules.bindToInstance(ModelRegistry.class, new ModelRegistry.NoopModelRegistry()); + } + injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); } + private static boolean isPluginComponentDefined(Collection pluginComponents, Class clazz) { + return pluginComponents.stream() + .map(p -> p instanceof PluginComponentBinding ? ((PluginComponentBinding) p).impl() : p) + .anyMatch(p -> clazz.isAssignableFrom(clazz)); + } + private ClusterService createClusterService(SettingsModule settingsModule, ThreadPool threadPool, TaskManager taskManager) { ClusterService clusterService = new ClusterService( settingsModule.getSettings(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7ac19f8f6bad6..f9e5a59032ddb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -23,6 +23,7 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceRegistryImpl; import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.PluginComponentBinding; import org.elasticsearch.plugins.ActionPlugin; @@ -138,11 +139,14 @@ public Collection createComponents(PluginServices services) { inferenceServices.add(this::getInferenceServiceFactories); var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); - var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + var registry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext); registry.init(services.client()); inferenceServiceRegistry.set(registry); - return List.of(new PluginComponentBinding<>(ModelRegistry.class, modelRegistry), registry); + return List.of( + new PluginComponentBinding<>(ModelRegistry.class, modelRegistry), + new PluginComponentBinding<>(InferenceServiceRegistry.class, registry) + ); } @Override From 878611e87da7b453e021f5ca0be90ac89e43c1e6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 2 Feb 2024 16:48:56 +0100 Subject: [PATCH 084/106] Fix index version --- .../ml/mapper/SemanticTextInferenceResultFieldMapperTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java index f44be37906c6d..1a28ee4594f36 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -80,7 +80,7 @@ protected boolean isConfigurable() { @Override protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_13); // TODO: Switch to ES_VERSION_8_14 when available + return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available } @Override From e20bba5e0a47fa45d16ce3f17b383832f7d6bbfd Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 5 Feb 2024 10:00:39 +0100 Subject: [PATCH 085/106] Fix error when marking bulk items as null --- .../action/bulk/BulkOperation.java | 34 +++++++------------ .../BulkShardRequestInferenceProvider.java | 21 ++++++++---- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 66f9a5a8bebff..b0e5129c8439b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -29,6 +29,7 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.routing.IndexRouting; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.TriConsumer; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; @@ -43,13 +44,10 @@ import org.elasticsearch.threadpool.ThreadPool; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -220,7 +218,7 @@ public void onFailure(Exception e) { ); } - private void processRequestsByShards( + void processRequestsByShards( Map> requestsByShard, ClusterState clusterState, BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider @@ -239,23 +237,15 @@ private void processRequestsByShards( BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed( - bulkShardRequest, + final TriConsumer bulkItemFailedListener = ( itemReq, - e - ); + itemIndex, + e) -> markBulkItemRequestFailed(bulkShardRequest, itemReq, itemIndex, e); bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardRequest bulkShardRequest) { - // We need to remove items that have had an inference error, as the response will have been updated already - // and we don't need to process them further - BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new) - ); executeBulkShardRequest( - errorsFilteredShardRequest, + bulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), bulkItemFailedListener ); @@ -286,7 +276,7 @@ private BulkShardRequest createBulkShardRequest(ClusterState clusterState, Shard } // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, Exception e) { + private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, int bulkItemIndex, Exception e) { final String indexName = itemRequest.index(); DocWriteRequest docWriteRequest = itemRequest.request(); @@ -294,13 +284,13 @@ private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRe responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); // make sure the request gets never processed again, removing the item from the shard request - shardRequest.items()[itemRequest.id()] = null; + shardRequest.items()[bulkItemIndex] = null; } private void executeBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - BiConsumer bulkItemErrorListener + TriConsumer bulkItemErrorListener ) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early @@ -324,8 +314,10 @@ public void onResponse(BulkShardResponse bulkShardResponse) { @Override public void onFailure(Exception e) { // create failures for all relevant requests - for (BulkItemRequest request : bulkShardRequest.items()) { - bulkItemErrorListener.accept(request, e); + BulkItemRequest[] items = bulkShardRequest.items(); + for (int i = 0; i < items.length; i++) { + BulkItemRequest request = items[i]; + bulkItemErrorListener.apply(request, i, e); } listener.onFailure(e); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 7ccf3d6baaac2..8a2847ddcb842 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -32,7 +33,6 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; import java.util.stream.Collectors; /** @@ -109,7 +109,7 @@ public void onFailure(Exception e) { public void processBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - BiConsumer onBulkItemFailure + TriConsumer onBulkItemFailure ) { Map> fieldsForModels = clusterState.metadata() @@ -123,16 +123,22 @@ public void processBulkShardRequest( Runnable onInferenceComplete = () -> { listener.onResponse(bulkShardRequest); }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - for (BulkItemRequest bulkItemRequest : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest(bulkItemRequest, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire()); + BulkItemRequest[] items = bulkShardRequest.items(); + for (int i = 0; i < items.length; i++) { + BulkItemRequest bulkItemRequest = items[i]; + // Bulk item might be null because of previous errors, skip in that case + if (bulkItemRequest != null) { + performInferenceOnBulkItemRequest(bulkItemRequest, i, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire()); + } } } } private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, + int bulkItemIndex, Map> fieldsForModels, - BiConsumer onBulkItemFailure, + TriConsumer onBulkItemFailure, Releasable releaseOnFinish ) { @@ -180,8 +186,9 @@ private void performInferenceOnBulkItemRequest( InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { - onBulkItemFailure.accept( + onBulkItemFailure.apply( bulkItemRequest, + bulkItemIndex, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); continue; @@ -216,7 +223,7 @@ public void onResponse(InferenceServiceResults results) { @Override public void onFailure(Exception e) { - onBulkItemFailure.accept(bulkItemRequest, e); + onBulkItemFailure.apply(bulkItemRequest, bulkItemIndex, e); } }; inferenceProvider.service() From df5f799195bb945a0f1435c996d857fd2809e471 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 5 Feb 2024 17:57:59 +0100 Subject: [PATCH 086/106] First test version --- .../action/bulk/BulkOperationTests.java | 257 ++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java new file mode 100644 index 0000000000000..5a6280dbeda40 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexAbstraction; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.service.ClusterApplierService; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Collections.emptyMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BulkOperationTests extends ESTestCase { + + public static final String INFERENCE_SERVICE_ID = "inferenece_service_id"; + public static final String INDEX_NAME = "test-index"; + public static final String INFERENCE_FIELD = "inference_field"; + public static final String SERVICE_ID = "elser_v2"; + private static TestThreadPool threadPool; + + @SuppressWarnings("unchecked") + public void testInference() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_ID, Set.of(INFERENCE_FIELD)); + + ModelRegistry modelRegistry = createModelRegistry(); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceService(model); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(inferenceService); + + String inferenceText = "test"; + Map originalSource = Map.of(INFERENCE_FIELD, inferenceText); + + BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels , modelRegistry, inferenceServiceRegistry); + + verifyInferenceDone(modelRegistry, inferenceService, model, inferenceText); + + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + Map docSource = ((IndexRequest) items[0].request()).sourceAsMap(); + Map inferenceRootResultField = (Map) docSource.get(BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD); + List> inferenceFieldResults = (List>) inferenceRootResultField.get(INFERENCE_FIELD); + assertNotNull(inferenceFieldResults); + assertThat(inferenceFieldResults.size(), equalTo(1)); + Map inferenceResultElement = inferenceFieldResults.get(0); + assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); + assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(inferenceText)); + } + + private static void verifyInferenceDone(ModelRegistry modelRegistry, InferenceService inferenceService, Model model, String inferenceText) { + verify(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any()); + verify(inferenceService).parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(inferenceService).infer(eq(model), eq(List.of(inferenceText)), anyMap(), eq(InputType.INGEST), any()); + } + + private static BulkShardRequest runBulkOperation(Map docSource, Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry) { + + + Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); + IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) + .fieldsForModels(fieldsForModels) + .settings(settings) + .numberOfShards(1) + .numberOfReplicas(0) + .build(); + ClusterService clusterService = createClusterService(indexMetadata); + + IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); + when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); + + BulkRequest bulkRequest = new BulkRequest(); + bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); + + NodeClient client = mock(NodeClient.class); + + ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); + doAnswer(invocation -> { + BulkShardRequest request = invocation.getArgument(1); + ActionListener bulkShardResponseListener = invocation.getArgument(2); + + BulkShardResponse bulkShardResponse = new BulkShardResponse( + request.shardId(), + Arrays.stream(request.items()) + .map( + item -> BulkItemResponse.success( + item.id(), + DocWriteRequest.OpType.INDEX, + new IndexResponse( + request.shardId(), + randomIdentifier(), + randomLong(), + randomLong(), + randomLong(), + randomBoolean() + ) + ) + ) + .toArray(BulkItemResponse[]::new) + ); + bulkShardResponseListener.onResponse(bulkShardResponse); + return null; + } + ).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); + + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); + BulkOperation bulkOperation = new BulkOperation( + task, + threadPool, + ThreadPool.Names.WRITE, + clusterService, + bulkRequest, + client, + new AtomicArray<>(bulkRequest.requests.size()), + new HashMap<>(), + indexResolver, + () -> System.nanoTime(), + System.nanoTime(), + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + + bulkOperation.doRun(); + verify(bulkOperationListener).onResponse(any()); + verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + + return bulkShardRequestCaptor.getValue(); + } + + private static InferenceService createInferenceService(Model model) { + InferenceService inferenceService = mock(InferenceService.class); + when(inferenceService.parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); + InferenceResults inferenceResults = mock(InferenceResults.class); + when(inferenceResults.asMap(any())).then( + invocation -> Map.of( + (String) invocation.getArguments()[0], + Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) + ) + ); + doReturn(List.of(inferenceResults)).when(inferenceServiceResults).transformToLegacyFormat(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(inferenceServiceResults); + return null; + }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); + return inferenceService; + } + + private static InferenceServiceRegistry createInferenceServiceRegistry(InferenceService inferenceService) { + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(SERVICE_ID)).thenReturn(Optional.of(inferenceService)); + return inferenceServiceRegistry; + } + + private static ModelRegistry createModelRegistry() { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( + INFERENCE_SERVICE_ID, + TaskType.SPARSE_EMBEDDING, + SERVICE_ID, + emptyMap(), + emptyMap() + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any()); + return modelRegistry; + } + + private static ClusterService createClusterService(IndexMetadata indexMetadata) { + Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); + + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); + + ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); + when(clusterService.state()).thenReturn(clusterState); + + ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); + when(clusterApplierService.state()).thenReturn(clusterState); + when(clusterApplierService.threadPool()).thenReturn(threadPool); + when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); + return clusterService; + } + + + @BeforeClass + public static void createThreadPool() { + threadPool = new TestThreadPool(getTestClass().getName()); + } + + @AfterClass + public static void stopThreadPool() { + if (threadPool != null) { + threadPool.shutdownNow(); + threadPool = null; + } + } + +} From 04651188baf7c44ce51809bf4ca9581ba2da0518 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 5 Feb 2024 19:05:16 +0100 Subject: [PATCH 087/106] Add multiple fields to test --- .../action/bulk/BulkOperationTests.java | 193 +++++++++++++----- 1 file changed, 141 insertions(+), 52 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 5a6280dbeda40..f8db0758db9db 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -41,19 +41,24 @@ import org.junit.AfterClass; import org.junit.BeforeClass; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static java.util.Collections.emptyMap; import static org.hamcrest.CoreMatchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -63,51 +68,125 @@ public class BulkOperationTests extends ESTestCase { - public static final String INFERENCE_SERVICE_ID = "inferenece_service_id"; public static final String INDEX_NAME = "test-index"; - public static final String INFERENCE_FIELD = "inference_field"; - public static final String SERVICE_ID = "elser_v2"; + public static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; + public static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; + public static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; + public static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; + public static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; + public static final String SERVICE_1_ID = "elser_v2"; + public static final String SERVICE_2_ID = "e5"; private static TestThreadPool threadPool; @SuppressWarnings("unchecked") public void testInference() { - Map> fieldsForModels = Map.of(INFERENCE_SERVICE_ID, Set.of(INFERENCE_FIELD)); - - ModelRegistry modelRegistry = createModelRegistry(); + Map> fieldsForModels = Map.of( + INFERENCE_SERVICE_1_ID, + Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), + INFERENCE_SERVICE_2_ID, + Set.of(INFERENCE_FIELD_SERVICE_2) + ); - Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(inferenceService); + ModelRegistry modelRegistry = createModelRegistry(Map.of( + INFERENCE_SERVICE_1_ID, SERVICE_1_ID, + INFERENCE_SERVICE_2_ID, SERVICE_2_ID) + ); - String inferenceText = "test"; - Map originalSource = Map.of(INFERENCE_FIELD, inferenceText); + Model model1 = mock(Model.class); + InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID, 2); + Model model2 = mock(Model.class); + InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID, 1); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( + Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) + ); - BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels , modelRegistry, inferenceServiceRegistry); + String firstInferenceTextService1 = "firstInferenceTextService1"; + String secondInferenceTextService1 = "secondInferenceTextService1"; + String inferenceTextService2 = "inferenceTextService2"; + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + firstInferenceTextService1, + SECOND_INFERENCE_FIELD_SERVICE_1, + secondInferenceTextService1, + INFERENCE_FIELD_SERVICE_2, + inferenceTextService2, + "other_field", + "other_value", + "yet_another_field", + "yet_another_value" + ); - verifyInferenceDone(modelRegistry, inferenceService, model, inferenceText); + BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(1)); - Map docSource = ((IndexRequest) items[0].request()).sourceAsMap(); - Map inferenceRootResultField = (Map) docSource.get(BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD); - List> inferenceFieldResults = (List>) inferenceRootResultField.get(INFERENCE_FIELD); - assertNotNull(inferenceFieldResults); - assertThat(inferenceFieldResults.size(), equalTo(1)); - Map inferenceResultElement = inferenceFieldResults.get(0); + + Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); + // Original doc source is preserved + assertTrue(writtenDocSource.keySet().containsAll(originalSource.keySet())); + assertTrue(writtenDocSource.values().containsAll(originalSource.values())); + + // Check inference results + verifyInferenceServiceInvoked( + modelRegistry, + INFERENCE_SERVICE_1_ID, + inferenceService1, + model1, + List.of(firstInferenceTextService1, secondInferenceTextService1) + ); + verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); + Map inferenceRootResultField = (Map) writtenDocSource.get( + BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD + ); + + checkInferenceResult(inferenceRootResultField, FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1); + checkInferenceResult(inferenceRootResultField, SECOND_INFERENCE_FIELD_SERVICE_1, secondInferenceTextService1); + checkInferenceResult(inferenceRootResultField, INFERENCE_FIELD_SERVICE_2, inferenceTextService2); + } + + private static void checkInferenceResult(Map inferenceRootResultField, String fieldName, String expectedText) { + @SuppressWarnings("unchecked") + List> inferenceService1FieldResults = (List>) inferenceRootResultField.get(fieldName); + assertNotNull(inferenceService1FieldResults); + assertThat(inferenceService1FieldResults.size(), equalTo(1)); + Map inferenceResultElement = inferenceService1FieldResults.get(0); assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); - assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(inferenceText)); + assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(expectedText)); } - private static void verifyInferenceDone(ModelRegistry modelRegistry, InferenceService inferenceService, Model model, String inferenceText) { - verify(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any()); - verify(inferenceService).parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap()); - verify(inferenceService).infer(eq(model), eq(List.of(inferenceText)), anyMap(), eq(InputType.INGEST), any()); + private static void verifyInferenceServiceInvoked( + ModelRegistry modelRegistry, + String inferenceService1Id, + InferenceService inferenceService, + Model model, + List inferenceTexts + ) { + verify(modelRegistry).getModel(eq(inferenceService1Id), any()); + verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); + verify(inferenceService).infer(eq(model), argThat(containsAll(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); } - private static BulkShardRequest runBulkOperation(Map docSource, Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry) { + private static ArgumentMatcher> containsAll(List expected) { + return new ArgumentMatcher<>() { + @Override + public boolean matches(List argument) { + return argument.containsAll(expected) && argument.size() == expected.size(); + } + @Override + public String toString() { + return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; + } + }; + } + private static BulkShardRequest runBulkOperation( + Map docSource, + Map> fieldsForModels, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry + ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) .fieldsForModels(fieldsForModels) @@ -133,6 +212,7 @@ private static BulkShardRequest runBulkOperation(Map docSource, BulkShardResponse bulkShardResponse = new BulkShardResponse( request.shardId(), Arrays.stream(request.items()) + .filter(Objects::nonNull) .map( item -> BulkItemResponse.success( item.id(), @@ -151,8 +231,7 @@ private static BulkShardRequest runBulkOperation(Map docSource, ); bulkShardResponseListener.onResponse(bulkShardResponse); return null; - } - ).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); + }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); @@ -181,10 +260,24 @@ private static BulkShardRequest runBulkOperation(Map docSource, return bulkShardRequestCaptor.getValue(); } - private static InferenceService createInferenceService(Model model) { + private static InferenceService createInferenceService(Model model, String inferenceServiceId, int numResults) { InferenceService inferenceService = mock(InferenceService.class); - when(inferenceService.parsePersistedConfig(eq(INFERENCE_SERVICE_ID), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); + List inferenceResults = new ArrayList<>(); + for (int i = 0; i < numResults; i++) { + inferenceResults.add(createInferenceResults()); + } + doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(inferenceServiceResults); + return null; + }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); + return inferenceService; + } + + private static InferenceResults createInferenceResults() { InferenceResults inferenceResults = mock(InferenceResults.class); when(inferenceResults.asMap(any())).then( invocation -> Map.of( @@ -192,35 +285,32 @@ private static InferenceService createInferenceService(Model model) { Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) ) ); - doReturn(List.of(inferenceResults)).when(inferenceServiceResults).transformToLegacyFormat(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - listener.onResponse(inferenceServiceResults); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; + return inferenceResults; } - private static InferenceServiceRegistry createInferenceServiceRegistry(InferenceService inferenceService) { + private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - when(inferenceServiceRegistry.getService(SERVICE_ID)).thenReturn(Optional.of(inferenceService)); + inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); return inferenceServiceRegistry; } - private static ModelRegistry createModelRegistry() { + private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { ModelRegistry modelRegistry = mock(ModelRegistry.class); - ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( - INFERENCE_SERVICE_ID, - TaskType.SPARSE_EMBEDDING, - SERVICE_ID, - emptyMap(), - emptyMap() - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModel(eq(INFERENCE_SERVICE_ID), any()); + inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { + ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceId, + emptyMap(), + emptyMap() + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModel(eq(inferenceId), any()); + }); + return modelRegistry; } @@ -240,7 +330,6 @@ private static ClusterService createClusterService(IndexMetadata indexMetadata) return clusterService; } - @BeforeClass public static void createThreadPool() { threadPool = new TestThreadPool(getTestClass().getName()); From d97a0436d7598c16dc2c951b0491d2b944724f39 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 5 Feb 2024 19:39:03 +0100 Subject: [PATCH 088/106] Add failing inference test --- .../action/bulk/BulkOperationTests.java | 73 +++++++++++++++++-- 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index f8db0758db9db..b6862fdb18ea5 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -64,6 +64,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class BulkOperationTests extends ESTestCase { @@ -88,9 +89,8 @@ public void testInference() { Set.of(INFERENCE_FIELD_SERVICE_2) ); - ModelRegistry modelRegistry = createModelRegistry(Map.of( - INFERENCE_SERVICE_1_ID, SERVICE_1_ID, - INFERENCE_SERVICE_2_ID, SERVICE_2_ID) + ModelRegistry modelRegistry = createModelRegistry( + Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) ); Model model1 = mock(Model.class); @@ -117,7 +117,15 @@ public void testInference() { "yet_another_value" ); - BulkShardRequest bulkShardRequest = runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry); + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + verify(bulkOperationListener).onResponse(any()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(1)); @@ -145,6 +153,45 @@ public void testInference() { checkInferenceResult(inferenceRootResultField, INFERENCE_FIELD_SERVICE_2, inferenceTextService2); } + public void testFailedInference() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + String firstInferenceTextService1 = "firstInferenceTextService1"; + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + firstInferenceTextService1, + "other_field", + "other_value" + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + assertNull(items[0]); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + + verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); + + } + private static void checkInferenceResult(Map inferenceRootResultField, String fieldName, String expectedText) { @SuppressWarnings("unchecked") List> inferenceService1FieldResults = (List>) inferenceRootResultField.get(fieldName); @@ -165,6 +212,7 @@ private static void verifyInferenceServiceInvoked( verify(modelRegistry).getModel(eq(inferenceService1Id), any()); verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); verify(inferenceService).infer(eq(model), argThat(containsAll(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); + verifyNoMoreInteractions(inferenceService); } private static ArgumentMatcher> containsAll(List expected) { @@ -185,7 +233,8 @@ private static BulkShardRequest runBulkOperation( Map docSource, Map> fieldsForModels, ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry + InferenceServiceRegistry inferenceServiceRegistry, + ActionListener bulkOperationListener ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) @@ -233,8 +282,6 @@ private static BulkShardRequest runBulkOperation( return null; }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); BulkOperation bulkOperation = new BulkOperation( task, @@ -254,7 +301,6 @@ private static BulkShardRequest runBulkOperation( ); bulkOperation.doRun(); - verify(bulkOperationListener).onResponse(any()); verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); return bulkShardRequestCaptor.getValue(); @@ -277,6 +323,17 @@ private static InferenceService createInferenceService(Model model, String infer return inferenceService; } + private static InferenceService createInferenceServiceThatFails(Model model, String inferenceServiceId) { + InferenceService inferenceService = mock(InferenceService.class); + when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onFailure(new IllegalArgumentException("Inference failed")); + return null; + }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); + return inferenceService; + } + private static InferenceResults createInferenceResults() { InferenceResults inferenceResults = mock(InferenceResults.class); when(inferenceResults.asMap(any())).then( From 9c8cd3716755b7f7c8c4937a9864d55adc698703 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 5 Feb 2024 19:55:49 +0100 Subject: [PATCH 089/106] Add test for inference id not found --- .../action/bulk/BulkOperationTests.java | 70 ++++++++++++++++--- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index b6862fdb18ea5..f4222d4bb8253 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -69,14 +69,15 @@ public class BulkOperationTests extends ESTestCase { - public static final String INDEX_NAME = "test-index"; - public static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; - public static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; - public static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; - public static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; - public static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; - public static final String SERVICE_1_ID = "elser_v2"; - public static final String SERVICE_2_ID = "e5"; + private static final String INDEX_NAME = "test-index"; + private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; + private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; + private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; + private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; + private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; + private static final String SERVICE_1_ID = "elser_v2"; + private static final String SERVICE_2_ID = "e5"; + private static final String INFERENCE_FAILED_MSG = "Inference failed"; private static TestThreadPool threadPool; @SuppressWarnings("unchecked") @@ -187,11 +188,56 @@ public void testFailedInference() { verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); } + public void testInferenceIdNotFound() { + + Map> fieldsForModels = Map.of( + INFERENCE_SERVICE_1_ID, + Set.of(FIRST_INFERENCE_FIELD_SERVICE_1, SECOND_INFERENCE_FIELD_SERVICE_1), + INFERENCE_SERVICE_2_ID, + Set.of(INFERENCE_FIELD_SERVICE_2) + ); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID, 1); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + String firstInferenceTextService1 = "firstInferenceTextService1"; + Map originalSource = Map.of(INFERENCE_FIELD_SERVICE_2, "text_for_service_2", "other_field", "other_value"); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + assertNull(items[0]); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat( + item.getFailure().getCause().getMessage(), + equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) + ); + } + private static void checkInferenceResult(Map inferenceRootResultField, String fieldName, String expectedText) { @SuppressWarnings("unchecked") List> inferenceService1FieldResults = (List>) inferenceRootResultField.get(fieldName); @@ -328,7 +374,7 @@ private static InferenceService createInferenceServiceThatFails(Model model, Str when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); - listener.onFailure(new IllegalArgumentException("Inference failed")); + listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); return null; }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); return inferenceService; @@ -353,6 +399,12 @@ private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceIdsToServiceIds) { ModelRegistry modelRegistry = mock(ModelRegistry.class); + // Fails for unknown inference ids + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IllegalArgumentException("Model not found")); + return null; + }).when(modelRegistry).getModel(any(), any()); inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( inferenceId, From 37989445d19751a94bb306803df57602b817fdcb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 6 Feb 2024 09:36:32 +0100 Subject: [PATCH 090/106] Tests improvements --- .../action/bulk/BulkOperationTests.java | 76 ++++++++++++------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index f4222d4bb8253..21a81076d9893 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -45,6 +45,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -95,9 +96,9 @@ public void testInference() { ); Model model1 = mock(Model.class); - InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID, 2); + InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); Model model2 = mock(Model.class); - InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID, 1); + InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); @@ -133,8 +134,7 @@ public void testInference() { Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); // Original doc source is preserved - assertTrue(writtenDocSource.keySet().containsAll(originalSource.keySet())); - assertTrue(writtenDocSource.values().containsAll(originalSource.values())); + originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); // Check inference results verifyInferenceServiceInvoked( @@ -145,13 +145,13 @@ public void testInference() { List.of(firstInferenceTextService1, secondInferenceTextService1) ); verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - Map inferenceRootResultField = (Map) writtenDocSource.get( - BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD + checkInferenceResult( + originalSource, + writtenDocSource, + FIRST_INFERENCE_FIELD_SERVICE_1, + SECOND_INFERENCE_FIELD_SERVICE_1, + INFERENCE_FIELD_SERVICE_2 ); - - checkInferenceResult(inferenceRootResultField, FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1); - checkInferenceResult(inferenceRootResultField, SECOND_INFERENCE_FIELD_SERVICE_1, secondInferenceTextService1); - checkInferenceResult(inferenceRootResultField, INFERENCE_FIELD_SERVICE_2, inferenceTextService2); } public void testFailedInference() { @@ -208,7 +208,7 @@ public void testInferenceIdNotFound() { ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); Model model = mock(Model.class); - InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID, 1); + InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); String firstInferenceTextService1 = "firstInferenceTextService1"; @@ -238,14 +238,30 @@ public void testInferenceIdNotFound() { ); } - private static void checkInferenceResult(Map inferenceRootResultField, String fieldName, String expectedText) { - @SuppressWarnings("unchecked") - List> inferenceService1FieldResults = (List>) inferenceRootResultField.get(fieldName); - assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(1)); - Map inferenceResultElement = inferenceService1FieldResults.get(0); - assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); - assertThat(inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), equalTo(expectedText)); + @SuppressWarnings("unchecked") + private static void checkInferenceResult( + Map docSource, + Map writtenDocSource, + String... inferenceFieldNames + ) { + + Map inferenceRootResultField = (Map) writtenDocSource.get( + BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD + ); + + for (String inferenceFieldName : inferenceFieldNames) { + List> inferenceService1FieldResults = (List>) inferenceRootResultField.get( + inferenceFieldName + ); + assertNotNull(inferenceService1FieldResults); + assertThat(inferenceService1FieldResults.size(), equalTo(1)); + Map inferenceResultElement = inferenceService1FieldResults.get(0); + assertNotNull(inferenceResultElement.get(BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME)); + assertThat( + inferenceResultElement.get(BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME), + equalTo(docSource.get(inferenceFieldName)) + ); + } } private static void verifyInferenceServiceInvoked( @@ -253,15 +269,15 @@ private static void verifyInferenceServiceInvoked( String inferenceService1Id, InferenceService inferenceService, Model model, - List inferenceTexts + Collection inferenceTexts ) { verify(modelRegistry).getModel(eq(inferenceService1Id), any()); verify(inferenceService).parsePersistedConfig(eq(inferenceService1Id), eq(TaskType.SPARSE_EMBEDDING), anyMap()); - verify(inferenceService).infer(eq(model), argThat(containsAll(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); + verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); verifyNoMoreInteractions(inferenceService); } - private static ArgumentMatcher> containsAll(List expected) { + private static ArgumentMatcher> containsInAnyOrder(Collection expected) { return new ArgumentMatcher<>() { @Override public boolean matches(List argument) { @@ -352,17 +368,19 @@ private static BulkShardRequest runBulkOperation( return bulkShardRequestCaptor.getValue(); } - private static InferenceService createInferenceService(Model model, String inferenceServiceId, int numResults) { + private static InferenceService createInferenceService(Model model, String inferenceServiceId) { InferenceService inferenceService = mock(InferenceService.class); when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); - InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); - List inferenceResults = new ArrayList<>(); - for (int i = 0; i < numResults; i++) { - inferenceResults.add(createInferenceResults()); - } - doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(4); + InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); + List texts = invocation.getArgument(1); + List inferenceResults = new ArrayList<>(); + for (int i = 0; i < texts.size(); i++) { + inferenceResults.add(createInferenceResults()); + } + doReturn(inferenceResults).when(inferenceServiceResults).transformToLegacyFormat(); + listener.onResponse(inferenceServiceResults); return null; }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); From 773846015446bfa1f8aa6079d586814bfd950434 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 6 Feb 2024 10:19:09 +0100 Subject: [PATCH 091/106] Add bulk shard failure test --- .../action/bulk/BulkOperationTests.java | 181 ++++++++++++++---- 1 file changed, 145 insertions(+), 36 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 21a81076d9893..c1500abc28abe 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -52,6 +52,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import static java.util.Collections.emptyMap; @@ -81,6 +82,97 @@ public class BulkOperationTests extends ESTestCase { private static final String INFERENCE_FAILED_MSG = "Inference failed"; private static TestThreadPool threadPool; + public void testNoInference() { + + Map> fieldsForModels = Map.of(); + ModelRegistry modelRegistry = createModelRegistry( + Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) + ); + + Model model1 = mock(Model.class); + InferenceService inferenceService1 = createInferenceService(model1, INFERENCE_SERVICE_1_ID); + Model model2 = mock(Model.class); + InferenceService inferenceService2 = createInferenceService(model2, INFERENCE_SERVICE_2_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( + Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) + ); + + Map originalSource = Map.of( + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + BulkShardRequest bulkShardRequest = runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener + ); + verify(bulkOperationListener).onResponse(any()); + + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(1)); + + Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); + // Original doc source is preserved + originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); + + // Check inference not invoked + verifyNoMoreInteractions(modelRegistry); + verifyNoMoreInteractions(inferenceServiceRegistry); + } + + public void testFailedBulkShardRequest() { + + Map> fieldsForModels = Map.of(); + ModelRegistry modelRegistry = createModelRegistry(Map.of()); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); + + Map originalSource = Map.of( + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); + + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + + runBulkOperation( + originalSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener, + request -> new BulkShardResponse( + request.shardId(), + new BulkItemResponse[] { + BulkItemResponse.failure( + 0, + DocWriteRequest.OpType.INDEX, + new BulkItemResponse.Failure( + INDEX_NAME, + randomIdentifier(), + new IllegalArgumentException("Error on bulk shard request") + ) + ) } + ) + ); + verify(bulkOperationListener).onResponse(any()); + + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse[] items = bulkResponse.getItems(); + assertTrue(items[0].isFailed()); + } + @SuppressWarnings("unchecked") public void testInference() { @@ -103,9 +195,9 @@ public void testInference() { Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); - String firstInferenceTextService1 = "firstInferenceTextService1"; - String secondInferenceTextService1 = "secondInferenceTextService1"; - String inferenceTextService2 = "inferenceTextService2"; + String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); + String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); + String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); Map originalSource = Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1, @@ -113,10 +205,10 @@ public void testInference() { secondInferenceTextService1, INFERENCE_FIELD_SERVICE_2, inferenceTextService2, - "other_field", - "other_value", - "yet_another_field", - "yet_another_value" + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) ); ActionListener bulkOperationListener = mock(ActionListener.class); @@ -145,7 +237,7 @@ public void testInference() { List.of(firstInferenceTextService1, secondInferenceTextService1) ); verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - checkInferenceResult( + checkInferenceResults( originalSource, writtenDocSource, FIRST_INFERENCE_FIELD_SERVICE_1, @@ -164,12 +256,12 @@ public void testFailedInference() { InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - String firstInferenceTextService1 = "firstInferenceTextService1"; + String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); Map originalSource = Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1, - "other_field", - "other_value" + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) ); ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @@ -211,8 +303,12 @@ public void testInferenceIdNotFound() { InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - String firstInferenceTextService1 = "firstInferenceTextService1"; - Map originalSource = Map.of(INFERENCE_FIELD_SERVICE_2, "text_for_service_2", "other_field", "other_value"); + Map originalSource = Map.of( + INFERENCE_FIELD_SERVICE_2, + randomAlphaOfLengthBetween(1, 100), + randomAlphaOfLengthBetween(1, 20), + randomAlphaOfLengthBetween(1, 100) + ); ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") @@ -239,7 +335,7 @@ public void testInferenceIdNotFound() { } @SuppressWarnings("unchecked") - private static void checkInferenceResult( + private static void checkInferenceResults( Map docSource, Map writtenDocSource, String... inferenceFieldNames @@ -297,6 +393,24 @@ private static BulkShardRequest runBulkOperation( ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener + ) { + return runBulkOperation( + docSource, + fieldsForModels, + modelRegistry, + inferenceServiceRegistry, + bulkOperationListener, + successfulBulkShardResponse + ); + } + + private static BulkShardRequest runBulkOperation( + Map docSource, + Map> fieldsForModels, + ModelRegistry modelRegistry, + InferenceServiceRegistry inferenceServiceRegistry, + ActionListener bulkOperationListener, + Function bulkShardResponseSupplier ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) @@ -319,28 +433,7 @@ private static BulkShardRequest runBulkOperation( doAnswer(invocation -> { BulkShardRequest request = invocation.getArgument(1); ActionListener bulkShardResponseListener = invocation.getArgument(2); - - BulkShardResponse bulkShardResponse = new BulkShardResponse( - request.shardId(), - Arrays.stream(request.items()) - .filter(Objects::nonNull) - .map( - item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse( - request.shardId(), - randomIdentifier(), - randomLong(), - randomLong(), - randomLong(), - randomBoolean() - ) - ) - ) - .toArray(BulkItemResponse[]::new) - ); - bulkShardResponseListener.onResponse(bulkShardResponse); + bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); return null; }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); @@ -368,6 +461,22 @@ private static BulkShardRequest runBulkOperation( return bulkShardRequestCaptor.getValue(); } + private static final Function successfulBulkShardResponse = (request) -> { + return new BulkShardResponse( + request.shardId(), + Arrays.stream(request.items()) + .filter(Objects::nonNull) + .map( + item -> BulkItemResponse.success( + item.id(), + DocWriteRequest.OpType.INDEX, + new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) + ) + ) + .toArray(BulkItemResponse[]::new) + ); + }; + private static InferenceService createInferenceService(Model model, String inferenceServiceId) { InferenceService inferenceService = mock(InferenceService.class); when(inferenceService.parsePersistedConfig(eq(inferenceServiceId), eq(TaskType.SPARSE_EMBEDDING), anyMap())).thenReturn(model); From c4154b992c2e73fb3b9c556ba6ebd84fa78a281c Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 6 Feb 2024 13:04:21 +0100 Subject: [PATCH 092/106] Avoid removing bulk items from request on failure, fix tests --- .../action/bulk/BulkOperation.java | 23 +++----- .../BulkShardRequestInferenceProvider.java | 58 ++++++++++++++++--- .../action/bulk/BulkOperationTests.java | 39 ++++++------- 3 files changed, 74 insertions(+), 46 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index b0e5129c8439b..2b84ec8746cd2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -29,7 +29,6 @@ import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.routing.IndexRouting; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.TriConsumer; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; @@ -48,6 +47,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -237,15 +237,12 @@ void processRequestsByShards( BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final TriConsumer bulkItemFailedListener = ( - itemReq, - itemIndex, - e) -> markBulkItemRequestFailed(bulkShardRequest, itemReq, itemIndex, e); + final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { @Override - public void onResponse(BulkShardRequest bulkShardRequest) { + public void onResponse(BulkShardRequest inferenceBulkShardRequest) { executeBulkShardRequest( - bulkShardRequest, + inferenceBulkShardRequest, ActionListener.releaseAfter(ActionListener.noop(), ref), bulkItemFailedListener ); @@ -276,21 +273,18 @@ private BulkShardRequest createBulkShardRequest(ClusterState clusterState, Shard } // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkShardRequest shardRequest, BulkItemRequest itemRequest, int bulkItemIndex, Exception e) { + private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { final String indexName = itemRequest.index(); DocWriteRequest docWriteRequest = itemRequest.request(); BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - - // make sure the request gets never processed again, removing the item from the shard request - shardRequest.items()[bulkItemIndex] = null; } private void executeBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - TriConsumer bulkItemErrorListener + BiConsumer bulkItemErrorListener ) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early @@ -315,9 +309,8 @@ public void onResponse(BulkShardResponse bulkShardResponse) { public void onFailure(Exception e) { // create failures for all relevant requests BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest request = items[i]; - bulkItemErrorListener.apply(request, i, e); + for (BulkItemRequest item : items) { + bulkItemErrorListener.accept(item, e); } listener.onFailure(e); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 8a2847ddcb842..7acc93be13a46 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -26,6 +25,7 @@ import org.elasticsearch.inference.ModelRegistry; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -33,6 +33,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.stream.Collectors; /** @@ -106,10 +107,20 @@ public void onFailure(Exception e) { } } + /** + * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from + * the original request will be modified with the inference results, to avoid copying the entire requests from + * the original bulk request. + * + * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. + * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, + * which may have fewer items than the original because of inference failures + * @param onBulkItemFailure invoked when a bulk item fails inference + */ public void processBulkShardRequest( BulkShardRequest bulkShardRequest, ActionListener listener, - TriConsumer onBulkItemFailure + BiConsumer onBulkItemFailure ) { Map> fieldsForModels = clusterState.metadata() @@ -121,14 +132,41 @@ public void processBulkShardRequest( return; } - Runnable onInferenceComplete = () -> { listener.onResponse(bulkShardRequest); }; + Set failedItems = Collections.synchronizedSet(new HashSet<>()); + Runnable onInferenceComplete = () -> { + if (failedItems.isEmpty()) { + listener.onResponse(bulkShardRequest); + return; + } + // Remove failed items from the original bulk shard request + BulkItemRequest[] originalItems = bulkShardRequest.items(); + BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; + for (int i = 0, j = 0; i < originalItems.length; i++) { + if (failedItems.contains(i) == false) { + newItems[j++] = originalItems[i]; + } + } + BulkShardRequest newBulkShardRequest = new BulkShardRequest( + bulkShardRequest.shardId(), + bulkShardRequest.getRefreshPolicy(), + newItems + ); + listener.onResponse(newBulkShardRequest); + }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { BulkItemRequest[] items = bulkShardRequest.items(); for (int i = 0; i < items.length; i++) { BulkItemRequest bulkItemRequest = items[i]; // Bulk item might be null because of previous errors, skip in that case if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest(bulkItemRequest, i, fieldsForModels, onBulkItemFailure, bulkItemReqRef.acquire()); + performInferenceOnBulkItemRequest( + bulkItemRequest, + fieldsForModels, + i, + onBulkItemFailure, + failedItems, + bulkItemReqRef.acquire() + ); } } } @@ -136,9 +174,10 @@ public void processBulkShardRequest( private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, - int bulkItemIndex, Map> fieldsForModels, - TriConsumer onBulkItemFailure, + Integer itemIndex, + BiConsumer onBulkItemFailure, + Set failedItems, Releasable releaseOnFinish ) { @@ -186,9 +225,9 @@ private void performInferenceOnBulkItemRequest( InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { - onBulkItemFailure.apply( + failedItems.add(itemIndex); + onBulkItemFailure.accept( bulkItemRequest, - bulkItemIndex, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); continue; @@ -223,7 +262,8 @@ public void onResponse(InferenceServiceResults results) { @Override public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, bulkItemIndex, e); + failedItems.add(itemIndex); + onBulkItemFailure.accept(bulkItemRequest, e); } }; inferenceProvider.service() diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index c1500abc28abe..a688df5d797a2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -111,6 +111,7 @@ public void testNoInference() { fieldsForModels, modelRegistry, inferenceServiceRegistry, + true, bulkOperationListener ); verify(bulkOperationListener).onResponse(any()); @@ -151,6 +152,7 @@ public void testFailedBulkShardRequest() { modelRegistry, inferenceServiceRegistry, bulkOperationListener, + true, request -> new BulkShardResponse( request.shardId(), new BulkItemResponse[] { @@ -217,6 +219,7 @@ public void testInference() { fieldsForModels, modelRegistry, inferenceServiceRegistry, + true, bulkOperationListener ); verify(bulkOperationListener).onResponse(any()); @@ -267,16 +270,8 @@ public void testFailedInference() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - assertNull(items[0]); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); assertTrue(bulkResponse.hasFailures()); @@ -313,16 +308,10 @@ public void testInferenceIdNotFound() { ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); @SuppressWarnings("unchecked") ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldsForModels, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - assertNull(items[0]); + doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); BulkResponse bulkResponse = bulkResponseCaptor.getValue(); assertTrue(bulkResponse.hasFailures()); @@ -392,6 +381,7 @@ private static BulkShardRequest runBulkOperation( Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, + boolean expectTransportShardBulkActionToExecute, ActionListener bulkOperationListener ) { return runBulkOperation( @@ -400,6 +390,7 @@ private static BulkShardRequest runBulkOperation( modelRegistry, inferenceServiceRegistry, bulkOperationListener, + expectTransportShardBulkActionToExecute, successfulBulkShardResponse ); } @@ -410,6 +401,7 @@ private static BulkShardRequest runBulkOperation( ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, ActionListener bulkOperationListener, + boolean expectTransportShardBulkActionToExecute, Function bulkShardResponseSupplier ) { Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); @@ -456,9 +448,12 @@ private static BulkShardRequest runBulkOperation( ); bulkOperation.doRun(); - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + if (expectTransportShardBulkActionToExecute) { + verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); + return bulkShardRequestCaptor.getValue(); + } - return bulkShardRequestCaptor.getValue(); + return null; } private static final Function successfulBulkShardResponse = (request) -> { From 175051b171ab29bf208ce5a805144bcc6dbd9258 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 6 Feb 2024 14:05:22 +0100 Subject: [PATCH 093/106] Move semantic_text field mappers to inference plugin --- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferencePlugin.java | 27 ++++++++++++++++++- .../xpack/inference}/SemanticTextFeature.java | 2 +- .../mapper/SemanticTextFieldMapper.java | 2 +- ...emanticTextInferenceResultFieldMapper.java | 2 +- .../SemanticTextClusterMetadataTests.java | 13 +++++++-- .../mapper/SemanticTextFieldMapperTests.java | 6 ++--- ...icTextInferenceResultFieldMapperTests.java | 8 +++--- .../xpack/ml/MachineLearning.java | 21 +-------------- .../xpack/ml/LocalStateMachineLearning.java | 6 ----- 10 files changed, 49 insertions(+), 39 deletions(-) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/SemanticTextFeature.java (93%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextFieldMapper.java (98%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml => inference/src/main/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextInferenceResultFieldMapper.java (99%) rename x-pack/plugin/{ml => inference}/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java (87%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml => inference/src/test/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextFieldMapperTests.java (96%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml => inference/src/test/java/org/elasticsearch/xpack/inference}/mapper/SemanticTextInferenceResultFieldMapperTests.java (98%) diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 2d25a48117778..ddd56c758d67c 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -17,6 +17,7 @@ requires org.apache.httpcomponents.httpasyncclient; requires org.apache.httpcomponents.httpcore.nio; requires org.apache.lucene.core; + requires org.elasticsearch.logging; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 4e44929e7ba9b..a83f8bb5f9b5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,6 +21,8 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -29,6 +31,7 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.InferenceRegistryPlugin; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -52,6 +55,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -67,12 +72,19 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceRegistryPlugin { +public class InferencePlugin extends Plugin + implements + ActionPlugin, + ExtensiblePlugin, + SystemIndexPlugin, + InferenceRegistryPlugin, + MapperPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -254,4 +266,17 @@ public InferenceServiceRegistry getInferenceServiceRegistry() { public ModelRegistry getModelRegistry() { return modelRegistry.get(); } + + @Override + public Map getMappers() { + if (SemanticTextFeature.isEnabled()) { + return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + } + return Map.of(); + } + + @Override + public Map getMetadataMappers() { + return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java similarity index 93% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java index f861760803e56..4f2c5c564bcb8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml; +package org.elasticsearch.xpack.inference; import org.elasticsearch.common.util.FeatureFlag; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 9546bc4ba9add..027b85a9a9f45 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java similarity index 99% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index ff224522034bf..5dda6ae3781ab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java similarity index 87% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index 47cae14003c70..69fa64ffa6d1c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -13,14 +13,23 @@ import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexService; -import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; -public class SemanticTextClusterMetadataTests extends MlSingleNodeTestCase { +public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return List.of(InferencePlugin.class); + } + public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index ccb8f106e4945..a3a705c9cc902 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.index.IndexableField; import org.elasticsearch.common.Strings; @@ -18,7 +18,7 @@ import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.AssumptionViolatedException; import java.io.IOException; @@ -74,7 +74,7 @@ public void testUpdatesToModelIdNotSupported() throws IOException { @Override protected Collection getPlugins() { - return singletonList(new MachineLearning(Settings.EMPTY)); + return singletonList(new InferencePlugin(Settings.EMPTY)); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java similarity index 98% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java index bde6da7fe8277..7f13d34986482 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.ml.mapper; +package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; @@ -37,7 +37,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; import java.util.ArrayList; @@ -78,7 +78,7 @@ protected boolean isConfigurable() { @Override protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_13); // TODO: Switch to ES_VERSION_8_14 when available + return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available } @Override @@ -88,7 +88,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { @Override protected Collection getPlugins() { - return List.of(new MachineLearning(Settings.EMPTY)); + return List.of(new InferencePlugin(Settings.EMPTY)); } public void testSuccessfulParse() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 10b2ed089d632..70a3b9bab49f1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,8 +49,6 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -70,7 +68,6 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.IngestPlugin; -import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Platforms; import org.elasticsearch.plugins.Plugin; @@ -366,8 +363,6 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; -import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -487,8 +482,7 @@ public class MachineLearning extends Plugin PersistentTaskPlugin, SearchPlugin, ShutdownAwarePlugin, - ExtensiblePlugin, - MapperPlugin { + ExtensiblePlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; // Endpoints that were deprecated in 7.x can still be called in 8.x using the REST compatibility layer @@ -2298,17 +2292,4 @@ public void signalShutdown(Collection shutdownNodeIds) { mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds); } } - - @Override - public Map getMappers() { - if (SemanticTextFeature.isEnabled()) { - return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); - } - return Map.of(); - } - - @Override - public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); - } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 5af3fd527e31e..2d7832d747de4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; -import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; @@ -103,11 +102,6 @@ public Map> getTokeniz return mlPlugin.getTokenizers(); } - @Override - public Map getMappers() { - return mlPlugin.getMappers(); - } - /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. * The MachineLearningLicenseTests attempt to create a datafeed referencing this LocalStateMachineLearning object. From cb1f27074f22b385afdef437afc6eea0806bedd5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 6 Feb 2024 17:05:49 +0100 Subject: [PATCH 094/106] Remove @Nullable annotations for registries --- .../action/bulk/TransportBulkAction.java | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index ddec60a6c3fa8..a8a2c5047d3ed 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -46,7 +46,6 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Assertions; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; @@ -106,8 +105,8 @@ public TransportBulkAction( IndexNameExpressionResolver indexNameExpressionResolver, IndexingPressure indexingPressure, SystemIndices systemIndices, - @Nullable InferenceServiceRegistry inferenceServiceRegistry, - @Nullable ModelRegistry modelRegistry + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry ) { this( threadPool, @@ -136,8 +135,8 @@ public TransportBulkAction( IndexingPressure indexingPressure, SystemIndices systemIndices, LongSupplier relativeTimeProvider, - @Nullable InferenceServiceRegistry inferenceServiceRegistry, - @Nullable ModelRegistry modelRegistry + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry ) { this( BulkAction.INSTANCE, @@ -170,8 +169,8 @@ public TransportBulkAction( IndexingPressure indexingPressure, SystemIndices systemIndices, LongSupplier relativeTimeProvider, - @Nullable InferenceServiceRegistry inferenceServiceRegistry, - @Nullable ModelRegistry modelRegistry + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry ) { super(bulkAction.name(), transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); Objects.requireNonNull(relativeTimeProvider); From 79962440589c9e008935a9694964968c09f02b98 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 11:48:21 +0100 Subject: [PATCH 095/106] Add YAML REST test scaffolding --- x-pack/plugin/inference/build.gradle | 2 + x-pack/plugin/inference/qa/rest/build.gradle | 15 +++++++ .../xpack/inference/InferenceRestIT.java | 43 +++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 x-pack/plugin/inference/qa/rest/build.gradle create mode 100644 x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index e4f4de0027073..7000225f14e57 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -25,3 +25,5 @@ dependencies { testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(':modules:reindex') } + +addQaCheckDependencies(project) diff --git a/x-pack/plugin/inference/qa/rest/build.gradle b/x-pack/plugin/inference/qa/rest/build.gradle new file mode 100644 index 0000000000000..c5c6c7ff4a0f8 --- /dev/null +++ b/x-pack/plugin/inference/qa/rest/build.gradle @@ -0,0 +1,15 @@ +apply plugin: 'elasticsearch.internal-yaml-rest-test' + +dependencies { + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') +} + +restResources { + restApi { + include '_common', 'bulk', 'indices', 'inference', 'index', 'get' + } +} + +tasks.named('yamlRestTest') { + usesDefaultDistribution() +} diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java new file mode 100644 index 0000000000000..79fb9a6cbf083 --- /dev/null +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +public class InferenceRestIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") + .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin") + .distribution(DistributionType.DEFAULT) + .build(); + + public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(); + } +} From fbce1d4f0594e74a29ae1f602b27adf3797241d7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 11:48:32 +0100 Subject: [PATCH 096/106] First test version --- .../inference/10_semantic_text_inference.yml | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml new file mode 100644 index 0000000000000..fe765968ad26c --- /dev/null +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -0,0 +1,54 @@ +setup: + - do: + inference.put_model: + task_type: sparse_embedding + model_id: test-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-index + body: + settings: + index: + number_of_shards: 2 + number_of_replicas: 0 + mappings: + properties: + inference_field: + type: semantic_text + model_id: test-inference-id + another_inference_field: + type: semantic_text + model_id: test-inference-id + non_inference_field: + type: text + +--- +"Create index with semantic_text field type": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } From b12ea91f16b1918a11832a675ac90c919eedddce Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 12:05:37 +0100 Subject: [PATCH 097/106] First test version --- .../inference/10_semantic_text_inference.yml | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index fe765968ad26c..11ced6964e0aa 100644 --- a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -1,4 +1,8 @@ setup: + - skip: + version: " - 8.12.99" + reason: semantic_text introduced in 8.13.0 # TODO change when 8.13.0 is released + - do: inference.put_model: task_type: sparse_embedding @@ -34,7 +38,7 @@ setup: type: text --- -"Create index with semantic_text field type": +"Calculates embeddings for new documents": - do: index: index: test-index @@ -52,3 +56,42 @@ setup: - match: { _source.inference_field: "inference test" } - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding + - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + +--- +"Fails for non-existent model": + - do: + indices.create: + index: incorrect-test-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: non-existing-inference-id + non_inference_field: + type: text + - do: + catch: /No inference provider found .* non-existing-inference-id/ + index: + index: incorrect-test-index + id: doc_1 + body: + inference_field: "inference test" + non_inference_field: "non inference test" + + # Succeeds when semantic_text field is not used + - do: + index: + index: incorrect-test-index + id: doc_1 + body: + non_inference_field: "non inference test" From b17a4cca7e402d6783274c02feb356b9a23d0d73 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 15:43:14 +0100 Subject: [PATCH 098/106] Add tests --- x-pack/plugin/inference/qa/rest/build.gradle | 2 +- .../xpack/inference/InferenceRestIT.java | 2 - .../inference/10_semantic_text_inference.yml | 150 +++++++++++++++++- 3 files changed, 144 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/qa/rest/build.gradle b/x-pack/plugin/inference/qa/rest/build.gradle index c5c6c7ff4a0f8..977a97704a4da 100644 --- a/x-pack/plugin/inference/qa/rest/build.gradle +++ b/x-pack/plugin/inference/qa/rest/build.gradle @@ -6,7 +6,7 @@ dependencies { restResources { restApi { - include '_common', 'bulk', 'indices', 'inference', 'index', 'get' + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' } } diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 79fb9a6cbf083..933e696d29d83 100644 --- a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -29,8 +29,6 @@ public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { super(testCandidate); } - - @Override protected String getTestRestCluster() { return cluster.getHttpAddresses(); diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 11ced6964e0aa..c6d95b6151c11 100644 --- a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -22,10 +22,6 @@ setup: indices.create: index: test-index body: - settings: - index: - number_of_shards: 2 - number_of_replicas: 0 mappings: properties: inference_field: @@ -60,11 +56,148 @@ setup: - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding + - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + +--- +"Updating non semantic_text fields does not recalculate embeddings": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + + - do: + update: + index: test-index + id: doc_1 + body: + doc: + non_inference_field: "another non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "another non inference test" } + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.0.sparse_embedding - - exists: _source._semantic_text_inference.another_inference_field.0.sparse_embedding + - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } + +--- +"Updating semantic_text fields recalculates embeddings": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - do: + update: + index: test-index + id: doc_1 + body: + doc: + inference_field: "updated inference test" + another_inference_field: "another updated inference test" + + - do: + get: + index: test-index + id: doc_1 + + - match: { _source.inference_field: "updated inference test" } + - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "updated inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another updated inference test" } + + +--- +"Reindex works for semantic_text fields": + - do: + index: + index: test-index + id: doc_1 + body: + inference_field: "inference test" + another_inference_field: "another inference test" + non_inference_field: "non inference test" + + - do: + get: + index: test-index + id: doc_1 + + - set: { _source._semantic_text_inference.inference_field.0.sparse_embedding: inference_field_embedding } + - set: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: another_inference_field_embedding } + + - do: + indices.refresh: { } + + - do: + indices.create: + index: destination-index + body: + mappings: + properties: + inference_field: + type: semantic_text + model_id: test-inference-id + another_inference_field: + type: semantic_text + model_id: test-inference-id + non_inference_field: + type: text + + - do: + reindex: + wait_for_completion: true + body: + source: + index: test-index + dest: + index: destination-index + - do: + get: + index: destination-index + id: doc_1 + + - match: { _source.inference_field: "inference test" } + - match: { _source.another_inference_field: "another inference test" } + - match: { _source.non_inference_field: "non inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.text: "inference test" } + - match: { _source._semantic_text_inference.another_inference_field.0.text: "another inference test" } + + - match: { _source._semantic_text_inference.inference_field.0.sparse_embedding: $inference_field_embedding } + - match: { _source._semantic_text_inference.another_inference_field.0.sparse_embedding: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -79,8 +212,9 @@ setup: model_id: non-existing-inference-id non_inference_field: type: text + - do: - catch: /No inference provider found .* non-existing-inference-id/ + catch: bad_request index: index: incorrect-test-index id: doc_1 @@ -88,6 +222,8 @@ setup: inference_field: "inference test" non_inference_field: "non inference test" + - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + # Succeeds when semantic_text field is not used - do: index: From fb7f9d3ebb494da4114a68c2785001508d14fefd Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 15:43:34 +0100 Subject: [PATCH 099/106] Fix bug for re-calculating inference results --- .../action/bulk/BulkShardRequestInferenceProvider.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index 7acc93be13a46..fde668de3c42a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -248,6 +248,8 @@ public void onResponse(InferenceServiceResults results) { @SuppressWarnings("unchecked") List> inferenceFieldResultList = (List>) rootInferenceFieldMap .computeIfAbsent(fieldName, k -> new ArrayList<>()); + // Remove previous inference results if any + inferenceFieldResultList.clear(); // TODO Check inference result type to change subfield name var inferenceFieldMap = Map.of( From e5ee95675c7b4ec9ede34d5c387164ff8af84cbd Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 17:35:05 +0100 Subject: [PATCH 100/106] Fix merge --- .../action/bulk/TransportBulkAction.java | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 32566b559410d..86fc511251553 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,6 +57,8 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -98,6 +100,8 @@ public class TransportBulkAction extends HandledTransportAction responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); + executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -614,10 +630,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, - ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated + Map indicesThatCannotBeCreated, + ActionListener listener ) { new BulkOperation( task, @@ -631,6 +647,8 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, + modelRegistry, + inferenceServiceRegistry, listener ).run(); } From f818bd0e3173fbeb28fdc6d96f5b2af6da272b2b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 18:04:54 +0100 Subject: [PATCH 101/106] Fix javadoc for SemanticTextInferenceResultFieldMapper --- ...emanticTextInferenceResultFieldMapper.java | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index d3fb3ab1684b3..9e6c1eb0a6586 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -61,15 +61,12 @@ * "my_semantic_text_field": [ * { * "sparse_embedding": { - * "is_truncated": false, - * "embedding": { - * "lucas": 0.05212344, - * "ty": 0.041213956, - * "dragon": 0.50991, - * "type": 0.23241979, - * "dr": 1.9312073, - * "##o": 0.2797593 - * } + * "lucas": 0.05212344, + * "ty": 0.041213956, + * "dragon": 0.50991, + * "type": 0.23241979, + * "dr": 1.9312073, + * "##o": 0.2797593 * }, * "text": "these are not the droids you're looking for" * } @@ -90,11 +87,7 @@ * "type": "nested", * "properties": { * "sparse_embedding": { - * "properties": { - * "embedding": { - * "type": "sparse_vector" - * } - * } + * "type": "sparse_vector" * }, * "text": { * "type": "text", From 4fdd65e4beb4ea51e9448cfde225041532fbbc67 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 18:05:01 +0100 Subject: [PATCH 102/106] Remove unnecessary class --- .../inference/TestInferenceResults.java | 67 ------------------- 1 file changed, 67 deletions(-) delete mode 100644 server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java diff --git a/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java b/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java deleted file mode 100644 index f24997fd6a328..0000000000000 --- a/server/src/test/java/org/elasticsearch/inference/TestInferenceResults.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.util.HashMap; -import java.util.Map; - -public class TestInferenceResults implements InferenceResults { - - private final String resultField; - private final Map inferenceResults; - - public TestInferenceResults(String resultField, Map inferenceResults) { - this.resultField = resultField; - this.inferenceResults = inferenceResults; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - throw new UnsupportedEncodingException(); - } - - @Override - public String getWriteableName() { - return "test_inference_results"; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public String getResultsField() { - return resultField; - } - - @Override - public Map asMap() { - Map result = new HashMap<>(); - result.put(resultField, inferenceResults); - return result; - } - - @Override - public Map asMap(String outputField) { - Map result = new HashMap<>(); - result.put(outputField, inferenceResults); - return result; - } - - @Override - public Object predictedValue() { - throw new UnsupportedOperationException(); - } -} From 75cbe3dda9c8a2fb34c4f60f9d47ca1caa1768b9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 7 Feb 2024 19:03:52 +0100 Subject: [PATCH 103/106] Fix merge with main --- .../rest-api-spec/test/inference/10_semantic_text_inference.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index c6d95b6151c11..0e1b33252153b 100644 --- a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -6,7 +6,7 @@ setup: - do: inference.put_model: task_type: sparse_embedding - model_id: test-inference-id + inference_id: test-inference-id body: > { "service": "test_service", From 28e64d838f03d8d7e894cbe8e337a9049d567d17 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 8 Feb 2024 22:22:11 +0100 Subject: [PATCH 104/106] Add comments on failure to load method --- .../action/bulk/BulkShardRequestInferenceProvider.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index fde668de3c42a..a305aaa4937ec 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -98,7 +98,9 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { @Override public void onFailure(Exception e) { - // Do nothing - let it fail afterwards when model is retrieved + // Failure on loading a model should not prevent the rest from being loaded and used. + // When the model is actually retrieved via the inference ID in the inference process, it will fail + // and the user will get the details on the inference failure. } }; From 792f3de96349b2798a6a88e2dc61b57c1f8926ff Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 8 Feb 2024 23:11:44 +0100 Subject: [PATCH 105/106] Moved yamlRestTest directory from qa to inference --- x-pack/plugin/inference/build.gradle | 12 +++++++++++- x-pack/plugin/inference/qa/rest/build.gradle | 15 --------------- .../xpack/inference/InferenceRestIT.java | 0 .../test/inference/10_semantic_text_inference.yml | 0 4 files changed, 11 insertions(+), 16 deletions(-) delete mode 100644 x-pack/plugin/inference/qa/rest/build.gradle rename x-pack/plugin/inference/{qa/rest => }/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java (100%) rename x-pack/plugin/inference/{qa/rest => }/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml (100%) diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 7000225f14e57..781261c330e78 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -6,6 +6,13 @@ */ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' + +restResources { + restApi { + include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' + } +} esplugin { name 'x-pack-inference' @@ -24,6 +31,9 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(':modules:reindex') + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } -addQaCheckDependencies(project) +tasks.named('yamlRestTest') { + usesDefaultDistribution() +} diff --git a/x-pack/plugin/inference/qa/rest/build.gradle b/x-pack/plugin/inference/qa/rest/build.gradle deleted file mode 100644 index 977a97704a4da..0000000000000 --- a/x-pack/plugin/inference/qa/rest/build.gradle +++ /dev/null @@ -1,15 +0,0 @@ -apply plugin: 'elasticsearch.internal-yaml-rest-test' - -dependencies { - clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') -} - -restResources { - restApi { - include '_common', 'bulk', 'indices', 'inference', 'index', 'get', 'update', 'reindex' - } -} - -tasks.named('yamlRestTest') { - usesDefaultDistribution() -} diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java similarity index 100% rename from x-pack/plugin/inference/qa/rest/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java rename to x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java diff --git a/x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml similarity index 100% rename from x-pack/plugin/inference/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml rename to x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml From 2b944b1b9aa9f47178f1f1a51cbc8e5fde5db200 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 9 Feb 2024 13:38:03 +0100 Subject: [PATCH 106/106] Add cast checks and refactored a bit error handling --- .../BulkShardRequestInferenceProvider.java | 68 +++++++++++----- .../action/bulk/BulkOperationTests.java | 81 +++++++++++++++++-- 2 files changed, 120 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java index a305aaa4937ec..02f905f7cd87a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.TriConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.InferenceResults; @@ -155,6 +156,10 @@ public void processBulkShardRequest( ); listener.onResponse(newBulkShardRequest); }; + TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { + failedItems.add(i); + onBulkItemFailure.accept(bulkItemRequest, e); + }; try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { BulkItemRequest[] items = bulkShardRequest.items(); for (int i = 0; i < items.length; i++) { @@ -165,8 +170,7 @@ public void processBulkShardRequest( bulkItemRequest, fieldsForModels, i, - onBulkItemFailure, - failedItems, + onBulkItemFailureWithIndex, bulkItemReqRef.acquire() ); } @@ -174,12 +178,12 @@ public void processBulkShardRequest( } } + @SuppressWarnings("unchecked") private void performInferenceOnBulkItemRequest( BulkItemRequest bulkItemRequest, Map> fieldsForModels, Integer itemIndex, - BiConsumer onBulkItemFailure, - Set failedItems, + TriConsumer onBulkItemFailure, Releasable releaseOnFinish ) { @@ -210,46 +214,69 @@ private void performInferenceOnBulkItemRequest( releaseOnFinish.close(); })) { - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - - @SuppressWarnings("unchecked") - Map rootInferenceFieldMap = (Map) docMap.computeIfAbsent( + Map rootInferenceFieldMap; + try { + rootInferenceFieldMap = (Map) docMap.computeIfAbsent( ROOT_INFERENCE_FIELD, k -> new HashMap() ); + } catch (ClassCastException e) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") + ); + return; + } + for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { + String modelId = fieldModelsEntrySet.getKey(); List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { continue; } InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); if (inferenceProvider == null) { - failedItems.add(itemIndex); - onBulkItemFailure.accept( + onBulkItemFailure.apply( bulkItemRequest, + itemIndex, new IllegalArgumentException("No inference provider found for model ID " + modelId) ); - continue; + return; } ActionListener inferenceResultsListener = new ActionListener<>() { @Override public void onResponse(InferenceServiceResults results) { - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException( + "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ) ); } int i = 0; for (InferenceResults inferenceResults : results.transformToLegacyFormat()) { String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - List> inferenceFieldResultList = (List>) rootInferenceFieldMap - .computeIfAbsent(fieldName, k -> new ArrayList<>()); + List> inferenceFieldResultList; + try { + inferenceFieldResultList = (List>) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new ArrayList<>() + ); + } catch (ClassCastException e) { + onBulkItemFailure.apply( + bulkItemRequest, + itemIndex, + new IllegalArgumentException( + "Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object" + ) + ); + return; + } // Remove previous inference results if any inferenceFieldResultList.clear(); @@ -266,8 +293,7 @@ public void onResponse(InferenceServiceResults results) { @Override public void onFailure(Exception e) { - failedItems.add(itemIndex); - onBulkItemFailure.accept(bulkItemRequest, e); + onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); } }; inferenceProvider.service() diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index a688df5d797a2..f8ed331d358b2 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -56,6 +56,8 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; +import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -97,7 +99,7 @@ public void testNoInference() { Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) ); - Map originalSource = Map.of( + Map originalSource = Map.of( randomAlphaOfLengthBetween(1, 20), randomAlphaOfLengthBetween(1, 100), randomAlphaOfLengthBetween(1, 20), @@ -134,7 +136,7 @@ public void testFailedBulkShardRequest() { ModelRegistry modelRegistry = createModelRegistry(Map.of()); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); - Map originalSource = Map.of( + Map originalSource = Map.of( randomAlphaOfLengthBetween(1, 20), randomAlphaOfLengthBetween(1, 100), randomAlphaOfLengthBetween(1, 20), @@ -200,7 +202,7 @@ public void testInference() { String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( + Map originalSource = Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1, SECOND_INFERENCE_FIELD_SERVICE_1, @@ -260,7 +262,7 @@ public void testFailedInference() { InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( + Map originalSource = Map.of( FIRST_INFERENCE_FIELD_SERVICE_1, firstInferenceTextService1, randomAlphaOfLengthBetween(1, 20), @@ -283,6 +285,69 @@ public void testFailedInference() { } + public void testInferenceFailsForIncorrectRootObject() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceServiceThatFails(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + randomAlphaOfLengthBetween(1, 100), + ROOT_INFERENCE_FIELD, + "incorrect_root_object" + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); + } + + public void testInferenceFailsForIncorrectInferenceFieldObject() { + + Map> fieldsForModels = Map.of(INFERENCE_SERVICE_1_ID, Set.of(FIRST_INFERENCE_FIELD_SERVICE_1)); + + ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); + + Model model = mock(Model.class); + InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); + InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); + + Map originalSource = Map.of( + FIRST_INFERENCE_FIELD_SERVICE_1, + randomAlphaOfLengthBetween(1, 100), + ROOT_INFERENCE_FIELD, + Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, "incorrect_inference_field_value") + ); + + ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); + @SuppressWarnings("unchecked") + ActionListener bulkOperationListener = mock(ActionListener.class); + runBulkOperation(originalSource, fieldsForModels, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); + + verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); + BulkResponse bulkResponse = bulkResponseCaptor.getValue(); + assertTrue(bulkResponse.hasFailures()); + BulkItemResponse item = bulkResponse.getItems()[0]; + assertTrue(item.isFailed()); + assertThat( + item.getFailure().getCause().getMessage(), + containsString("Inference result field [_semantic_text_inference.first_inference_field_service_1] is not an object") + ); + } + public void testInferenceIdNotFound() { Map> fieldsForModels = Map.of( @@ -298,7 +363,7 @@ public void testInferenceIdNotFound() { InferenceService inferenceService = createInferenceService(model, INFERENCE_SERVICE_1_ID); InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - Map originalSource = Map.of( + Map originalSource = Map.of( INFERENCE_FIELD_SERVICE_2, randomAlphaOfLengthBetween(1, 100), randomAlphaOfLengthBetween(1, 20), @@ -325,7 +390,7 @@ public void testInferenceIdNotFound() { @SuppressWarnings("unchecked") private static void checkInferenceResults( - Map docSource, + Map docSource, Map writtenDocSource, String... inferenceFieldNames ) { @@ -377,7 +442,7 @@ public String toString() { } private static BulkShardRequest runBulkOperation( - Map docSource, + Map docSource, Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry, @@ -396,7 +461,7 @@ private static BulkShardRequest runBulkOperation( } private static BulkShardRequest runBulkOperation( - Map docSource, + Map docSource, Map> fieldsForModels, ModelRegistry modelRegistry, InferenceServiceRegistry inferenceServiceRegistry,