From 57ccceef58411f191bf4170d8066ac8c5f89042a Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 25 Sep 2024 14:29:33 -0400 Subject: [PATCH] Updated SemanticTextFieldMapperTests --- .../mapper/SemanticTextFieldMapperTests.java | 93 +++++++++++++++---- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 1697b33fedd92..e74ca584c8bee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -77,8 +77,11 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.SEARCH_INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_INFERENCE_ID; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_SEARCH_INFERENCE_ID; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -92,7 +95,17 @@ protected Collection getPlugins() { @Override protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("inference_id", "test_model"); + b.field("type", "semantic_text"); + } + + private void minimalMappingWithDefaults(XContentBuilder b) throws IOException { + minimalMapping(b); + b.field(INFERENCE_ID_FIELD, DEFAULT_INFERENCE_ID).field(SEARCH_INFERENCE_ID_FIELD, DEFAULT_SEARCH_INFERENCE_ID); + } + + @Override + protected void metaMapping(XContentBuilder b) throws IOException { + minimalMappingWithDefaults(b); } @Override @@ -156,7 +169,7 @@ protected void assertSearchable(MappedFieldType fieldType) { public void testDefaults() throws Exception { DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + assertEquals(Strings.toString(fieldMapping(this::minimalMappingWithDefaults)), mapper.mappingSource().toString()); ParsedDocument doc1 = mapper.parse(source(this::writeField)); List fields = doc1.rootDoc().getFields("field"); @@ -172,12 +185,40 @@ public void testFieldHasValue() { assertTrue(fieldType.fieldHasValue(fieldInfos)); } - public void testInferenceIdNotPresent() { + public void testSetInferenceEndpoints() throws IOException { + final String fieldName = "field"; + final String inferenceId = "foo"; + final String searchInferenceId = "bar"; + + { + MapperService mapperService = createMapperService( + fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, inferenceId)) + ); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); + } + { + MapperService mapperService = createMapperService( + fieldMapping( + b -> b.field("type", "semantic_text") + .field(INFERENCE_ID_FIELD, inferenceId) + .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) + ) + ); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); + } + } + + public void testInferenceIdRequired() { Exception e = expectThrows( MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "foo"))) + ); + assertThat( + e.getMessage(), + containsString("[inference_id] must be specified for semantic_text field [field] when [search_inference_id] is specified") ); - assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); } public void testCannotBeUsedInMultiFields() { @@ -221,7 +262,7 @@ public void testDynamicUpdate() throws IOException { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, inferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); } { @@ -232,7 +273,7 @@ public void testDynamicUpdate() throws IOException { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, searchInferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); } } @@ -331,19 +372,19 @@ public void testUpdateSearchInferenceId() throws IOException { String fieldName = randomFieldName(depth); MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null)); assertSemanticTextField(mapperService, fieldName, false); - assertSearchInferenceId(mapperService, fieldName, inferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); assertSemanticTextField(mapperService, fieldName, false); - assertSearchInferenceId(mapperService, fieldName, searchInferenceId1); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); assertSemanticTextField(mapperService, fieldName, false); - assertSearchInferenceId(mapperService, fieldName, searchInferenceId2); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); merge(mapperService, buildMapping.apply(fieldName, null)); assertSemanticTextField(mapperService, fieldName, false); - assertSearchInferenceId(mapperService, fieldName, inferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); mapperService = mapperServiceForFieldWithModelSettings( fieldName, @@ -351,19 +392,19 @@ public void testUpdateSearchInferenceId() throws IOException { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, inferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, searchInferenceId1); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, searchInferenceId2); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); merge(mapperService, buildMapping.apply(fieldName, null)); assertSemanticTextField(mapperService, fieldName, true); - assertSearchInferenceId(mapperService, fieldName, inferenceId); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); } } @@ -409,11 +450,17 @@ private static void assertSemanticTextField(MapperService mapperService, String } } - private static void assertSearchInferenceId(MapperService mapperService, String fieldName, String expectedSearchInferenceId) { + private static void assertInferenceEndpoints( + MapperService mapperService, + String fieldName, + String expectedInferenceId, + String expectedSearchInferenceId + ) { var fieldType = mapperService.fieldType(fieldName); assertNotNull(fieldType); assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertEquals(expectedInferenceId, semanticTextFieldType.getInferenceId()); assertEquals(expectedSearchInferenceId, semanticTextFieldType.getSearchInferenceId()); } @@ -433,9 +480,19 @@ public void testSuccessfulParse() throws IOException { MapperService mapperService = createMapperService(mapping); assertSemanticTextField(mapperService, fieldName1, false); - assertSearchInferenceId(mapperService, fieldName1, setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId()); + assertInferenceEndpoints( + mapperService, + fieldName1, + model1.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId() + ); assertSemanticTextField(mapperService, fieldName2, false); - assertSearchInferenceId(mapperService, fieldName2, setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId()); + assertInferenceEndpoints( + mapperService, + fieldName2, + model2.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId() + ); DocumentMapper documentMapper = mapperService.documentMapper(); ParsedDocument doc = documentMapper.parse(