diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java new file mode 100644 index 0000000000000..490d7f36219cf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceModelFieldType.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +/** + * Field type that uses an inference model. + */ +public interface InferenceModelFieldType { + /** + * Retrieve inference model used by the field type. + * + * @return model id used by the field type + */ + String getInferenceModel(); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java index 7747461a6f93a..a383004c12878 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java @@ -448,7 +448,10 @@ public void onIndexModule(IndexModule indexModule) { @Override public Function> getFieldFilter() { - List>> items = filterPlugins(MapperPlugin.class).stream().map(p -> p.getFieldFilter()).toList(); + List>> items = filterPlugins(MapperPlugin.class).stream() + .map(p -> p.getFieldFilter()) + .filter(p -> p.equals(NOOP_FIELD_FILTER) == false) + .toList(); if (items.size() > 1) { throw new UnsupportedOperationException("Only the security MapperPlugin should override this"); } else if (items.size() == 1) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 749a31de51b07..1031d45facf85 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -46,6 +46,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -65,6 +66,7 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; @@ -359,6 +361,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; +import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -476,7 +479,8 @@ public class MachineLearning extends Plugin PersistentTaskPlugin, SearchPlugin, ShutdownAwarePlugin, - ExtensiblePlugin { + ExtensiblePlugin, + MapperPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; // Endpoints that were deprecated in 7.x can still be called in 8.x using the REST compatibility layer @@ -2288,4 +2292,12 @@ public void signalShutdown(Collection shutdownNodeIds) { mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds); } } + + @Override + public Map getMappers() { + if (SemanticTextFeature.isEnabled()) { + return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + } + return Map.of(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java new file mode 100644 index 0000000000000..f861760803e56 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/SemanticTextFeature.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * semantic_text feature flag. When the feature is complete, this flag will be removed. + */ +public class SemanticTextFeature { + + private SemanticTextFeature() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("semantic_text"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java new file mode 100644 index 0000000000000..cf713546a071a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; + +import java.io.IOException; +import java.util.Map; + +/** + * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference + * at ingestion and query time. + * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will + * be indexed using a different field mapper. + */ +public class SemanticTextFieldMapper extends FieldMapper { + + public static final String CONTENT_TYPE = "semantic_text"; + + private static SemanticTextFieldMapper toType(FieldMapper in) { + return (SemanticTextFieldMapper) in; + } + + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName()).init(this); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + // Just parses text - no indexing is performed + context.parser().textOrNull(); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } + + public static class Builder extends FieldMapper.Builder { + + private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) + .addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [model_id] must be specified"); + } + }); + + private final Parameter> meta = Parameter.metaParam(); + + public Builder(String name) { + super(name); + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { modelId, meta }; + } + + @Override + public SemanticTextFieldMapper build(MapperBuilderContext context) { + return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + } + } + + public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + + private final String modelId; + + public SemanticTextFieldType(String name, String modelId, Map meta) { + super(name, false, false, false, TextSearchInfo.NONE, meta); + this.modelId = modelId; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public String getInferenceModel() { + return modelId; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException("termQuery not implemented yet"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.toString(name(), context, format); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { + throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java new file mode 100644 index 0000000000000..ccb8f106e4945 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.index.IndexableField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperParsingException; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.junit.AssumptionViolatedException; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.hamcrest.Matchers.containsString; + +public class SemanticTextFieldMapperTests extends MapperTestCase { + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testModelIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [model_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToModelIdNotSupported() throws IOException { + MapperService mapperService = createMapperService( + fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) + ); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); + } + + @Override + protected Collection getPlugins() { + return singletonList(new MachineLearning(Settings.EMPTY)); + } + + @Override + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "semantic_text").field("model_id", "test_model"); + } + + @Override + protected Object getSampleValueForDocument() { + return "value"; + } + + @Override + protected boolean supportsIgnoreMalformed() { + return false; + } + + @Override + protected boolean supportsStoredFields() { + return false; + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException {} + + @Override + protected Object generateRandomInputValue(MappedFieldType ft) { + assumeFalse("doc_values are not supported in semantic_text", true); + return null; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { + throw new AssumptionViolatedException("not supported"); + } + + @Override + protected IngestScriptSupport ingestScriptSupport() { + throw new AssumptionViolatedException("not supported"); + } +}