Skip to content

Commit

Permalink
[ML] More checks and tests for parsing Inference processor config (el…
Browse files Browse the repository at this point in the history
…astic#100335)

Following on from elastic#100205 this PR adds more tests and checks 
for corner cases when parsing the configuration.
  • Loading branch information
davidkyle committed Oct 6, 2023
1 parent 89e8c65 commit 46d9594
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 69 deletions.
3 changes: 2 additions & 1 deletion docs/reference/ingest/processors/inference.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ ingested in the pipeline.
|======
| Name | Required | Default | Description
| `model_id` . | yes | - | (String) The ID or alias for the trained model, or the ID of the deployment.
| `input_output` | no | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options.
| `input_output` | no | - | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options.
| `target_field` | no | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
| `field_map` | no | If defined the model's default field map | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
| `inference_config` | no | The default settings defined in the model | (Object) Contains the inference type and its options.
| `ignore_missing` | no | `false` | (Boolean) If `true` and any of the input fields defined in `input_ouput` are missing then those missing fields are quietly ignored, otherwise a missing field causes a failure. Only applies when using `input_output` configurations to explicitly list the input fields.
include::common-options.asciidoc[]
|======

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,13 @@ public void testToXContent() throws IOException {
}

@Override
void assertFieldValues(ClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) {
String path = resultsField + "." + createdInstance.getResultsField();
void assertFieldValues(
ClassificationInferenceResults createdInstance,
IngestDocument document,
String parentField,
String resultsField
) {
String path = parentField + resultsField;
switch (createdInstance.getPredictionFieldType()) {
case NUMBER -> assertThat(document.getFieldValue(path, Double.class), equalTo(createdInstance.predictedValue()));
case STRING -> assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ protected ErrorInferenceResults mutateInstance(ErrorInferenceResults instance) t
}

@Override
void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String resultsField) {
assertThat(document.getFieldValue(resultsField + ".error", String.class), equalTo(createdInstance.getException().getMessage()));
void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertThat(document.getFieldValue(parentField + "error", String.class), equalTo(createdInstance.getException().getMessage()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import java.io.IOException;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;

abstract class InferenceResultsTestCase<T extends InferenceResults> extends AbstractWireSerializingTestCase<T> {

public void testWriteToIngestDoc() throws IOException {
Expand All @@ -34,11 +36,57 @@ public void testWriteToIngestDoc() throws IOException {
document.setFieldValue(parentField, Map.of());
}
InferenceResults.writeResult(inferenceResult, document, parentField, modelId);
assertFieldValues(inferenceResult, document, alreadyHasResult ? parentField + ".1" : parentField);

String expectedOutputPath = alreadyHasResult ? parentField + ".1." : parentField + ".";

assertThat(
document.getFieldValue(expectedOutputPath + InferenceResults.MODEL_ID_RESULTS_FIELD, String.class),
equalTo(modelId)
);
if (inferenceResult instanceof NlpInferenceResults nlpInferenceResults && nlpInferenceResults.isTruncated()) {
assertTrue(document.getFieldValue(expectedOutputPath + "is_truncated", Boolean.class));
}

assertFieldValues(inferenceResult, document, expectedOutputPath, inferenceResult.getResultsField());
}
}

private void testWriteToIngestDocField() throws IOException {
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
T inferenceResult = createTestInstance();
if (randomBoolean()) {
inferenceResult = copyInstance(inferenceResult, TransportVersion.current());
}
IngestDocument document = TestIngestDocument.emptyIngestDocument();
String outputField = randomAlphaOfLength(10);
String modelId = randomAlphaOfLength(10);
String parentField = randomBoolean() ? null : randomAlphaOfLength(10);
boolean writeModelId = randomBoolean();

boolean alreadyHasResult = randomBoolean();
if (alreadyHasResult && parentField != null) {
document.setFieldValue(parentField, Map.of());
}
InferenceResults.writeResultToField(inferenceResult, document, parentField, outputField, modelId, writeModelId);

String expectedOutputPath = parentField == null ? "" : parentField + ".";
if (alreadyHasResult && parentField != null) {
expectedOutputPath = expectedOutputPath + "1.";
}

if (writeModelId) {
String modelIdPath = expectedOutputPath + InferenceResults.MODEL_ID_RESULTS_FIELD;
assertThat(document.getFieldValue(modelIdPath, String.class), equalTo(modelId));
}
if (inferenceResult instanceof NlpInferenceResults nlpInferenceResults && nlpInferenceResults.isTruncated()) {
assertTrue(document.getFieldValue(expectedOutputPath + "is_truncated", Boolean.class));
}

assertFieldValues(inferenceResult, document, expectedOutputPath, outputField);
}
}

abstract void assertFieldValues(T createdInstance, IngestDocument document, String resultsField);
abstract void assertFieldValues(T createdInstance, IngestDocument document, String parentField, String resultsField);

public void testWriteToDocAndSerialize() throws IOException {
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,12 @@ public void testAsMap() {

@Override
@SuppressWarnings("unchecked")
void assertFieldValues(NerResults createdInstance, IngestDocument document, String resultsField) {
assertThat(
document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), String.class),
equalTo(createdInstance.getAnnotatedResult())
);
void assertFieldValues(NerResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertThat(document.getFieldValue(parentField + resultsField, String.class), equalTo(createdInstance.getAnnotatedResult()));

if (createdInstance.getEntityGroups().size() > 0) {
List<Map<String, Object>> resultList = (List<Map<String, Object>>) document.getFieldValue(
resultsField + "." + ENTITY_FIELD,
parentField + ENTITY_FIELD,
List.class
);
assertThat(resultList.size(), equalTo(createdInstance.getEntityGroups().size()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ protected Writeable.Reader<NlpClassificationInferenceResults> instanceReader() {
}

@Override
void assertFieldValues(NlpClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) {
String path = resultsField + "." + createdInstance.getResultsField();
void assertFieldValues(
NlpClassificationInferenceResults createdInstance,
IngestDocument document,
String parentField,
String resultsField
) {
String path = parentField + resultsField;
assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ public void testAsMap() {
}

@Override
void assertFieldValues(PyTorchPassThroughResults createdInstance, IngestDocument document, String resultsField) {
assertArrayEquals(
createdInstance.getInference(),
document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), double[][].class)
);
void assertFieldValues(PyTorchPassThroughResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertArrayEquals(createdInstance.getInference(), document.getFieldValue(parentField + resultsField, double[][].class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ protected Writeable.Reader<QuestionAnsweringInferenceResults> instanceReader() {
}

@Override
void assertFieldValues(QuestionAnsweringInferenceResults createdInstance, IngestDocument document, String resultsField) {
String path = resultsField + "." + createdInstance.getResultsField();
void assertFieldValues(
QuestionAnsweringInferenceResults createdInstance,
IngestDocument document,
String parentField,
String resultsField
) {
String path = parentField + resultsField;
assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ public void testToXContent() {
}

@Override
void assertFieldValues(RegressionInferenceResults createdInstance, IngestDocument document, String resultsField) {
assertThat(
document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), Double.class),
closeTo(createdInstance.value(), 1e-10)
);
void assertFieldValues(RegressionInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertThat(document.getFieldValue(parentField + resultsField, Double.class), closeTo(createdInstance.value(), 1e-10));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ public void testAsMap() {
}

@Override
void assertFieldValues(TextEmbeddingResults createdInstance, IngestDocument document, String resultsField) {
assertArrayEquals(
document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), double[].class),
createdInstance.getInference(),
1e-10
);
void assertFieldValues(TextEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertArrayEquals(document.getFieldValue(parentField + resultsField, double[].class), createdInstance.getInference(), 1e-10);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,8 @@ protected TextExpansionResults mutateInstance(TextExpansionResults instance) {

@Override
@SuppressWarnings("unchecked")
void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String resultsField) {
var ingestedTokens = (Map<String, Object>) document.getFieldValue(
resultsField + '.' + createdInstance.getResultsField(),
Map.class
);
void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String parentField, String resultsField) {
var ingestedTokens = (Map<String, Object>) document.getFieldValue(parentField + resultsField, Map.class);
var tokenMap = createdInstance.getWeightedTokens()
.stream()
.collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ protected Writeable.Reader<TextSimilarityInferenceResults> instanceReader() {
}

@Override
void assertFieldValues(TextSimilarityInferenceResults createdInstance, IngestDocument document, String resultsField) {
String path = resultsField + "." + createdInstance.getResultsField();
void assertFieldValues(
TextSimilarityInferenceResults createdInstance,
IngestDocument document,
String parentField,
String resultsField
) {
String path = parentField + resultsField;
assertThat(document.getFieldValue(path, Double.class), equalTo(createdInstance.predictedValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ protected Writeable.Reader<WarningInferenceResults> instanceReader() {
}

@Override
void assertFieldValues(WarningInferenceResults createdInstance, IngestDocument document, String resultsField) {
assertThat(document.getFieldValue(resultsField + ".warning", String.class), equalTo(createdInstance.getWarning()));
void assertFieldValues(WarningInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) {
assertThat(document.getFieldValue(parentField + "warning", String.class), equalTo(createdInstance.getWarning()));
}
}
Loading

0 comments on commit 46d9594

Please sign in to comment.