Skip to content

Commit

Permalink
PR review - add validations for ModelSettings
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Apr 18, 2024
1 parent b3246a6 commit 6a32de2
Showing 1 changed file with 37 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
Expand All @@ -30,6 +32,7 @@

import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.toSemanticTextFieldChunks;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTextField> {
Expand Down Expand Up @@ -95,6 +98,40 @@ protected boolean supportsUnknownFields() {
return true;
}

public void testModelSettingsValidation() {
NullPointerException npe = expectThrows(NullPointerException.class, () -> {
new SemanticTextField.ModelSettings(null, 10, SimilarityMeasure.COSINE);
});
assertThat(npe.getMessage(), equalTo("task type must not be null"));

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> {
new SemanticTextField.ModelSettings(TaskType.COMPLETION, 10, SimilarityMeasure.COSINE);
});
assertThat(ex.getMessage(), containsString("Wrong [task_type]"));

ex = expectThrows(
IllegalArgumentException.class,
() -> { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, 10, null); }
);
assertThat(ex.getMessage(), containsString("[dimensions] is not allowed"));

ex = expectThrows(IllegalArgumentException.class, () -> {
new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE);
});
assertThat(ex.getMessage(), containsString("[similarity] is not allowed"));

ex = expectThrows(IllegalArgumentException.class, () -> {
new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, null, SimilarityMeasure.COSINE);
});
assertThat(ex.getMessage(), containsString("required [dimensions] field is missing"));

ex = expectThrows(
IllegalArgumentException.class,
() -> { new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, null); }
);
assertThat(ex.getMessage(), containsString("required [similarity] field is missing"));
}

public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List<String> inputs) {
List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk> chunks = new ArrayList<>();
for (String input : inputs) {
Expand Down

0 comments on commit 6a32de2

Please sign in to comment.