From ab97838a3f74c1a7df6f419b3ff95240ce665484 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 21 Nov 2023 13:05:37 +0100 Subject: [PATCH] Add diff support for model for fields, changed implementation to Set --- .../cluster/metadata/IndexMetadata.java | 53 +++++++++++++++---- .../cluster/metadata/MappingMetadata.java | 8 +-- .../elasticsearch/index/IndexSettings.java | 9 ---- .../index/mapper/FieldTypeLookup.java | 12 ++--- .../index/mapper/MappingLookup.java | 6 +-- .../elasticsearch/ingest/IngestService.java | 2 +- .../xpack/ml/MachineLearning.java | 5 +- .../SemanticTextInferenceProcessor.java | 7 +-- 8 files changed, 61 insertions(+), 41 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 5404788d401cc..c163899bbdb0b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -638,7 +638,7 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; - private final Map> inferenceModelsForFields; + private final Map> inferenceModelsForFields; private IndexMetadata( final Index index, @@ -685,7 +685,7 @@ private IndexMetadata( @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, @Nullable Long shardSizeInBytesForecast, - final Map> inferenceModelsForFields + final Map> inferenceModelsForFields ) { this.index = index; this.version = version; @@ -1224,7 +1224,7 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } - public Map> getInferenceModelsForFields() { + public Map> getInferenceModelsForFields() { return inferenceModelsForFields; } @@ -1492,6 +1492,7 @@ private static class IndexMetadataDiff implements Diff { private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; + private final Diff>> modelsForFields; IndexMetadataDiff(IndexMetadata before, IndexMetadata after) { index = after.index.getName(); @@ -1528,6 +1529,12 @@ private static class IndexMetadataDiff implements Diff { stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; + modelsForFields = DiffableUtils.diff( + before.inferenceModelsForFields, + after.inferenceModelsForFields, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); } private static final DiffableUtils.DiffableValueReader ALIAS_METADATA_DIFF_VALUE_READER = @@ -1587,6 +1594,15 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } + if (in.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD)) { + modelsForFields = DiffableUtils.readJdkMapDiff( + in, + DiffableUtils.getStringKeySerializer(), + DiffableUtils.StringSetValueSerializer.getInstance() + ); + } else { + modelsForFields = DiffableUtils.emptyDiff(); + } } @Override @@ -1622,6 +1638,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD)) { + modelsForFields.writeTo(out); + } } @Override @@ -1651,6 +1670,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); + builder.inferenceModelsForFields(modelsForFields.apply(part.inferenceModelsForFields)); return builder.build(true); } } @@ -1719,7 +1739,9 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function i.readCollectionAsImmutableSet(StreamInput::readString)) + ); } return builder.build(true); } @@ -1819,7 +1841,8 @@ public static class Builder { private IndexMetadataStats stats = null; private Double indexWriteLoadForecast = null; private Long shardSizeInBytesForecast = null; - private Map> inferenceModelsForFields = Map.of(); + + private Map> inferenceModelsForFields = Map.of(); public Builder(String index) { this.index = index; @@ -1933,7 +1956,7 @@ public Builder putMapping(String source) { public Builder putMapping(MappingMetadata mappingMd) { mapping = mappingMd; - Map> fieldsForModels = mappingMd.getFieldsForModels(); + Map> fieldsForModels = mappingMd.getFieldsForModels(); if (fieldsForModels != null) { inferenceModelsForFields = fieldsForModels; } @@ -2085,11 +2108,16 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } - public Builder inferenceModelsForfields(Map> inferenceModelsForfields) { + public Builder inferenceModelsForfields(Map> inferenceModelsForfields) { this.inferenceModelsForFields = inferenceModelsForfields; return this; } + public Builder inferenceModelsForFields(Map> inferenceModelsForFields) { + this.inferenceModelsForFields = inferenceModelsForFields; + return this; + } + public IndexMetadata build() { return build(false); } @@ -2411,7 +2439,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast); } - Map> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields(); + Map> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields(); if ((inferenceModelsForFields != null) && (inferenceModelsForFields.isEmpty() == false)) { builder.field(INFERENCE_MODELS_FIELDS, indexMetadata.getInferenceModelsForFields()); } @@ -2494,10 +2522,15 @@ public static IndexMetadata fromXContent(XContentParser parser, Map> inferenceModels = parser.map(HashMap::new, XContentParser::list) + Map> inferenceModels = parser.map(HashMap::new, XContentParser::list) .entrySet() .stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().stream().map(Object::toString).toList())); + .collect( + Collectors.toMap( + Map.Entry::getKey, + e -> e.getValue().stream().map(Object::toString).collect(Collectors.toSet()) + ) + ); builder.inferenceModelsForfields(inferenceModels); break; default: diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index a5695b61104d0..bd1ceb36862cf 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -21,9 +21,9 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -43,7 +43,7 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; - private final Map> fieldsForModels; + private final Map> fieldsForModels; public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); @@ -127,7 +127,7 @@ public CompressedXContent source() { return this.source; } - public Map> getFieldsForModels() { + public Map> getFieldsForModels() { return fieldsForModels; } @@ -205,7 +205,7 @@ public MappingMetadata(StreamInput in) throws IOException { source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD)) { - fieldsForModels = in.readMapOfLists(StreamInput::readString); + fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); } else { fieldsForModels = Map.of(); } diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 664fd2aec0f21..83a6d9319c75a 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -523,15 +523,6 @@ public Iterator> settings() { Property.ServerlessPublic ); - public static final Setting INFERENCE_PIPELINE = new Setting<>( - "index.inference_pipeline", - IngestService.NOOP_PIPELINE_NAME, - Function.identity(), - Property.PrivateIndex, - Property.IndexScope, - Property.ServerlessPublic - ); - /** * Marks an index to be searched throttled. This means that never more than one shard of such an index will be searched concurrently */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 25c20e9797177..81ac639703694 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -10,12 +10,10 @@ import org.elasticsearch.common.regex.Regex; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -38,7 +36,7 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; - private final Map> fieldsForModel; + private final Map> fieldsForModel; private final int maxParentPathDots; @@ -52,7 +50,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); - final Map> fieldsForModel = new HashMap<>(); + final Map> fieldsForModel = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -71,7 +69,7 @@ final class FieldTypeLookup { fieldToCopiedFields.get(targetField).add(fieldName); } if (fieldType.hasInferenceModel()) { - Collection fields = fieldsForModel.computeIfAbsent(fieldType.getInferenceModel(), v -> new ArrayList<>()); + Collection fields = fieldsForModel.computeIfAbsent(fieldType.getInferenceModel(), v -> new HashSet<>()); fields.add(fieldName); } } @@ -119,11 +117,11 @@ public static int dotCount(String path) { return dotCount; } - List fieldsForModel(String modelName) { + Set fieldsForModel(String modelName) { return this.fieldsForModel.get(modelName); } - Map> fieldsForModel() { + Map> fieldsForModel() { return this.fieldsForModel; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index ee389340e5d33..f483c47e200f6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -492,11 +492,7 @@ public void validateDoesNotShadow(String name) { } } - public List fieldsForModel(String modelName) { - return fieldTypeLookup.fieldsForModel(modelName); - } - - public Map> fieldsForModels() { + public Map> fieldsForModels() { return fieldTypeLookup.fieldsForModel(); } } diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index c584eb444bc6a..d191a2ba0852e 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -1429,7 +1429,7 @@ private static Optional resolvePipelinesFromIndexTemplates(IndexReque defaultPipeline = Objects.requireNonNullElse(defaultPipeline, NOOP_PIPELINE_NAME); finalPipeline = Objects.requireNonNullElse(finalPipeline, NOOP_PIPELINE_NAME); - return Optional.of(new Pipelines(defaultPipeline, finalPipeline, null)); + return Optional.of(new Pipelines(defaultPipeline, finalPipeline, NOOP_PIPELINE_NAME)); } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index abb174816d8fb..9c31599bd7c4f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -465,6 +465,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; import java.util.function.UnaryOperator; @@ -2287,13 +2288,13 @@ public Optional getIngestPipeline(IndexMetadata indexMetadata, Process return Optional.empty(); } - Map> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields(); + Map> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields(); if (inferenceModelsForFields.isEmpty()) { return Optional.empty(); } Collection inferenceProcessors = new ArrayList<>(); - for (Map.Entry> modelsForFieldsEntry : inferenceModelsForFields.entrySet()) { + for (Map.Entry> modelsForFieldsEntry : inferenceModelsForFields.entrySet()) { Map inferenceConfig = new HashMap<>(); String modelId = modelsForFieldsEntry.getKey(); inferenceConfig.put("model_id", modelId); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java index aea92c83c2a4b..f34846bb5b62c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.BiConsumer; public class SemanticTextInferenceProcessor extends AbstractProcessor implements WrappingProcessor { @@ -25,7 +26,7 @@ public class SemanticTextInferenceProcessor extends AbstractProcessor implements public static final String TYPE = "semanticTextInference"; public static final String TAG = "semantic_text"; - private final Map> fieldsForModels; + private final Map> fieldsForModels; private final Processor wrappedProcessor; @@ -36,7 +37,7 @@ public SemanticTextInferenceProcessor( Client client, InferenceAuditor inferenceAuditor, String description, - Map> fieldsForModels + Map> fieldsForModels ) { super(TAG, description); this.client = client; @@ -54,7 +55,7 @@ private Processor createWrappedProcessor() { return new CompoundProcessor(inferenceProcessors); } - private InferenceProcessor createInferenceProcessor(String modelId, List fields) { + private InferenceProcessor createInferenceProcessor(String modelId, Set fields) { List inputConfigs = fields.stream() .map(f -> new InferenceProcessor.Factory.InputConfig(f, "ml.inference", f, Map.of())) .toList();