Skip to content

Commit

Permalink
Field mapper uses new inference results structure
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 14, 2024
1 parent 896ec49 commit af763f0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,16 @@ protected Parameter<?>[] getParameters() {
return new Parameter<?>[] { elementType, dims, indexed, similarity, indexOptions, meta };
}

public Builder similarity(VectorSimilarity vectorSimilarity) {
similarity.setValue(vectorSimilarity);
return this;
}

public Builder dimensions(int dimensions) {
this.dims.setValue(dimensions);
return this;
}

@Override
public DenseVectorFieldMapper build(MapperBuilderContext context) {
return new DenseVectorFieldMapper(
Expand Down Expand Up @@ -724,7 +734,7 @@ static Function<StringBuilder, StringBuilder> errorByteElementsAppender(byte[] v
ElementType.FLOAT
);

enum VectorSimilarity {
public enum VectorSimilarity {
L2_NORM {
@Override
float score(float similarity, ElementType elementType, int dim) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
package org.elasticsearch.xpack.inference.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.BooleanFieldMapper;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.DocumentParsingException;
import org.elasticsearch.index.mapper.FieldMapper;
Expand All @@ -25,25 +25,25 @@
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.script.ScriptCompiler;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD;

/**
* A mapper for the {@code _semantic_text_inference} field.
Expand Down Expand Up @@ -102,16 +102,13 @@
*/
public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper {
public static final String CONTENT_TYPE = "_semantic_text_inference";
public static final String NAME = "_semantic_text_inference";
public static final String NAME = ROOT_INFERENCE_FIELD;
public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper());

private static final Map<List<String>, Set<String>> REQUIRED_SUBFIELDS_MAP = Map.of(
List.of(),
Set.of(SPARSE_VECTOR_SUBFIELD_NAME, TEXT_SUBFIELD_NAME)
);

private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class);

private static final Set<String> REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS);

static class SemanticTextInferenceFieldType extends MappedFieldType {
private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType();

Expand Down Expand Up @@ -142,75 +139,86 @@ private SemanticTextInferenceResultFieldMapper() {
@Override
protected void parseCreateField(DocumentParserContext context) throws IOException {
XContentParser parser = context.parser();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken());
}
failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT);

parseInferenceResults(context);
parseAllFields(context);
}

private static void parseInferenceResults(DocumentParserContext context) throws IOException {
private static void parseAllFields(DocumentParserContext context) throws IOException {
XContentParser parser = context.parser();
MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false);
for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + token);
}
failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME);

parseFieldInferenceResults(context, mapperBuilderContext);
parseSingleField(context, mapperBuilderContext);
}
}

private static void parseFieldInferenceResults(DocumentParserContext context, MapperBuilderContext mapperBuilderContext)
throws IOException {
private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException {

String fieldName = context.parser().currentName();
XContentParser parser = context.parser();
String fieldName = parser.currentName();
Mapper mapper = context.getMapper(fieldName);
if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) {
throw new DocumentParsingException(
context.parser().getTokenLocation(),
parser.getTokenLocation(),
Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE)
);
}
parser.nextToken();
failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT);
parser.nextToken();
ModelSettings modelSettings = ModelSettings.parse(parser);
for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME);

parseFieldInferenceResultsArray(context, mapperBuilderContext, fieldName);
String currentName = parser.currentName();
if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) {
NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper(
context,
mapperBuilderContext,
fieldName,
modelSettings
);
parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper);
} else {
logger.debug("Skipping unrecognized field name [" + currentName + "]");
advancePastCurrentFieldName(parser);
}
}
}

private static void parseFieldInferenceResultsArray(
private static void parseFieldInferenceChunks(
DocumentParserContext context,
MapperBuilderContext mapperBuilderContext,
String fieldName
String fieldName,
ModelSettings modelSettings,
NestedObjectMapper nestedObjectMapper
) throws IOException {
XContentParser parser = context.parser();
NestedObjectMapper nestedObjectMapper = createNestedObjectMapper(context, mapperBuilderContext, fieldName);

if (parser.nextToken() != XContentParser.Token.START_ARRAY) {
throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_ARRAY, got " + parser.currentToken());
}
parser.nextToken();
failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY);

for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) {
DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper);
parseFieldInferenceResultElement(nestedContext, nestedObjectMapper, new LinkedList<>());
parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings);
}
}

private static void parseFieldInferenceResultElement(
private static void parseFieldInferenceChunkElement(
DocumentParserContext context,
ObjectMapper objectMapper,
LinkedList<String> subfieldPath
ModelSettings modelSettings
) throws IOException {
XContentParser parser = context.parser();
DocumentParserContext childContext = context.createChildContext(objectMapper);

if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_OBJECT, got " + parser.currentToken());
}
failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT);

Set<String> visitedSubfields = new HashSet<>();
for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new DocumentParsingException(parser.getTokenLocation(), "Expected a FIELD_NAME, got " + parser.currentToken());
}
failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME);

String currentName = parser.currentName();
visitedSubfields.add(currentName);
Expand All @@ -222,14 +230,9 @@ private static void parseFieldInferenceResultElement(
continue;
}

if (childMapper instanceof FieldMapper) {
if (childMapper instanceof FieldMapper fieldMapper) {
parser.nextToken();
((FieldMapper) childMapper).parse(childContext);
} else if (childMapper instanceof ObjectMapper) {
parser.nextToken();
subfieldPath.push(currentName);
parseFieldInferenceResultElement(childContext, (ObjectMapper) childMapper, subfieldPath);
subfieldPath.pop();
fieldMapper.parse(childContext);
} else {
// This should never happen, but fail parsing if it does so that it's not a silent failure
throw new DocumentParsingException(
Expand All @@ -239,29 +242,51 @@ private static void parseFieldInferenceResultElement(
}
}

Set<String> requiredSubfields = REQUIRED_SUBFIELDS_MAP.get(subfieldPath);
if (requiredSubfields != null && visitedSubfields.containsAll(requiredSubfields) == false) {
Set<String> missingSubfields = requiredSubfields.stream()
if (REQUIRED_SUBFIELDS.containsAll(visitedSubfields) == false) {
Set<String> missingSubfields = REQUIRED_SUBFIELDS.stream()
.filter(s -> visitedSubfields.contains(s) == false)
.collect(Collectors.toSet());
throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields);
}
}

private static NestedObjectMapper createNestedObjectMapper(
private static NestedObjectMapper createInferenceResultsObjectMapper(
DocumentParserContext context,
MapperBuilderContext mapperBuilderContext,
String fieldName
String fieldName,
ModelSettings modelSettings
) {
IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated();
ObjectMapper.Builder sparseVectorMapperBuilder = new ObjectMapper.Builder(
SPARSE_VECTOR_SUBFIELD_NAME,
ObjectMapper.Defaults.SUBOBJECTS
).add(
new BooleanFieldMapper.Builder(SparseEmbeddingResults.Embedding.IS_TRUNCATED, ScriptCompiler.NONE, false, indexVersionCreated)
).add(new SparseVectorFieldMapper.Builder(SparseEmbeddingResults.Embedding.EMBEDDING));
FieldMapper.Builder resultsBuilder;
if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) {
resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS);
} else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) {
DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder(
INFERENCE_CHUNKS_RESULTS,
indexVersionCreated
);
SimilarityMeasure similarity = modelSettings.similarity();
if (similarity != null) {
switch (similarity) {
case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE);
case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT);
default -> throw new IllegalArgumentException(
"Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity
);
}
}
Integer dimensions = modelSettings.dimensions();
if (dimensions == null) {
throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions");
}
denseVectorMapperBuilder.dimensions(dimensions);
resultsBuilder = denseVectorMapperBuilder;
} else {
throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType());
}

TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder(
TEXT_SUBFIELD_NAME,
INFERENCE_CHUNKS_TEXT,
indexVersionCreated,
context.indexAnalyzers()
).index(false).store(false);
Expand All @@ -270,7 +295,7 @@ private static NestedObjectMapper createNestedObjectMapper(
fieldName,
context.indexSettings().getIndexVersionCreated()
);
nestedBuilder.add(sparseVectorMapperBuilder).add(textMapperBuilder);
nestedBuilder.add(resultsBuilder).add(textMapperBuilder);

return nestedBuilder.build(mapperBuilderContext);
}
Expand All @@ -286,6 +311,15 @@ private static void advancePastCurrentFieldName(XContentParser parser) throws IO
}
}

private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) {
if (parser.currentToken() != expected) {
throw new DocumentParsingException(
parser.getTokenLocation(),
"Expected a " + expected.toString() + ", got " + parser.currentToken()
);
}
}

@Override
protected String contentType() {
return CONTENT_TYPE;
Expand Down

0 comments on commit af763f0

Please sign in to comment.