Skip to content

Commit

Permalink
toSemanticTextFieldChunks() method belongs to SemanticTextField - fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Apr 30, 2024
1 parent b382b3b commit 60f5e72
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,35 @@
package org.elasticsearch.xpack.inference.mapper;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.support.MapXContentParser;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -271,4 +280,72 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
}

/**
* Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}.
*/
public static List<Chunk> toSemanticTextFieldChunks(
String field,
String inferenceId,
List<ChunkedInferenceServiceResults> results,
XContentType contentType
) {
List<Chunk> chunks = new ArrayList<>();
for (var result : results) {
if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) {
for (var chunk : textExpansionResults.getChunkedResults()) {
chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens())));
}
} else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) {
for (var chunk : textEmbeddingResults.getChunks()) {
chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding())));
}
} else {
throw new ElasticsearchStatusException(
"Invalid inference results format for field [{}] with inference id [{}], got {}",
RestStatus.BAD_REQUEST,
field,
inferenceId,
result.getWriteableName()
);
}
}
return chunks;
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, double[] value) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (double v : value) {
b.value(v);
}
b.endArray();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

/**
* Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent},
* into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, List<TextExpansionResults.WeightedToken> tokens) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
}
b.endObject();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
package org.elasticsearch.xpack.inference.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.Explicit;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
Expand All @@ -37,19 +35,11 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentLocation;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -275,77 +265,6 @@ public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) {
return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields);
}

/**
* Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link SemanticTextField.Chunk}.
*/
static List<SemanticTextField.Chunk> toSemanticTextFieldChunks(
String field,
String inferenceId,
List<ChunkedInferenceServiceResults> results,
XContentType contentType
) {
List<SemanticTextField.Chunk> chunks = new ArrayList<>();
for (var result : results) {
if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) {
for (var chunk : textExpansionResults.getChunkedResults()) {
chunks.add(
new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))
);
}
} else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) {
for (var chunk : textEmbeddingResults.getChunks()) {
chunks.add(
new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))
);
}
} else {
throw new ElasticsearchStatusException(
"Invalid inference results format for field [{}] with inference id [{}], got {}",
RestStatus.BAD_REQUEST,
field,
inferenceId,
result.getWriteableName()
);
}
}
return chunks;
}

/**
* Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent},
* into a {@link BytesReference}.
*/
static BytesReference toBytesReference(XContent xContent, List<TextExpansionResults.WeightedToken> tokens) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
}
b.endObject();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, double[] value) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (double v : value) {
b.value(v);
}
b.endArray();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

public static class SemanticTextFieldType extends SimpleMappedFieldType {
private final String inferenceId;
private final SemanticTextField.ModelSettings modelSettings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import java.util.function.Predicate;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.toSemanticTextFieldChunks;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

Expand Down

0 comments on commit 60f5e72

Please sign in to comment.