diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java index 805931550ad62..f527b4cd8d684 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java @@ -555,7 +555,7 @@ public static Map nodeMapValue(Object node, String desc) { if (node instanceof Map) { return (Map) node; } else { - throw new ElasticsearchParseException(desc + " should be a hash but was of type: " + node.getClass()); + throw new ElasticsearchParseException(desc + " should be a map but was of type: " + node.getClass()); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..f9354025cab49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1176,7 +1176,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1188,7 +1188,11 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public boolean hasConflicts() { + return conflicts.isEmpty() == false; + } + + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java index 903e4e5da5b29..da184d6f7a45e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java @@ -76,7 +76,7 @@ public CompressedXContent toCompressedXContent() { /** * Returns the root object for the current mapping */ - RootObjectMapper getRoot() { + public RootObjectMapper getRoot() { return root; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 54fe6e01946b4..586850eb948d3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -166,7 +166,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2a9c300e12c13..3fcd9049ae803 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,7 +55,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -284,7 +284,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fbf84762eb314..00dc195313a61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -38,7 +38,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +50,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { @@ -267,10 +267,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons Map newDocMap = indexRequest.sourceAsMap(); Map inferenceMap = new LinkedHashMap<>(); // ignore the existing inference map if any - newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { try { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceMap, fieldResponse.field(), fieldResponse.model(), @@ -295,6 +295,7 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); + boolean hasInput = false; for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { String field = entry.getKey(); String inferenceId = entry.getValue().inferenceId(); @@ -315,6 +316,7 @@ private Map> createFieldInferenceRequests(Bu if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + hasInput = true; } else { inferenceResults.get(item.id()).failures.add( new ElasticsearchStatusException( @@ -326,6 +328,12 @@ private Map> createFieldInferenceRequests(Bu ); } } + if (hasInput == false) { + // remove the existing _inference field (if present) since none of the content require inference. + if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { + indexRequest.source(docMap); + } + } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java new file mode 100644 index 0000000000000..9eeb7a5407bc4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -0,0 +1,449 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_inference": {
+ *             "my_semantic_text_field": {
+ *                  "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks" [
+ *                      {
+ *                          "inference": {
+ *                              "lucas": 0.05212344,
+ *                              "ty": 0.041213956,
+ *                              "dragon": 0.50991,
+ *                              "type": 0.23241979,
+ *                              "dr": 1.9312073,
+ *                              "##o": 0.2797593
+ *                          },
+ *                          "text": "these are not the droids you're looking for"
+ *                      }
+ *                  ]
+ *              }
+ *          }
+ *      }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_field": {
+ *                 "chunks": {
+ *                      "type": "nested",
+ *                      "properties": {
+ *                          "embedding": {
+ *                              "type": "sparse_vector|dense_vector"
+ *                          },
+ *                          "text": {
+ *                              "type": "keyword",
+ *                              "index": false,
+ *                              "doc_values": false
+ *                          }
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class InferenceMetadataFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String INFERENCE_ID = "inference_id"; + public static final String CHUNKS = "chunks"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); + + private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); + + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + public InferenceMetadataFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + boolean origWithLeafObject = context.path().isWithinLeafObject(); + try { + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); + parseSingleField(context); + } + } finally { + context.path().setWithinLeafObject(origWithLeafObject); + } + } + + private NestedObjectMapper updateSemanticTextFieldMapper( + DocumentParserContext docContext, + SemanticTextMapperContext semanticFieldContext, + String newInferenceId, + SemanticTextModelSettings newModelSettings, + XContentLocation xContentLocation + ) { + final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); + final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + if (newInferenceId.equals(inferenceId) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID, + inferenceId, + fullFieldName, + INFERENCE_ID, + newInferenceId + ) + ); + } + if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + fullFieldName + "] must contain dimensions" + ); + } + if (semanticFieldContext.mapper.getModelSettings() == null) { + SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( + semanticFieldContext.mapper.simpleName(), + docContext.indexSettings().getIndexVersionCreated() + ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); + docContext.addDynamicMapper(newMapper); + return newMapper.getSubMappers(); + } else { + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); + } + } + return semanticFieldContext.mapper.getSubMappers(); + } + + private void parseSingleField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); + SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); + if (builderContext == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + parser.nextToken(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + + // record the location of the inference field in the original source + XContentLocation xContentLocation = parser.getTokenLocation(); + // parse eagerly to extract the inference id and the model settings first + Map map = parser.mapOrdered(); + + // inference_id + Object inferenceIdObj = map.remove(INFERENCE_ID); + final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); + if (inferenceId == null) { + throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); + } + + // model_settings + Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); + if (modelSettingsObj == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format( + "Missing required [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + } + final SemanticTextModelSettings modelSettings; + try { + modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); + } catch (Exception exc) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "Error parsing [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ), + exc + ); + } + + var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); + + // we know the model settings, so we can (re) parse the results array now + XContentParser subParser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + DocumentParserContext mapContext = context.switchParser(subParser); + parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper); + } + + private void parseFieldInference( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + switch (parser.currentName()) { + case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper); + default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); + } + } + } + + private void parseChunks( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext subContext = context.createNestedContext(nestedMapper); + parseResultsObject(xContentLocation, parser, subContext, nestedMapper); + } + } + + private void parseResultsObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + Set visited = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); + visited.add(parser.currentName()); + FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); + if (fieldMapper == null) { + if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { + throw new DocumentParsingException( + xContentLocation, + "Missing sub-fields definition for [" + parser.currentName() + "]" + ); + } else { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } + } + parser.nextToken(); + fieldMapper.parse(context); + } + if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() + .filter(s -> visited.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); + } + } + + private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); + } + } + + private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(CHUNKS, chunks); + inferenceMap.put(field, fieldMap); + } + + record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} + + /** + * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} + * and the {@link MapperBuilderContext} that was used to build it. + * If the field is not found or is of the wrong type, this method returns {@code null}. + */ + static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { + ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); + return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\.")); + } + + static SemanticTextMapperContext createSemanticFieldContext( + MapperBuilderContext mapperContext, + ObjectMapper objectMapper, + String[] paths + ) { + Mapper mapper = objectMapper.getMapper(paths[0]); + if (mapper instanceof ObjectMapper newObjectMapper) { + mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); + return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length)); + } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } else { + if (mapper == null || paths.length == 1) { + return null; + } + // check if the semantic field is defined within a multi-field + Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); + if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } + } + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index 2ede5419ab74e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 83272a10f98d4..2445d5c8751a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,30 +9,50 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; /** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * A {@link FieldMapper} for semantic text fields. + * These fields have a reference id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceResultFieldMapper}. + * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -40,15 +60,39 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); + + private final IndexVersion indexVersionCreated; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + CopyTo copyTo, + IndexVersion indexVersionCreated, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers + ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.indexVersionCreated = indexVersionCreated; + this.modelSettings = modelSettings; + this.subMappers = subMappers; + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(subMappers); + return subIterators.iterator(); } @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); + return new Builder(simpleName(), indexVersionCreated).init(this); } @Override @@ -67,39 +111,100 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; + } + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; - private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) - .addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + m -> toType(m).fieldType().inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (n, c, o) -> SemanticTextModelSettings.fromMap(o), + mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, + XContentBuilder::field, + (m) -> m == null ? "null" : Strings.toString(m) + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - public Builder(String name) { + public Builder(String name, IndexVersion indexVersionCreated) { super(name); + this.indexVersionCreated = indexVersionCreated; + } + + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextModelSettings value) { + this.modelSettings.setValue(value); + return this; } @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; + return new Parameter[] { inferenceId, modelSettings, meta }; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + final String fullName = context.buildFullName(name()); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); + nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) + .indexed(false) + .docValues(false); + if (modelSettings.get() != null) { + nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + } + nestedBuilder.add(textMapperBuilder); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + var subMappers = nestedBuilder.build(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), + copyTo, + indexVersionCreated, + modelSettings.getValue(), + subMappers + ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private final String inferenceId; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private final String modelId; - - public SemanticTextFieldType(String name, String modelId, Map meta) { + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers, + Map meta + ) { super(name, false, false, false, TextSearchInfo.NONE, meta); - this.modelId = modelId; + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.subMappers = subMappers; } @Override @@ -109,7 +214,15 @@ public String typeName() { @Override public String getInferenceId() { - return modelId; + return inferenceId; + } + + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; } @Override @@ -127,4 +240,59 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return super.syntheticFieldLoader(); + } + + private static Mapper.Builder createInferenceMapperBuilder( + String fieldName, + SemanticTextModelSettings modelSettings, + IndexVersion indexVersionCreated + ) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException( + "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() + ); + }; + } + + static boolean canMergeModelSettings( + SemanticTextModelSettings previous, + SemanticTextModelSettings current, + FieldMapper.Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + if (current == null) { + conflicts.addConflict("model_settings", ""); + return false; + } + conflicts.addConflict("model_settings", ""); + return false; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 1b6bb22c0d6b5..b1d0511008db8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,73 +7,100 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; + /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ -public class SemanticTextModelSettings { +public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); private final TaskType taskType; - private final String inferenceId; private final Integer dimensions; private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public SemanticTextModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; - this.inferenceId = inferenceId; this.dimensions = dimensions; this.similarity = similarity; - } - - public SemanticTextModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); + validate(); } public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); - }); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new SemanticTextModelSettings(taskType, dimensions, similarity); + } + ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } + public static SemanticTextModelSettings fromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, NAME); + if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { + throw new IllegalArgumentException( + "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" + ); + } + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + public Map asMap() { Map attrsMap = new HashMap<>(); attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); if (dimensions != null) { attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); } @@ -87,10 +114,6 @@ public TaskType taskType() { return taskType; } - public String inferenceId() { - return inferenceId; - } - public Integer dimensions() { return dimensions; } @@ -98,4 +121,61 @@ public Integer dimensions() { public SimilarityMeasure similarity() { return similarity; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + public void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticTextModelSettings that = (SemanticTextModelSettings) o; + return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, dimensions, similarity); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index a7d3fcce26116..bf3cc6334433a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -31,7 +31,7 @@ protected Collection> getPlugins() { public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", - client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); assertEquals( indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), @@ -46,7 +46,7 @@ public void testAddSemanticTextField() throws Exception { final ClusterService clusterService = getInstanceFromNode(ClusterService.class); final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { "properties": { "field": { "type": "semantic_text", "model_id": "test_model" }}}"""); + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); request.indices(new Index[] { indexService.index() }); final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( clusterService.state(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4a1825303b5a7..8b18cf74236a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -32,7 +32,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +51,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -285,7 +285,7 @@ private static BulkItemRequest[] randomBulkItemRequest( final ChunkedInferenceServiceResults results; switch (taskType) { case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); + results = randomTextEmbeddings(model, chunks); break; case SPARSE_EMBEDDING: @@ -296,10 +296,10 @@ private static BulkItemRequest[] randomBulkItemRequest( throw new AssertionError("Unknown task type " + taskType.name()); } model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java similarity index 57% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b5d75b528c6ab..37e4e5e774bec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -51,26 +53,28 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} +public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int numChunks) {} + private record VisitedChildDocInfo(String path) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return InferenceResultFieldMapper.NAME; + return InferenceMetadataFieldMapper.NAME; } @Override @@ -94,109 +98,129 @@ protected Collection getPlugins() { } public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) + ) ) ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) ); - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) ); - assertEquals(0, topDocs.totalHits.value); - } - }); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a", "b") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("d") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("z") + ), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } } public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); { DocumentParsingException ex = expectThrows( @@ -206,7 +230,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, Map.of() @@ -224,7 +248,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, Map.of() @@ -242,7 +266,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, Map.of() @@ -259,15 +283,18 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -358,28 +385,97 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults( + fieldName, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + List.of("a b") + ) + ) + ) + ) ) ); assertThat( ex.getMessage(), containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ) ); } + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .startObject(SemanticTextModelSettings.NAME) + .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .startObject(SemanticTextModelSettings.NAME) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); + } + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("model_id", modelId); + mappingBuilder.field("inference_id", modelId); mappingBuilder.endObject(); } - public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { List chunks = new ArrayList<>(); for (String input : inputs) { - double[] values = new double[5]; + double[] values = new double[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = randomDouble(); } @@ -400,8 +496,17 @@ public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List return new ChunkedSparseEmbeddingResults(chunks); } - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults( + String semanticTextFieldName, + Model model, + List chunks + ) { + ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); + default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); + }; + return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); } private static void addSemanticTextInferenceResults( @@ -425,16 +530,16 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceResultsMap, semanticTextInferenceResult.fieldName, - randomModel(), + semanticTextInferenceResult.model, semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(RESULTS); + List> fieldResultList = (List>) optionsMap.get(CHUNKS); for (var entry : fieldResultList) { if (includeTextSubfield == false) { entry.remove(INFERENCE_CHUNKS_TEXT); @@ -445,15 +550,26 @@ private static void addSemanticTextInferenceResults( entry.putAll(extraSubfields); } } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); } - private static Model randomModel() { + private static Model randomModel(TaskType taskType) { String serviceName = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomAlphaOfLengthBetween(5, 10); return new TestModel( inferenceId, - TaskType.SPARSE_EMBEDDING, + taskType, serviceName, new TestModel.TestServiceSettings("my-model"), new TestModel.TestTaskSettings(randomIntBetween(1, 100)), @@ -461,29 +577,6 @@ private static Model randomModel() { ); } - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); assertNotNull(mapper); @@ -503,12 +596,10 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook private static void assertValidChildDoc( LuceneDocument childDoc, LuceneDocument expectedParent, - Set visitedChildDocs + Collection visitedChildDocs ) { assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); + visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); } private static void assertChildLeafNestedDocument( @@ -524,4 +615,15 @@ private static void assertChildLeafNestedDocument( assertNotNull(leaf.nestedIdentity()); visitedNestedIdentities.add(leaf.nestedIdentity()); } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..1b5311ac9effb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -11,11 +11,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -26,52 +32,12 @@ import java.util.List; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { - - public void testDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - public void testModelIdNotPresent() throws IOException { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) - ); - assertThat(e.getMessage(), containsString("field [model_id] must be specified")); - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToModelIdNotSupported() throws IOException { - MapperService mapperService = createMapperService( - fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) - ); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); - } - @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); @@ -79,7 +45,12 @@ protected Collection getPlugins() { @Override protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("model_id", "test_model"); + b.field("type", "semantic_text").field("inference_id", "test_model"); + } + + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; } @Override @@ -115,4 +86,180 @@ protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .field("similarity", "cosine") + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\"}] " + + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( + MapperBuilderContext.root(false, false), + mapperService.mappingLookup().getMapping().getRoot(), + fieldName.split("\\.") + ); + Mapper mapper = res.mapper(); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); + assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); + + NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); + assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); + Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.getModelSettings()); + Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(inferenceMapper); + switch (semanticFieldMapper.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.getModelSettings()); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..b64485a3d3fb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -121,6 +122,16 @@ public void writeTo(StreamOutput out) throws IOException { public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.COSINE; + } + + @Override + public Integer dimensions() { + return 100; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 6008ebbcbedf8..528003e278aeb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -27,6 +27,7 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, + "similarity": "cosine", "api_key": "abc64" }, "task_settings": { @@ -41,10 +42,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -56,10 +57,10 @@ setup: properties: inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id another_inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -83,11 +84,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +121,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- @@ -154,8 +155,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +175,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +215,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +234,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -247,10 +248,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -271,11 +272,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -287,7 +288,7 @@ setup: properties: inference_field: type: semantic_text - model_id: non-existing-inference-id + inference_id: non-existing-inference-id non_inference_field: type: text diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 2c69f49218091..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -27,7 +27,8 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, - "api_key": "abc64" + "api_key": "abc64", + "similarity": "cosine" }, "task_settings": { } @@ -41,10 +42,10 @@ setup: properties: sparse_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id dense_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -55,25 +56,7 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 - - text: "another inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 + sparse_field: "you know, for testing" --- "Dense vector results format": @@ -82,72 +65,4 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - dimensions: 5 - similarity: cosine - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] - ---- -"Model settings inference id not included": - - do: - catch: /Required \[inference_id\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings task type not included": - - do: - catch: /Required \[task_type\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings dense vector dimensions not included": - - do: - catch: /Model settings for field \[dense_field\] must contain dimensions/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + dense_field: "you know, for testing"