Skip to content

Commit

Permalink
Change inference results structure
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Feb 14, 2024
1 parent bd4e19e commit 896ec49
Showing 1 changed file with 26 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.inference.ModelSettings;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -46,10 +48,10 @@ public class BulkShardRequestInferenceProvider {
public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference";

// Contains the original text for the field
public static final String TEXT_SUBFIELD_NAME = "text";

// Contains the inference result when it's a sparse vector
public static final String SPARSE_VECTOR_SUBFIELD_NAME = "sparse_embedding";
public static final String INFERENCE_RESULTS = "inference_results";
public static final String INFERENCE_CHUNKS_RESULTS = "inference";
public static final String INFERENCE_CHUNKS_TEXT = "text";

private final ClusterState clusterState;
private final Map<String, InferenceProvider> inferenceProvidersMap;
Expand Down Expand Up @@ -90,7 +92,13 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) {
var service = inferenceServiceRegistry.getService(unparsedModel.service());
if (service.isEmpty() == false) {
InferenceProvider inferenceProvider = new InferenceProvider(
service.get().parsePersistedConfig(inferenceId, unparsedModel.taskType(), unparsedModel.settings()),
service.get()
.parsePersistedConfigWithSecrets(
inferenceId,
unparsedModel.taskType(),
unparsedModel.settings(),
unparsedModel.secrets()
),
service.get()
);
inferenceProviderMap.put(inferenceId, inferenceProvider);
Expand All @@ -105,7 +113,7 @@ public void onFailure(Exception e) {
}
};

modelRegistry.getModel(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire()));
modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire()));
}
}
}
Expand Down Expand Up @@ -260,34 +268,21 @@ public void onResponse(InferenceServiceResults results) {

int i = 0;
for (InferenceResults inferenceResults : results.transformToLegacyFormat()) {
String fieldName = inferenceFieldNames.get(i++);
List<Map<String, Object>> inferenceFieldResultList;
try {
inferenceFieldResultList = (List<Map<String, Object>>) rootInferenceFieldMap.computeIfAbsent(
fieldName,
k -> new ArrayList<>()
);
} catch (ClassCastException e) {
onBulkItemFailure.apply(
bulkItemRequest,
itemIndex,
new IllegalArgumentException(
"Inference result field [" + ROOT_INFERENCE_FIELD + "." + fieldName + "] is not an object"
String inferenceFieldName = inferenceFieldNames.get(i++);
Map<String, Object> inferenceFieldResult = new LinkedHashMap<>();
inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap());
inferenceFieldResult.put(
INFERENCE_RESULTS,
List.of(
Map.of(
INFERENCE_CHUNKS_RESULTS,
inferenceResults.asMap("output").get("output"),
INFERENCE_CHUNKS_TEXT,
docMap.get(inferenceFieldName)
)
);
return;
}
// Remove previous inference results if any
inferenceFieldResultList.clear();

// TODO Check inference result type to change subfield name
var inferenceFieldMap = Map.of(
SPARSE_VECTOR_SUBFIELD_NAME,
inferenceResults.asMap("output").get("output"),
TEXT_SUBFIELD_NAME,
docMap.get(fieldName)
)
);
inferenceFieldResultList.add(inferenceFieldMap);
rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult);
}
}

Expand Down

0 comments on commit 896ec49

Please sign in to comment.