Skip to content

Commit

Permalink
Fix tests of SemanticTextInferenceResultFieldMapper
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 14, 2024
1 parent fbefa0b commit 59194d7
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.elasticsearch.inference;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -18,13 +17,17 @@
import java.util.Map;
import java.util.Objects;

public record ModelSettings(TaskType taskType, String inferenceId, @Nullable Integer dimensions, @Nullable SimilarityMeasure similarity) {
public class ModelSettings {

public static final String NAME = "model_settings";
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
public static final ParseField SIMILARITY_FIELD = new ParseField("similarity");
private final TaskType taskType;
private final String inferenceId;
private final Integer dimensions;
private final SimilarityMeasure similarity;

public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
Objects.requireNonNull(taskType, "task type must not be null");
Expand Down Expand Up @@ -74,4 +77,20 @@ public Map<String, Object> asMap() {
}
return Map.of(NAME, attrsMap);
}

public TaskType taskType() {
return taskType;
}

public String inferenceId() {
return inferenceId;
}

public Integer dimensions() {
return dimensions;
}

public SimilarityMeasure similarity() {
return similarity;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ private static void parseFieldInferenceChunkElement(
}
}

if (REQUIRED_SUBFIELDS.containsAll(visitedSubfields) == false) {
if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) {
Set<String> missingSubfields = REQUIRED_SUBFIELDS.stream()
.filter(s -> visitedSubfields.contains(s) == false)
.collect(Collectors.toSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.LeafNestedDocuments;
import org.elasticsearch.search.NestedDocuments;
Expand All @@ -51,8 +53,9 @@
import java.util.Set;
import java.util.function.Consumer;

import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.SPARSE_VECTOR_SUBFIELD_NAME;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.TEXT_SUBFIELD_NAME;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT;
import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS;
import static org.hamcrest.Matchers.containsString;

public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase {
Expand Down Expand Up @@ -214,7 +217,7 @@ public void testMissingSubfields() throws IOException {
)
)
);
assertThat(ex.getMessage(), containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + "]"));
assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]"));
}
{
DocumentParsingException ex = expectThrows(
Expand All @@ -232,7 +235,7 @@ public void testMissingSubfields() throws IOException {
)
)
);
assertThat(ex.getMessage(), containsString("Missing required subfields: [" + TEXT_SUBFIELD_NAME + "]"));
assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]"));
}
{
DocumentParsingException ex = expectThrows(
Expand All @@ -252,7 +255,7 @@ public void testMissingSubfields() throws IOException {
);
assertThat(
ex.getMessage(),
containsString("Missing required subfields: [" + SPARSE_VECTOR_SUBFIELD_NAME + ", " + TEXT_SUBFIELD_NAME + "]")
containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]")
);
}
}
Expand Down Expand Up @@ -411,8 +414,10 @@ private static void addSemanticTextInferenceResults(
Map<String, Object> extraSubfields
) throws IOException {

Map<String, List<Map<String, Object>>> inferenceResultsMap = new HashMap<>();
Map<String, Map<String, Object>> inferenceResultsMap = new HashMap<>();
for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) {
Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put(ModelSettings.NAME, modelSettingsMap());
List<Map<String, Object>> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size());

Iterator<SparseEmbeddingResults.Embedding> embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults()
Expand All @@ -425,17 +430,10 @@ private static void addSemanticTextInferenceResults(

Map<String, Object> subfieldMap = new HashMap<>();
if (sparseVectorSubfieldOptions.include()) {
Map<String, Object> embeddingMap = embedding.asMap();
if (sparseVectorSubfieldOptions.includeIsTruncated() == false) {
embeddingMap.remove(SparseEmbeddingResults.Embedding.IS_TRUNCATED);
}
if (sparseVectorSubfieldOptions.includeEmbedding() == false) {
embeddingMap.remove(SparseEmbeddingResults.Embedding.EMBEDDING);
}
subfieldMap.put(SPARSE_VECTOR_SUBFIELD_NAME, embeddingMap);
subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING));
}
if (includeTextSubfield) {
subfieldMap.put(TEXT_SUBFIELD_NAME, text);
subfieldMap.put(INFERENCE_CHUNKS_TEXT, text);
}
if (extraSubfields != null) {
subfieldMap.putAll(extraSubfields);
Expand All @@ -444,28 +442,40 @@ private static void addSemanticTextInferenceResults(
parsedInferenceResults.add(subfieldMap);
}

inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), parsedInferenceResults);
fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults);
inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap);
}

sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap);
}

private static Map<String, Object> modelSettingsMap() {
return Map.of(
ModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING.toString(),
ModelSettings.INFERENCE_ID_FIELD.getPreferredName(), randomAlphaOfLength(8)
);
}

private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException {
mappingBuilder.startObject(semanticTextFieldName);
mappingBuilder.field("type", "nested");
mappingBuilder.startObject("properties");
mappingBuilder.startObject(SPARSE_VECTOR_SUBFIELD_NAME);
mappingBuilder.startObject("properties");
mappingBuilder.startObject(SparseEmbeddingResults.Embedding.EMBEDDING);
mappingBuilder.field("type", "sparse_vector");
mappingBuilder.endObject();
mappingBuilder.endObject();
mappingBuilder.endObject();
mappingBuilder.startObject(TEXT_SUBFIELD_NAME);
mappingBuilder.field("type", "text");
mappingBuilder.field("index", false);
mappingBuilder.endObject();
mappingBuilder.endObject();
{
mappingBuilder.field("type", "nested");
mappingBuilder.startObject("properties");
{
mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS);
{
mappingBuilder.field("type", "sparse_vector");
}
mappingBuilder.endObject();
mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT);
{
mappingBuilder.field("type", "text");
mappingBuilder.field("index", false);
}
mappingBuilder.endObject();
}
mappingBuilder.endObject();
}
mappingBuilder.endObject();
}

Expand All @@ -478,9 +488,7 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook
for (String token : tokens) {
queryBuilder.add(
new BooleanClause(
new TermQuery(
new Term(path + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING, token)
),
new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)),
BooleanClause.Occur.MUST
)
);
Expand All @@ -499,9 +507,8 @@ private static void assertValidChildDoc(
visitedChildDocs.add(
new VisitedChildDocInfo(
childDoc.getPath(),
childDoc.getFields(
childDoc.getPath() + "." + SPARSE_VECTOR_SUBFIELD_NAME + "." + SparseEmbeddingResults.Embedding.EMBEDDING
).size()
childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS)
.size()
)
);
}
Expand Down

0 comments on commit 59194d7

Please sign in to comment.