Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Optimize parameter parsing in text chunking processor #754

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DISABLED_MAX_CHUNK_LIMIT;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.CHUNK_STRING_COUNT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerWithDefault;

/**
* This processor is used for text chunking.
Expand Down Expand Up @@ -115,8 +116,8 @@ private void parseAlgorithmMap(final Map<String, Object> algorithmMap) {
}
Map<String, Object> chunkerParameters = (Map<String, Object>) algorithmValue;
// parse processor level max chunk limit
this.maxChunkLimit = parseIntegerParameter(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
if (maxChunkLimit < 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
this.maxChunkLimit = parseIntegerWithDefault(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
if (maxChunkLimit <= 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
Expand Down Expand Up @@ -309,10 +310,10 @@ private List<String> chunkString(final String content, final Map<String, Object>
}
List<String> contentResult = chunker.chunk(content, runTimeParameters);
// update chunk_string_count for each string
int chunkStringCount = parseIntegerParameter(runTimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int chunkStringCount = parseInteger(runTimeParameters, CHUNK_STRING_COUNT_FIELD);
runTimeParameters.put(CHUNK_STRING_COUNT_FIELD, chunkStringCount - 1);
// update runtime max_chunk_limit if not disabled
int runtimeMaxChunkLimit = parseIntegerParameter(runTimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit);
int runtimeMaxChunkLimit = parseInteger(runTimeParameters, MAX_CHUNK_LIMIT_FIELD);
if (runtimeMaxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
runTimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit - contentResult.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ private ChunkerParameterParser() {} // no instance of this util class
* Parse String type parameter.
* Throw IllegalArgumentException if parameter is not a string or an empty string.
*/
public static String parseStringParameter(final Map<String, Object> parameters, final String fieldName, final String defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all string parameters are optional
return defaultValue;
}
public static String parseString(final Map<String, Object> parameters, final String fieldName) {
Object fieldValue = parameters.get(fieldName);
if (!(fieldValue instanceof String)) {
throw new IllegalArgumentException(
Expand All @@ -40,14 +36,23 @@ public static String parseStringParameter(final Map<String, Object> parameters,
}

/**
* Parse integer type parameter.
* Throw IllegalArgumentException if parameter is not an integer.
* Parse String type parameter.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if parameter is not a string or an empty string.
*/
public static int parseIntegerParameter(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
public static String parseStringWithDefault(final Map<String, Object> parameters, final String fieldName, final String defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all integer parameters are optional
// all string parameters are optional
return defaultValue;
}
return parseString(parameters, fieldName);
}

/**
* Parse integer type parameter with default value.
* Throw IllegalArgumentException if the parameter is not an integer.
*/
public static int parseInteger(final Map<String, Object> parameters, final String fieldName) {
String fieldValueString = parameters.get(fieldName).toString();
try {
return NumberUtils.createInteger(fieldValueString);
Expand All @@ -58,27 +63,54 @@ public static int parseIntegerParameter(final Map<String, Object> parameters, fi
}
}

/**
* Parse integer type parameter with default value.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not an integer.
*/
public static int parseIntegerWithDefault(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
if (!parameters.containsKey(fieldName)) {
// return the default value when parameter is missing
return defaultValue;
}
return parseInteger(parameters, fieldName);
}

/**
* Parse integer type parameter with positive value.
* Throw IllegalArgumentException if parameter is not a positive integer.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not a positive integer.
*/
public static int parsePositiveIntegerParameter(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
int fieldValueInt = parseIntegerParameter(parameters, fieldName, defaultValue);
public static int parsePositiveInteger(final Map<String, Object> parameters, final String fieldName) {
int fieldValueInt = parseInteger(parameters, fieldName);
if (fieldValueInt <= 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName));
}
return fieldValueInt;
}

/**
* Parse double type parameter.
* Throw IllegalArgumentException if parameter is not a double.
* Parse integer type parameter with positive value.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not a positive integer.
*/
public static double parseDoubleParameter(final Map<String, Object> parameters, final String fieldName, final double defaultValue) {
public static int parsePositiveIntegerWithDefault(
final Map<String, Object> parameters,
final String fieldName,
final Integer defaultValue
) {
if (!parameters.containsKey(fieldName)) {
// all double parameters are optional
return defaultValue;
}
return parsePositiveInteger(parameters, fieldName);
}

/**
* Parse double type parameter.
* Throw IllegalArgumentException if parameter is not a double.
*/
public static double parseDouble(final Map<String, Object> parameters, final String fieldName) {
String fieldValueString = parameters.get(fieldName).toString();
try {
return NumberUtils.createDouble(fieldValueString);
Expand All @@ -88,4 +120,17 @@ public static double parseDoubleParameter(final Map<String, Object> parameters,
);
}
}

/**
* Parse double type parameter.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if parameter is not a double.
*/
public static double parseDoubleWithDefault(final Map<String, Object> parameters, final String fieldName, final double defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all double parameters are optional
return defaultValue;
}
return parseDouble(parameters, fieldName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import java.util.List;
import java.util.ArrayList;

import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault;

/**
* The implementation {@link Chunker} for delimiter algorithm
Expand All @@ -23,7 +23,6 @@ public final class DelimiterChunker implements Chunker {
public static final String DEFAULT_DELIMITER = "\n\n";

private String delimiter;
private int maxChunkLimit;

public DelimiterChunker(final Map<String, Object> parameters) {
parseParameters(parameters);
Expand All @@ -39,8 +38,7 @@ public DelimiterChunker(final Map<String, Object> parameters) {
*/
@Override
public void parseParameters(Map<String, Object> parameters) {
this.delimiter = parseStringParameter(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER);
this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
this.delimiter = parseStringWithDefault(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER);
}

/**
Expand All @@ -53,8 +51,8 @@ public void parseParameters(Map<String, Object> parameters) {
*/
@Override
public List<String> chunk(final String content, final Map<String, Object> runtimeParameters) {
int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit);
int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD);
int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD);

List<String> chunkResult = new ArrayList<>();
int start = 0, end;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import org.opensearch.action.admin.indices.analyze.AnalyzeAction;
import org.opensearch.action.admin.indices.analyze.AnalyzeAction.AnalyzeToken;
import static org.opensearch.action.admin.indices.analyze.TransportAnalyzeAction.analyze;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleWithDefault;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerWithDefault;

/**
* The implementation {@link Chunker} for fixed token length algorithm.
Expand All @@ -33,10 +33,9 @@ public final class FixedTokenLengthChunker implements Chunker {
public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count";
public static final String TOKENIZER_FIELD = "tokenizer";

// default values for each parameter
// default values for each non-runtime parameter
private static final int DEFAULT_TOKEN_LIMIT = 384;
private static final double DEFAULT_OVERLAP_RATE = 0.0;
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final String DEFAULT_TOKENIZER = "standard";

// parameter restrictions
Expand All @@ -54,7 +53,6 @@ public final class FixedTokenLengthChunker implements Chunker {

// parameter value
private int tokenLimit;
private int maxChunkLimit;
private String tokenizer;
private double overlapRate;
private final AnalysisRegistry analysisRegistry;
Expand All @@ -81,10 +79,9 @@ public FixedTokenLengthChunker(final Map<String, Object> parameters) {
*/
@Override
public void parseParameters(Map<String, Object> parameters) {
this.tokenLimit = parsePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT);
this.overlapRate = parseDoubleParameter(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE);
this.tokenizer = parseStringParameter(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER);
this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
this.tokenLimit = parsePositiveIntegerWithDefault(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT);
this.overlapRate = parseDoubleWithDefault(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE);
this.tokenizer = parseStringWithDefault(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER);
if (overlapRate < OVERLAP_RATE_LOWER_BOUND || overlapRate > OVERLAP_RATE_UPPER_BOUND) {
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -121,9 +118,9 @@ public void parseParameters(Map<String, Object> parameters) {
*/
@Override
public List<String> chunk(final String content, final Map<String, Object> runtimeParameters) {
int maxTokenCount = parsePositiveIntegerParameter(runtimeParameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT);
int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, this.maxChunkLimit);
int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int maxTokenCount = parseInteger(runtimeParameters, MAX_TOKEN_COUNT_FIELD);
int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD);
int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD);

List<AnalyzeToken> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> chunkResult = new ArrayList<>();
Expand Down
Loading
Loading