Skip to content

Commit

Permalink
fix map type validation issue in processors (#687)
Browse files Browse the repository at this point in the history
* fix map type validation issue in processors

Signed-off-by: zane-neo <[email protected]>

* fix test failures on main branch

Signed-off-by: zane-neo <[email protected]>

* Fix potential NPE issue in chunking processor; add changee log

Signed-off-by: zane-neo <[email protected]>

* Fix failure tests

Signed-off-by: zane-neo <[email protected]>

* Address comments and add one more UT to cover uncovered line

Signed-off-by: zane-neo <[email protected]>

* Address comments

Signed-off-by: zane-neo <[email protected]>

* Add more UTs

Signed-off-by: zane-neo <[email protected]>

* fix failure ITs

Signed-off-by: zane-neo <[email protected]>

* Add public method with default depth parameter value

Signed-off-by: zane-neo <[email protected]>

* rebase latest code

Signed-off-by: zane-neo <[email protected]>

* address comments

Signed-off-by: zane-neo <[email protected]>

* address comment

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
(cherry picked from commit 54ac672)
  • Loading branch information
zane-neo authored and github-actions[bot] committed Jun 5, 2024
1 parent 1f03a20 commit 25f57ad
Show file tree
Hide file tree
Showing 18 changed files with 648 additions and 240 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765))
### Bug Fixes
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
### Infrastructure
- Disable memory circuit breaker for integ tests ([#770](https://github.com/opensearch-project/neural-search/pull/770))
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextChunkingProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -24,8 +23,9 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand All @@ -35,6 +35,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
* The abstract class for text processing use cases. Users provide a field name map and a model id.
Expand All @@ -60,6 +61,7 @@ public abstract class InferenceProcessor extends AbstractProcessor {
protected final MLCommonsClientAccessor mlCommonsClientAccessor;

private final Environment environment;
private final ClusterService clusterService;

public InferenceProcessor(
String tag,
Expand All @@ -69,18 +71,19 @@ public InferenceProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);

this.listTypeNestedMapKey = listTypeNestedMapKey;
this.modelId = modelId;
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
Expand Down Expand Up @@ -117,12 +120,12 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(ProcessMap);
Map<String, Object> processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
} else {
doExecute(ingestDocument, ProcessMap, inferenceList, handler);
doExecute(ingestDocument, processMap, inferenceList, handler);
}
} catch (Exception e) {
handler.accept(null, e);
Expand Down Expand Up @@ -225,7 +228,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
List<String> inferenceList = null;
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
Expand Down Expand Up @@ -273,7 +276,7 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
}

@VisibleForTesting
Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
Expand Down Expand Up @@ -331,54 +334,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(

private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (sourceValue != null) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1);
} else if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
indexName,
clusterService,
environment,
false
);
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
Expand All @@ -33,9 +34,10 @@ public SparseEncodingProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import org.opensearch.env.Environment;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.IndexSettings;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.processor.chunker.Chunker;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT;
Expand Down Expand Up @@ -164,7 +164,16 @@ private int getMaxTokenCount(final Map<String, Object> sourceAndMetadataMap) {
@Override
public IngestDocument execute(final IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
validateFieldsValue(sourceAndMetadataMap);
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
indexName,
clusterService,
environment,
true
);
// fixed token length algorithm needs runtime parameter max_token_count for tokenization
Map<String, Object> runtimeParameters = new HashMap<>();
int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap);
Expand All @@ -176,59 +185,6 @@ public IngestDocument execute(final IngestDocument ingestDocument) {
return ingestDocument;
}

private void validateFieldsValue(final Map<String, Object> sourceAndMetadataMap) {
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (Objects.nonNull(sourceValue)) {
String sourceKey = embeddingFieldsEntry.getKey();
if (sourceValue instanceof List || sourceValue instanceof Map) {
validateNestedTypeValue(sourceKey, sourceValue, 1);
} else if (!(sourceValue instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", sourceKey)
);
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", sourceKey)
);
} else if (sourceValue instanceof List) {
validateListTypeValue(sourceKey, sourceValue, maxDepth);
} else if (sourceValue instanceof Map) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1));
} else if (!(sourceValue instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", sourceKey)
);
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, maxDepth + 1);
} else if (value == null) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey)
);
} else if (!(value instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", sourceKey)
);
}
}
}

@SuppressWarnings("unchecked")
private int getChunkStringCountFromMap(Map<String, Object> sourceAndMetadataMap, final Map<String, Object> fieldMap) {
int chunkStringCount = 0;
Expand Down Expand Up @@ -334,7 +290,13 @@ private List<String> chunkLeafType(final Object value, final Map<String, Object>
// leaf type means null, String or List<String>
// the result should be an empty list when the input is null
List<String> result = new ArrayList<>();
if (value == null) {
return result;
}
if (value instanceof String) {
if (StringUtils.isBlank(String.valueOf(value))) {
return result;
}
result = chunkString(value.toString(), runTimeParameters);
} else if (isListOfString(value)) {
result = chunkList((List<String>) value, runTimeParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
Expand All @@ -32,9 +33,10 @@ public TextEmbeddingProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Loading

0 comments on commit 25f57ad

Please sign in to comment.