Skip to content

Commit

Permalink
Write results to output field
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 3, 2023
1 parent fc99bf5 commit 13e2264
Show file tree
Hide file tree
Showing 17 changed files with 468 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xcontent.ToXContentFragment;

Expand All @@ -24,16 +25,58 @@ static void writeResult(InferenceResults results, IngestDocument ingestDocument,
Objects.requireNonNull(resultField, "resultField");
Map<String, Object> resultMap = results.asMap();
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
if (ingestDocument.hasField(resultField)) {
ingestDocument.appendFieldValue(resultField, resultMap);
setOrAppendValue(resultField, resultMap, ingestDocument);
}

static void writeResultToField(
InferenceResults results,
IngestDocument ingestDocument,
@Nullable String basePath,
String outputField,
String modelId,
boolean includeModelId
) {
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(outputField, "outputField");
Map<String, Object> resultMap = results.asMap(outputField);
if (includeModelId) {
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
}
if (basePath == null) {
// insert the results into the root of the document
for (var entry : resultMap.entrySet()) {
setOrAppendValue(entry.getKey(), entry.getValue(), ingestDocument);
}
} else {
ingestDocument.setFieldValue(resultField, resultMap);
for (var entry : resultMap.entrySet()) {
setOrAppendValue(basePath + "." + entry.getKey(), entry.getValue(), ingestDocument);
}
}
}

private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) {
if (ingestDocument.hasField(path)) {
ingestDocument.appendFieldValue(path, value);
} else {
ingestDocument.setFieldValue(path, value);
}
}

String getResultsField();

/**
* Convert to a map
* @return Map representation of the InferenceResult
*/
Map<String, Object> asMap();

/**
* Convert to a map placing the inference result in {@code outputField}
* @param outputField Write the inference result to this field
* @return Map representation of the InferenceResult
*/
Map<String, Object> asMap(String outputField);

Object predictedValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ public String getResultsField() {
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
addSupportingFieldsToMap(map);
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
map.put(outputField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (topClasses.isEmpty() == false) {
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
Expand All @@ -235,7 +248,6 @@ public Map<String, Object> asMap() {
featureImportance.stream().map(ClassificationFeatureImportance::toMap).collect(Collectors.toList())
);
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public Map<String, Object> asMap() {
return asMap;
}

@Override
public Map<String, Object> asMap(String outputField) {
// errors do not have a result
return asMap();
}

@Override
public String toString() {
return Strings.toString(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField + "_sequence", predictedSequence);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField + "_sequence", predictedSequence);
return map;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,21 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)
public final Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
addMapFields(map);
addSupportingFieldsToMap(map);
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (isTruncated) {
map.put("is_truncated", isTruncated);
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, inference);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, inference);
return map;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public String predictedValue() {
@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, answer);
addSupportingFieldsToMap(map);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, answer);
addSupportingFieldsToMap(map);
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
map.put(START_OFFSET.getPreferredName(), startOffset);
map.put(END_OFFSET.getPreferredName(), endOffset);
if (topClasses.isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public Map<String, Object> asMap() {
throw new UnsupportedOperationException("[raw] does not support map conversion");
}

@Override
public Map<String, Object> asMap(String outputField) {
throw new UnsupportedOperationException("[raw] does not support map conversion");
}

@Override
public Object predictedValue() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,23 @@ public String getResultsField() {
@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
map.put(resultsField, value());
return map;
}

@Override
public Map<String, Object> asMap(String outputField) {
Map<String, Object> map = new LinkedHashMap<>();
addSupportingFieldsToMap(map);
map.put(outputField, value());
return map;
}

private void addSupportingFieldsToMap(Map<String, Object> map) {
if (featureImportance.isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
}
return map;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, inference);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, inference);
return map;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,11 @@ void doWriteTo(StreamOutput out) throws IOException {
void addMapFields(Map<String, Object> map) {
map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField, score);
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, score);
return map;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public Map<String, Object> asMap() {
return asMap;
}

@Override
public Map<String, Object> asMap(String outputField) {
// warnings do not have a result
return asMap();
}

@Override
public Object predictedValue() {
return null;
Expand Down
Loading

0 comments on commit 13e2264

Please sign in to comment.