Skip to content

Commit

Permalink
Updated SemanticTextFieldMapperTests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikep86 committed Sep 25, 2024
1 parent 932fe8f commit 57cccee
Showing 1 changed file with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.SEARCH_INFERENCE_ID_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_INFERENCE_ID;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_SEARCH_INFERENCE_ID;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -92,7 +95,17 @@ protected Collection<? extends Plugin> getPlugins() {

@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "semantic_text").field("inference_id", "test_model");
b.field("type", "semantic_text");
}

private void minimalMappingWithDefaults(XContentBuilder b) throws IOException {
minimalMapping(b);
b.field(INFERENCE_ID_FIELD, DEFAULT_INFERENCE_ID).field(SEARCH_INFERENCE_ID_FIELD, DEFAULT_SEARCH_INFERENCE_ID);
}

@Override
protected void metaMapping(XContentBuilder b) throws IOException {
minimalMappingWithDefaults(b);
}

@Override
Expand Down Expand Up @@ -156,7 +169,7 @@ protected void assertSearchable(MappedFieldType fieldType) {

public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString());
assertEquals(Strings.toString(fieldMapping(this::minimalMappingWithDefaults)), mapper.mappingSource().toString());

ParsedDocument doc1 = mapper.parse(source(this::writeField));
List<IndexableField> fields = doc1.rootDoc().getFields("field");
Expand All @@ -172,12 +185,40 @@ public void testFieldHasValue() {
assertTrue(fieldType.fieldHasValue(fieldInfos));
}

public void testInferenceIdNotPresent() {
public void testSetInferenceEndpoints() throws IOException {
final String fieldName = "field";
final String inferenceId = "foo";
final String searchInferenceId = "bar";

{
MapperService mapperService = createMapperService(
fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, inferenceId))
);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
}
{
MapperService mapperService = createMapperService(
fieldMapping(
b -> b.field("type", "semantic_text")
.field(INFERENCE_ID_FIELD, inferenceId)
.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
)
);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId);
}
}

public void testInferenceIdRequired() {
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text")))
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "foo")))
);
assertThat(
e.getMessage(),
containsString("[inference_id] must be specified for semantic_text field [field] when [search_inference_id] is specified")
);
assertThat(e.getMessage(), containsString("field [inference_id] must be specified"));
}

public void testCannotBeUsedInMultiFields() {
Expand Down Expand Up @@ -221,7 +262,7 @@ public void testDynamicUpdate() throws IOException {
new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, inferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
}

{
Expand All @@ -232,7 +273,7 @@ public void testDynamicUpdate() throws IOException {
new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, searchInferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId);
}
}

Expand Down Expand Up @@ -331,39 +372,39 @@ public void testUpdateSearchInferenceId() throws IOException {
String fieldName = randomFieldName(depth);
MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null));
assertSemanticTextField(mapperService, fieldName, false);
assertSearchInferenceId(mapperService, fieldName, inferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);

merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1));
assertSemanticTextField(mapperService, fieldName, false);
assertSearchInferenceId(mapperService, fieldName, searchInferenceId1);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1);

merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2));
assertSemanticTextField(mapperService, fieldName, false);
assertSearchInferenceId(mapperService, fieldName, searchInferenceId2);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2);

merge(mapperService, buildMapping.apply(fieldName, null));
assertSemanticTextField(mapperService, fieldName, false);
assertSearchInferenceId(mapperService, fieldName, inferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);

mapperService = mapperServiceForFieldWithModelSettings(
fieldName,
inferenceId,
new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
);
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, inferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);

merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1));
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, searchInferenceId1);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1);

merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2));
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, searchInferenceId2);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2);

merge(mapperService, buildMapping.apply(fieldName, null));
assertSemanticTextField(mapperService, fieldName, true);
assertSearchInferenceId(mapperService, fieldName, inferenceId);
assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
}
}

Expand Down Expand Up @@ -409,11 +450,17 @@ private static void assertSemanticTextField(MapperService mapperService, String
}
}

private static void assertSearchInferenceId(MapperService mapperService, String fieldName, String expectedSearchInferenceId) {
private static void assertInferenceEndpoints(
MapperService mapperService,
String fieldName,
String expectedInferenceId,
String expectedSearchInferenceId
) {
var fieldType = mapperService.fieldType(fieldName);
assertNotNull(fieldType);
assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class));
SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType;
assertEquals(expectedInferenceId, semanticTextFieldType.getInferenceId());
assertEquals(expectedSearchInferenceId, semanticTextFieldType.getSearchInferenceId());
}

Expand All @@ -433,9 +480,19 @@ public void testSuccessfulParse() throws IOException {

MapperService mapperService = createMapperService(mapping);
assertSemanticTextField(mapperService, fieldName1, false);
assertSearchInferenceId(mapperService, fieldName1, setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId());
assertInferenceEndpoints(
mapperService,
fieldName1,
model1.getInferenceEntityId(),
setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId()
);
assertSemanticTextField(mapperService, fieldName2, false);
assertSearchInferenceId(mapperService, fieldName2, setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId());
assertInferenceEndpoints(
mapperService,
fieldName2,
model2.getInferenceEntityId(),
setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId()
);

DocumentMapper documentMapper = mapperService.documentMapper();
ParsedDocument doc = documentMapper.parse(
Expand Down

0 comments on commit 57cccee

Please sign in to comment.