From 03423e3b2f2954d6f2da0a42385da21a96f31c91 Mon Sep 17 00:00:00 2001
From: carlosdelest <carlos.delgado@elastic.co>
Date: Tue, 5 Dec 2023 12:27:16 +0100
Subject: [PATCH] Added tests

---
 .../ml/mapper/SemanticTextFieldMapper.java    |  40 ++++--
 .../mapper/SemanticTextFieldMapperTests.java  | 126 ++++++++++++++++++
 2 files changed, 152 insertions(+), 14 deletions(-)
 create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java

diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java
index df1447c5368bb..b3bc399c04a4c 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java
@@ -9,6 +9,8 @@
 
 import org.apache.lucene.search.Query;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.index.fielddata.FieldDataContext;
+import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.mapper.DocumentParserContext;
 import org.elasticsearch.index.mapper.FieldMapper;
 import org.elasticsearch.index.mapper.MappedFieldType;
@@ -17,7 +19,6 @@
 import org.elasticsearch.index.mapper.SourceValueFetcher;
 import org.elasticsearch.index.mapper.TextSearchInfo;
 import org.elasticsearch.index.mapper.ValueFetcher;
-import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
 import org.elasticsearch.index.query.SearchExecutionContext;
 
 import java.io.IOException;
@@ -32,13 +33,18 @@ private static SemanticTextFieldMapper toType(FieldMapper in) {
         return (SemanticTextFieldMapper) in;
     }
 
+    private static Builder builder(FieldMapper in) {
+        return ((SemanticTextFieldMapper) in).builder;
+    }
+
     public static class Builder extends FieldMapper.Builder {
 
-        final Parameter<String> modelId = Parameter.stringParam("model_id", false, m -> toType(m).modelId, null).addValidator(v -> {
-            if (Strings.isEmpty(v)) {
-                throw new IllegalArgumentException("field [model_id] must be specified");
-            }
-        });
+        private final Parameter<String> modelId = Parameter.stringParam("model_id", false, m -> builder(m).modelId.get(), null)
+            .addValidator(v -> {
+                if (Strings.isEmpty(v)) {
+                    throw new IllegalArgumentException("field [model_id] must be specified");
+                }
+            });
 
         private final Parameter<Map<String, String>> meta = Parameter.metaParam();
 
@@ -62,7 +68,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
                 name(),
                 new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()),
                 modelId.getValue(),
-                copyTo
+                copyTo,
+                this
             );
         }
     }
@@ -71,13 +78,10 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
 
     public static class SemanticTextFieldType extends SimpleMappedFieldType {
 
-        private final SparseVectorFieldMapper.SparseVectorFieldType sparseVectorFieldType;
-
         private final String modelId;
 
         public SemanticTextFieldType(String name, String modelId, Map<String, String> meta) {
-            super(name, true, false, false, TextSearchInfo.NONE, meta);
-            this.sparseVectorFieldType = new SparseVectorFieldMapper.SparseVectorFieldType(name + "." + "inference", meta);
+            super(name, false, false, false, TextSearchInfo.NONE, meta);
             this.modelId = modelId;
         }
 
@@ -93,19 +97,27 @@ public String getInferenceModel() {
 
         @Override
         public Query termQuery(Object value, SearchExecutionContext context) {
-            return null;
+            throw new IllegalArgumentException("termQuery not implemented yet");
         }
 
         @Override
         public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
-            return SourceValueFetcher.identity(name(), context, format);
+            return SourceValueFetcher.toString(name(), context, format);
+        }
+
+        @Override
+        public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) {
+            throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating");
         }
     }
 
     private final String modelId;
 
-    private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo) {
+    private final Builder builder;
+
+    private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo, Builder builder) {
         super(simpleName, mappedFieldType, MultiFields.empty(), copyTo);
+        this.builder = builder;
         this.modelId = modelId;
     }
 
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java
new file mode 100644
index 0000000000000..0f08abfd6fa59
--- /dev/null
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java
@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.mapper;
+
+import org.apache.lucene.index.IndexableField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.mapper.DocumentMapper;
+import org.elasticsearch.index.mapper.MappedFieldType;
+import org.elasticsearch.index.mapper.MapperParsingException;
+import org.elasticsearch.index.mapper.MapperService;
+import org.elasticsearch.index.mapper.MapperTestCase;
+import org.elasticsearch.index.mapper.ParsedDocument;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.ml.MachineLearning;
+import org.junit.AssumptionViolatedException;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import static java.util.Collections.singletonList;
+import static org.hamcrest.Matchers.containsString;
+
+public class SemanticTextFieldMapperTests extends MapperTestCase {
+
+    public void testDefaults() throws Exception {
+        DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
+        assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString());
+
+        ParsedDocument doc1 = mapper.parse(source(this::writeField));
+        List<IndexableField> fields = doc1.rootDoc().getFields("field");
+
+        // No indexable fields
+        assertTrue(fields.isEmpty());
+    }
+
+    public void testModelIdNotPresent() throws IOException {
+        Exception e = expectThrows(
+            MapperParsingException.class,
+            () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text")))
+        );
+        assertThat(e.getMessage(), containsString("field [model_id] must be specified"));
+    }
+
+    public void testCannotBeUsedInMultiFields() {
+        Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
+            b.field("type", "text");
+            b.startObject("fields");
+            b.startObject("semantic");
+            b.field("type", "semantic_text");
+            b.endObject();
+            b.endObject();
+        })));
+        assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields"));
+    }
+
+    public void testUpdatesToModelIdNotSupported() throws IOException {
+        MapperService mapperService = createMapperService(
+            fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model"))
+        );
+        Exception e = expectThrows(
+            IllegalArgumentException.class,
+            () ->  merge(
+                mapperService,
+                fieldMapping(
+                    b -> b.field("type", "semantic_text")
+                        .field("model_id", "another_model")
+                )
+            )
+        );
+        assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]"));
+    }
+
+    @Override
+    protected Collection<? extends Plugin> getPlugins() {
+        return singletonList(new MachineLearning(Settings.EMPTY));
+    }
+
+
+    @Override
+    protected void minimalMapping(XContentBuilder b) throws IOException {
+        b.field("type", "semantic_text").field("model_id", "test_model");
+    }
+
+    @Override
+    protected Object getSampleValueForDocument() {
+        return "value";
+    }
+
+    @Override
+    protected boolean supportsIgnoreMalformed() {
+        return false;
+    }
+
+    @Override
+    protected boolean supportsStoredFields() {
+        return false;
+    }
+
+    @Override
+    protected void registerParameters(ParameterChecker checker) throws IOException {
+    }
+
+    @Override
+    protected Object generateRandomInputValue(MappedFieldType ft) {
+        assumeFalse("doc_values are not supported in semantic_text", true);
+        return null;
+    }
+
+    @Override
+    protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) {
+        throw new AssumptionViolatedException("not supported");
+    }
+
+    @Override
+    protected IngestScriptSupport ingestScriptSupport() {
+        throw new AssumptionViolatedException("not supported");
+    }
+}