Skip to content

Commit

Permalink
Semantic text - Clear inference results on explicit nulls (elastic#11…
Browse files Browse the repository at this point in the history
…9463)

Fix a bug where setting a semantic_text source field explicitly to null in an update request to clear inference results did not actually clear the inference results for that field. This bug only affects the new _inference_fields format.
  • Loading branch information
Mikep86 committed Jan 3, 2025
1 parent 0ed520b commit 354d0a3
Show file tree
Hide file tree
Showing 9 changed files with 461 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
Expand Down Expand Up @@ -46,7 +47,8 @@ public Set<NodeFeature> getTestFeatures() {
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
SEMANTIC_TEXT_HIGHLIGHTER,
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -50,6 +51,7 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -67,6 +69,8 @@
*/
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
protected static final int DEFAULT_BATCH_SIZE = 512;
private static final Object EXPLICIT_NULL = new Object();
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();

private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
Expand Down Expand Up @@ -393,11 +397,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
var model = responses.get(0).model();
Model model = null;

InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
if (inferenceFieldMetadata == null) {
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
}

// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
for (var resp : responses) {
// Get the first non-null model from the response list
if (model == null) {
model = resp.model;
}

var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
lst.addAll(
SemanticTextField.toSemanticTextFieldChunks(
Expand All @@ -409,21 +424,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
)
);
}

List<String> inputs = responses.stream()
.filter(r -> r.sourceField().equals(fieldName))
.map(r -> r.input)
.collect(Collectors.toList());

// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
// merge in the field's existing model settings on the data node.
var result = new SemanticTextField(
useLegacyFormat,
fieldName,
useLegacyFormat ? inputs : null,
new SemanticTextField.InferenceResult(
model.getInferenceEntityId(),
new SemanticTextField.ModelSettings(model),
inferenceFieldMetadata.getInferenceId(),
model != null ? new SemanticTextField.ModelSettings(model) : null,
chunkMap
),
indexRequest.getContentType()
);

if (useLegacyFormat) {
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
} else {
Expand Down Expand Up @@ -490,7 +510,8 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
} else {
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + "." + field,
docMap
docMap,
EXPLICIT_NULL
);
if (inferenceMetadataFieldsValue != null) {
// Inference has already been computed
Expand All @@ -500,9 +521,22 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu

int order = 0;
for (var sourceField : entry.getSourceFields()) {
// TODO: Detect when the field is provided with an explicit null value
var valueObj = XContentMapValues.extractValue(sourceField, docMap);
if (valueObj == null) {
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
/**
* It's an update request, and the source field is explicitly set to null,
* so we need to propagate this information to the inference fields metadata
* to overwrite any inference previously computed on the field.
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
);
continue;
}
if (valueObj == null || valueObj == EXPLICIT_NULL) {
if (isUpdateRequest && useLegacyFormat) {
addInferenceResponseFailure(
item.id(),
Expand Down Expand Up @@ -552,4 +586,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
return null;
}
}

private static class EmptyChunkedInference implements ChunkedInference {
@Override
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
return Collections.emptyIterator();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.ContentPath;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
Expand All @@ -38,6 +39,8 @@
public class SemanticInferenceMetadataFieldsMapper extends InferenceMetadataFieldsMapper {
private static final SemanticInferenceMetadataFieldsMapper INSTANCE = new SemanticInferenceMetadataFieldsMapper();

public static final NodeFeature EXPLICIT_NULL_FIXES = new NodeFeature("semantic_text.inference_metadata_fields.explicit_null_fixes");

public static final TypeParser PARSER = new FixedTypeParser(
c -> InferenceMetadataFieldsMapper.isEnabled(c.getSettings()) ? INSTANCE : null
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

static {
SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD));
SEMANTIC_TEXT_FIELD_PARSER.declareObject(
constructorArg(),
(p, c) -> INFERENCE_RESULT_PARSER.parse(p, c),
new ParseField(INFERENCE_FIELD)
);
SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD));

INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD));
INFERENCE_RESULT_PARSER.declareObject(
INFERENCE_RESULT_PARSER.declareObjectOrNull(
constructorArg(),
(p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
null,
new ParseField(MODEL_SETTINGS_FIELD)
);
INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,17 @@ void parseCreateFieldFromContext(DocumentParserContext context, SemanticTextFiel
mapper = this;
}

if (mapper.fieldType().getModelSettings() == null) {
for (var chunkList : field.inference().chunks().values()) {
if (chunkList.isEmpty() == false) {
throw new DocumentParsingException(
xContentLocation,
"[" + MODEL_SETTINGS_FIELD + "] must be set for field [" + fullFieldName + "] when chunks are provided"
);
}
}
}

var chunksField = mapper.fieldType().getChunksField();
var embeddingsField = mapper.fieldType().getEmbeddingsField();
var offsetsField = mapper.fieldType().getOffsetsField();
Expand Down Expand Up @@ -895,7 +906,7 @@ private static boolean canMergeModelSettings(
if (Objects.equals(previous, current)) {
return true;
}
if (previous == null) {
if (previous == null || current == null) {
return true;
}
conflicts.addConflict("model_settings", "");
Expand Down
Loading

0 comments on commit 354d0a3

Please sign in to comment.