diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index d36ca9e0b25c1..33815d9240d3a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -210,6 +210,16 @@ protected Parameter[] getParameters() { return new Parameter[] { elementType, dims, indexed, similarity, indexOptions, meta }; } + public Builder similarity(VectorSimilarity vectorSimilarity) { + similarity.setValue(vectorSimilarity); + return this; + } + + public Builder dimensions(int dimensions) { + this.dims.setValue(dimensions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { return new DenseVectorFieldMapper( @@ -724,7 +734,7 @@ static Function errorByteElementsAppender(byte[] v ElementType.FLOAT ); - enum VectorSimilarity { + public enum VectorSimilarity { L2_NORM { @Override float score(float similarity, ElementType elementType, int dim) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java index 9e6c1eb0a6586..cc0257adb0a68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; @@ -25,25 +25,25 @@ import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ModelSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; /** * A mapper for the {@code _semantic_text_inference} field. @@ -102,16 +102,13 @@ */ public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = "_semantic_text_inference"; + public static final String NAME = ROOT_INFERENCE_FIELD; public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); - private static final Map, Set> REQUIRED_SUBFIELDS_MAP = Map.of( - List.of(), - Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME) - ); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + static class SemanticTextInferenceFieldType extends MappedFieldType { private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); @@ -142,75 +139,86 @@ private SemanticTextInferenceResultFieldMapper() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parseInferenceResults(context); + parseAllFields(context); } - private static void parseInferenceResults(DocumentParserContext context) throws IOException { + private static void parseAllFields(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + token); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResults(context, mapperBuilderContext); + parseSingleField(context, mapperBuilderContext); } } - private static void parseFieldInferenceResults(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) - throws IOException { + private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - String fieldName = context.parser().currentName(); + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); Mapper mapper = context.getMapper(fieldName); if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new DocumentParsingException( - context.parser().getTokenLocation(), + parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ); } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); + parser.nextToken(); + ModelSettings modelSettings = ModelSettings.parse(parser); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - parseFieldInferenceResultsArray(context, mapperBuilderContext, fieldName); + String currentName = parser.currentName(); + if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( + context, + mapperBuilderContext, + fieldName, + modelSettings + ); + parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); + } else { + logger.debug("Skipping unrecognized field name [" + currentName + "]"); + advancePastCurrentFieldName(parser); + } + } } - private static void parseFieldInferenceResultsArray( + private static void parseFieldInferenceChunks( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings, + NestedObjectMapper nestedObjectMapper ) throws IOException { XContentParser parser = context.parser(); - NestedObjectMapper nestedObjectMapper = createNestedObjectMapper(context, mapperBuilderContext, fieldName); - if (parser.nextToken() != XContentParser.Token.START_ARRAY) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_ARRAY, got " + parser.currentToken()); - } + parser.nextToken(); + failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceResultElement(nestedContext, nestedObjectMapper, new LinkedList<>()); + parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); } } - private static void parseFieldInferenceResultElement( + private static void parseFieldInferenceChunkElement( DocumentParserContext context, ObjectMapper objectMapper, - LinkedList subfieldPath + ModelSettings modelSettings ) throws IOException { XContentParser parser = context.parser(); DocumentParserContext childContext = context.createChildContext(objectMapper); - if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); Set visitedSubfields = new HashSet<>(); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + parser.currentToken()); - } + failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); visitedSubfields.add(currentName); @@ -222,14 +230,9 @@ private static void parseFieldInferenceResultElement( continue; } - if (childMapper instanceof FieldMapper) { + if (childMapper instanceof FieldMapper fieldMapper) { parser.nextToken(); - ((FieldMapper) childMapper).parse(childContext); - } else if (childMapper instanceof ObjectMapper) { - parser.nextToken(); - subfieldPath.push(currentName); - parseFieldInferenceResultElement(childContext, (ObjectMapper) childMapper, subfieldPath); - subfieldPath.pop(); + fieldMapper.parse(childContext); } else { // This should never happen, but fail parsing if it does so that it's not a silent failure throw new DocumentParsingException( @@ -239,29 +242,51 @@ private static void parseFieldInferenceResultElement( } } - Set requiredSubfields = REQUIRED_SUBFIELDS_MAP.get(subfieldPath); - if (requiredSubfields != null && visitedSubfields.containsAll(requiredSubfields) == false) { - Set missingSubfields = requiredSubfields.stream() + if (REQUIRED_SUBFIELDS.containsAll(visitedSubfields) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() .filter(s -> visitedSubfields.contains(s) == false) .collect(Collectors.toSet()); throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); } } - private static NestedObjectMapper createNestedObjectMapper( + private static NestedObjectMapper createInferenceResultsObjectMapper( DocumentParserContext context, MapperBuilderContext mapperBuilderContext, - String fieldName + String fieldName, + ModelSettings modelSettings ) { IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - ObjectMapper.Builder sparseVectorMapperBuilder = new ObjectMapper.Builder( - SPARSE_VECTOR_SUBFIELD_NAME, - ObjectMapper.Defaults.SUBOBJECTS - ).add( - new BooleanFieldMapper.Builder(SparseEmbeddingResults.Embedding.IS_TRUNCATED, ScriptCompiler.NONE, false, indexVersionCreated) - ).add(new SparseVectorFieldMapper.Builder(SparseEmbeddingResults.Embedding.EMBEDDING)); + FieldMapper.Builder resultsBuilder; + if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { + resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + Integer dimensions = modelSettings.dimensions(); + if (dimensions == null) { + throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); + } + denseVectorMapperBuilder.dimensions(dimensions); + resultsBuilder = denseVectorMapperBuilder; + } else { + throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); + } + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - TEXT_SUBFIELD_NAME, + INFERENCE_CHUNKS_TEXT, indexVersionCreated, context.indexAnalyzers() ).index(false).store(false); @@ -270,7 +295,7 @@ private static NestedObjectMapper createNestedObjectMapper( fieldName, context.indexSettings().getIndexVersionCreated() ); - nestedBuilder.add(sparseVectorMapperBuilder).add(textMapperBuilder); + nestedBuilder.add(resultsBuilder).add(textMapperBuilder); return nestedBuilder.build(mapperBuilderContext); } @@ -286,6 +311,15 @@ private static void advancePastCurrentFieldName(XContentParser parser) throws IO } } + private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException( + parser.getTokenLocation(), + "Expected a " + expected.toString() + ", got " + parser.currentToken() + ); + } + } + @Override protected String contentType() { return CONTENT_TYPE;