diff --git a/CHANGELOG.md b/CHANGELOG.md index fb02a26d9..9097dd26c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731)) +- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java index c59478de6..4338139d9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java @@ -30,7 +30,8 @@ import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT; import static org.opensearch.neuralsearch.processor.chunker.Chunker.DISABLED_MAX_CHUNK_LIMIT; import static org.opensearch.neuralsearch.processor.chunker.Chunker.CHUNK_STRING_COUNT_FIELD; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerWithDefault; /** * This processor is used for text chunking. @@ -115,8 +116,8 @@ private void parseAlgorithmMap(final Map algorithmMap) { } Map chunkerParameters = (Map) algorithmValue; // parse processor level max chunk limit - this.maxChunkLimit = parseIntegerParameter(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); - if (maxChunkLimit < 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) { + this.maxChunkLimit = parseIntegerWithDefault(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + if (maxChunkLimit <= 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) { throw new IllegalArgumentException( String.format( Locale.ROOT, @@ -309,10 +310,10 @@ private List chunkString(final String content, final Map } List contentResult = chunker.chunk(content, runTimeParameters); // update chunk_string_count for each string - int chunkStringCount = parseIntegerParameter(runTimeParameters, CHUNK_STRING_COUNT_FIELD, 1); + int chunkStringCount = parseInteger(runTimeParameters, CHUNK_STRING_COUNT_FIELD); runTimeParameters.put(CHUNK_STRING_COUNT_FIELD, chunkStringCount - 1); // update runtime max_chunk_limit if not disabled - int runtimeMaxChunkLimit = parseIntegerParameter(runTimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + int runtimeMaxChunkLimit = parseInteger(runTimeParameters, MAX_CHUNK_LIMIT_FIELD); if (runtimeMaxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) { runTimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit - contentResult.size()); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java index 56a61a26f..52d8eef00 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java @@ -22,11 +22,7 @@ private ChunkerParameterParser() {} // no instance of this util class * Parse String type parameter. * Throw IllegalArgumentException if parameter is not a string or an empty string. */ - public static String parseStringParameter(final Map parameters, final String fieldName, final String defaultValue) { - if (!parameters.containsKey(fieldName)) { - // all string parameters are optional - return defaultValue; - } + public static String parseString(final Map parameters, final String fieldName) { Object fieldValue = parameters.get(fieldName); if (!(fieldValue instanceof String)) { throw new IllegalArgumentException( @@ -40,14 +36,23 @@ public static String parseStringParameter(final Map parameters, } /** - * Parse integer type parameter. - * Throw IllegalArgumentException if parameter is not an integer. + * Parse String type parameter. + * Return default value if the parameter is missing. + * Throw IllegalArgumentException if parameter is not a string or an empty string. */ - public static int parseIntegerParameter(final Map parameters, final String fieldName, final int defaultValue) { + public static String parseStringWithDefault(final Map parameters, final String fieldName, final String defaultValue) { if (!parameters.containsKey(fieldName)) { - // all integer parameters are optional + // all string parameters are optional return defaultValue; } + return parseString(parameters, fieldName); + } + + /** + * Parse integer type parameter with default value. + * Throw IllegalArgumentException if the parameter is not an integer. + */ + public static int parseInteger(final Map parameters, final String fieldName) { String fieldValueString = parameters.get(fieldName).toString(); try { return NumberUtils.createInteger(fieldValueString); @@ -58,12 +63,26 @@ public static int parseIntegerParameter(final Map parameters, fi } } + /** + * Parse integer type parameter with default value. + * Return default value if the parameter is missing. + * Throw IllegalArgumentException if the parameter is not an integer. + */ + public static int parseIntegerWithDefault(final Map parameters, final String fieldName, final int defaultValue) { + if (!parameters.containsKey(fieldName)) { + // return the default value when parameter is missing + return defaultValue; + } + return parseInteger(parameters, fieldName); + } + /** * Parse integer type parameter with positive value. - * Throw IllegalArgumentException if parameter is not a positive integer. + * Return default value if the parameter is missing. + * Throw IllegalArgumentException if the parameter is not a positive integer. */ - public static int parsePositiveIntegerParameter(final Map parameters, final String fieldName, final int defaultValue) { - int fieldValueInt = parseIntegerParameter(parameters, fieldName, defaultValue); + public static int parsePositiveInteger(final Map parameters, final String fieldName) { + int fieldValueInt = parseInteger(parameters, fieldName); if (fieldValueInt <= 0) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName)); } @@ -71,14 +90,27 @@ public static int parsePositiveIntegerParameter(final Map parame } /** - * Parse double type parameter. - * Throw IllegalArgumentException if parameter is not a double. + * Parse integer type parameter with positive value. + * Return default value if the parameter is missing. + * Throw IllegalArgumentException if the parameter is not a positive integer. */ - public static double parseDoubleParameter(final Map parameters, final String fieldName, final double defaultValue) { + public static int parsePositiveIntegerWithDefault( + final Map parameters, + final String fieldName, + final Integer defaultValue + ) { if (!parameters.containsKey(fieldName)) { // all double parameters are optional return defaultValue; } + return parsePositiveInteger(parameters, fieldName); + } + + /** + * Parse double type parameter. + * Throw IllegalArgumentException if parameter is not a double. + */ + public static double parseDouble(final Map parameters, final String fieldName) { String fieldValueString = parameters.get(fieldName).toString(); try { return NumberUtils.createDouble(fieldValueString); @@ -88,4 +120,17 @@ public static double parseDoubleParameter(final Map parameters, ); } } + + /** + * Parse double type parameter. + * Return default value if the parameter is missing. + * Throw IllegalArgumentException if parameter is not a double. + */ + public static double parseDoubleWithDefault(final Map parameters, final String fieldName, final double defaultValue) { + if (!parameters.containsKey(fieldName)) { + // all double parameters are optional + return defaultValue; + } + return parseDouble(parameters, fieldName); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java index fe2418ee8..0f3d66c55 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java @@ -8,8 +8,8 @@ import java.util.List; import java.util.ArrayList; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault; /** * The implementation {@link Chunker} for delimiter algorithm @@ -23,7 +23,6 @@ public final class DelimiterChunker implements Chunker { public static final String DEFAULT_DELIMITER = "\n\n"; private String delimiter; - private int maxChunkLimit; public DelimiterChunker(final Map parameters) { parseParameters(parameters); @@ -39,8 +38,7 @@ public DelimiterChunker(final Map parameters) { */ @Override public void parseParameters(Map parameters) { - this.delimiter = parseStringParameter(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER); - this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + this.delimiter = parseStringWithDefault(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER); } /** @@ -53,8 +51,8 @@ public void parseParameters(Map parameters) { */ @Override public List chunk(final String content, final Map runtimeParameters) { - int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); - int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1); + int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD); + int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD); List chunkResult = new ArrayList<>(); int start = 0, end; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java index 276e41ac7..614ea33f9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java @@ -14,10 +14,10 @@ import org.opensearch.action.admin.indices.analyze.AnalyzeAction; import org.opensearch.action.admin.indices.analyze.AnalyzeAction.AnalyzeToken; import static org.opensearch.action.admin.indices.analyze.TransportAnalyzeAction.analyze; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleParameter; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; -import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleWithDefault; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerWithDefault; /** * The implementation {@link Chunker} for fixed token length algorithm. @@ -33,10 +33,9 @@ public final class FixedTokenLengthChunker implements Chunker { public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count"; public static final String TOKENIZER_FIELD = "tokenizer"; - // default values for each parameter + // default values for each non-runtime parameter private static final int DEFAULT_TOKEN_LIMIT = 384; private static final double DEFAULT_OVERLAP_RATE = 0.0; - private static final int DEFAULT_MAX_TOKEN_COUNT = 10000; private static final String DEFAULT_TOKENIZER = "standard"; // parameter restrictions @@ -54,7 +53,6 @@ public final class FixedTokenLengthChunker implements Chunker { // parameter value private int tokenLimit; - private int maxChunkLimit; private String tokenizer; private double overlapRate; private final AnalysisRegistry analysisRegistry; @@ -81,10 +79,9 @@ public FixedTokenLengthChunker(final Map parameters) { */ @Override public void parseParameters(Map parameters) { - this.tokenLimit = parsePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT); - this.overlapRate = parseDoubleParameter(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE); - this.tokenizer = parseStringParameter(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER); - this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + this.tokenLimit = parsePositiveIntegerWithDefault(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT); + this.overlapRate = parseDoubleWithDefault(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE); + this.tokenizer = parseStringWithDefault(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER); if (overlapRate < OVERLAP_RATE_LOWER_BOUND || overlapRate > OVERLAP_RATE_UPPER_BOUND) { throw new IllegalArgumentException( String.format( @@ -121,9 +118,9 @@ public void parseParameters(Map parameters) { */ @Override public List chunk(final String content, final Map runtimeParameters) { - int maxTokenCount = parsePositiveIntegerParameter(runtimeParameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT); - int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, this.maxChunkLimit); - int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1); + int maxTokenCount = parseInteger(runtimeParameters, MAX_TOKEN_COUNT_FIELD); + int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD); + int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD); List tokens = tokenize(content, tokenizer, maxTokenCount); List chunkResult = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParserTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParserTests.java new file mode 100644 index 000000000..8453b1f6d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParserTests.java @@ -0,0 +1,300 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import org.junit.Assert; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Locale; +import java.util.Map; + +public class ChunkerParameterParserTests extends OpenSearchTestCase { + + private static final String fieldName = "parameter"; + private static final String defaultString = "default_string"; + private static final Integer defaultInteger = 0; + private static final Integer defaultPositiveInteger = 100; + private static final Double defaultDouble = 0.0; + + public void testParseString_withFieldValueNotString_thenFail() { + Map parameters = Map.of(fieldName, 1); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseString(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, String.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseString_withFieldValueEmptyString_thenFail() { + Map parameters = Map.of(fieldName, ""); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseString(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] should not be empty.", fieldName), + illegalArgumentException.getMessage() + ); + } + + public void testParseString_withFieldValueValidString_thenSucceed() { + String parameterValue = "string_parameter_value"; + Map parameters = Map.of(fieldName, parameterValue); + String parsedStringValue = ChunkerParameterParser.parseString(parameters, fieldName); + assertEquals(parameterValue, parsedStringValue); + } + + public void testParseStringWithDefault_withFieldValueNotString_thenFail() { + Map parameters = Map.of(fieldName, 1); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseStringWithDefault(parameters, fieldName, defaultString) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, String.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseStringWithDefault_withFieldValueEmptyString_thenFail() { + Map parameters = Map.of(fieldName, ""); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseStringWithDefault(parameters, fieldName, defaultString) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] should not be empty.", fieldName), + illegalArgumentException.getMessage() + ); + } + + public void testParseStringWithDefault_withFieldValueValidString_thenSucceed() { + String parameterValue = "string_parameter_value"; + Map parameters = Map.of(fieldName, parameterValue); + String parsedStringValue = ChunkerParameterParser.parseStringWithDefault(parameters, fieldName, defaultString); + assertEquals(parameterValue, parsedStringValue); + } + + public void testParseStringWithDefault_withFieldValueMissing_thenSucceed() { + Map parameters = Map.of(); + String parsedStringValue = ChunkerParameterParser.parseStringWithDefault(parameters, fieldName, defaultString); + assertEquals(defaultString, parsedStringValue); + } + + public void testParseInteger_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseInteger(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseInteger_withFieldValueDouble_thenFail() { + Map parameters = Map.of(fieldName, "1.0"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseInteger(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseInteger_withFieldValueValidInteger_thenSucceed() { + String parameterValue = "1"; + Integer expectedIntegerValue = 1; + Map parameters = Map.of(fieldName, parameterValue); + Integer parsedIntegerValue = ChunkerParameterParser.parseInteger(parameters, fieldName); + assertEquals(expectedIntegerValue, parsedIntegerValue); + } + + public void testParseIntegerWithDefault_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseIntegerWithDefault(parameters, fieldName, defaultInteger) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseIntegerWithDefault_withFieldValueDouble_thenFail() { + Map parameters = Map.of(fieldName, "1.0"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseIntegerWithDefault(parameters, fieldName, defaultInteger) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseIntegerWithDefault_withFieldValueValidInteger_thenSucceed() { + String parameterValue = "1"; + Integer expectedIntegerValue = 1; + Map parameters = Map.of(fieldName, parameterValue); + Integer parsedIntegerValue = ChunkerParameterParser.parseIntegerWithDefault(parameters, fieldName, defaultInteger); + assertEquals(expectedIntegerValue, parsedIntegerValue); + } + + public void testParseIntegerWithDefault_withFieldValueMissing_thenSucceed() { + Map parameters = Map.of(); + Integer parsedIntegerValue = ChunkerParameterParser.parseIntegerWithDefault(parameters, fieldName, defaultInteger); + assertEquals(defaultInteger, parsedIntegerValue); + } + + public void testParsePositiveInteger_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveInteger(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveInteger_withFieldValueDouble_thenFail() { + Map parameters = Map.of(fieldName, "1.0"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveInteger(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveInteger_withFieldValueNegativeInteger_thenFail() { + String parameterValue = "-1"; + Map parameters = Map.of(fieldName, parameterValue); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveInteger(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveInteger_withFieldValueValidInteger_thenSucceed() { + String parameterValue = "1"; + Integer expectedIntegerValue = 1; + Map parameters = Map.of(fieldName, parameterValue); + Integer parsedIntegerValue = ChunkerParameterParser.parsePositiveInteger(parameters, fieldName); + assertEquals(expectedIntegerValue, parsedIntegerValue); + } + + public void testParsePositiveIntegerWithDefault_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveIntegerWithDefault(parameters, fieldName, defaultPositiveInteger) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveIntegerWithDefault_withFieldValueDouble_thenFail() { + Map parameters = Map.of(fieldName, "1.0"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveIntegerWithDefault(parameters, fieldName, defaultPositiveInteger) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveIntegerWithDefault_withFieldValueNegativeInteger_thenFail() { + String parameterValue = "-1"; + Map parameters = Map.of(fieldName, parameterValue); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parsePositiveIntegerWithDefault(parameters, fieldName, defaultPositiveInteger) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName), + illegalArgumentException.getMessage() + ); + } + + public void testParsePositiveIntegerWithDefault_withFieldValueValidInteger_thenSucceed() { + String parameterValue = "1"; + Integer expectedIntegerValue = 1; + Map parameters = Map.of(fieldName, parameterValue); + Integer parsedIntegerValue = ChunkerParameterParser.parsePositiveIntegerWithDefault(parameters, fieldName, defaultPositiveInteger); + assertEquals(expectedIntegerValue, parsedIntegerValue); + } + + public void testParsePositiveIntegerWithDefault_withFieldValueMissing_thenSucceed() { + Map parameters = Map.of(); + Integer parsedIntegerValue = ChunkerParameterParser.parsePositiveIntegerWithDefault(parameters, fieldName, defaultPositiveInteger); + assertEquals(defaultPositiveInteger, parsedIntegerValue); + } + + public void testParseDouble_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseDouble(parameters, fieldName) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Double.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseDouble_withFieldValueValidDouble_thenSucceed() { + String parameterValue = "1"; + Double expectedDoubleValue = 1.0; + Map parameters = Map.of(fieldName, parameterValue); + Double parsedDoubleValue = ChunkerParameterParser.parseDouble(parameters, fieldName); + assertEquals(expectedDoubleValue, parsedDoubleValue); + } + + public void testParseDoubleWithDefault_withFieldValueString_thenFail() { + Map parameters = Map.of(fieldName, "a"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> ChunkerParameterParser.parseDoubleWithDefault(parameters, fieldName, defaultDouble) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Double.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseDoubleWithDefault_withFieldValueValidDouble_thenSucceed() { + String parameterValue = "1"; + Double expectedDoubleValue = 1.0; + Map parameters = Map.of(fieldName, parameterValue); + Double parsedDoubleValue = ChunkerParameterParser.parseDoubleWithDefault(parameters, fieldName, defaultDouble); + assertEquals(expectedDoubleValue, parsedDoubleValue); + } + + public void testParseDoubleWithDefault_withFieldValueMissing_thenSucceed() { + Map parameters = Map.of(); + Double parsedDoubleValue = ChunkerParameterParser.parseDoubleWithDefault(parameters, fieldName, defaultDouble); + assertEquals(defaultDouble, parsedDoubleValue); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java index e4c2a5c05..cb7efe712 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java @@ -18,6 +18,8 @@ public class DelimiterChunkerTests extends OpenSearchTestCase { + private final Map runtimeParameters = Map.of(MAX_CHUNK_LIMIT_FIELD, 100, CHUNK_STRING_COUNT_FIELD, 1); + public void testCreate_withDelimiterFieldInvalidType_thenFail() { Exception exception = assertThrows( IllegalArgumentException.class, @@ -37,7 +39,7 @@ public void testCreate_withDelimiterFieldEmptyString_thenFail() { public void testChunk_withNewlineDelimiter_thenSucceed() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "a\nb\nc\nd"; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("a\n", "b\n", "c\n", "d"), chunkResult); } @@ -45,72 +47,59 @@ public void testChunk_withDefaultDelimiter_thenSucceed() { // default delimiter is \n\n DelimiterChunker chunker = new DelimiterChunker(Map.of()); String content = "a.b\n\nc.d"; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("a.b\n\n", "c.d"), chunkResult); } public void testChunk_withOnlyDelimiterContent_thenSucceed() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "\n"; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("\n"), chunkResult); } public void testChunk_WithAllDelimiterContent_thenSucceed() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); String content = "\n\n\n"; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("\n", "\n", "\n"), chunkResult); } public void testChunk_WithPeriodDelimiters_thenSucceed() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, ".")); String content = "a.b.cc.d."; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("a.", "b.", "cc.", "d."), chunkResult); } public void testChunk_withDoubleNewlineDelimiter_thenSucceed() { DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); String content = "\n\na\n\n\n"; - List chunkResult = chunker.chunk(content, Map.of()); + List chunkResult = chunker.chunk(content, runtimeParameters); assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult); } - public void testChunk_whenExceedMaxChunkLimit_thenLastPassageGetConcatenated() { - int maxChunkLimit = 2; - DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); - String content = "\n\na\n\n\n"; - List passages = chunker.chunk(content, Map.of()); - List expectedPassages = new ArrayList<>(); - expectedPassages.add("\n\n"); - expectedPassages.add("a\n\n\n"); - assertEquals(expectedPassages, passages); - } - public void testChunk_whenWithinMaxChunkLimit_thenSucceed() { - int maxChunkLimit = 3; - DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); String content = "\n\na\n\n\n"; - List chunkResult = chunker.chunk(content, Map.of()); + int runtimeMaxChunkLimit = 3; + List chunkResult = chunker.chunk(content, Map.of(CHUNK_STRING_COUNT_FIELD, 1, MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit)); assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult); } - public void testChunk_whenExceedRuntimeMaxChunkLimit_thenLastPassageGetConcatenated() { - int maxChunkLimit = 3; - DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + public void testChunk_whenExceedMaxChunkLimit_thenLastPassageGetConcatenated() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); String content = "\n\na\n\n\n"; int runtimeMaxChunkLimit = 2; - List passages = chunker.chunk(content, Map.of(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit)); + List passages = chunker.chunk(content, Map.of(CHUNK_STRING_COUNT_FIELD, 1, MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit)); List expectedPassages = new ArrayList<>(); expectedPassages.add("\n\n"); expectedPassages.add("a\n\n\n"); assertEquals(expectedPassages, passages); } - public void testChunk_whenExceedRuntimeMaxChunkLimit_withTwoStringsTobeChunked_thenLastPassageGetConcatenated() { - int maxChunkLimit = 3; - DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + public void testChunk_whenExceedMaxChunkLimit_withTwoStringsTobeChunked_thenLastPassageGetConcatenated() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); String content = "\n\na\n\n\n"; int runtimeMaxChunkLimit = 2, chunkStringCount = 2; List passages = chunker.chunk( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java index d2a607a5b..3ba589174 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java @@ -36,6 +36,15 @@ public class FixedTokenLengthChunkerTests extends OpenSearchTestCase { private FixedTokenLengthChunker fixedTokenLengthChunker; + private final Map runtimeParameters = Map.of( + MAX_CHUNK_LIMIT_FIELD, + 100, + CHUNK_STRING_COUNT_FIELD, + 1, + MAX_TOKEN_COUNT_FIELD, + 10000 + ); + @Before public void setup() { fixedTokenLengthChunker = createFixedTokenLengthChunker(Map.of()); @@ -166,7 +175,7 @@ public void testChunk_whenTokenizationException_thenFail() { "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; IllegalStateException illegalStateException = assertThrows( IllegalStateException.class, - () -> fixedTokenLengthChunker.chunk(content, parameters) + () -> fixedTokenLengthChunker.chunk(content, runtimeParameters) ); assert (illegalStateException.getMessage() .contains(String.format(Locale.ROOT, "analyzer %s throws exception", lowercaseTokenizer))); @@ -177,8 +186,6 @@ public void testChunk_withEmptyInput_thenSucceed() { parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); String content = ""; List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); assert (passages.isEmpty()); @@ -189,8 +196,6 @@ public void testChunk_withTokenLimit10_thenSucceed() { parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); @@ -206,8 +211,7 @@ public void testChunk_withTokenLimit20_thenSucceed() { parameters.put(TOKEN_LIMIT_FIELD, 20); parameters.put(TOKENIZER_FIELD, "standard"); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + Map runtimeParameters = new HashMap<>(this.runtimeParameters); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); @@ -226,7 +230,7 @@ public void testChunk_withOverlapRateHalf_thenSucceed() { FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; - List passages = fixedTokenLengthChunker.chunk(content, Map.of()); + List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); List expectedPassages = new ArrayList<>(); expectedPassages.add("This is an example document to be chunked. The document "); expectedPassages.add("to be chunked. The document contains a single paragraph, two "); @@ -236,14 +240,13 @@ public void testChunk_withOverlapRateHalf_thenSucceed() { } public void testChunk_whenExceedMaxChunkLimit_thenLastPassageGetConcatenated() { - int maxChunkLimit = 2; Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); - parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + int runtimeMaxChunkLimit = 2; + Map runtimeParameters = new HashMap<>(this.runtimeParameters); + runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); @@ -254,14 +257,13 @@ public void testChunk_whenExceedMaxChunkLimit_thenLastPassageGetConcatenated() { } public void testChunk_whenWithinMaxChunkLimit_thenSucceed() { - int maxChunkLimit = 3; Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); - parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + int runtimeMaxChunkLimit = 3; + Map runtimeParameters = new HashMap<>(this.runtimeParameters); + runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); @@ -273,14 +275,12 @@ public void testChunk_whenWithinMaxChunkLimit_thenSucceed() { } public void testChunk_whenExceedRuntimeMaxChunkLimit_thenLastPassageGetConcatenated() { - int maxChunkLimit = 3, runtimeMaxChunkLimit = 2; Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); - parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + int runtimeMaxChunkLimit = 2; + Map runtimeParameters = new HashMap<>(this.runtimeParameters); runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit); String content = "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; @@ -292,14 +292,12 @@ public void testChunk_whenExceedRuntimeMaxChunkLimit_thenLastPassageGetConcatena } public void testChunk_whenExceedRuntimeMaxChunkLimit_withOneStringTobeChunked_thenLastPassageGetConcatenated() { - int maxChunkLimit = 3, runtimeMaxChunkLimit = 2, chunkStringCount = 1; Map parameters = new HashMap<>(); parameters.put(TOKEN_LIMIT_FIELD, 10); parameters.put(TOKENIZER_FIELD, "standard"); - parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); - Map runtimeParameters = new HashMap<>(); - runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + int runtimeMaxChunkLimit = 2, chunkStringCount = 1; + Map runtimeParameters = new HashMap<>(this.runtimeParameters); runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit); runtimeParameters.put(CHUNK_STRING_COUNT_FIELD, chunkStringCount); String content = diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e790ffb77..ff9616637 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -24,7 +24,6 @@ import java.util.LinkedList; import java.util.Map; import java.util.Set; -import java.util.concurrent.ExecutorService; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field;