diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index a9bc74579c165..19b3976ecc6e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,6 +49,7 @@ import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -364,6 +365,7 @@ import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -2284,13 +2286,13 @@ public Map getMappers() { ); } -// @Override -// public Map getMetadataMappers() { -// return Map.of( -// SemanticTextInferenceResultFieldMapper.CONTENT_TYPE, -// SemanticTextInferenceResultFieldMapper.PARSER -// ); -// } + @Override + public Map getMetadataMappers() { + return Map.of( + SemanticTextInferenceResultFieldMapper.CONTENT_TYPE, + SemanticTextInferenceResultFieldMapper.PARSER + ); + } @Override public Optional getIngestPipeline(IndexMetadata indexMetadata, Processor.Parameters parameters) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java index 30c253e46c690..9ac64ed0e7bcf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java @@ -15,19 +15,21 @@ import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; -import java.util.Map; +import java.util.HashSet; +import java.util.Set; public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { @@ -47,7 +49,6 @@ private static SemanticTextInferenceResultFieldMapper toType(FieldMapper in) { public static class SemanticTextInferenceFieldType extends MappedFieldType { public static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - private SparseVectorFieldType sparseVectorFieldType; public SemanticTextInferenceFieldType() { super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); @@ -65,7 +66,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format) @Override public Query termQuery(Object value, SearchExecutionContext context) { - return sparseVectorFieldType.termQuery(value, context); + return null; } } @@ -78,40 +79,82 @@ public void parse(DocumentParserContext context) throws IOException { if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { throw new IllegalArgumentException( - "[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: " + "[_semantic_text] produced inference must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken() ); } MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false).createChildContext(NAME); - // TODO Can we validate that semantic text fields have actual text values? for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() .nextToken()) { if (token != XContentParser.Token.FIELD_NAME) { - throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); + throw new IllegalArgumentException("[semantic_text] produced inference expect an object with field names, found " + token); } String fieldName = context.parser().currentName(); Mapper mapper = context.getMapper(fieldName); - if (mapper == null) { - // Not a field we have mapped? Must be model output, skip it - context.parser().nextToken(); - context.path().setWithinLeafObject(true); - Map fieldMap = context.parser().map(); - context.path().setWithinLeafObject(false); - continue; - } - if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + if ((mapper == null) || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { throw new IllegalArgumentException( "Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type" ); } - context.parser().nextToken(); - SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext); - sparseVectorFieldMapper.parse(context); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( + fieldName, + context.indexSettings().getIndexVersionCreated() + ); + SparseVectorFieldMapper.Builder sparseVectorFieldMapperBuilder = new SparseVectorFieldMapper.Builder( + "inference" + ); + nestedBuilder.add(sparseVectorFieldMapperBuilder); + TextFieldMapper.Builder textFieldMapperBuilder = new TextFieldMapper.Builder("text", context.indexAnalyzers()).index(false) + .store(false); + nestedBuilder.add(textFieldMapperBuilder); + NestedObjectMapper nestedObjectMapper = nestedBuilder.build(mapperBuilderContext); + + if (context.parser().nextToken() != XContentParser.Token.START_ARRAY) { + throw new IllegalArgumentException( + "[_semantic_text] produced inference must be an array of objects, expected a START_ARRAY but got: " + + context.parser().currentToken() + ); + } + for (token = context.parser().nextToken(); token != XContentParser.Token.END_ARRAY; token = context.parser() + .nextToken()) { + DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); + + if (token != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "each [_semantic_text] produced inference must be an object, expected a START_OBJECT but got: " + + context.parser().currentToken() + ); + } + + Set visitedFields = new HashSet<>(); + for (token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() + .nextToken()) { + + if (token != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException( + "each [semantic_text] produced objects fields expect an object with field names, found " + token + ); + } + + String inferenceField = context.parser().currentName(); + FieldMapper childNestedMapper = (FieldMapper) nestedObjectMapper.getMapper(inferenceField); + if (childNestedMapper == null) { + throw new IllegalArgumentException("unexpected inference result field name: " + inferenceField); + } + context.parser().nextToken(); + childNestedMapper.parse(nestedContext); + visitedFields.add(inferenceField); + } + if (visitedFields.size() != nestedObjectMapper.getChildren().size()) { + throw new IllegalArgumentException("unexpected inference fields: " + visitedFields); + } + } + } }