diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 4e28483d24200..ed41ddbfd86e9 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -275,7 +275,7 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec IndexRequest indexRequest = getIndexWriteRequest(actionRequest); if (indexRequest != null) { ingestService.resolvePipelinesAndUpdateIndexRequest(actionRequest, indexRequest, metadata); - hasIndexRequestsWithPipelines |= IngestService.hasPipeline(indexRequest); + hasIndexRequestsWithPipelines |= ingestService.hasPipeline(indexRequest); } if (actionRequest instanceof IndexRequest ir) { diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index d883049816368..335d5e5666792 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -102,7 +102,7 @@ public class IndexRequest extends ReplicatedWriteRequest implement private String pipeline; private String finalPipeline; - private String inferencePipeline; + private String pluginsPipeline; private boolean isPipelineResolved; @@ -191,7 +191,7 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio } } if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD)) { - this.inferencePipeline = in.readOptionalString(); + this.pluginsPipeline = in.readOptionalString(); } } @@ -274,10 +274,6 @@ public ActionRequestValidationException validate() { validationException = addValidationError("final pipeline cannot be an empty string", validationException); } - if (inferencePipeline != null && inferencePipeline.isEmpty()) { - validationException = addValidationError("inference pipeline cannot be an empty string", validationException); - } - return validationException; } @@ -363,12 +359,13 @@ public String getFinalPipeline() { return this.finalPipeline; } - public String getInferencePipeline() { - return inferencePipeline; + public String getPluginsPipeline() { + return pluginsPipeline; } - public void setInferencePipeline(String inferencePipeline) { - this.inferencePipeline = inferencePipeline; + public IndexRequest setPluginsPipeline(String pluginsPipeline) { + this.pluginsPipeline = pluginsPipeline; + return this; } /** @@ -751,7 +748,7 @@ private void writeBody(StreamOutput out) throws IOException { } } if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD)) { - out.writeOptionalString(inferencePipeline); + out.writeOptionalString(pluginsPipeline); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 742b52365c8d7..5b62c3eb17adf 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -80,6 +80,7 @@ import java.util.Set; import java.util.function.Function; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_FIELD; import static org.elasticsearch.cluster.metadata.Metadata.CONTEXT_MODE_PARAM; import static org.elasticsearch.cluster.metadata.Metadata.DEDUPLICATED_MAPPINGS_PARAM; import static org.elasticsearch.cluster.node.DiscoveryNodeFilters.OpType.AND; @@ -635,6 +636,7 @@ public Iterator> settings() { private final Double writeLoadForecast; @Nullable private final Long shardSizeInBytesForecast; + private final Map> inferenceModelsForFields; private IndexMetadata( final Index index, @@ -680,7 +682,8 @@ private IndexMetadata( final IndexVersion indexCompatibilityVersion, @Nullable final IndexMetadataStats stats, @Nullable final Double writeLoadForecast, - @Nullable Long shardSizeInBytesForecast + @Nullable Long shardSizeInBytesForecast, + final Map> inferenceModelsForFields ) { this.index = index; this.version = version; @@ -736,6 +739,8 @@ private IndexMetadata( this.writeLoadForecast = writeLoadForecast; this.shardSizeInBytesForecast = shardSizeInBytesForecast; assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards; + this.inferenceModelsForFields = inferenceModelsForFields; + assert inferenceModelsForFields != null; } IndexMetadata withMappingMetadata(MappingMetadata mapping) { @@ -786,7 +791,8 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.inferenceModelsForFields ); } @@ -844,7 +850,8 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.inferenceModelsForFields ); } @@ -900,7 +907,8 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.inferenceModelsForFields ); } @@ -956,7 +964,8 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.inferenceModelsForFields ); } @@ -1008,7 +1017,8 @@ public IndexMetadata withIncrementedVersion() { this.indexCompatibilityVersion, this.stats, this.writeLoadForecast, - this.shardSizeInBytesForecast + this.shardSizeInBytesForecast, + this.inferenceModelsForFields ); } @@ -1212,6 +1222,10 @@ public OptionalLong getForecastedShardSizeInBytes() { return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast); } + public Map> getInferenceModelsForFields() { + return inferenceModelsForFields; + } + public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid"; public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name"; public static final Setting INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY); @@ -1702,6 +1716,9 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function> inferenceModelsForFields = Map.of(); public Builder(String index) { this.index = index; @@ -1828,6 +1849,7 @@ public Builder(IndexMetadata indexMetadata) { this.stats = indexMetadata.stats; this.indexWriteLoadForecast = indexMetadata.writeLoadForecast; this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast; + this.inferenceModelsForFields = indexMetadata.inferenceModelsForFields; } public Builder index(String index) { @@ -2057,6 +2079,11 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) { return this; } + public Builder inferenceModelsForfields(Map> inferenceModelsForfields) { + this.inferenceModelsForFields = inferenceModelsForfields; + return this; + } + public IndexMetadata build() { return build(false); } @@ -2251,7 +2278,8 @@ IndexMetadata build(boolean repair) { SETTING_INDEX_VERSION_COMPATIBILITY.get(settings), stats, indexWriteLoadForecast, - shardSizeInBytesForecast + shardSizeInBytesForecast, + inferenceModelsForFields ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index da24f0b9d0dc5..0f0773f538163 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1267,6 +1267,7 @@ static IndexMetadata buildIndexMetadata( if (mapper != null) { MappingMetadata mappingMd = new MappingMetadata(mapper); mappingsMetadata.put(mapper.type(), mappingMd); + indexMetadataBuilder.inferenceModelsForfields(mapper.mappers().fieldsForModels()); } for (MappingMetadata mappingMd : mappingsMetadata.values()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java index 7a2d20d042f84..b04a30d68387d 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java @@ -199,6 +199,7 @@ private static ClusterState applyRequest( DocumentMapper mapper = mapperService.documentMapper(); if (mapper != null) { indexMetadataBuilder.putMapping(new MappingMetadata(mapper)); + indexMetadataBuilder.inferenceModelsForfields(mapper.mappers().fieldsForModels()); } if (updatedMapping) { indexMetadataBuilder.mappingVersion(1 + indexMetadataBuilder.mappingVersion()); diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index a91be798b5ae0..f6970b01c017c 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -41,7 +41,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterServiceTaskQueue; import org.elasticsearch.common.Priority; -import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; @@ -55,8 +54,6 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.grok.MatcherWatchdog; -import org.elasticsearch.index.Index; -import org.elasticsearch.index.IndexService; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.VersionType; import org.elasticsearch.index.analysis.AnalysisRegistry; @@ -68,7 +65,6 @@ import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentType; import java.io.IOException; import java.util.ArrayList; @@ -113,20 +109,19 @@ public class IngestService implements ClusterStateApplier, ReportingService processorFactories; private final IndexNameExpressionResolver indexNameExpressionResolver; private final IndicesService indicesService; + private final List ingestPlugins; // Ideally this should be in IngestMetadata class, but we don't have the processor factories around there. // We know of all the processor factories when a node with all its plugin have been initialized. Also some // processor factories rely on other node services. Custom metadata is statically registered when classes // are loaded, so in the cluster state we just save the pipeline config and here we keep the actual pipelines around. private volatile Map pipelines = Map.of(); - - private volatile Map inferencePipelines = new HashMap<>(); - - private volatile Map inferencePipelinesForIndices = new HashMap<>(); + private volatile Map> pluginPipelines = Map.of(); private final ThreadPool threadPool; private final IngestMetric totalMetrics = new IngestMetric(); private final List> ingestClusterStateListeners = new CopyOnWriteArrayList<>(); private volatile ClusterState state; + private final Processor.Parameters processorParameters; private static BiFunction createScheduler(ThreadPool threadPool) { return (delay, command) -> threadPool.schedule(command, TimeValue.timeValueMillis(delay), threadPool.generic()); @@ -200,20 +195,23 @@ public IngestService( this.clusterService = clusterService; this.scriptService = scriptService; this.documentParsingObserverSupplier = documentParsingObserverSupplier; + this.ingestPlugins = ingestPlugins; + this.processorParameters = new Processor.Parameters( + env, + scriptService, + analysisRegistry, + threadPool.getThreadContext(), + threadPool::relativeTimeInMillis, + createScheduler(threadPool), + this, + client, + threadPool.generic()::execute, + matcherWatchdog + ); + this.processorFactories = processorFactories( ingestPlugins, - new Processor.Parameters( - env, - scriptService, - analysisRegistry, - threadPool.getThreadContext(), - threadPool::relativeTimeInMillis, - createScheduler(threadPool), - this, - client, - threadPool.generic()::execute, - matcherWatchdog - ) + processorParameters ); this.threadPool = threadPool; this.taskQueue = clusterService.createTaskQueue("ingest-pipelines", Priority.NORMAL, PIPELINE_TASK_EXECUTOR); @@ -279,7 +277,7 @@ void resolvePipelinesAndUpdateIndexRequest( indexRequest.setPipeline(pipelines.defaultPipeline); } indexRequest.setFinalPipeline(pipelines.finalPipeline); - indexRequest.setInferencePipeline(pipelines.inferencePipeline); + indexRequest.setPluginsPipeline(pipelines.pluginPipelines); indexRequest.isPipelineResolved(true); } @@ -494,17 +492,12 @@ public Pipeline getPipeline(String id) { } } - private Pipeline getInferencePipeline(String id) { + private List getPluginsPipelines(String id) { if (id == null) { return null; } - PipelineHolder holder = inferencePipelines.get(id); - if (holder != null) { - return holder.pipeline; - } else { - return null; - } + return this.pluginPipelines.get(id); } public Map getProcessorFactories() { @@ -771,10 +764,10 @@ private PipelineIterator getAndResetPipelines(IndexRequest indexRequest) { indexRequest.setPipeline(NOOP_PIPELINE_NAME); final String finalPipelineId = indexRequest.getFinalPipeline(); indexRequest.setFinalPipeline(NOOP_PIPELINE_NAME); - final String inferencePipelineId = indexRequest.getInferencePipeline(); - indexRequest.setInferencePipeline(NOOP_PIPELINE_NAME); + final String pluginPipelineId = indexRequest.getPluginsPipeline(); + indexRequest.setPluginsPipeline(NOOP_PIPELINE_NAME); - return new PipelineIterator(pipelineId, finalPipelineId, inferencePipelineId); + return new PipelineIterator(pipelineId, finalPipelineId, pluginPipelines.get(pluginPipelineId)); } /** @@ -794,19 +787,19 @@ private class PipelineIterator implements Iterator { private final String defaultPipeline; private final String finalPipeline; - private final String inferencePipeline; + private final List pluginsPipelines; private final Iterator pipelineSlotIterator; - private PipelineIterator(String defaultPipeline, String finalPipeline, String inferencePipeline) { + private PipelineIterator(String defaultPipeline, String finalPipeline, List pluginsPipelines) { this.defaultPipeline = NOOP_PIPELINE_NAME.equals(defaultPipeline) ? null : defaultPipeline; this.finalPipeline = NOOP_PIPELINE_NAME.equals(finalPipeline) ? null : finalPipeline; - this.inferencePipeline = NOOP_PIPELINE_NAME.equals(inferencePipeline) ? null : inferencePipeline; + this.pluginsPipelines = pluginsPipelines; this.pipelineSlotIterator = iterator(); } public PipelineIterator withoutDefaultPipeline() { - return new PipelineIterator(null, finalPipeline, inferencePipeline); + return new PipelineIterator(null, finalPipeline, pluginsPipelines); } private Iterator iterator() { @@ -817,8 +810,8 @@ private Iterator iterator() { if (finalPipeline != null) { slotList.add(new PipelineSlot(finalPipeline, getPipeline(finalPipeline), true)); } - if (inferencePipeline != null) { - slotList.add(new PipelineSlot("inference", getInferencePipeline(inferencePipeline), false)); + if (pluginsPipelines != null) { + slotList.addAll(pluginsPipelines.stream().map(pipeline -> new PipelineSlot("plugins", pipeline, false)).toList()); } return slotList.iterator(); } @@ -1120,18 +1113,41 @@ public void applyClusterState(final ClusterChangedEvent event) { ingestClusterStateListeners.forEach(consumer -> consumer.accept(state)); IngestMetadata newIngestMetadata = state.getMetadata().custom(IngestMetadata.TYPE); - if (newIngestMetadata == null) { - return; + if (newIngestMetadata != null) { + try { + innerUpdatePipelines(newIngestMetadata, event.state().metadata()); + } catch (ElasticsearchParseException e) { + logger.warn("failed to update ingest pipelines", e); + } } - try { - innerUpdatePipelines(newIngestMetadata, event.state().metadata()); - } catch (ElasticsearchParseException e) { - logger.warn("failed to update ingest pipelines", e); - } + updatePluginPipelines(event); + } - inferencePipelines.clear(); - inferencePipelinesForIndices.clear(); + private synchronized void updatePluginPipelines(ClusterChangedEvent event) { + // TODO Update just for changed indices + Map currentIndexMetadata = state.metadata().indices(); + Map previousIndexMetadata = event.previousState().metadata().indices(); + if (currentIndexMetadata.equals(previousIndexMetadata) == false) { + Map> updatedPluginPipelines = new HashMap<>(); + HashSet indicesNames = new HashSet<>(currentIndexMetadata.keySet()); + indicesNames.addAll(previousIndexMetadata.keySet()); + for (String indexName : indicesNames) { + List pipelineList = ingestPlugins.stream() + .map( + plugin -> plugin.getIngestPipeline( + currentIndexMetadata.get(indexName), + previousIndexMetadata.get(indexName), + processorParameters + ) + ).flatMap(Optional::stream) + .toList(); + if (pipelineList.isEmpty() == false) { + updatedPluginPipelines.put(indexName, pipelineList); + } + } + pluginPipelines = Map.copyOf(updatedPluginPipelines); + } } synchronized void innerUpdatePipelines(IngestMetadata newIngestMetadata, Metadata metadata) { @@ -1324,8 +1340,9 @@ private Optional resolvePipelinesFromMetadata( .get(IndexNameExpressionResolver.resolveDateMathExpression(originalRequest.index(), epochMillis)); } // check the alias for the index request (this is how normal index requests are modeled) - if (indexMetadata == null && indexRequest.index() != null) { - IndexAbstraction indexAbstraction = metadata.getIndicesLookup().get(indexRequest.index()); + String indexName = indexRequest.index(); + if (indexMetadata == null && indexName != null) { + IndexAbstraction indexAbstraction = metadata.getIndicesLookup().get(indexName); if (indexAbstraction != null && indexAbstraction.getWriteIndex() != null) { indexMetadata = metadata.index(indexAbstraction.getWriteIndex()); } @@ -1343,78 +1360,16 @@ private Optional resolvePipelinesFromMetadata( } final Settings settings = indexMetadata.getSettings(); + List pluginsPipelines = getPluginsPipelines(indexName); return Optional.of( new Pipelines( IndexSettings.DEFAULT_PIPELINE.get(settings), IndexSettings.FINAL_PIPELINE.get(settings), - getOrCreateInferencePipeline(indexMetadata) - ) + pluginPipelines == null ? NOOP_PIPELINE_NAME : indexName + ) ); } - private String getOrCreateInferencePipeline(IndexMetadata indexMetadata) { - Index index = indexMetadata.getIndex(); - String inferencePipelineName = inferencePipelinesForIndices.get(index); - if (inferencePipelineName != null) { - return inferencePipelineName; - } - - IndexService indexService = indicesService.indexService(index); - Map> fieldsForModels = null; -// if (indexService != null) { -// fieldsForModels = indexService.mapperService().mappingLookup().fieldsForModels(); -// } else { - fieldsForModels = indexMetadata.mapping().getFieldsForModels(); -// } - if (fieldsForModels.isEmpty()) { - inferencePipelineName = NOOP_PIPELINE_NAME; - inferencePipelinesForIndices.put(index, inferencePipelineName); - return inferencePipelineName; - } - - Collection inferenceProcessors = new ArrayList<>(); - for (Map.Entry> modelsForFieldsEntry : fieldsForModels.entrySet()) { - Map inferenceConfig = new HashMap<>(); - inferenceConfig.put("model_id", modelsForFieldsEntry.getKey()); - Collection> inputOutputConfigs = new ArrayList<>(); - for (String field : modelsForFieldsEntry.getValue()) { - Map params = new HashMap<>(); - params.put("input_field", field); - params.put("output_field", "ml.inference." + field); - inputOutputConfigs.add(params); - } - inferenceConfig.put("input_output", inputOutputConfigs); - - try { - inferenceProcessors.add(processorFactories.get("inference").create(processorFactories, null, null, inferenceConfig)); - } catch (Exception e) { - logger.error( - "Cannot create inference processor for model [" - + modelsForFieldsEntry.getKey() - + "] with fields " - + modelsForFieldsEntry.getValue(), - e - ); - } - } - - inferencePipelineName = "_inference_" + index.getName(); - Pipeline inferencePipeline = new Pipeline( - inferencePipelineName, - null, - null, - null, - new CompoundProcessor(inferenceProcessors.toArray(new Processor[] {})) - ); - inferencePipelines.put( - inferencePipelineName, - new PipelineHolder(new PipelineConfiguration(inferencePipelineName, new BytesArray(""), XContentType.JSON), inferencePipeline) - ); - inferencePipelinesForIndices.put(index, inferencePipelineName); - - return inferencePipelineName; - } - private static Optional resolvePipelinesFromIndexTemplates(IndexRequest indexRequest, Metadata metadata) { if (indexRequest.index() == null) { return Optional.empty(); @@ -1458,7 +1413,7 @@ private static Optional resolvePipelinesFromIndexTemplates(IndexReque defaultPipeline = Objects.requireNonNullElse(defaultPipeline, NOOP_PIPELINE_NAME); finalPipeline = Objects.requireNonNullElse(finalPipeline, NOOP_PIPELINE_NAME); - return Optional.of(new Pipelines(defaultPipeline, finalPipeline, NOOP_PIPELINE_NAME)); + return Optional.of(new Pipelines(defaultPipeline, finalPipeline, null)); } /** @@ -1466,24 +1421,24 @@ private static Optional resolvePipelinesFromIndexTemplates(IndexReque *

* This method assumes that the pipelines are beforehand resolved. */ - public static boolean hasPipeline(IndexRequest indexRequest) { + public boolean hasPipeline(IndexRequest indexRequest) { assert indexRequest.isPipelineResolved(); assert indexRequest.getPipeline() != null; assert indexRequest.getFinalPipeline() != null; - assert indexRequest.getInferencePipeline() != null; + assert indexRequest.getPluginsPipeline() != null; return NOOP_PIPELINE_NAME.equals(indexRequest.getPipeline()) == false || NOOP_PIPELINE_NAME.equals(indexRequest.getFinalPipeline()) == false - || NOOP_PIPELINE_NAME.equals(indexRequest.getInferencePipeline()) == false; + || NOOP_PIPELINE_NAME.equals(indexRequest.getPluginsPipeline()) == false; } - private record Pipelines(String defaultPipeline, String finalPipeline, String inferencePipeline) { + private record Pipelines(String defaultPipeline, String finalPipeline, String pluginPipelines) { private static final Pipelines NO_PIPELINES_DEFINED = new Pipelines(NOOP_PIPELINE_NAME, NOOP_PIPELINE_NAME, NOOP_PIPELINE_NAME); public Pipelines { Objects.requireNonNull(defaultPipeline); Objects.requireNonNull(finalPipeline); - Objects.requireNonNull(inferencePipeline); + Objects.requireNonNull(pluginPipelines); } } } diff --git a/server/src/main/java/org/elasticsearch/plugins/IngestPlugin.java b/server/src/main/java/org/elasticsearch/plugins/IngestPlugin.java index 75b58908ef599..9100fbf6f59db 100644 --- a/server/src/main/java/org/elasticsearch/plugins/IngestPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/IngestPlugin.java @@ -8,9 +8,12 @@ package org.elasticsearch.plugins; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.ingest.Pipeline; import org.elasticsearch.ingest.Processor; import java.util.Map; +import java.util.Optional; /** * An extension point for {@link Plugin} implementations to add custom ingest processors @@ -27,4 +30,8 @@ public interface IngestPlugin { default Map getProcessors(Processor.Parameters parameters) { return Map.of(); } + + default Optional getIngestPipeline(IndexMetadata current, IndexMetadata previous, Processor.Parameters parameters) { + return Optional.empty(); + } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index f767a36ec0dc1..9b6edd5f7ced4 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1824,7 +1824,8 @@ protected void assertSnapshotOrGenericThread() { actionFilters ) ); - final MetadataMappingService metadataMappingService = new MetadataMappingService(clusterService, indicesService); + + final MetadataMappingService metadataMappingService = new MetadataMappingService(clusterService, indicesService, List.of()); peerRecoverySourceService = new PeerRecoverySourceService( transportService, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 18419a5094e92..e7e47d5b86584 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; import org.elasticsearch.cluster.metadata.Metadata; @@ -53,6 +54,8 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; import org.elasticsearch.indices.breaker.BreakerSettings; +import org.elasticsearch.ingest.CompoundProcessor; +import org.elasticsearch.ingest.Pipeline; import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; @@ -324,6 +327,7 @@ import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService; import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.ingest.SemanticTextInferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -457,8 +461,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; import java.util.function.UnaryOperator; @@ -753,8 +759,8 @@ public void loadExtensions(ExtensionLoader loader) { private final SetOnce mlAutoscalingDeciderService = new SetOnce<>(); private final SetOnce deploymentManager = new SetOnce<>(); private final SetOnce trainedModelAllocationClusterServiceSetOnce = new SetOnce<>(); - private final SetOnce machineLearningExtension = new SetOnce<>(); + private final SetOnce inferenceAuditorSetOnce = new SetOnce<>(); public MachineLearning(Settings settings) { this.settings = settings; @@ -923,7 +929,7 @@ public Collection createComponents(PluginServices services) { clusterService, machineLearningExtension.get().includeNodeInfo() ); - InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService, machineLearningExtension.get().includeNodeInfo()); + inferenceAuditorSetOnce.set(new InferenceAuditor(client, clusterService, machineLearningExtension.get().includeNodeInfo())); SystemAuditor systemAuditor = new SystemAuditor(client, clusterService); this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor); @@ -1095,7 +1101,7 @@ public Collection createComponents(PluginServices services) { final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); final ModelLoadingService modelLoadingService = new ModelLoadingService( trainedModelProvider, - inferenceAuditor, + inferenceAuditorSetOnce.get(), threadPool, clusterService, trainedModelStatsService, @@ -1244,7 +1250,7 @@ public Collection createComponents(PluginServices services) { datafeedManager, anomalyDetectionAuditor, dataFrameAnalyticsAuditor, - inferenceAuditor, + inferenceAuditorSetOnce.get(), systemAuditor, mlAssignmentNotifier, mlAutoUpdateService, @@ -2273,4 +2279,61 @@ public void signalShutdown(Collection shutdownNodeIds) { public Map getMappers() { return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); } + + @Override + public Optional getIngestPipeline(IndexMetadata current, IndexMetadata previous, Processor.Parameters parameters) { + + if (current == null) { + return Optional.empty(); + } + + Map> inferenceModelsForFields = current.getInferenceModelsForFields(); + if (inferenceModelsForFields.isEmpty()) { + return Optional.empty(); + } + + Collection inferenceProcessors = new ArrayList<>(); + for (Map.Entry> modelsForFieldsEntry : inferenceModelsForFields.entrySet()) { + Map inferenceConfig = new HashMap<>(); + String modelId = modelsForFieldsEntry.getKey(); + inferenceConfig.put("model_id", modelId); + Collection> inputOutputConfigs = new ArrayList<>(); + for (String field : modelsForFieldsEntry.getValue()) { + Map params = new HashMap<>(); + params.put("input_field", field); + params.put("output_field", "ml.inference." + field); + inputOutputConfigs.add(params); + } + inferenceConfig.put("input_output", inputOutputConfigs); + + try { + SemanticTextInferenceProcessor semanticTextInferenceProcessor = new SemanticTextInferenceProcessor( + parameters.client, + inferenceAuditorSetOnce.get(), + "semantic text processor for index " + current.getIndex().getName() + ", model " + modelId, + inferenceModelsForFields + ); + inferenceProcessors.add(semanticTextInferenceProcessor); + } catch (Exception e) { + logger.error( + "Cannot create inference processor for model [" + + modelId + + "] with fields " + + modelsForFieldsEntry.getValue(), + e + ); + } + } + + String inferencePipelineName = "_semantic_text_inference_" + current.getIndex().getName(); + Pipeline inferencePipeline = new Pipeline( + inferencePipelineName, + "semantic text pipeline for index " + current.getIndex().getName(), + null, + null, + new CompoundProcessor(inferenceProcessors.toArray(new Processor[] {})) + ); + + return Optional.of(inferencePipeline); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java new file mode 100644 index 0000000000000..f795a727fea67 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.CompoundProcessor; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.ingest.WrappingProcessor; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +public class SemanticTextInferenceProcessor extends AbstractProcessor implements WrappingProcessor { + + public static final String TYPE = "semanticTextInference"; + public static final String TAG = "semantic_text"; + + private final Map> fieldsForModels; + + private final Processor wrappedProcessor; + + private final Client client; + private final InferenceAuditor inferenceAuditor; + + public SemanticTextInferenceProcessor( + Client client, + InferenceAuditor inferenceAuditor, + String description, + Map> fieldsForModels + ) { + super(TAG, description); + this.client = client; + this.inferenceAuditor = inferenceAuditor; + + this.fieldsForModels = fieldsForModels; + this.wrappedProcessor = createWrappedProcessor(); + } + + private Processor createWrappedProcessor() { + InferenceProcessor[] inferenceProcessors = fieldsForModels.entrySet() + .stream() + .map(e -> createInferenceProcessor(e.getKey(), e.getValue())) + .toArray(InferenceProcessor[]::new); + return new CompoundProcessor(inferenceProcessors); + } + + private InferenceProcessor createInferenceProcessor(String modelId, List fields) { + List inputConfigs = fields.stream() + .map(f -> new InferenceProcessor.Factory.InputConfig(f, "ml.inference", f, Map.of())) + .toList(); + + return InferenceProcessor.fromInputFieldConfiguration(client, inferenceAuditor, tag, "inference processor for semantic text", modelId, + TextExpansionConfigUpdate.EMPTY_UPDATE, inputConfigs, false); + } + + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + getInnerProcessor().execute(ingestDocument, handler); + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return getInnerProcessor().execute(ingestDocument); + } + + @Override + public Processor getInnerProcessor() { + return wrappedProcessor; + } + + @Override + public String getType() { + return TYPE; + } +}