Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix map type validation issue in processors #687

Merged
merged 12 commits into from
Jun 5, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
### Bug Fixes
- Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextChunkingProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -24,8 +23,9 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand All @@ -35,6 +35,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
* The abstract class for text processing use cases. Users provide a field name map and a model id.
Expand All @@ -60,6 +61,7 @@ public abstract class InferenceProcessor extends AbstractProcessor {
protected final MLCommonsClientAccessor mlCommonsClientAccessor;

private final Environment environment;
private final ClusterService clusterService;

public InferenceProcessor(
String tag,
Expand All @@ -69,18 +71,19 @@ public InferenceProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description);
this.type = type;
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
validateEmbeddingConfiguration(fieldMap);

this.listTypeNestedMapKey = listTypeNestedMapKey;
this.modelId = modelId;
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
}

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
Expand Down Expand Up @@ -117,12 +120,12 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(ProcessMap);
Map<String, Object> processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍, I also wanted to change it

List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
} else {
doExecute(ingestDocument, ProcessMap, inferenceList, handler);
doExecute(ingestDocument, processMap, inferenceList, handler);
}
} catch (Exception e) {
handler.accept(null, e);
Expand Down Expand Up @@ -225,7 +228,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
List<String> inferenceList = null;
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
Expand Down Expand Up @@ -273,7 +276,7 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
}

@VisibleForTesting
Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
Expand Down Expand Up @@ -331,54 +334,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(

private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (sourceValue != null) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it");
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it");
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1);
} else if (value == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(value instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (StringUtils.isBlank(value.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
indexName,
clusterService,
environment,
false
);
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
Expand All @@ -33,9 +34,10 @@ public SparseEncodingProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import org.opensearch.env.Environment;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.IndexSettings;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.processor.chunker.Chunker;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT;
Expand Down Expand Up @@ -164,7 +164,16 @@ private int getMaxTokenCount(final Map<String, Object> sourceAndMetadataMap) {
@Override
public IngestDocument execute(final IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
validateFieldsValue(sourceAndMetadataMap);
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
indexName,
clusterService,
environment,
true
);
// fixed token length algorithm needs runtime parameter max_token_count for tokenization
Map<String, Object> runtimeParameters = new HashMap<>();
int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap);
Expand All @@ -176,59 +185,6 @@ public IngestDocument execute(final IngestDocument ingestDocument) {
return ingestDocument;
}

private void validateFieldsValue(final Map<String, Object> sourceAndMetadataMap) {
for (Map.Entry<String, Object> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
if (Objects.nonNull(sourceValue)) {
String sourceKey = embeddingFieldsEntry.getKey();
if (sourceValue instanceof List || sourceValue instanceof Map) {
validateNestedTypeValue(sourceKey, sourceValue, 1);
} else if (!(sourceValue instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", sourceKey)
);
}
}
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", sourceKey)
);
} else if (sourceValue instanceof List) {
validateListTypeValue(sourceKey, sourceValue, maxDepth);
} else if (sourceValue instanceof Map) {
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1));
} else if (!(sourceValue instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", sourceKey)
);
}
}

@SuppressWarnings({ "rawtypes" })
private void validateListTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) {
for (Object value : (List) sourceValue) {
if (value instanceof Map) {
validateNestedTypeValue(sourceKey, value, maxDepth + 1);
} else if (value == null) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey)
);
} else if (!(value instanceof String)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", sourceKey)
);
}
}
}

@SuppressWarnings("unchecked")
private int getChunkStringCountFromMap(Map<String, Object> sourceAndMetadataMap, final Map<String, Object> fieldMap) {
int chunkStringCount = 0;
Expand Down Expand Up @@ -334,7 +290,13 @@ private List<String> chunkLeafType(final Object value, final Map<String, Object>
// leaf type means null, String or List<String>
// the result should be an empty list when the input is null
List<String> result = new ArrayList<>();
if (value == null) {
return result;
}
if (value instanceof String) {
if (StringUtils.isBlank(String.valueOf(value))) {
return result;
}
result = chunkString(value.toString(), runTimeParameters);
} else if (isListOfString(value)) {
result = chunkList((List<String>) value, runTimeParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
Expand All @@ -32,9 +33,10 @@ public TextEmbeddingProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment
Environment environment,
ClusterService clusterService
) {
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment);
super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
}

@Override
Expand Down
Loading
Loading