Skip to content

Commit

Permalink
Test inference service provides different inference results for diffe…
Browse files Browse the repository at this point in the history
…rent inputs
  • Loading branch information
carlosdelest committed May 7, 2024
1 parent 93d0a87 commit 356ac9f
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ public void testMockService() throws IOException {
assertEquals("text_embedding_test_service", modelMap.get("service"));
}

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
List<String> input = List.of(randomAlphaOfLength(10));
var inference = inferOnMockService(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING);
// Same input should return the same result
assertEquals(inference, inferOnMockService(inferenceEntityId, input));
// Different input values should not
assertNotEquals(
inference,
inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
);
}

public void testMockServiceWithMultipleInputs() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ public void testMockService() throws IOException {
assertEquals("test_service", modelMap.get("service"));
}

// The response is randomly generated, the input can be anything
var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
List<String> input = List.of(randomAlphaOfLength(10));
var inference = inferOnMockService(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
// Same input should return the same result
assertEquals(inference, inferOnMockService(inferenceEntityId, input));
// Different input values should not
assertNotEquals(
inference,
inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
);
}

public void testMockServiceWithMultipleInputs() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@

public abstract class AbstractTestInferenceService implements InferenceService {

protected static int stringWeight(String input, int position) {
int hashCode = input.hashCode();
if (hashCode < 0) {
hashCode = -hashCode;
}
return hashCode + position;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
Expand Down Expand Up @@ -198,4 +206,8 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
}
}

protected static float stringAsFloat(String value) {
return value.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private TextEmbeddingResults makeResults(List<String> input, int dimensions) {
for (int i = 0; i < input.size(); i++) {
List<Float> values = new ArrayList<>();
for (int j = 0; j < dimensions; j++) {
values.add((float) j);
values.add((float) stringWeight(input.get(i), j));
}
embeddings.add(new TextEmbeddingResults.Embedding(values));
}
Expand All @@ -136,7 +136,7 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
for (int i = 0; i < input.size(); i++) {
double[] values = new double[dimensions];
for (int j = 0; j < 5; j++) {
values[j] = j;
values[j] = stringWeight(input.get(i), j);
}
results.add(
new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private SparseEmbeddingResults makeResults(List<String> input) {
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<SparseEmbeddingResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, j + 1.0F));
tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j)));
}
embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false));
}
Expand All @@ -133,7 +133,7 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F));
tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j)));
}
results.add(
new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)))
Expand All @@ -145,7 +145,6 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
}

}

public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ public void testIngestWithMultipleModelTypes() throws IOException {
assertThat(simulatedDocs, hasSize(2));
assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0)));
var sparseEmbedding = (Map<String, Double>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0));
assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1"));
assertNotNull(sparseEmbedding.get("feature_1"));
assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1)));
sparseEmbedding = (Map<String, Double>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1));
assertEquals(Double.valueOf(2.0), sparseEmbedding.get("feature_1"));
assertNotNull(sparseEmbedding.get("feature_1"));
}

{
Expand Down

0 comments on commit 356ac9f

Please sign in to comment.