Skip to content

Commit

Permalink
Fix issues in field inference when multiple fields use inference
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Oct 27, 2023
1 parent 7a38395 commit 7e52a49
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.plugins.internal.DocumentParsingObserver;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.function.BiConsumer;
import java.util.function.IntConsumer;
Expand Down Expand Up @@ -97,4 +98,8 @@ protected IngestDocument newIngestDocument(final IndexRequest request) {
request.sourceAsMap(documentParsingObserverSupplier.get())
);
}

protected IngestDocument newIngestDocument(final IndexRequest request, Map<String, Object> sourceMap) {
return new IngestDocument(request.index(), request.id(), request.version(), request.routing(), request.versionType(), sourceMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.mapper.SemanticTextFieldMapper;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.internal.DocumentParsingObserver;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.IntConsumer;
import java.util.function.Supplier;
Expand All @@ -30,11 +35,25 @@ public class FieldInferenceBulkRequestPreprocessor extends AbstractBulkRequestPr

public static final String SEMANTIC_TEXT_ORIGIN = "semantic_text";

private final OriginSettingClient client;
private final IndicesService indicesService;

private final ClusterService clusterService;

public FieldInferenceBulkRequestPreprocessor(Supplier<DocumentParsingObserver> documentParsingObserver, Client client) {
private final OriginSettingClient client;
private final IndexNameExpressionResolver indexNameExpressionResolver;

public FieldInferenceBulkRequestPreprocessor(
Supplier<DocumentParsingObserver> documentParsingObserver,
ClusterService clusterService,
IndicesService indicesService,
IndexNameExpressionResolver indexNameExpressionResolver,
Client client
) {
super(documentParsingObserver);
this.indicesService = indicesService;
this.clusterService = clusterService;
this.client = new OriginSettingClient(client, SEMANTIC_TEXT_ORIGIN);
this.indexNameExpressionResolver = indexNameExpressionResolver;
}

protected void processIndexRequest(
Expand All @@ -46,11 +65,22 @@ protected void processIndexRequest(
) {
assert indexRequest.isFieldInferenceDone() == false;

String index = indexRequest.index();
Map<String, Object> sourceMap = indexRequest.sourceAsMap();
sourceMap.entrySet().stream().filter(entry -> fieldNeedsInference(index, entry.getKey(), entry.getValue())).forEach(entry -> {
runInferenceForField(indexRequest, entry.getKey(), refs, slot, onFailure);
});
refs.acquire();
// Inference responses can update the fields concurrently
final Map<String, Object> sourceMap = new ConcurrentHashMap<>(indexRequest.sourceAsMap());
try (var inferenceRefs = new RefCountingRunnable(() -> onInferenceComplete(refs, indexRequest, sourceMap))) {
sourceMap.entrySet()
.stream()
.filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue()))
.forEach(entry -> {
runInferenceForField(indexRequest, entry.getKey(), inferenceRefs, slot, sourceMap, onFailure);
});
}
}

private void onInferenceComplete(RefCountingRunnable refs, IndexRequest indexRequest, Map<String, Object> sourceMap) {
updateIndexRequestSource(indexRequest, newIngestDocument(indexRequest, sourceMap));
refs.close();
}

@Override
Expand All @@ -59,7 +89,7 @@ public boolean needsProcessing(DocWriteRequest<?> docWriteRequest, IndexRequest
&& indexRequest.sourceAsMap()
.entrySet()
.stream()
.anyMatch(entry -> fieldNeedsInference(indexRequest.index(), entry.getKey(), entry.getValue()));
.anyMatch(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue()));
}

@Override
Expand All @@ -72,32 +102,43 @@ public boolean shouldExecuteOnIngestNode() {
return false;
}

private boolean fieldNeedsInference(String index, String fieldName, Object fieldValue) {
// TODO actual mapping check here
return fieldName.startsWith("infer_")
// We want to perform inference when we don't have already calculated it
&& (fieldValue instanceof String);
private boolean fieldNeedsInference(IndexRequest indexRequest, String fieldName, Object fieldValue) {

if (fieldValue instanceof String == false) {
return false;
}

return getModelForField(indexRequest, fieldName) != null;
}

private String getModelForField(IndexRequest indexRequest, String fieldName) {
IndexService indexService = indicesService.indexService(
indexNameExpressionResolver.concreteSingleIndex(clusterService.state(), indexRequest)
);
return indexService.mapperService().mappingLookup().modelForField(fieldName);
}

private void runInferenceForField(
IndexRequest indexRequest,
String fieldName,
RefCountingRunnable refs,
int position,
final Map<String, Object> sourceAsMap,
BiConsumer<Integer, Exception> onFailure
) {
var ingestDocument = newIngestDocument(indexRequest);
if (ingestDocument.hasField(fieldName) == false) {
final String fieldValue = (String) sourceAsMap.get(fieldName);
if (fieldValue == null) {
return;
}

refs.acquire();
String modelForField = getModelForField(indexRequest, fieldName);
assert modelForField != null : "Field " + fieldName + " has no model associated in mappings";

// TODO Hardcoding model ID and task type
final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class);
// TODO Hardcoding task type, how to get that from model ID?
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"my-elser-model",
modelForField,
fieldValue,
Map.of()
);
Expand All @@ -114,9 +155,7 @@ public void onResponse(InferenceAction.Response response) {
SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME,
response.getResult().asMap(fieldName).get(fieldName)
);
ingestDocument.setFieldValue(fieldName, newFieldValue);

updateIndexRequestSource(indexRequest, ingestDocument);
sourceAsMap.put(fieldName, newFieldValue);
}

@Override
Expand Down
11 changes: 0 additions & 11 deletions server/src/main/java/org/elasticsearch/ingest/IngestService.java
Original file line number Diff line number Diff line change
Expand Up @@ -1038,17 +1038,6 @@ private static void updateIndexRequestMetadata(final IndexRequest request, final
}
}

/**
* Updates an index request based on the source of an ingest document, guarding against self-references if necessary.
*/
protected static void updateIndexRequestSource(final IndexRequest request, final IngestDocument document) {
boolean ensureNoSelfReferences = document.doNoSelfReferencesCheck();
// we already check for self references elsewhere (and clear the bit), so this should always be false,
// keeping the check and assert as a guard against extraordinarily surprising circumstances
assert ensureNoSelfReferences == false;
request.source(document.getSource(), request.getContentType(), ensureNoSelfReferences);
}

/**
* Grab the @timestamp and store it on the index request so that TSDB can use it without needing to parse
* the source for this document.
Expand Down

0 comments on commit 7e52a49

Please sign in to comment.