Skip to content

Commit

Permalink
another test
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 29, 2023
1 parent dc2c4bd commit 00c7a4b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package org.elasticsearch.inference;

import org.elasticsearch.client.internal.Client;
import org.elasticsearch.inference.InferenceService;

import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.ml.integration;

import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -27,6 +26,7 @@
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;

public class CoordinatedInferenceIngestIT extends ESRestTestCase {
Expand Down Expand Up @@ -85,23 +85,25 @@ public void testIngestWithMultipleModelTypes() throws IOException {
{
var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinition(inferenceServiceModelId), docs);
var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
System.out.println("IS DOCS " + simulatedDocs);
assertThat(simulatedDocs, hasSize(2));
assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0)));
assertEquals("bar", MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0)));
var sparseEmbedding = (Map<String, Double>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0));
assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1"));
assertEquals(inferenceServiceModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1)));
assertEquals("bar", MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1)));
sparseEmbedding = (Map<String, Double>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1));
assertEquals(Double.valueOf(1.0), sparseEmbedding.get("1"));
}

{
var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinition(pyTorchModelId), docs);
var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
System.out.println("PT DOCS " + simulatedDocs);
assertThat(simulatedDocs, hasSize(2));
assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(0)));
List<List<Double>> results = (List<List<Double>>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(0));
assertThat(results.get(0), contains(1.0, 1.0));
assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.model_id", simulatedDocs.get(1)));
results = (List<List<Double>>) MapHelper.dig("doc._source.ml.body", simulatedDocs.get(1));
assertThat(results.get(0), contains(1.0, 1.0));
}

String boostedTreeDocs = Strings.format("""
Expand All @@ -120,7 +122,6 @@ public void testIngestWithMultipleModelTypes() throws IOException {
boostedTreeDocs
);
var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
System.out.println("DFA DOCS " + simulatedDocs);
assertThat(simulatedDocs, hasSize(2));
assertEquals(boostedTreeModelId, MapHelper.dig("doc._source.ml.regression.model_id", simulatedDocs.get(0)));
assertNotNull(MapHelper.dig("doc._source.ml.regression.predicted_value", simulatedDocs.get(0)));
Expand Down Expand Up @@ -161,18 +162,25 @@ public void testPipelineConfiguredWithFieldMap() throws IOException {
{
var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinitionWithFieldMap(pyTorchModelId), docs);
var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
System.out.println("DOCS pt" + simulatedDocs);
assertThat(simulatedDocs, hasSize(2));
// assertEquals(boostedTreeModelId, MapHelper.dig("doc._source.ml.regression.model_id", simulatedDocs.get(0)));
assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.inference.model_id", simulatedDocs.get(0)));
List<List<Double>> results = (List<List<Double>>) MapHelper.dig(
"doc._source.ml.inference.predicted_value",
simulatedDocs.get(0)
);
assertThat(results.get(0), contains(1.0, 1.0));
assertEquals(pyTorchModelId, MapHelper.dig("doc._source.ml.inference.model_id", simulatedDocs.get(1)));
results = (List<List<Double>>) MapHelper.dig("doc._source.ml.inference.predicted_value", simulatedDocs.get(1));
assertThat(results.get(0), contains(1.0, 1.0));
}

{
// Inference service models cannot be configured with the field map
var responseMap = simulatePipeline(ExampleModels.nlpModelPipelineDefinitionWithFieldMap(inferenceServiceModelId), docs);
var simulatedDocs = (List<Map<String, Object>>) responseMap.get("docs");
System.out.println("DOCS is" + simulatedDocs);
var errorMsg = (String) MapHelper.dig("error.reason", simulatedDocs.get(0));
assertThat(errorMsg, containsString("[is_model] is configured for the _inference API and does not accept documents as input"));
assertThat(simulatedDocs, hasSize(2));
// "[" + inferenceServiceModelId + "] is configured for the _inference API and does not accept documents as input"
// assertEquals(boostedTreeModelId, MapHelper.dig("doc._source.ml.regression.model_id", simulatedDocs.get(0)));
}

}
Expand Down Expand Up @@ -216,8 +224,6 @@ protected Map<String, Object> simulatePipeline(String pipelineDef, String docs)

Request request = new Request("POST", "_ingest/pipeline/_simulate?error_trace=true");
request.setJsonEntity(simulate);
var response = client().performRequest(request);
System.out.println(EntityUtils.toString(response.getEntity()));
return entityAsMap(client().performRequest(request));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ public static String nlpModelPipelineDefinitionWithFieldMap(String modelId) {
"inference": {
"model_id": "%s",
"field_map": {
"text_field": "body"
"body": "input"
}
}
}
Expand Down

0 comments on commit 00c7a4b

Please sign in to comment.