Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Apr 16, 2024
1 parent 92cdfa8 commit 914ce1c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,26 @@
package org.elasticsearch.xpack.inference.mapper;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.support.MapXContentParser;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -100,6 +91,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder.endObject();
}

@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("task_type=").append(taskType);
if (dimensions != null) {
sb.append(", dimensions=").append(dimensions);
}
if (similarity != null) {
sb.append(", similarity=").append(similarity);
}
return sb.toString();
}

private void validate() {
switch (taskType) {
case TEXT_EMBEDDING:
Expand Down Expand Up @@ -257,71 +261,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD);
MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD);
}

/**
* Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}.
*/
public static List<Chunk> toSemanticTextFieldChunks(
String field,
String inferenceId,
List<ChunkedInferenceServiceResults> results,
XContentType contentType
) {
List<Chunk> chunks = new ArrayList<>();
for (var result : results) {
if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) {
for (var chunk : textExpansionResults.getChunkedResults()) {
chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens())));
}
} else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) {
for (var chunk : textEmbeddingResults.getChunks()) {
chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding())));
}
} else {
throw new ElasticsearchStatusException(
"Invalid inference results format for field [{}] with inference id [{}], got {}",
RestStatus.BAD_REQUEST,
field,
inferenceId,
result.getWriteableName()
);
}
}
return chunks;
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, double[] value) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (double v : value) {
b.value(v);
}
b.endArray();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

/**
* Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent},
* into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, List<TextExpansionResults.WeightedToken> tokens) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
}
b.endObject();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public static class Builder extends FieldMapper.Builder {
(n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o),
mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings,
XContentBuilder::field,
(m) -> m == null ? "null" : Strings.toString(m)
Objects::toString
).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings);

private final Parameter<Map<String, String>> meta = Parameter.metaParam();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ public void testUpdateModelSettings() throws IOException {
);
assertThat(
exc.getMessage(),
containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]")
containsString("Cannot update parameter [model_settings] " + "from [task_type=sparse_embedding] to [null]")
);
}
{
Expand All @@ -289,8 +289,8 @@ public void testUpdateModelSettings() throws IOException {
exc.getMessage(),
containsString(
"Cannot update parameter [model_settings] "
+ "from [{\"task_type\":\"sparse_embedding\"}] "
+ "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]"
+ "from [task_type=sparse_embedding] "
+ "to [task_type=text_embedding, dimensions=10, similarity=cosine]"
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@

package org.elasticsearch.xpack.inference.mapper;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
Expand All @@ -30,7 +35,6 @@
import java.util.function.Predicate;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.hamcrest.Matchers.equalTo;

public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTextField> {
Expand Down Expand Up @@ -143,6 +147,77 @@ public static SemanticTextField randomSemanticText(String fieldName, Model model
);
}

/**
* Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link SemanticTextField.Chunk}.
*/
private static List<SemanticTextField.Chunk> toSemanticTextFieldChunks(
String field,
String inferenceId,
List<ChunkedInferenceServiceResults> results,
XContentType contentType
) {
List<SemanticTextField.Chunk> chunks = new ArrayList<>();
for (var result : results) {
if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) {
for (var chunk : textExpansionResults.getChunkedResults()) {
chunks.add(
new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))
);
}
} else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) {
for (var chunk : textEmbeddingResults.getChunks()) {
chunks.add(
new SemanticTextField.Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))
);
}
} else {
throw new ElasticsearchStatusException(
"Invalid inference results format for field [{}] with inference id [{}], got {}",
RestStatus.BAD_REQUEST,
field,
inferenceId,
result.getWriteableName()
);
}
}
return chunks;
}

/**
* Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent},
* into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, List<TextExpansionResults.WeightedToken> tokens) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS);
}
b.endObject();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, double[] value) {
try {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (double v : value) {
b.value(v);
}
b.endArray();
return BytesReference.bytes(b);
} catch (IOException exc) {
throw new RuntimeException(exc);
}
}

public static Model randomModel(TaskType taskType) {
String serviceName = randomAlphaOfLengthBetween(5, 10);
String inferenceId = randomAlphaOfLengthBetween(5, 10);
Expand Down

0 comments on commit 914ce1c

Please sign in to comment.