Skip to content

Commit

Permalink
remove processor validator
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 6, 2024
1 parent 4d69364 commit e3b3ff4
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ProcessorInputValidator;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.DocumentChunkingProcessor;
Expand Down Expand Up @@ -110,22 +109,19 @@ public List<QuerySpec<?>> getQueries() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
ProcessorInputValidator processorInputValidator = new ProcessorInputValidator();
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, processorInputValidator),
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, processorInputValidator),
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
TextImageEmbeddingProcessor.TYPE,
new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
DocumentChunkingProcessor.TYPE,
new DocumentChunkingProcessor.Factory(
parameters.env.settings(),
parameters.ingestService.getClusterService(),
parameters.indicesService,
parameters.analysisRegistry,
parameters.env,
processorInputValidator
parameters.analysisRegistry
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.LinkedHashMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.function.BiConsumer;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.index.IndexService;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.indices.IndicesService;
import org.opensearch.index.IndexSettings;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory;
Expand All @@ -31,7 +34,8 @@
import static org.opensearch.neuralsearch.processor.chunker.ChunkerFactory.DELIMITER_ALGORITHM;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerFactory.FIXED_LENGTH_ALGORITHM;

public final class DocumentChunkingProcessor extends InferenceProcessor {
@Log4j2
public final class DocumentChunkingProcessor extends AbstractProcessor {

public static final String TYPE = "chunking";

Expand All @@ -47,6 +51,8 @@ public final class DocumentChunkingProcessor extends InferenceProcessor {

private Map<String, Object> chunkerParameters;

private final Map<String, Object> fieldMap;

private final ClusterService clusterService;

private final IndicesService indicesService;
Expand All @@ -61,12 +67,11 @@ public DocumentChunkingProcessor(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
AnalysisRegistry analysisRegistry,
Environment environment,
ProcessorInputValidator processorInputValidator
AnalysisRegistry analysisRegistry
) {
super(tag, description, TYPE, null, null, fieldMap, null, environment, processorInputValidator);
super(tag, description);
validateAndParseAlgorithmMap(algorithmMap);
this.fieldMap = fieldMap;
this.settings = settings;
this.clusterService = clusterService;
this.indicesService = indicesService;
Expand Down Expand Up @@ -116,7 +121,40 @@ private void validateAndParseAlgorithmMap(Map<String, Object> algorithmMap) {
}

@Override
protected List<?> buildResultForListType(List<Object> sourceValue, List<?> results, IndexWrapper indexWrapper) {
public IngestDocument execute(IngestDocument ingestDocument) {
Map<String, Object> processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.isEmpty()) {
return ingestDocument;
} else {
return doExecute(ingestDocument, processMap, inferenceList);
}
}

public IngestDocument doExecute(IngestDocument ingestDocument, Map<String, Object> ProcessMap, List<String> inferenceList) {
if (Objects.equals(chunkerType, FIXED_LENGTH_ALGORITHM)) {
// add maxTokenCount setting from index metadata to chunker parameters
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(settings);
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
if (indexMetadata != null) {
// if the index exists, read maxTokenCount from the index setting
IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex());
maxTokenCount = indexService.getIndexSettings().getMaxTokenCount();
}
chunkerParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount);
}

List<List<String>> chunkedResults = new ArrayList<>();
for (String inferenceString : inferenceList) {
chunkedResults.add(chunk(inferenceString));
}
setTargetFieldsToDocument(ingestDocument, ProcessMap, chunkedResults);
return ingestDocument;
}

private List<?> buildResultForListType(List<Object> sourceValue, List<?> results, InferenceProcessor.IndexWrapper indexWrapper) {
Object peek = sourceValue.get(0);
if (peek instanceof String) {
List<Object> keyToResult = new ArrayList<>();
Expand All @@ -133,37 +171,151 @@ protected List<?> buildResultForListType(List<Object> sourceValue, List<?> resul
}
}

@Override
public void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
private Map<String, Object> buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
String originalKey = fieldMapEntry.getKey();
Object targetKey = fieldMapEntry.getValue();
if (targetKey instanceof Map) {
Map<String, Object> treeRes = new LinkedHashMap<>();
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
}
}
return mapWithProcessorKeys;
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
try {
processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, true);
if (Objects.equals(chunkerType, FIXED_LENGTH_ALGORITHM)) {
// add maxTokenCount setting from index metadata to chunker parameters
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(settings);
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
if (indexMetadata != null) {
// if the index exists, read maxTokenCount from the index setting
IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex());
maxTokenCount = indexService.getIndexSettings().getMaxTokenCount();
if (processorKey == null || sourceAndMetadataMap == null) return;
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
if (sourceAndMetadataMap.get(parentKey) instanceof Map) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
}
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
map,
next
);
}
chunkerParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount);
}
treeRes.put(parentKey, next);
} else {
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
}
}

List<List<String>> chunkedResults = new ArrayList<>();
for (String inferenceString : inferenceList) {
chunkedResults.add(chunk(inferenceString));
@SuppressWarnings({ "unchecked" })
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
List<String> texts = new ArrayList<>();
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
Object sourceValue = knnMapEntry.getValue();
if (sourceValue instanceof List) {
for (Object nestedValue : (List<Object>) sourceValue) {
if (nestedValue instanceof String) {
texts.add((String) nestedValue);
} else {
texts.addAll((List<String>) nestedValue);
}
}
} else if (sourceValue instanceof Map) {
createInferenceListForMapTypeInput(sourceValue, texts);
} else {
texts.add(sourceValue.toString());
}
setTargetFieldsToDocument(ingestDocument, ProcessMap, chunkedResults);
handler.accept(ingestDocument, null);
} catch (Exception e) {
handler.accept(null, e);
});
return texts;
}

@SuppressWarnings("unchecked")
private void createInferenceListForMapTypeInput(Object sourceValue, List<String> texts) {
if (sourceValue instanceof Map) {
((Map<String, Object>) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts));
} else if (sourceValue instanceof List) {
texts.addAll(((List<String>) sourceValue));
} else {
if (sourceValue == null) return;
texts.add(sourceValue.toString());
}
}

private void setTargetFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
Objects.requireNonNull(results, "embedding failed, inference returns null result!");
log.debug("Model inference result fetched, starting build vector output!");
Map<String, Object> result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata());
result.forEach(ingestDocument::setFieldValue);
}

@VisibleForTesting
Map<String, Object> buildResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
Object sourceValue = knnMapEntry.getValue();
if (sourceValue instanceof String) {
result.put(knnKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
result.put(knnKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
} else if (sourceValue instanceof Map) {
putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
}
}
return result;
}

@SuppressWarnings({ "unchecked" })
private void putResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List<?> results,
InferenceProcessor.IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
if (sourceValue instanceof Map) {
for (Map.Entry<String, Object> inputNestedMapEntry : ((Map<String, Object>) sourceValue).entrySet()) {
if (sourceAndMetadataMap.get(processorKey) instanceof List) {
// build output for list of nested objects
for (Map<String, Object> nestedElement : (List<Map<String, Object>>) sourceAndMetadataMap.get(processorKey)) {
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
} else {
putResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
);
}
}
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildResultForListType((List<Object>) sourceValue, results, indexWrapper));
}
}

Expand All @@ -175,26 +327,13 @@ public static class Factory implements Processor.Factory {

private final IndicesService indicesService;

private final Environment environment;

private final AnalysisRegistry analysisRegistry;

private final ProcessorInputValidator processorInputValidator;

public Factory(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
AnalysisRegistry analysisRegistry,
Environment environment,
ProcessorInputValidator processorInputValidator
) {
public Factory(Settings settings, ClusterService clusterService, IndicesService indicesService, AnalysisRegistry analysisRegistry) {
this.settings = settings;
this.clusterService = clusterService;
this.indicesService = indicesService;
this.analysisRegistry = analysisRegistry;
this.environment = environment;
this.processorInputValidator = processorInputValidator;
}

@Override
Expand All @@ -214,9 +353,7 @@ public DocumentChunkingProcessor create(
settings,
clusterService,
indicesService,
analysisRegistry,
environment,
processorInputValidator
analysisRegistry
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ public abstract class InferenceProcessor extends AbstractProcessor {

protected final Environment environment;

protected final ProcessorInputValidator processorInputValidator;

public InferenceProcessor(
String tag,
String description,
Expand All @@ -59,19 +57,17 @@ public InferenceProcessor(
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ProcessorInputValidator processorInputValidator
Environment environment
) {
super(tag, description);
this.type = type;
validateEmbeddingConfiguration(fieldMap);

if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it");
this.listTypeNestedMapKey = listTypeNestedMapKey;
this.modelId = modelId;
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
this.processorInputValidator = processorInputValidator;
}

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
Expand Down
Loading

0 comments on commit e3b3ff4

Please sign in to comment.