Skip to content

Commit

Permalink
Semantic text - field mapper (elastic#102971)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest authored Dec 12, 2023
1 parent e1835c9 commit 8ee4721
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.index.mapper;

/**
* Field type that uses an inference model.
*/
public interface InferenceModelFieldType {
/**
* Retrieve inference model used by the field type.
*
* @return model id used by the field type
*/
String getInferenceModel();
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,10 @@ public void onIndexModule(IndexModule indexModule) {

@Override
public Function<String, Predicate<String>> getFieldFilter() {
List<Function<String, Predicate<String>>> items = filterPlugins(MapperPlugin.class).stream().map(p -> p.getFieldFilter()).toList();
List<Function<String, Predicate<String>>> items = filterPlugins(MapperPlugin.class).stream()
.map(p -> p.getFieldFilter())
.filter(p -> p.equals(NOOP_FIELD_FILTER) == false)
.toList();
if (items.size() > 1) {
throw new UnsupportedOperationException("Only the security MapperPlugin should override this");
} else if (items.size() == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.elasticsearch.env.Environment;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.indices.AssociatedIndexDescriptor;
import org.elasticsearch.indices.SystemIndexDescriptor;
Expand All @@ -65,6 +66,7 @@
import org.elasticsearch.plugins.CircuitBreakerPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
Expand Down Expand Up @@ -359,6 +361,7 @@
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down Expand Up @@ -476,7 +479,8 @@ public class MachineLearning extends Plugin
PersistentTaskPlugin,
SearchPlugin,
ShutdownAwarePlugin,
ExtensiblePlugin {
ExtensiblePlugin,
MapperPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
// Endpoints that were deprecated in 7.x can still be called in 8.x using the REST compatibility layer
Expand Down Expand Up @@ -2288,4 +2292,12 @@ public void signalShutdown(Collection<String> shutdownNodeIds) {
mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds);
}
}

@Override
public Map<String, Mapper.TypeParser> getMappers() {
if (SemanticTextFeature.isEnabled()) {
return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER);
}
return Map.of();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml;

import org.elasticsearch.common.util.FeatureFlag;

/**
* semantic_text feature flag. When the feature is complete, this flag will be removed.
*/
public class SemanticTextFeature {

private SemanticTextFeature() {}

private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("semantic_text");

public static boolean isEnabled() {
return FEATURE_FLAG.isEnabled();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.InferenceModelFieldType;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.query.SearchExecutionContext;

import java.io.IOException;
import java.util.Map;

/**
* A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference
* at ingestion and query time.
* For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well.
* This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will
* be indexed using a different field mapper.
*/
public class SemanticTextFieldMapper extends FieldMapper {

public static final String CONTENT_TYPE = "semantic_text";

private static SemanticTextFieldMapper toType(FieldMapper in) {
return (SemanticTextFieldMapper) in;
}

public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE));

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) {
super(simpleName, mappedFieldType, MultiFields.empty(), copyTo);
}

@Override
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName()).init(this);
}

@Override
protected void parseCreateField(DocumentParserContext context) throws IOException {
// Just parses text - no indexing is performed
context.parser().textOrNull();
}

@Override
protected String contentType() {
return CONTENT_TYPE;
}

@Override
public SemanticTextFieldType fieldType() {
return (SemanticTextFieldType) super.fieldType();
}

public static class Builder extends FieldMapper.Builder {

private final Parameter<String> modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null)
.addValidator(v -> {
if (Strings.isEmpty(v)) {
throw new IllegalArgumentException("field [model_id] must be specified");
}
});

private final Parameter<Map<String, String>> meta = Parameter.metaParam();

public Builder(String name) {
super(name);
}

@Override
protected Parameter<?>[] getParameters() {
return new Parameter<?>[] { modelId, meta };
}

@Override
public SemanticTextFieldMapper build(MapperBuilderContext context) {
return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo);
}
}

public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType {

private final String modelId;

public SemanticTextFieldType(String name, String modelId, Map<String, String> meta) {
super(name, false, false, false, TextSearchInfo.NONE, meta);
this.modelId = modelId;
}

@Override
public String typeName() {
return CONTENT_TYPE;
}

@Override
public String getInferenceModel() {
return modelId;
}

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException("termQuery not implemented yet");
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.toString(name(), context, format);
}

@Override
public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) {
throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.mapper;

import org.apache.lucene.index.IndexableField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MapperTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.junit.AssumptionViolatedException;

import java.io.IOException;
import java.util.Collection;
import java.util.List;

import static java.util.Collections.singletonList;
import static org.hamcrest.Matchers.containsString;

public class SemanticTextFieldMapperTests extends MapperTestCase {

public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString());

ParsedDocument doc1 = mapper.parse(source(this::writeField));
List<IndexableField> fields = doc1.rootDoc().getFields("field");

// No indexable fields
assertTrue(fields.isEmpty());
}

public void testModelIdNotPresent() throws IOException {
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text")))
);
assertThat(e.getMessage(), containsString("field [model_id] must be specified"));
}

public void testCannotBeUsedInMultiFields() {
Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
b.field("type", "text");
b.startObject("fields");
b.startObject("semantic");
b.field("type", "semantic_text");
b.endObject();
b.endObject();
})));
assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields"));
}

public void testUpdatesToModelIdNotSupported() throws IOException {
MapperService mapperService = createMapperService(
fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model"))
);
Exception e = expectThrows(
IllegalArgumentException.class,
() -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model")))
);
assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]"));
}

@Override
protected Collection<? extends Plugin> getPlugins() {
return singletonList(new MachineLearning(Settings.EMPTY));
}

@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "semantic_text").field("model_id", "test_model");
}

@Override
protected Object getSampleValueForDocument() {
return "value";
}

@Override
protected boolean supportsIgnoreMalformed() {
return false;
}

@Override
protected boolean supportsStoredFields() {
return false;
}

@Override
protected void registerParameters(ParameterChecker checker) throws IOException {}

@Override
protected Object generateRandomInputValue(MappedFieldType ft) {
assumeFalse("doc_values are not supported in semantic_text", true);
return null;
}

@Override
protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) {
throw new AssumptionViolatedException("not supported");
}

@Override
protected IngestScriptSupport ingestScriptSupport() {
throw new AssumptionViolatedException("not supported");
}
}

0 comments on commit 8ee4721

Please sign in to comment.